init commit
This commit is contained in:
26
ultralytics/data/__init__.py
Normal file
26
ultralytics/data/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .base import BaseDataset
|
||||
from .build import build_dataloader, build_grounding, build_yolo_dataset, load_inference_source
|
||||
from .dataset import (
|
||||
ClassificationDataset,
|
||||
GroundingDataset,
|
||||
SemanticDataset,
|
||||
YOLOConcatDataset,
|
||||
YOLODataset,
|
||||
YOLOMultiModalDataset,
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
"BaseDataset",
|
||||
"ClassificationDataset",
|
||||
"SemanticDataset",
|
||||
"YOLODataset",
|
||||
"YOLOMultiModalDataset",
|
||||
"YOLOConcatDataset",
|
||||
"GroundingDataset",
|
||||
"build_yolo_dataset",
|
||||
"build_grounding",
|
||||
"build_dataloader",
|
||||
"load_inference_source",
|
||||
)
|
||||
BIN
ultralytics/data/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/augment.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/augment.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/base.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/base.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/build.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/build.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/converter.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/converter.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/dataset.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/dataset.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/loaders.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/loaders.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/data/__pycache__/utils.cpython-310.pyc
Normal file
BIN
ultralytics/data/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
67
ultralytics/data/annotator.py
Normal file
67
ultralytics/data/annotator.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics import SAM, YOLO
|
||||
|
||||
|
||||
def auto_annotate(
|
||||
data: str | Path,
|
||||
det_model: str = "yolo11x.pt",
|
||||
sam_model: str = "sam_b.pt",
|
||||
device: str = "",
|
||||
conf: float = 0.25,
|
||||
iou: float = 0.45,
|
||||
imgsz: int = 640,
|
||||
max_det: int = 300,
|
||||
classes: list[int] | None = None,
|
||||
output_dir: str | Path | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Automatically annotate images using a YOLO object detection model and a SAM segmentation model.
|
||||
|
||||
This function processes images in a specified directory, detects objects using a YOLO model, and then generates
|
||||
segmentation masks using a SAM model. The resulting annotations are saved as text files in YOLO format.
|
||||
|
||||
Args:
|
||||
data (str | Path): Path to a folder containing images to be annotated.
|
||||
det_model (str): Path or name of the pre-trained YOLO detection model.
|
||||
sam_model (str): Path or name of the pre-trained SAM segmentation model.
|
||||
device (str): Device to run the models on (e.g., 'cpu', 'cuda', '0'). Empty string for auto-selection.
|
||||
conf (float): Confidence threshold for detection model.
|
||||
iou (float): IoU threshold for filtering overlapping boxes in detection results.
|
||||
imgsz (int): Input image resize dimension.
|
||||
max_det (int): Maximum number of detections per image.
|
||||
classes (list[int], optional): Filter predictions to specified class IDs, returning only relevant detections.
|
||||
output_dir (str | Path, optional): Directory to save the annotated results. If None, creates a default
|
||||
directory based on the input data path.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.annotator import auto_annotate
|
||||
>>> auto_annotate(data="ultralytics/assets", det_model="yolo11n.pt", sam_model="mobile_sam.pt")
|
||||
"""
|
||||
det_model = YOLO(det_model)
|
||||
sam_model = SAM(sam_model)
|
||||
|
||||
data = Path(data)
|
||||
if not output_dir:
|
||||
output_dir = data.parent / f"{data.stem}_auto_annotate_labels"
|
||||
Path(output_dir).mkdir(exist_ok=True, parents=True)
|
||||
|
||||
det_results = det_model(
|
||||
data, stream=True, device=device, conf=conf, iou=iou, imgsz=imgsz, max_det=max_det, classes=classes
|
||||
)
|
||||
|
||||
for result in det_results:
|
||||
if class_ids := result.boxes.cls.int().tolist(): # Extract class IDs from detection results
|
||||
boxes = result.boxes.xyxy # Boxes object for bbox outputs
|
||||
sam_results = sam_model(result.orig_img, bboxes=boxes, verbose=False, save=False, device=device)
|
||||
segments = sam_results[0].masks.xyn
|
||||
|
||||
with open(f"{Path(output_dir) / Path(result.path).stem}.txt", "w", encoding="utf-8") as f:
|
||||
for i, s in enumerate(segments):
|
||||
if s.any():
|
||||
segment = map(str, s.reshape(-1).tolist())
|
||||
f.write(f"{class_ids[i]} " + " ".join(segment) + "\n")
|
||||
2991
ultralytics/data/augment.py
Normal file
2991
ultralytics/data/augment.py
Normal file
File diff suppressed because it is too large
Load Diff
443
ultralytics/data/base.py
Normal file
443
ultralytics/data/base.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ultralytics.data.utils import FORMATS_HELP_MSG, HELP_URL, IMG_FORMATS, check_file_speeds
|
||||
from ultralytics.utils import DEFAULT_CFG, LOCAL_RANK, LOGGER, NUM_THREADS, TQDM
|
||||
from ultralytics.utils.patches import imread
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
Base dataset class for loading and processing image data.
|
||||
|
||||
This class provides core functionality for loading images, caching, and preparing data for training and inference
|
||||
in object detection tasks.
|
||||
|
||||
Attributes:
|
||||
img_path (str): Path to the folder containing images.
|
||||
imgsz (int): Target image size for resizing.
|
||||
augment (bool): Whether to apply data augmentation.
|
||||
single_cls (bool): Whether to treat all objects as a single class.
|
||||
prefix (str): Prefix to print in log messages.
|
||||
fraction (float): Fraction of dataset to utilize.
|
||||
channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).
|
||||
cv2_flag (int): OpenCV flag for reading images.
|
||||
im_files (list[str]): List of image file paths.
|
||||
labels (list[dict]): List of label data dictionaries.
|
||||
ni (int): Number of images in the dataset.
|
||||
rect (bool): Whether to use rectangular training.
|
||||
batch_size (int): Size of batches.
|
||||
stride (int): Stride used in the model.
|
||||
pad (float): Padding value.
|
||||
buffer (list): Buffer for mosaic images.
|
||||
max_buffer_length (int): Maximum buffer size.
|
||||
ims (list): List of loaded images.
|
||||
im_hw0 (list): List of original image dimensions (h, w).
|
||||
im_hw (list): List of resized image dimensions (h, w).
|
||||
npy_files (list[Path]): List of numpy file paths.
|
||||
cache (str): Cache images to RAM or disk during training.
|
||||
transforms (callable): Image transformation function.
|
||||
batch_shapes (np.ndarray): Batch shapes for rectangular training.
|
||||
batch (np.ndarray): Batch index of each image.
|
||||
|
||||
Methods:
|
||||
get_img_files: Read image files from the specified path.
|
||||
update_labels: Update labels to include only specified classes.
|
||||
load_image: Load an image from the dataset.
|
||||
cache_images: Cache images to memory or disk.
|
||||
cache_images_to_disk: Save an image as an *.npy file for faster loading.
|
||||
check_cache_disk: Check image caching requirements vs available disk space.
|
||||
check_cache_ram: Check image caching requirements vs available memory.
|
||||
set_rectangle: Set the shape of bounding boxes as rectangles.
|
||||
get_image_and_label: Get and return label information from the dataset.
|
||||
update_labels_info: Custom label format method to be implemented by subclasses.
|
||||
build_transforms: Build transformation pipeline to be implemented by subclasses.
|
||||
get_labels: Get labels method to be implemented by subclasses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_path: str | list[str],
|
||||
imgsz: int = 640,
|
||||
cache: bool | str = False,
|
||||
augment: bool = True,
|
||||
hyp: dict[str, Any] = DEFAULT_CFG,
|
||||
prefix: str = "",
|
||||
rect: bool = False,
|
||||
batch_size: int = 16,
|
||||
stride: int = 32,
|
||||
pad: float = 0.5,
|
||||
single_cls: bool = False,
|
||||
classes: list[int] | None = None,
|
||||
fraction: float = 1.0,
|
||||
channels: int = 3,
|
||||
):
|
||||
"""
|
||||
Initialize BaseDataset with given configuration and options.
|
||||
|
||||
Args:
|
||||
img_path (str | list[str]): Path to the folder containing images or list of image paths.
|
||||
imgsz (int): Image size for resizing.
|
||||
cache (bool | str): Cache images to RAM or disk during training.
|
||||
augment (bool): If True, data augmentation is applied.
|
||||
hyp (dict[str, Any]): Hyperparameters to apply data augmentation.
|
||||
prefix (str): Prefix to print in log messages.
|
||||
rect (bool): If True, rectangular training is used.
|
||||
batch_size (int): Size of batches.
|
||||
stride (int): Stride used in the model.
|
||||
pad (float): Padding value.
|
||||
single_cls (bool): If True, single class training is used.
|
||||
classes (list[int], optional): List of included classes.
|
||||
fraction (float): Fraction of dataset to utilize.
|
||||
channels (int): Number of channels in the images (1 for grayscale, 3 for RGB).
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_path = img_path
|
||||
self.imgsz = imgsz
|
||||
self.augment = augment
|
||||
self.single_cls = single_cls
|
||||
self.prefix = prefix
|
||||
self.fraction = fraction
|
||||
self.channels = channels
|
||||
self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR
|
||||
self.im_files = self.get_img_files(self.img_path)
|
||||
self.labels = self.get_labels()
|
||||
self.update_labels(include_class=classes) # single_cls and include_class
|
||||
self.ni = len(self.labels) # number of images
|
||||
self.rect = rect
|
||||
self.batch_size = batch_size
|
||||
self.stride = stride
|
||||
self.pad = pad
|
||||
if self.rect:
|
||||
assert self.batch_size is not None
|
||||
self.set_rectangle()
|
||||
|
||||
# Buffer thread for mosaic images
|
||||
self.buffer = [] # buffer size = batch size
|
||||
self.max_buffer_length = min((self.ni, self.batch_size * 8, 1000)) if self.augment else 0
|
||||
|
||||
# Cache images (options are cache = True, False, None, "ram", "disk")
|
||||
self.ims, self.im_hw0, self.im_hw = [None] * self.ni, [None] * self.ni, [None] * self.ni
|
||||
self.npy_files = [Path(f).with_suffix(".npy") for f in self.im_files]
|
||||
self.cache = cache.lower() if isinstance(cache, str) else "ram" if cache is True else None
|
||||
if self.cache == "ram" and self.check_cache_ram():
|
||||
if hyp.deterministic:
|
||||
LOGGER.warning(
|
||||
"cache='ram' may produce non-deterministic training results. "
|
||||
"Consider cache='disk' as a deterministic alternative if your disk space allows."
|
||||
)
|
||||
self.cache_images()
|
||||
elif self.cache == "disk" and self.check_cache_disk():
|
||||
self.cache_images()
|
||||
|
||||
# Transforms
|
||||
self.transforms = self.build_transforms(hyp=hyp)
|
||||
|
||||
def get_img_files(self, img_path: str | list[str]) -> list[str]:
|
||||
"""
|
||||
Read image files from the specified path.
|
||||
|
||||
Args:
|
||||
img_path (str | list[str]): Path or list of paths to image directories or files.
|
||||
|
||||
Returns:
|
||||
(list[str]): List of image file paths.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If no images are found or the path doesn't exist.
|
||||
"""
|
||||
try:
|
||||
f = [] # image files
|
||||
for p in img_path if isinstance(img_path, list) else [img_path]:
|
||||
p = Path(p) # os-agnostic
|
||||
if p.is_dir(): # dir
|
||||
f += glob.glob(str(p / "**" / "*.*"), recursive=True)
|
||||
# F = list(p.rglob('*.*')) # pathlib
|
||||
elif p.is_file(): # file
|
||||
with open(p, encoding="utf-8") as t:
|
||||
t = t.read().strip().splitlines()
|
||||
parent = str(p.parent) + os.sep
|
||||
f += [x.replace("./", parent) if x.startswith("./") else x for x in t] # local to global path
|
||||
# F += [p.parent / x.lstrip(os.sep) for x in t] # local to global path (pathlib)
|
||||
else:
|
||||
raise FileNotFoundError(f"{self.prefix}{p} does not exist")
|
||||
im_files = sorted(x.replace("/", os.sep) for x in f if x.rpartition(".")[-1].lower() in IMG_FORMATS)
|
||||
# self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
|
||||
assert im_files, f"{self.prefix}No images found in {img_path}. {FORMATS_HELP_MSG}"
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(f"{self.prefix}Error loading data from {img_path}\n{HELP_URL}") from e
|
||||
if self.fraction < 1:
|
||||
im_files = im_files[: round(len(im_files) * self.fraction)] # retain a fraction of the dataset
|
||||
check_file_speeds(im_files, prefix=self.prefix) # check image read speeds
|
||||
return im_files
|
||||
|
||||
def update_labels(self, include_class: list[int] | None) -> None:
|
||||
"""
|
||||
Update labels to include only specified classes.
|
||||
|
||||
Args:
|
||||
include_class (list[int], optional): List of classes to include. If None, all classes are included.
|
||||
"""
|
||||
include_class_array = np.array(include_class).reshape(1, -1)
|
||||
for i in range(len(self.labels)):
|
||||
if include_class is not None:
|
||||
cls = self.labels[i]["cls"]
|
||||
bboxes = self.labels[i]["bboxes"]
|
||||
segments = self.labels[i]["segments"]
|
||||
keypoints = self.labels[i]["keypoints"]
|
||||
j = (cls == include_class_array).any(1)
|
||||
self.labels[i]["cls"] = cls[j]
|
||||
self.labels[i]["bboxes"] = bboxes[j]
|
||||
if segments:
|
||||
self.labels[i]["segments"] = [segments[si] for si, idx in enumerate(j) if idx]
|
||||
if keypoints is not None:
|
||||
self.labels[i]["keypoints"] = keypoints[j]
|
||||
if self.single_cls:
|
||||
self.labels[i]["cls"][:, 0] = 0
|
||||
|
||||
def load_image(self, i: int, rect_mode: bool = True) -> tuple[np.ndarray, tuple[int, int], tuple[int, int]]:
|
||||
"""
|
||||
Load an image from dataset index 'i'.
|
||||
|
||||
Args:
|
||||
i (int): Index of the image to load.
|
||||
rect_mode (bool): Whether to use rectangular resizing.
|
||||
|
||||
Returns:
|
||||
im (np.ndarray): Loaded image as a NumPy array.
|
||||
hw_original (tuple[int, int]): Original image dimensions in (height, width) format.
|
||||
hw_resized (tuple[int, int]): Resized image dimensions in (height, width) format.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the image file is not found.
|
||||
"""
|
||||
im, f, fn = self.ims[i], self.im_files[i], self.npy_files[i]
|
||||
if im is None: # not cached in RAM
|
||||
if fn.exists(): # load npy
|
||||
try:
|
||||
im = np.load(fn)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{self.prefix}Removing corrupt *.npy image file {fn} due to: {e}")
|
||||
Path(fn).unlink(missing_ok=True)
|
||||
im = imread(f, flags=self.cv2_flag) # BGR
|
||||
else: # read image
|
||||
im = imread(f, flags=self.cv2_flag) # BGR
|
||||
if im is None:
|
||||
raise FileNotFoundError(f"Image Not Found {f}")
|
||||
|
||||
h0, w0 = im.shape[:2] # orig hw
|
||||
if rect_mode: # resize long side to imgsz while maintaining aspect ratio
|
||||
r = self.imgsz / max(h0, w0) # ratio
|
||||
if r != 1: # if sizes are not equal
|
||||
w, h = (min(math.ceil(w0 * r), self.imgsz), min(math.ceil(h0 * r), self.imgsz))
|
||||
im = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
elif not (h0 == w0 == self.imgsz): # resize by stretching image to square imgsz
|
||||
im = cv2.resize(im, (self.imgsz, self.imgsz), interpolation=cv2.INTER_LINEAR)
|
||||
if im.ndim == 2:
|
||||
im = im[..., None]
|
||||
|
||||
# Add to buffer if training with augmentations
|
||||
if self.augment:
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = im, (h0, w0), im.shape[:2] # im, hw_original, hw_resized
|
||||
self.buffer.append(i)
|
||||
if 1 < len(self.buffer) >= self.max_buffer_length: # prevent empty buffer
|
||||
j = self.buffer.pop(0)
|
||||
if self.cache != "ram":
|
||||
self.ims[j], self.im_hw0[j], self.im_hw[j] = None, None, None
|
||||
|
||||
return im, (h0, w0), im.shape[:2]
|
||||
|
||||
return self.ims[i], self.im_hw0[i], self.im_hw[i]
|
||||
|
||||
def cache_images(self) -> None:
|
||||
"""Cache images to memory or disk for faster training."""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
fcn, storage = (self.cache_images_to_disk, "Disk") if self.cache == "disk" else (self.load_image, "RAM")
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(fcn, range(self.ni))
|
||||
pbar = TQDM(enumerate(results), total=self.ni, disable=LOCAL_RANK > 0)
|
||||
for i, x in pbar:
|
||||
if self.cache == "disk":
|
||||
b += self.npy_files[i].stat().st_size
|
||||
else: # 'ram'
|
||||
self.ims[i], self.im_hw0[i], self.im_hw[i] = x # im, hw_orig, hw_resized = load_image(self, i)
|
||||
b += self.ims[i].nbytes
|
||||
pbar.desc = f"{self.prefix}Caching images ({b / gb:.1f}GB {storage})"
|
||||
pbar.close()
|
||||
|
||||
def cache_images_to_disk(self, i: int) -> None:
|
||||
"""Save an image as an *.npy file for faster loading."""
|
||||
f = self.npy_files[i]
|
||||
if not f.exists():
|
||||
np.save(f.as_posix(), imread(self.im_files[i]), allow_pickle=False)
|
||||
|
||||
def check_cache_disk(self, safety_margin: float = 0.5) -> bool:
|
||||
"""
|
||||
Check if there's enough disk space for caching images.
|
||||
|
||||
Args:
|
||||
safety_margin (float): Safety margin factor for disk space calculation.
|
||||
|
||||
Returns:
|
||||
(bool): True if there's enough disk space, False otherwise.
|
||||
"""
|
||||
import shutil
|
||||
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
n = min(self.ni, 30) # extrapolate from 30 random images
|
||||
for _ in range(n):
|
||||
im_file = random.choice(self.im_files)
|
||||
im = imread(im_file)
|
||||
if im is None:
|
||||
continue
|
||||
b += im.nbytes
|
||||
if not os.access(Path(im_file).parent, os.W_OK):
|
||||
self.cache = None
|
||||
LOGGER.warning(f"{self.prefix}Skipping caching images to disk, directory not writeable")
|
||||
return False
|
||||
disk_required = b * self.ni / n * (1 + safety_margin) # bytes required to cache dataset to disk
|
||||
total, used, free = shutil.disk_usage(Path(self.im_files[0]).parent)
|
||||
if disk_required > free:
|
||||
self.cache = None
|
||||
LOGGER.warning(
|
||||
f"{self.prefix}{disk_required / gb:.1f}GB disk space required, "
|
||||
f"with {int(safety_margin * 100)}% safety margin but only "
|
||||
f"{free / gb:.1f}/{total / gb:.1f}GB free, not caching images to disk"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def check_cache_ram(self, safety_margin: float = 0.5) -> bool:
|
||||
"""
|
||||
Check if there's enough RAM for caching images.
|
||||
|
||||
Args:
|
||||
safety_margin (float): Safety margin factor for RAM calculation.
|
||||
|
||||
Returns:
|
||||
(bool): True if there's enough RAM, False otherwise.
|
||||
"""
|
||||
b, gb = 0, 1 << 30 # bytes of cached images, bytes per gigabytes
|
||||
n = min(self.ni, 30) # extrapolate from 30 random images
|
||||
for _ in range(n):
|
||||
im = imread(random.choice(self.im_files)) # sample image
|
||||
if im is None:
|
||||
continue
|
||||
ratio = self.imgsz / max(im.shape[0], im.shape[1]) # max(h, w) # ratio
|
||||
b += im.nbytes * ratio**2
|
||||
mem_required = b * self.ni / n * (1 + safety_margin) # GB required to cache dataset into RAM
|
||||
mem = __import__("psutil").virtual_memory()
|
||||
if mem_required > mem.available:
|
||||
self.cache = None
|
||||
LOGGER.warning(
|
||||
f"{self.prefix}{mem_required / gb:.1f}GB RAM required to cache images "
|
||||
f"with {int(safety_margin * 100)}% safety margin but only "
|
||||
f"{mem.available / gb:.1f}/{mem.total / gb:.1f}GB available, not caching images"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_rectangle(self) -> None:
|
||||
"""Set the shape of bounding boxes for YOLO detections as rectangles."""
|
||||
bi = np.floor(np.arange(self.ni) / self.batch_size).astype(int) # batch index
|
||||
nb = bi[-1] + 1 # number of batches
|
||||
|
||||
s = np.array([x.pop("shape") for x in self.labels]) # hw
|
||||
ar = s[:, 0] / s[:, 1] # aspect ratio
|
||||
irect = ar.argsort()
|
||||
self.im_files = [self.im_files[i] for i in irect]
|
||||
self.labels = [self.labels[i] for i in irect]
|
||||
ar = ar[irect]
|
||||
|
||||
# Set training image shapes
|
||||
shapes = [[1, 1]] * nb
|
||||
for i in range(nb):
|
||||
ari = ar[bi == i]
|
||||
mini, maxi = ari.min(), ari.max()
|
||||
if maxi < 1:
|
||||
shapes[i] = [maxi, 1]
|
||||
elif mini > 1:
|
||||
shapes[i] = [1, 1 / mini]
|
||||
|
||||
self.batch_shapes = np.ceil(np.array(shapes) * self.imgsz / self.stride + self.pad).astype(int) * self.stride
|
||||
self.batch = bi # batch index of image
|
||||
|
||||
def __getitem__(self, index: int) -> dict[str, Any]:
|
||||
"""Return transformed label information for given index."""
|
||||
return self.transforms(self.get_image_and_label(index))
|
||||
|
||||
def get_image_and_label(self, index: int) -> dict[str, Any]:
|
||||
"""
|
||||
Get and return label information from the dataset.
|
||||
|
||||
Args:
|
||||
index (int): Index of the image to retrieve.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Label dictionary with image and metadata.
|
||||
"""
|
||||
label = deepcopy(self.labels[index]) # requires deepcopy() https://github.com/ultralytics/ultralytics/pull/1948
|
||||
label.pop("shape", None) # shape is for rect, remove it
|
||||
label["img"], label["ori_shape"], label["resized_shape"] = self.load_image(index)
|
||||
label["ratio_pad"] = (
|
||||
label["resized_shape"][0] / label["ori_shape"][0],
|
||||
label["resized_shape"][1] / label["ori_shape"][1],
|
||||
) # for evaluation
|
||||
if self.rect:
|
||||
label["rect_shape"] = self.batch_shapes[self.batch[index]]
|
||||
return self.update_labels_info(label)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the labels list for the dataset."""
|
||||
return len(self.labels)
|
||||
|
||||
def update_labels_info(self, label: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Custom your label format here."""
|
||||
return label
|
||||
|
||||
def build_transforms(self, hyp: dict[str, Any] | None = None):
|
||||
"""
|
||||
Users can customize augmentations here.
|
||||
|
||||
Examples:
|
||||
>>> if self.augment:
|
||||
... # Training transforms
|
||||
... return Compose([])
|
||||
>>> else:
|
||||
... # Val transforms
|
||||
... return Compose([])
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_labels(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Users can customize their own format here.
|
||||
|
||||
Examples:
|
||||
Ensure output is a dictionary with the following keys:
|
||||
>>> dict(
|
||||
... im_file=im_file,
|
||||
... shape=shape, # format: (height, width)
|
||||
... cls=cls,
|
||||
... bboxes=bboxes, # xywh
|
||||
... segments=segments, # xy
|
||||
... keypoints=keypoints, # xy
|
||||
... normalized=True, # or False
|
||||
... bbox_format="xyxy", # or xywh, ltwh
|
||||
... )
|
||||
"""
|
||||
raise NotImplementedError
|
||||
315
ultralytics/data/build.py
Normal file
315
ultralytics/data/build.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import dataloader, distributed
|
||||
|
||||
from ultralytics.cfg import IterableSimpleNamespace
|
||||
from ultralytics.data.dataset import GroundingDataset, YOLODataset, YOLOMultiModalDataset
|
||||
from ultralytics.data.loaders import (
|
||||
LOADERS,
|
||||
LoadImagesAndVideos,
|
||||
LoadPilAndNumpy,
|
||||
LoadScreenshots,
|
||||
LoadStreams,
|
||||
LoadTensor,
|
||||
SourceTypes,
|
||||
autocast_list,
|
||||
)
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
from ultralytics.utils.checks import check_file
|
||||
from ultralytics.utils.torch_utils import TORCH_2_0
|
||||
|
||||
|
||||
class InfiniteDataLoader(dataloader.DataLoader):
|
||||
"""
|
||||
Dataloader that reuses workers for infinite iteration.
|
||||
|
||||
This dataloader extends the PyTorch DataLoader to provide infinite recycling of workers, which improves efficiency
|
||||
for training loops that need to iterate through the dataset multiple times without recreating workers.
|
||||
|
||||
Attributes:
|
||||
batch_sampler (_RepeatSampler): A sampler that repeats indefinitely.
|
||||
iterator (Iterator): The iterator from the parent DataLoader.
|
||||
|
||||
Methods:
|
||||
__len__: Return the length of the batch sampler's sampler.
|
||||
__iter__: Create a sampler that repeats indefinitely.
|
||||
__del__: Ensure workers are properly terminated.
|
||||
reset: Reset the iterator, useful when modifying dataset settings during training.
|
||||
|
||||
Examples:
|
||||
Create an infinite dataloader for training
|
||||
>>> dataset = YOLODataset(...)
|
||||
>>> dataloader = InfiniteDataLoader(dataset, batch_size=16, shuffle=True)
|
||||
>>> for batch in dataloader: # Infinite iteration
|
||||
>>> train_step(batch)
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
"""Initialize the InfiniteDataLoader with the same arguments as DataLoader."""
|
||||
if not TORCH_2_0:
|
||||
kwargs.pop("prefetch_factor", None) # not supported by earlier versions
|
||||
super().__init__(*args, **kwargs)
|
||||
object.__setattr__(self, "batch_sampler", _RepeatSampler(self.batch_sampler))
|
||||
self.iterator = super().__iter__()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the batch sampler's sampler."""
|
||||
return len(self.batch_sampler.sampler)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""Create an iterator that yields indefinitely from the underlying iterator."""
|
||||
for _ in range(len(self)):
|
||||
yield next(self.iterator)
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure that workers are properly terminated when the dataloader is deleted."""
|
||||
try:
|
||||
if not hasattr(self.iterator, "_workers"):
|
||||
return
|
||||
for w in self.iterator._workers: # force terminate
|
||||
if w.is_alive():
|
||||
w.terminate()
|
||||
self.iterator._shutdown_workers() # cleanup
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
"""Reset the iterator to allow modifications to the dataset during training."""
|
||||
self.iterator = self._get_iterator()
|
||||
|
||||
|
||||
class _RepeatSampler:
|
||||
"""
|
||||
Sampler that repeats forever for infinite iteration.
|
||||
|
||||
This sampler wraps another sampler and yields its contents indefinitely, allowing for infinite iteration
|
||||
over a dataset without recreating the sampler.
|
||||
|
||||
Attributes:
|
||||
sampler (Dataset.sampler): The sampler to repeat.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler: Any):
|
||||
"""Initialize the _RepeatSampler with a sampler to repeat indefinitely."""
|
||||
self.sampler = sampler
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
"""Iterate over the sampler indefinitely, yielding its contents."""
|
||||
while True:
|
||||
yield from iter(self.sampler)
|
||||
|
||||
|
||||
def seed_worker(worker_id: int): # noqa
|
||||
"""Set dataloader worker seed for reproducibility across worker processes."""
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
def build_yolo_dataset(
|
||||
cfg: IterableSimpleNamespace,
|
||||
img_path: str,
|
||||
batch: int,
|
||||
data: dict[str, Any],
|
||||
mode: str = "train",
|
||||
rect: bool = False,
|
||||
stride: int = 32,
|
||||
multi_modal: bool = False,
|
||||
):
|
||||
"""Build and return a YOLO dataset based on configuration parameters."""
|
||||
dataset = YOLOMultiModalDataset if multi_modal else YOLODataset
|
||||
return dataset(
|
||||
img_path=img_path,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=stride,
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
task=cfg.task,
|
||||
classes=cfg.classes,
|
||||
data=data,
|
||||
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
|
||||
def build_grounding(
|
||||
cfg: IterableSimpleNamespace,
|
||||
img_path: str,
|
||||
json_file: str,
|
||||
batch: int,
|
||||
mode: str = "train",
|
||||
rect: bool = False,
|
||||
stride: int = 32,
|
||||
max_samples: int = 80,
|
||||
):
|
||||
"""Build and return a GroundingDataset based on configuration parameters."""
|
||||
return GroundingDataset(
|
||||
img_path=img_path,
|
||||
json_file=json_file,
|
||||
max_samples=max_samples,
|
||||
imgsz=cfg.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train", # augmentation
|
||||
hyp=cfg, # TODO: probably add a get_hyps_from_cfg function
|
||||
rect=cfg.rect or rect, # rectangular batches
|
||||
cache=cfg.cache or None,
|
||||
single_cls=cfg.single_cls or False,
|
||||
stride=stride,
|
||||
pad=0.0 if mode == "train" else 0.5,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
task=cfg.task,
|
||||
classes=cfg.classes,
|
||||
fraction=cfg.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
|
||||
def build_dataloader(dataset, batch: int, workers: int, shuffle: bool = True, rank: int = -1, drop_last: bool = False):
|
||||
"""
|
||||
Create and return an InfiniteDataLoader or DataLoader for training or validation.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): Dataset to load data from.
|
||||
batch (int): Batch size for the dataloader.
|
||||
workers (int): Number of worker threads for loading data.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset.
|
||||
rank (int, optional): Process rank in distributed training. -1 for single-GPU training.
|
||||
drop_last (bool, optional): Whether to drop the last incomplete batch.
|
||||
|
||||
Returns:
|
||||
(InfiniteDataLoader): A dataloader that can be used for training or validation.
|
||||
|
||||
Examples:
|
||||
Create a dataloader for training
|
||||
>>> dataset = YOLODataset(...)
|
||||
>>> dataloader = build_dataloader(dataset, batch=16, workers=4, shuffle=True)
|
||||
"""
|
||||
batch = min(batch, len(dataset))
|
||||
nd = torch.cuda.device_count() # number of CUDA devices
|
||||
nw = min(os.cpu_count() // max(nd, 1), workers) # number of workers
|
||||
sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(6148914691236517205 + RANK)
|
||||
return InfiniteDataLoader(
|
||||
dataset=dataset,
|
||||
batch_size=batch,
|
||||
shuffle=shuffle and sampler is None,
|
||||
num_workers=nw,
|
||||
sampler=sampler,
|
||||
prefetch_factor=4 if nw > 0 else None, # increase over default 2
|
||||
pin_memory=nd > 0,
|
||||
collate_fn=getattr(dataset, "collate_fn", None),
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
drop_last=drop_last and len(dataset) % batch != 0,
|
||||
)
|
||||
|
||||
|
||||
def check_source(source):
|
||||
"""
|
||||
Check the type of input source and return corresponding flag values.
|
||||
|
||||
Args:
|
||||
source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The input source to check.
|
||||
|
||||
Returns:
|
||||
source (str | int | Path | list | tuple | np.ndarray | PIL.Image | torch.Tensor): The processed source.
|
||||
webcam (bool): Whether the source is a webcam.
|
||||
screenshot (bool): Whether the source is a screenshot.
|
||||
from_img (bool): Whether the source is an image or list of images.
|
||||
in_memory (bool): Whether the source is an in-memory object.
|
||||
tensor (bool): Whether the source is a torch.Tensor.
|
||||
|
||||
Examples:
|
||||
Check a file path source
|
||||
>>> source, webcam, screenshot, from_img, in_memory, tensor = check_source("image.jpg")
|
||||
|
||||
Check a webcam source
|
||||
>>> source, webcam, screenshot, from_img, in_memory, tensor = check_source(0)
|
||||
"""
|
||||
webcam, screenshot, from_img, in_memory, tensor = False, False, False, False, False
|
||||
if isinstance(source, (str, int, Path)): # int for local usb camera
|
||||
source = str(source)
|
||||
source_lower = source.lower()
|
||||
is_url = source_lower.startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://"))
|
||||
is_file = (urlsplit(source_lower).path if is_url else source_lower).rpartition(".")[-1] in (
|
||||
IMG_FORMATS | VID_FORMATS
|
||||
)
|
||||
webcam = source.isnumeric() or source.endswith(".streams") or (is_url and not is_file)
|
||||
screenshot = source_lower == "screen"
|
||||
if is_url and is_file:
|
||||
source = check_file(source) # download
|
||||
elif isinstance(source, LOADERS):
|
||||
in_memory = True
|
||||
elif isinstance(source, (list, tuple)):
|
||||
source = autocast_list(source) # convert all list elements to PIL or np arrays
|
||||
from_img = True
|
||||
elif isinstance(source, (Image.Image, np.ndarray)):
|
||||
from_img = True
|
||||
elif isinstance(source, torch.Tensor):
|
||||
tensor = True
|
||||
else:
|
||||
raise TypeError("Unsupported image type. For supported types see https://docs.ultralytics.com/modes/predict")
|
||||
|
||||
return source, webcam, screenshot, from_img, in_memory, tensor
|
||||
|
||||
|
||||
def load_inference_source(source=None, batch: int = 1, vid_stride: int = 1, buffer: bool = False, channels: int = 3):
|
||||
"""
|
||||
Load an inference source for object detection and apply necessary transformations.
|
||||
|
||||
Args:
|
||||
source (str | Path | torch.Tensor | PIL.Image | np.ndarray, optional): The input source for inference.
|
||||
batch (int, optional): Batch size for dataloaders.
|
||||
vid_stride (int, optional): The frame interval for video sources.
|
||||
buffer (bool, optional): Whether stream frames will be buffered.
|
||||
channels (int, optional): The number of input channels for the model.
|
||||
|
||||
Returns:
|
||||
(Dataset): A dataset object for the specified input source with attached source_type attribute.
|
||||
|
||||
Examples:
|
||||
Load an image source for inference
|
||||
>>> dataset = load_inference_source("image.jpg", batch=1)
|
||||
|
||||
Load a video stream source
|
||||
>>> dataset = load_inference_source("rtsp://example.com/stream", vid_stride=2)
|
||||
"""
|
||||
source, stream, screenshot, from_img, in_memory, tensor = check_source(source)
|
||||
source_type = source.source_type if in_memory else SourceTypes(stream, screenshot, from_img, tensor)
|
||||
|
||||
# Dataloader
|
||||
if tensor:
|
||||
dataset = LoadTensor(source)
|
||||
elif in_memory:
|
||||
dataset = source
|
||||
elif stream:
|
||||
dataset = LoadStreams(source, vid_stride=vid_stride, buffer=buffer, channels=channels)
|
||||
elif screenshot:
|
||||
dataset = LoadScreenshots(source, channels=channels)
|
||||
elif from_img:
|
||||
dataset = LoadPilAndNumpy(source, channels=channels)
|
||||
else:
|
||||
dataset = LoadImagesAndVideos(source, batch=batch, vid_stride=vid_stride, channels=channels)
|
||||
|
||||
# Attach source types to the dataset
|
||||
setattr(dataset, "source_type", source_type)
|
||||
|
||||
return dataset
|
||||
867
ultralytics/data/converter.py
Normal file
867
ultralytics/data/converter.py
Normal file
@@ -0,0 +1,867 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.utils import DATASETS_DIR, LOGGER, NUM_THREADS, TQDM, YAML
|
||||
from ultralytics.utils.checks import check_file, check_requirements
|
||||
from ultralytics.utils.downloads import download, zip_directory
|
||||
from ultralytics.utils.files import increment_path
|
||||
|
||||
|
||||
def coco91_to_coco80_class() -> list[int]:
|
||||
"""
|
||||
Convert 91-index COCO class IDs to 80-index COCO class IDs.
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of 91 class IDs where the index represents the 80-index class ID and the value
|
||||
is the corresponding 91-index class ID.
|
||||
"""
|
||||
return [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
None,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
None,
|
||||
24,
|
||||
25,
|
||||
None,
|
||||
None,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
None,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
None,
|
||||
60,
|
||||
None,
|
||||
None,
|
||||
61,
|
||||
None,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
66,
|
||||
67,
|
||||
68,
|
||||
69,
|
||||
70,
|
||||
71,
|
||||
72,
|
||||
None,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
def coco80_to_coco91_class() -> list[int]:
|
||||
r"""
|
||||
Convert 80-index (val2014) to 91-index (paper).
|
||||
|
||||
Returns:
|
||||
(list[int]): A list of 80 class IDs where each value is the corresponding 91-index class ID.
|
||||
|
||||
References:
|
||||
https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> a = np.loadtxt("data/coco.names", dtype="str", delimiter="\n")
|
||||
>>> b = np.loadtxt("data/coco_paper.names", dtype="str", delimiter="\n")
|
||||
|
||||
Convert the darknet to COCO format
|
||||
>>> x1 = [list(a[i] == b).index(True) + 1 for i in range(80)]
|
||||
|
||||
Convert the COCO to darknet format
|
||||
>>> x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)]
|
||||
"""
|
||||
return [
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
27,
|
||||
28,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
62,
|
||||
63,
|
||||
64,
|
||||
65,
|
||||
67,
|
||||
70,
|
||||
72,
|
||||
73,
|
||||
74,
|
||||
75,
|
||||
76,
|
||||
77,
|
||||
78,
|
||||
79,
|
||||
80,
|
||||
81,
|
||||
82,
|
||||
84,
|
||||
85,
|
||||
86,
|
||||
87,
|
||||
88,
|
||||
89,
|
||||
90,
|
||||
]
|
||||
|
||||
|
||||
def convert_coco(
|
||||
labels_dir: str = "../coco/annotations/",
|
||||
save_dir: str = "coco_converted/",
|
||||
use_segments: bool = False,
|
||||
use_keypoints: bool = False,
|
||||
cls91to80: bool = True,
|
||||
lvis: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert COCO dataset annotations to a YOLO annotation format suitable for training YOLO models.
|
||||
|
||||
Args:
|
||||
labels_dir (str, optional): Path to directory containing COCO dataset annotation files.
|
||||
save_dir (str, optional): Path to directory to save results to.
|
||||
use_segments (bool, optional): Whether to include segmentation masks in the output.
|
||||
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
|
||||
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
|
||||
lvis (bool, optional): Whether to convert data in lvis dataset way.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.converter import convert_coco
|
||||
|
||||
Convert COCO annotations to YOLO format
|
||||
>>> convert_coco("coco/annotations/", use_segments=True, use_keypoints=False, cls91to80=False)
|
||||
|
||||
Convert LVIS annotations to YOLO format
|
||||
>>> convert_coco("lvis/annotations/", use_segments=True, use_keypoints=False, cls91to80=False, lvis=True)
|
||||
"""
|
||||
# Create dataset directory
|
||||
save_dir = increment_path(save_dir) # increment if save directory already exists
|
||||
for p in save_dir / "labels", save_dir / "images":
|
||||
p.mkdir(parents=True, exist_ok=True) # make dir
|
||||
|
||||
# Convert classes
|
||||
coco80 = coco91_to_coco80_class()
|
||||
|
||||
# Import json
|
||||
for json_file in sorted(Path(labels_dir).resolve().glob("*.json")):
|
||||
lname = "" if lvis else json_file.stem.replace("instances_", "")
|
||||
fn = Path(save_dir) / "labels" / lname # folder name
|
||||
fn.mkdir(parents=True, exist_ok=True)
|
||||
if lvis:
|
||||
# NOTE: create folders for both train and val in advance,
|
||||
# since LVIS val set contains images from COCO 2017 train in addition to the COCO 2017 val split.
|
||||
(fn / "train2017").mkdir(parents=True, exist_ok=True)
|
||||
(fn / "val2017").mkdir(parents=True, exist_ok=True)
|
||||
with open(json_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Create image dict
|
||||
images = {f"{x['id']:d}": x for x in data["images"]}
|
||||
# Create image-annotations dict
|
||||
annotations = defaultdict(list)
|
||||
for ann in data["annotations"]:
|
||||
annotations[ann["image_id"]].append(ann)
|
||||
|
||||
image_txt = []
|
||||
# Write labels file
|
||||
for img_id, anns in TQDM(annotations.items(), desc=f"Annotations {json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w = img["height"], img["width"]
|
||||
f = str(Path(img["coco_url"]).relative_to("http://images.cocodataset.org")) if lvis else img["file_name"]
|
||||
if lvis:
|
||||
image_txt.append(str(Path("./images") / f))
|
||||
|
||||
bboxes = []
|
||||
segments = []
|
||||
keypoints = []
|
||||
for ann in anns:
|
||||
if ann.get("iscrowd", False):
|
||||
continue
|
||||
# The COCO box format is [top left x, top left y, width, height]
|
||||
box = np.array(ann["bbox"], dtype=np.float64)
|
||||
box[:2] += box[2:] / 2 # xy top-left corner to center
|
||||
box[[0, 2]] /= w # normalize x
|
||||
box[[1, 3]] /= h # normalize y
|
||||
if box[2] <= 0 or box[3] <= 0: # if w <= 0 and h <= 0
|
||||
continue
|
||||
|
||||
cls = coco80[ann["category_id"] - 1] if cls91to80 else ann["category_id"] - 1 # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
if use_segments and ann.get("segmentation") is not None:
|
||||
if len(ann["segmentation"]) == 0:
|
||||
segments.append([])
|
||||
continue
|
||||
elif len(ann["segmentation"]) > 1:
|
||||
s = merge_multi_segment(ann["segmentation"])
|
||||
s = (np.concatenate(s, axis=0) / np.array([w, h])).reshape(-1).tolist()
|
||||
else:
|
||||
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
||||
s = (np.array(s).reshape(-1, 2) / np.array([w, h])).reshape(-1).tolist()
|
||||
s = [cls] + s
|
||||
segments.append(s)
|
||||
if use_keypoints and ann.get("keypoints") is not None:
|
||||
keypoints.append(
|
||||
box + (np.array(ann["keypoints"]).reshape(-1, 3) / np.array([w, h, 1])).reshape(-1).tolist()
|
||||
)
|
||||
|
||||
# Write
|
||||
with open((fn / f).with_suffix(".txt"), "a", encoding="utf-8") as file:
|
||||
for i in range(len(bboxes)):
|
||||
if use_keypoints:
|
||||
line = (*(keypoints[i]),) # cls, box, keypoints
|
||||
else:
|
||||
line = (
|
||||
*(segments[i] if use_segments and len(segments[i]) > 0 else bboxes[i]),
|
||||
) # cls, box or segments
|
||||
file.write(("%g " * len(line)).rstrip() % line + "\n")
|
||||
|
||||
if lvis:
|
||||
filename = Path(save_dir) / json_file.name.replace("lvis_v1_", "").replace(".json", ".txt")
|
||||
with open(filename, "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{line}\n" for line in image_txt)
|
||||
|
||||
LOGGER.info(f"{'LVIS' if lvis else 'COCO'} data converted successfully.\nResults saved to {save_dir.resolve()}")
|
||||
|
||||
|
||||
def convert_segment_masks_to_yolo_seg(masks_dir: str, output_dir: str, classes: int):
|
||||
"""
|
||||
Convert a dataset of segmentation mask images to the YOLO segmentation format.
|
||||
|
||||
This function takes the directory containing the binary format mask images and converts them into YOLO segmentation
|
||||
format. The converted masks are saved in the specified output directory.
|
||||
|
||||
Args:
|
||||
masks_dir (str): The path to the directory where all mask images (png, jpg) are stored.
|
||||
output_dir (str): The path to the directory where the converted YOLO segmentation masks will be stored.
|
||||
classes (int): Total classes in the dataset i.e. for COCO classes=80
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.converter import convert_segment_masks_to_yolo_seg
|
||||
|
||||
The classes here is the total classes in the dataset, for COCO dataset we have 80 classes
|
||||
>>> convert_segment_masks_to_yolo_seg("path/to/masks_directory", "path/to/output/directory", classes=80)
|
||||
|
||||
Notes:
|
||||
The expected directory structure for the masks is:
|
||||
|
||||
- masks
|
||||
├─ mask_image_01.png or mask_image_01.jpg
|
||||
├─ mask_image_02.png or mask_image_02.jpg
|
||||
├─ mask_image_03.png or mask_image_03.jpg
|
||||
└─ mask_image_04.png or mask_image_04.jpg
|
||||
|
||||
After execution, the labels will be organized in the following structure:
|
||||
|
||||
- output_dir
|
||||
├─ mask_yolo_01.txt
|
||||
├─ mask_yolo_02.txt
|
||||
├─ mask_yolo_03.txt
|
||||
└─ mask_yolo_04.txt
|
||||
"""
|
||||
pixel_to_class_mapping = {i + 1: i for i in range(classes)}
|
||||
for mask_path in Path(masks_dir).iterdir():
|
||||
if mask_path.suffix in {".png", ".jpg"}:
|
||||
mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE) # Read the mask image in grayscale
|
||||
img_height, img_width = mask.shape # Get image dimensions
|
||||
LOGGER.info(f"Processing {mask_path} imgsz = {img_height} x {img_width}")
|
||||
|
||||
unique_values = np.unique(mask) # Get unique pixel values representing different classes
|
||||
yolo_format_data = []
|
||||
|
||||
for value in unique_values:
|
||||
if value == 0:
|
||||
continue # Skip background
|
||||
class_index = pixel_to_class_mapping.get(value, -1)
|
||||
if class_index == -1:
|
||||
LOGGER.warning(f"Unknown class for pixel value {value} in file {mask_path}, skipping.")
|
||||
continue
|
||||
|
||||
# Create a binary mask for the current class and find contours
|
||||
contours, _ = cv2.findContours(
|
||||
(mask == value).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
||||
) # Find contours
|
||||
|
||||
for contour in contours:
|
||||
if len(contour) >= 3: # YOLO requires at least 3 points for a valid segmentation
|
||||
contour = contour.squeeze() # Remove single-dimensional entries
|
||||
yolo_format = [class_index]
|
||||
for point in contour:
|
||||
# Normalize the coordinates
|
||||
yolo_format.append(round(point[0] / img_width, 6)) # Rounding to 6 decimal places
|
||||
yolo_format.append(round(point[1] / img_height, 6))
|
||||
yolo_format_data.append(yolo_format)
|
||||
# Save Ultralytics YOLO format data to file
|
||||
output_path = Path(output_dir) / f"{mask_path.stem}.txt"
|
||||
with open(output_path, "w", encoding="utf-8") as file:
|
||||
for item in yolo_format_data:
|
||||
line = " ".join(map(str, item))
|
||||
file.write(line + "\n")
|
||||
LOGGER.info(f"Processed and stored at {output_path} imgsz = {img_height} x {img_width}")
|
||||
|
||||
|
||||
def convert_dota_to_yolo_obb(dota_root_path: str):
|
||||
"""
|
||||
Convert DOTA dataset annotations to YOLO OBB (Oriented Bounding Box) format.
|
||||
|
||||
The function processes images in the 'train' and 'val' folders of the DOTA dataset. For each image, it reads the
|
||||
associated label from the original labels directory and writes new labels in YOLO OBB format to a new directory.
|
||||
|
||||
Args:
|
||||
dota_root_path (str): The root directory path of the DOTA dataset.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.converter import convert_dota_to_yolo_obb
|
||||
>>> convert_dota_to_yolo_obb("path/to/DOTA")
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
|
||||
- DOTA
|
||||
├─ images
|
||||
│ ├─ train
|
||||
│ └─ val
|
||||
└─ labels
|
||||
├─ train_original
|
||||
└─ val_original
|
||||
|
||||
After execution, the function will organize the labels into:
|
||||
|
||||
- DOTA
|
||||
└─ labels
|
||||
├─ train
|
||||
└─ val
|
||||
"""
|
||||
dota_root_path = Path(dota_root_path)
|
||||
|
||||
# Class names to indices mapping
|
||||
class_mapping = {
|
||||
"plane": 0,
|
||||
"ship": 1,
|
||||
"storage-tank": 2,
|
||||
"baseball-diamond": 3,
|
||||
"tennis-court": 4,
|
||||
"basketball-court": 5,
|
||||
"ground-track-field": 6,
|
||||
"harbor": 7,
|
||||
"bridge": 8,
|
||||
"large-vehicle": 9,
|
||||
"small-vehicle": 10,
|
||||
"helicopter": 11,
|
||||
"roundabout": 12,
|
||||
"soccer-ball-field": 13,
|
||||
"swimming-pool": 14,
|
||||
"container-crane": 15,
|
||||
"airport": 16,
|
||||
"helipad": 17,
|
||||
}
|
||||
|
||||
def convert_label(image_name: str, image_width: int, image_height: int, orig_label_dir: Path, save_dir: Path):
|
||||
"""Convert a single image's DOTA annotation to YOLO OBB format and save it to a specified directory."""
|
||||
orig_label_path = orig_label_dir / f"{image_name}.txt"
|
||||
save_path = save_dir / f"{image_name}.txt"
|
||||
|
||||
with orig_label_path.open("r") as f, save_path.open("w") as g:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 9:
|
||||
continue
|
||||
class_name = parts[8]
|
||||
class_idx = class_mapping[class_name]
|
||||
coords = [float(p) for p in parts[:8]]
|
||||
normalized_coords = [
|
||||
coords[i] / image_width if i % 2 == 0 else coords[i] / image_height for i in range(8)
|
||||
]
|
||||
formatted_coords = [f"{coord:.6g}" for coord in normalized_coords]
|
||||
g.write(f"{class_idx} {' '.join(formatted_coords)}\n")
|
||||
|
||||
for phase in {"train", "val"}:
|
||||
image_dir = dota_root_path / "images" / phase
|
||||
orig_label_dir = dota_root_path / "labels" / f"{phase}_original"
|
||||
save_dir = dota_root_path / "labels" / phase
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image_paths = list(image_dir.iterdir())
|
||||
for image_path in TQDM(image_paths, desc=f"Processing {phase} images"):
|
||||
if image_path.suffix != ".png":
|
||||
continue
|
||||
image_name_without_ext = image_path.stem
|
||||
img = cv2.imread(str(image_path))
|
||||
h, w = img.shape[:2]
|
||||
convert_label(image_name_without_ext, w, h, orig_label_dir, save_dir)
|
||||
|
||||
|
||||
def min_index(arr1: np.ndarray, arr2: np.ndarray):
|
||||
"""
|
||||
Find a pair of indexes with the shortest distance between two arrays of 2D points.
|
||||
|
||||
Args:
|
||||
arr1 (np.ndarray): A NumPy array of shape (N, 2) representing N 2D points.
|
||||
arr2 (np.ndarray): A NumPy array of shape (M, 2) representing M 2D points.
|
||||
|
||||
Returns:
|
||||
idx1 (int): Index of the point in arr1 with the shortest distance.
|
||||
idx2 (int): Index of the point in arr2 with the shortest distance.
|
||||
"""
|
||||
dis = ((arr1[:, None, :] - arr2[None, :, :]) ** 2).sum(-1)
|
||||
return np.unravel_index(np.argmin(dis, axis=None), dis.shape)
|
||||
|
||||
|
||||
def merge_multi_segment(segments: list[list]):
|
||||
"""
|
||||
Merge multiple segments into one list by connecting the coordinates with the minimum distance between each segment.
|
||||
|
||||
This function connects these coordinates with a thin line to merge all segments into one.
|
||||
|
||||
Args:
|
||||
segments (list[list]): Original segmentations in COCO's JSON file.
|
||||
Each element is a list of coordinates, like [segmentation1, segmentation2,...].
|
||||
|
||||
Returns:
|
||||
s (list[np.ndarray]): A list of connected segments represented as NumPy arrays.
|
||||
"""
|
||||
s = []
|
||||
segments = [np.array(i).reshape(-1, 2) for i in segments]
|
||||
idx_list = [[] for _ in range(len(segments))]
|
||||
|
||||
# Record the indexes with min distance between each segment
|
||||
for i in range(1, len(segments)):
|
||||
idx1, idx2 = min_index(segments[i - 1], segments[i])
|
||||
idx_list[i - 1].append(idx1)
|
||||
idx_list[i].append(idx2)
|
||||
|
||||
# Use two round to connect all the segments
|
||||
for k in range(2):
|
||||
# Forward connection
|
||||
if k == 0:
|
||||
for i, idx in enumerate(idx_list):
|
||||
# Middle segments have two indexes, reverse the index of middle segments
|
||||
if len(idx) == 2 and idx[0] > idx[1]:
|
||||
idx = idx[::-1]
|
||||
segments[i] = segments[i][::-1, :]
|
||||
|
||||
segments[i] = np.roll(segments[i], -idx[0], axis=0)
|
||||
segments[i] = np.concatenate([segments[i], segments[i][:1]])
|
||||
# Deal with the first segment and the last one
|
||||
if i in {0, len(idx_list) - 1}:
|
||||
s.append(segments[i])
|
||||
else:
|
||||
idx = [0, idx[1] - idx[0]]
|
||||
s.append(segments[i][idx[0] : idx[1] + 1])
|
||||
|
||||
else:
|
||||
for i in range(len(idx_list) - 1, -1, -1):
|
||||
if i not in {0, len(idx_list) - 1}:
|
||||
idx = idx_list[i]
|
||||
nidx = abs(idx[1] - idx[0])
|
||||
s.append(segments[i][nidx:])
|
||||
return s
|
||||
|
||||
|
||||
def yolo_bbox2segment(im_dir: str | Path, save_dir: str | Path | None = None, sam_model: str = "sam_b.pt", device=None):
|
||||
"""
|
||||
Convert existing object detection dataset (bounding boxes) to segmentation dataset or oriented bounding box (OBB) in
|
||||
YOLO format. Generate segmentation data using SAM auto-annotator as needed.
|
||||
|
||||
Args:
|
||||
im_dir (str | Path): Path to image directory to convert.
|
||||
save_dir (str | Path, optional): Path to save the generated labels, labels will be saved
|
||||
into `labels-segment` in the same directory level of `im_dir` if save_dir is None.
|
||||
sam_model (str): Segmentation model to use for intermediate segmentation data.
|
||||
device (int | str, optional): The specific device to run SAM models.
|
||||
|
||||
Notes:
|
||||
The input directory structure assumed for dataset:
|
||||
|
||||
- im_dir
|
||||
├─ 001.jpg
|
||||
├─ ...
|
||||
└─ NNN.jpg
|
||||
- labels
|
||||
├─ 001.txt
|
||||
├─ ...
|
||||
└─ NNN.txt
|
||||
"""
|
||||
from ultralytics import SAM
|
||||
from ultralytics.data import YOLODataset
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
# NOTE: add placeholder to pass class index check
|
||||
dataset = YOLODataset(im_dir, data=dict(names=list(range(1000)), channels=3))
|
||||
if len(dataset.labels[0]["segments"]) > 0: # if it's segment data
|
||||
LOGGER.info("Segmentation labels detected, no need to generate new ones!")
|
||||
return
|
||||
|
||||
LOGGER.info("Detection labels detected, generating segment labels by SAM model!")
|
||||
sam_model = SAM(sam_model)
|
||||
for label in TQDM(dataset.labels, total=len(dataset.labels), desc="Generating segment labels"):
|
||||
h, w = label["shape"]
|
||||
boxes = label["bboxes"]
|
||||
if len(boxes) == 0: # skip empty labels
|
||||
continue
|
||||
boxes[:, [0, 2]] *= w
|
||||
boxes[:, [1, 3]] *= h
|
||||
im = cv2.imread(label["im_file"])
|
||||
sam_results = sam_model(im, bboxes=xywh2xyxy(boxes), verbose=False, save=False, device=device)
|
||||
label["segments"] = sam_results[0].masks.xyn
|
||||
|
||||
save_dir = Path(save_dir) if save_dir else Path(im_dir).parent / "labels-segment"
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for label in dataset.labels:
|
||||
texts = []
|
||||
lb_name = Path(label["im_file"]).with_suffix(".txt").name
|
||||
txt_file = save_dir / lb_name
|
||||
cls = label["cls"]
|
||||
for i, s in enumerate(label["segments"]):
|
||||
if len(s) == 0:
|
||||
continue
|
||||
line = (int(cls[i]), *s.reshape(-1))
|
||||
texts.append(("%g " * len(line)).rstrip() % line)
|
||||
with open(txt_file, "a", encoding="utf-8") as f:
|
||||
f.writelines(text + "\n" for text in texts)
|
||||
LOGGER.info(f"Generated segment labels saved in {save_dir}")
|
||||
|
||||
|
||||
def create_synthetic_coco_dataset():
|
||||
"""
|
||||
Create a synthetic COCO dataset with random images based on filenames from label lists.
|
||||
|
||||
This function downloads COCO labels, reads image filenames from label list files,
|
||||
creates synthetic images for train2017 and val2017 subsets, and organizes
|
||||
them in the COCO dataset structure. It uses multithreading to generate images efficiently.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.converter import create_synthetic_coco_dataset
|
||||
>>> create_synthetic_coco_dataset()
|
||||
|
||||
Notes:
|
||||
- Requires internet connection to download label files.
|
||||
- Generates random RGB images of varying sizes (480x480 to 640x640 pixels).
|
||||
- Existing test2017 directory is removed as it's not needed.
|
||||
- Reads image filenames from train2017.txt and val2017.txt files.
|
||||
"""
|
||||
|
||||
def create_synthetic_image(image_file: Path):
|
||||
"""Generate synthetic images with random sizes and colors for dataset augmentation or testing purposes."""
|
||||
if not image_file.exists():
|
||||
size = (random.randint(480, 640), random.randint(480, 640))
|
||||
Image.new(
|
||||
"RGB",
|
||||
size=size,
|
||||
color=(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)),
|
||||
).save(image_file)
|
||||
|
||||
# Download labels
|
||||
dir = DATASETS_DIR / "coco"
|
||||
url = "https://github.com/ultralytics/assets/releases/download/v0.0.0/"
|
||||
label_zip = "coco2017labels-segments.zip"
|
||||
download([url + label_zip], dir=dir.parent)
|
||||
|
||||
# Create synthetic images
|
||||
shutil.rmtree(dir / "labels" / "test2017", ignore_errors=True) # Remove test2017 directory as not needed
|
||||
with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
||||
for subset in {"train2017", "val2017"}:
|
||||
subset_dir = dir / "images" / subset
|
||||
subset_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read image filenames from label list file
|
||||
label_list_file = dir / f"{subset}.txt"
|
||||
if label_list_file.exists():
|
||||
with open(label_list_file, encoding="utf-8") as f:
|
||||
image_files = [dir / line.strip() for line in f]
|
||||
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(create_synthetic_image, image_file) for image_file in image_files]
|
||||
for _ in TQDM(as_completed(futures), total=len(futures), desc=f"Generating images for {subset}"):
|
||||
pass # The actual work is done in the background
|
||||
else:
|
||||
LOGGER.warning(f"Labels file {label_list_file} does not exist. Skipping image creation for {subset}.")
|
||||
|
||||
LOGGER.info("Synthetic COCO dataset created successfully.")
|
||||
|
||||
|
||||
def convert_to_multispectral(path: str | Path, n_channels: int = 10, replace: bool = False, zip: bool = False):
|
||||
"""
|
||||
Convert RGB images to multispectral images by interpolating across wavelength bands.
|
||||
|
||||
This function takes RGB images and interpolates them to create multispectral images with a specified number
|
||||
of channels. It can process either a single image or a directory of images.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to an image file or directory containing images to convert.
|
||||
n_channels (int): Number of spectral channels to generate in the output image.
|
||||
replace (bool): Whether to replace the original image file with the converted one.
|
||||
zip (bool): Whether to zip the converted images into a zip file.
|
||||
|
||||
Examples:
|
||||
Convert a single image
|
||||
>>> convert_to_multispectral("path/to/image.jpg", n_channels=10)
|
||||
|
||||
Convert a dataset
|
||||
>>> convert_to_multispectral("coco8", n_channels=10)
|
||||
"""
|
||||
from scipy.interpolate import interp1d
|
||||
|
||||
from ultralytics.data.utils import IMG_FORMATS
|
||||
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
# Process directory
|
||||
im_files = sum((list(path.rglob(f"*.{ext}")) for ext in (IMG_FORMATS - {"tif", "tiff"})), [])
|
||||
for im_path in im_files:
|
||||
try:
|
||||
convert_to_multispectral(im_path, n_channels)
|
||||
if replace:
|
||||
im_path.unlink()
|
||||
except Exception as e:
|
||||
LOGGER.info(f"Error converting {im_path}: {e}")
|
||||
|
||||
if zip:
|
||||
zip_directory(path)
|
||||
else:
|
||||
# Process a single image
|
||||
output_path = path.with_suffix(".tiff")
|
||||
img = cv2.cvtColor(cv2.imread(str(path)), cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Interpolate all pixels at once
|
||||
rgb_wavelengths = np.array([650, 510, 475]) # R, G, B wavelengths (nm)
|
||||
target_wavelengths = np.linspace(450, 700, n_channels)
|
||||
f = interp1d(rgb_wavelengths.T, img, kind="linear", bounds_error=False, fill_value="extrapolate")
|
||||
multispectral = f(target_wavelengths)
|
||||
cv2.imwritemulti(str(output_path), np.clip(multispectral, 0, 255).astype(np.uint8).transpose(2, 0, 1))
|
||||
LOGGER.info(f"Converted {output_path}")
|
||||
|
||||
|
||||
async def convert_ndjson_to_yolo(ndjson_path: str | Path, output_path: str | Path | None = None) -> Path:
|
||||
"""
|
||||
Convert NDJSON dataset format to Ultralytics YOLO11 dataset structure.
|
||||
|
||||
This function converts datasets stored in NDJSON (Newline Delimited JSON) format to the standard YOLO
|
||||
format with separate directories for images and labels. It supports parallel processing for efficient
|
||||
conversion of large datasets and can download images from URLs if they don't exist locally.
|
||||
|
||||
The NDJSON format consists of:
|
||||
- First line: Dataset metadata with class names and configuration
|
||||
- Subsequent lines: Individual image records with annotations and optional URLs
|
||||
|
||||
Args:
|
||||
ndjson_path (Union[str, Path]): Path to the input NDJSON file containing dataset information.
|
||||
output_path (Optional[Union[str, Path]], optional): Directory where the converted YOLO dataset
|
||||
will be saved. If None, uses the parent directory of the NDJSON file. Defaults to None.
|
||||
|
||||
Returns:
|
||||
(Path): Path to the generated data.yaml file that can be used for YOLO training.
|
||||
|
||||
Examples:
|
||||
Convert a local NDJSON file:
|
||||
>>> yaml_path = convert_ndjson_to_yolo("dataset.ndjson")
|
||||
>>> print(f"Dataset converted to: {yaml_path}")
|
||||
|
||||
Convert with custom output directory:
|
||||
>>> yaml_path = convert_ndjson_to_yolo("dataset.ndjson", output_path="./converted_datasets")
|
||||
|
||||
Use with YOLO training
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
>>> model.train(data="https://github.com/ultralytics/assets/releases/download/v0.0.0/coco8-ndjson.ndjson")
|
||||
"""
|
||||
check_requirements("aiohttp")
|
||||
import aiohttp
|
||||
|
||||
ndjson_path = Path(check_file(ndjson_path))
|
||||
output_path = Path(output_path or DATASETS_DIR)
|
||||
with open(ndjson_path) as f:
|
||||
lines = [json.loads(line.strip()) for line in f if line.strip()]
|
||||
|
||||
dataset_record, image_records = lines[0], lines[1:]
|
||||
dataset_dir = output_path / ndjson_path.stem
|
||||
splits = {record["split"] for record in image_records}
|
||||
|
||||
# Create directories and prepare YAML structure
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_yaml = dict(dataset_record)
|
||||
data_yaml["names"] = {int(k): v for k, v in dataset_record.get("class_names", {}).items()}
|
||||
data_yaml.pop("class_names")
|
||||
|
||||
for split in sorted(splits):
|
||||
(dataset_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(dataset_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
data_yaml[split] = f"images/{split}"
|
||||
|
||||
async def process_record(session, semaphore, record):
|
||||
"""Process single image record with async session."""
|
||||
async with semaphore:
|
||||
split, original_name = record["split"], record["file"]
|
||||
label_path = dataset_dir / "labels" / split / f"{Path(original_name).stem}.txt"
|
||||
image_path = dataset_dir / "images" / split / original_name
|
||||
|
||||
annotations = record.get("annotations", {})
|
||||
lines_to_write = []
|
||||
for key in annotations.keys():
|
||||
lines_to_write = [" ".join(map(str, item)) for item in annotations[key]]
|
||||
break
|
||||
if "classification" in annotations:
|
||||
lines_to_write = [str(cls) for cls in annotations["classification"]]
|
||||
|
||||
label_path.write_text("\n".join(lines_to_write) + "\n" if lines_to_write else "")
|
||||
|
||||
if http_url := record.get("url"):
|
||||
if not image_path.exists():
|
||||
try:
|
||||
async with session.get(http_url, timeout=aiohttp.ClientTimeout(total=30)) as response:
|
||||
response.raise_for_status()
|
||||
with open(image_path, "wb") as f:
|
||||
async for chunk in response.content.iter_chunked(8192):
|
||||
f.write(chunk)
|
||||
return True
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to download {http_url}: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
# Process all images with async downloads
|
||||
semaphore = asyncio.Semaphore(64)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
pbar = TQDM(
|
||||
total=len(image_records),
|
||||
desc=f"Converting {ndjson_path.name} → {dataset_dir} ({len(image_records)} images)",
|
||||
)
|
||||
|
||||
async def tracked_process(record):
|
||||
result = await process_record(session, semaphore, record)
|
||||
pbar.update(1)
|
||||
return result
|
||||
|
||||
await asyncio.gather(*[tracked_process(record) for record in image_records])
|
||||
pbar.close()
|
||||
|
||||
# Write data.yaml
|
||||
yaml_path = dataset_dir / "data.yaml"
|
||||
YAML.save(yaml_path, data_yaml)
|
||||
|
||||
return yaml_path
|
||||
862
ultralytics/data/dataset.py
Normal file
862
ultralytics/data/dataset.py
Normal file
@@ -0,0 +1,862 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import ConcatDataset
|
||||
|
||||
from ultralytics.utils import LOCAL_RANK, LOGGER, NUM_THREADS, TQDM, colorstr
|
||||
from ultralytics.utils.instance import Instances
|
||||
from ultralytics.utils.ops import resample_segments, segments2boxes
|
||||
from ultralytics.utils.torch_utils import TORCHVISION_0_18
|
||||
|
||||
from .augment import (
|
||||
Compose,
|
||||
Format,
|
||||
LetterBox,
|
||||
RandomLoadText,
|
||||
classify_augmentations,
|
||||
classify_transforms,
|
||||
v8_transforms,
|
||||
)
|
||||
from .base import BaseDataset
|
||||
from .converter import merge_multi_segment
|
||||
from .utils import (
|
||||
HELP_URL,
|
||||
check_file_speeds,
|
||||
get_hash,
|
||||
img2label_paths,
|
||||
load_dataset_cache_file,
|
||||
save_dataset_cache_file,
|
||||
verify_image,
|
||||
verify_image_label,
|
||||
)
|
||||
|
||||
# Ultralytics dataset *.cache version, >= 1.0.0 for Ultralytics YOLO models
|
||||
DATASET_CACHE_VERSION = "1.0.3"
|
||||
|
||||
|
||||
class YOLODataset(BaseDataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format.
|
||||
|
||||
This class supports loading data for object detection, segmentation, pose estimation, and oriented bounding box
|
||||
(OBB) tasks using the YOLO format.
|
||||
|
||||
Attributes:
|
||||
use_segments (bool): Indicates if segmentation masks should be used.
|
||||
use_keypoints (bool): Indicates if keypoints should be used for pose estimation.
|
||||
use_obb (bool): Indicates if oriented bounding boxes should be used.
|
||||
data (dict): Dataset configuration dictionary.
|
||||
|
||||
Methods:
|
||||
cache_labels: Cache dataset labels, check images and read shapes.
|
||||
get_labels: Return dictionary of labels for YOLO training.
|
||||
build_transforms: Build and append transforms to the list.
|
||||
close_mosaic: Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
||||
update_labels_info: Update label format for different tasks.
|
||||
collate_fn: Collate data samples into batches.
|
||||
|
||||
Examples:
|
||||
>>> dataset = YOLODataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
||||
>>> dataset.get_labels()
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||||
"""
|
||||
Initialize the YOLODataset.
|
||||
|
||||
Args:
|
||||
data (dict, optional): Dataset configuration dictionary.
|
||||
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
|
||||
*args (Any): Additional positional arguments for the parent class.
|
||||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||||
"""
|
||||
self.use_segments = task == "segment"
|
||||
self.use_keypoints = task == "pose"
|
||||
self.use_obb = task == "obb"
|
||||
self.data = data
|
||||
assert not (self.use_segments and self.use_keypoints), "Can not use both segments and keypoints."
|
||||
super().__init__(*args, channels=self.data.get("channels", 3), **kwargs)
|
||||
|
||||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict:
|
||||
"""
|
||||
Cache dataset labels, check images and read shapes.
|
||||
|
||||
Args:
|
||||
path (Path): Path where to save the cache file.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing cached labels and related information.
|
||||
"""
|
||||
x = {"labels": []}
|
||||
nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number missing, found, empty, corrupt, messages
|
||||
desc = f"{self.prefix}Scanning {path.parent / path.stem}..."
|
||||
total = len(self.im_files)
|
||||
nkpt, ndim = self.data.get("kpt_shape", (0, 0))
|
||||
if self.use_keypoints and (nkpt <= 0 or ndim not in {2, 3}):
|
||||
raise ValueError(
|
||||
"'kpt_shape' in data.yaml missing or incorrect. Should be a list with [number of "
|
||||
"keypoints, number of dims (2 for x,y or 3 for x,y,visible)], i.e. 'kpt_shape: [17, 3]'"
|
||||
)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(
|
||||
func=verify_image_label,
|
||||
iterable=zip(
|
||||
self.im_files,
|
||||
self.label_files,
|
||||
repeat(self.prefix),
|
||||
repeat(self.use_keypoints),
|
||||
repeat(len(self.data["names"])),
|
||||
repeat(nkpt),
|
||||
repeat(ndim),
|
||||
repeat(self.single_cls),
|
||||
),
|
||||
)
|
||||
pbar = TQDM(results, desc=desc, total=total)
|
||||
for im_file, lb, shape, segments, keypoint, nm_f, nf_f, ne_f, nc_f, msg in pbar:
|
||||
nm += nm_f
|
||||
nf += nf_f
|
||||
ne += ne_f
|
||||
nc += nc_f
|
||||
if im_file:
|
||||
x["labels"].append(
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": shape,
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"segments": segments,
|
||||
"keypoints": keypoint,
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
}
|
||||
)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
pbar.desc = f"{desc} {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
pbar.close()
|
||||
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
if nf == 0:
|
||||
LOGGER.warning(f"{self.prefix}No labels found in {path}. {HELP_URL}")
|
||||
x["hash"] = get_hash(self.label_files + self.im_files)
|
||||
x["results"] = nf, nm, ne, nc, len(self.im_files)
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return x
|
||||
|
||||
def get_labels(self) -> list[dict]:
|
||||
"""
|
||||
Return dictionary of labels for YOLO training.
|
||||
|
||||
This method loads labels from disk or cache, verifies their integrity, and prepares them for training.
|
||||
|
||||
Returns:
|
||||
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
||||
"""
|
||||
self.label_files = img2label_paths(self.im_files)
|
||||
cache_path = Path(self.label_files[0]).parent.with_suffix(".cache")
|
||||
try:
|
||||
cache, exists = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash(self.label_files + self.im_files) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
||||
cache, exists = self.cache_labels(cache_path), False # run cache ops
|
||||
|
||||
# Display cache
|
||||
nf, nm, ne, nc, n = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if exists and LOCAL_RANK in {-1, 0}:
|
||||
d = f"Scanning {cache_path}... {nf} images, {nm + ne} backgrounds, {nc} corrupt"
|
||||
TQDM(None, desc=self.prefix + d, total=n, initial=n) # display results
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
|
||||
# Read cache
|
||||
[cache.pop(k) for k in ("hash", "version", "msgs")] # remove items
|
||||
labels = cache["labels"]
|
||||
if not labels:
|
||||
raise RuntimeError(
|
||||
f"No valid images found in {cache_path}. Images with incorrectly formatted labels are ignored. {HELP_URL}"
|
||||
)
|
||||
self.im_files = [lb["im_file"] for lb in labels] # update im_files
|
||||
|
||||
# Check if the dataset is all boxes or all segments
|
||||
lengths = ((len(lb["cls"]), len(lb["bboxes"]), len(lb["segments"])) for lb in labels)
|
||||
len_cls, len_boxes, len_segments = (sum(x) for x in zip(*lengths))
|
||||
if len_segments and len_boxes != len_segments:
|
||||
LOGGER.warning(
|
||||
f"Box and segment counts should be equal, but got len(segments) = {len_segments}, "
|
||||
f"len(boxes) = {len_boxes}. To resolve this only boxes will be used and all segments will be removed. "
|
||||
"To avoid this please supply either a detect or segment dataset, not a detect-segment mixed dataset."
|
||||
)
|
||||
for lb in labels:
|
||||
lb["segments"] = []
|
||||
if len_cls == 0:
|
||||
LOGGER.warning(f"Labels are missing or empty in {cache_path}, training may not work correctly. {HELP_URL}")
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Build and append transforms to the list.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
||||
Returns:
|
||||
(Compose): Composed transforms.
|
||||
"""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
|
||||
transforms = v8_transforms(self, self.imgsz, hyp)
|
||||
else:
|
||||
transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), scaleup=False)])
|
||||
transforms.append(
|
||||
Format(
|
||||
bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
return_obb=self.use_obb,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
bgr=hyp.bgr if self.augment else 0.0, # only affect training.
|
||||
)
|
||||
)
|
||||
return transforms
|
||||
|
||||
def close_mosaic(self, hyp: dict) -> None:
|
||||
"""
|
||||
Disable mosaic, copy_paste, mixup and cutmix augmentations by setting their probabilities to 0.0.
|
||||
|
||||
Args:
|
||||
hyp (dict): Hyperparameters for transforms.
|
||||
"""
|
||||
hyp.mosaic = 0.0
|
||||
hyp.copy_paste = 0.0
|
||||
hyp.mixup = 0.0
|
||||
hyp.cutmix = 0.0
|
||||
self.transforms = self.build_transforms(hyp)
|
||||
|
||||
def update_labels_info(self, label: dict) -> dict:
|
||||
"""
|
||||
Update label format for different tasks.
|
||||
|
||||
Args:
|
||||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||||
|
||||
Returns:
|
||||
(dict): Updated label dictionary with instances.
|
||||
|
||||
Note:
|
||||
cls is not with bboxes now, classification and semantic segmentation need an independent cls label
|
||||
Can also support classification and semantic segmentation by adding or removing dict keys there.
|
||||
"""
|
||||
bboxes = label.pop("bboxes")
|
||||
segments = label.pop("segments", [])
|
||||
keypoints = label.pop("keypoints", None)
|
||||
bbox_format = label.pop("bbox_format")
|
||||
normalized = label.pop("normalized")
|
||||
|
||||
# NOTE: do NOT resample oriented boxes
|
||||
segment_resamples = 100 if self.use_obb else 1000
|
||||
if len(segments) > 0:
|
||||
# make sure segments interpolate correctly if original length is greater than segment_resamples
|
||||
max_len = max(len(s) for s in segments)
|
||||
segment_resamples = (max_len + 1) if segment_resamples < max_len else segment_resamples
|
||||
# list[np.array(segment_resamples, 2)] * num_samples
|
||||
segments = np.stack(resample_segments(segments, n=segment_resamples), axis=0)
|
||||
else:
|
||||
segments = np.zeros((0, segment_resamples, 2), dtype=np.float32)
|
||||
label["instances"] = Instances(bboxes, segments, keypoints, bbox_format=bbox_format, normalized=normalized)
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch: list[dict]) -> dict:
|
||||
"""
|
||||
Collate data samples into batches.
|
||||
|
||||
Args:
|
||||
batch (list[dict]): List of dictionaries containing sample data.
|
||||
|
||||
Returns:
|
||||
(dict): Collated batch with stacked tensors.
|
||||
"""
|
||||
new_batch = {}
|
||||
batch = [dict(sorted(b.items())) for b in batch] # make sure the keys are in the same order
|
||||
keys = batch[0].keys()
|
||||
values = list(zip(*[list(b.values()) for b in batch]))
|
||||
for i, k in enumerate(keys):
|
||||
value = values[i]
|
||||
if k in {"img", "text_feats"}:
|
||||
value = torch.stack(value, 0)
|
||||
elif k == "visuals":
|
||||
value = torch.nn.utils.rnn.pad_sequence(value, batch_first=True)
|
||||
if k in {"masks", "keypoints", "bboxes", "cls", "segments", "obb"}:
|
||||
value = torch.cat(value, 0)
|
||||
new_batch[k] = value
|
||||
new_batch["batch_idx"] = list(new_batch["batch_idx"])
|
||||
for i in range(len(new_batch["batch_idx"])):
|
||||
new_batch["batch_idx"][i] += i # add target image index for build_targets()
|
||||
new_batch["batch_idx"] = torch.cat(new_batch["batch_idx"], 0)
|
||||
return new_batch
|
||||
|
||||
|
||||
class YOLOMultiModalDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for loading object detection and/or segmentation labels in YOLO format with multi-modal support.
|
||||
|
||||
This class extends YOLODataset to add text information for multi-modal model training, enabling models to
|
||||
process both image and text data.
|
||||
|
||||
Methods:
|
||||
update_labels_info: Add text information for multi-modal model training.
|
||||
build_transforms: Enhance data transformations with text augmentation.
|
||||
|
||||
Examples:
|
||||
>>> dataset = YOLOMultiModalDataset(img_path="path/to/images", data={"names": {0: "person"}}, task="detect")
|
||||
>>> batch = next(iter(dataset))
|
||||
>>> print(batch.keys()) # Should include 'texts'
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data: dict | None = None, task: str = "detect", **kwargs):
|
||||
"""
|
||||
Initialize a YOLOMultiModalDataset.
|
||||
|
||||
Args:
|
||||
data (dict, optional): Dataset configuration dictionary.
|
||||
task (str): Task type, one of 'detect', 'segment', 'pose', or 'obb'.
|
||||
*args (Any): Additional positional arguments for the parent class.
|
||||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||||
"""
|
||||
super().__init__(*args, data=data, task=task, **kwargs)
|
||||
|
||||
def update_labels_info(self, label: dict) -> dict:
|
||||
"""
|
||||
Add text information for multi-modal model training.
|
||||
|
||||
Args:
|
||||
label (dict): Label dictionary containing bboxes, segments, keypoints, etc.
|
||||
|
||||
Returns:
|
||||
(dict): Updated label dictionary with instances and texts.
|
||||
"""
|
||||
labels = super().update_labels_info(label)
|
||||
# NOTE: some categories are concatenated with its synonyms by `/`.
|
||||
# NOTE: and `RandomLoadText` would randomly select one of them if there are multiple words.
|
||||
labels["texts"] = [v.split("/") for _, v in self.data["names"].items()]
|
||||
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Enhance data transformations with optional text augmentation for multi-modal training.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
||||
Returns:
|
||||
(Compose): Composed transforms including text augmentation if applicable.
|
||||
"""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
# NOTE: this implementation is different from official yoloe,
|
||||
# the strategy of selecting negative is restricted in one dataset,
|
||||
# while official pre-saved neg embeddings from all datasets at once.
|
||||
transform = RandomLoadText(
|
||||
max_samples=min(self.data["nc"], 80),
|
||||
padding=True,
|
||||
padding_value=self._get_neg_texts(self.category_freq),
|
||||
)
|
||||
transforms.insert(-1, transform)
|
||||
return transforms
|
||||
|
||||
@property
|
||||
def category_names(self):
|
||||
"""
|
||||
Return category names for the dataset.
|
||||
|
||||
Returns:
|
||||
(set[str]): List of class names.
|
||||
"""
|
||||
names = self.data["names"].values()
|
||||
return {n.strip() for name in names for n in name.split("/")} # category names
|
||||
|
||||
@property
|
||||
def category_freq(self):
|
||||
"""Return frequency of each category in the dataset."""
|
||||
texts = [v.split("/") for v in self.data["names"].values()]
|
||||
category_freq = defaultdict(int)
|
||||
for label in self.labels:
|
||||
for c in label["cls"].squeeze(-1): # to check
|
||||
text = texts[int(c)]
|
||||
for t in text:
|
||||
t = t.strip()
|
||||
category_freq[t] += 1
|
||||
return category_freq
|
||||
|
||||
@staticmethod
|
||||
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
||||
"""Get negative text samples based on frequency threshold."""
|
||||
threshold = min(max(category_freq.values()), 100)
|
||||
return [k for k, v in category_freq.items() if v >= threshold]
|
||||
|
||||
|
||||
class GroundingDataset(YOLODataset):
|
||||
"""
|
||||
Dataset class for object detection tasks using annotations from a JSON file in grounding format.
|
||||
|
||||
This dataset is designed for grounding tasks where annotations are provided in a JSON file rather than
|
||||
the standard YOLO format text files.
|
||||
|
||||
Attributes:
|
||||
json_file (str): Path to the JSON file containing annotations.
|
||||
|
||||
Methods:
|
||||
get_img_files: Return empty list as image files are read in get_labels.
|
||||
get_labels: Load annotations from a JSON file and prepare them for training.
|
||||
build_transforms: Configure augmentations for training with optional text loading.
|
||||
|
||||
Examples:
|
||||
>>> dataset = GroundingDataset(img_path="path/to/images", json_file="annotations.json", task="detect")
|
||||
>>> len(dataset) # Number of valid images with annotations
|
||||
"""
|
||||
|
||||
def __init__(self, *args, task: str = "detect", json_file: str = "", max_samples: int = 80, **kwargs):
|
||||
"""
|
||||
Initialize a GroundingDataset for object detection.
|
||||
|
||||
Args:
|
||||
json_file (str): Path to the JSON file containing annotations.
|
||||
task (str): Must be 'detect' or 'segment' for GroundingDataset.
|
||||
max_samples (int): Maximum number of samples to load for text augmentation.
|
||||
*args (Any): Additional positional arguments for the parent class.
|
||||
**kwargs (Any): Additional keyword arguments for the parent class.
|
||||
"""
|
||||
assert task in {"detect", "segment"}, "GroundingDataset currently only supports `detect` and `segment` tasks"
|
||||
self.json_file = json_file
|
||||
self.max_samples = max_samples
|
||||
super().__init__(*args, task=task, data={"channels": 3}, **kwargs)
|
||||
|
||||
def get_img_files(self, img_path: str) -> list:
|
||||
"""
|
||||
The image files would be read in `get_labels` function, return empty list here.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the directory containing images.
|
||||
|
||||
Returns:
|
||||
(list): Empty list as image files are read in get_labels.
|
||||
"""
|
||||
return []
|
||||
|
||||
def verify_labels(self, labels: list[dict[str, Any]]) -> None:
|
||||
"""
|
||||
Verify the number of instances in the dataset matches expected counts.
|
||||
|
||||
This method checks if the total number of bounding box instances in the provided
|
||||
labels matches the expected count for known datasets. It performs validation
|
||||
against a predefined set of datasets with known instance counts.
|
||||
|
||||
Args:
|
||||
labels (list[dict[str, Any]]): List of label dictionaries, where each dictionary
|
||||
contains dataset annotations. Each label dict must have a 'bboxes' key with
|
||||
a numpy array or tensor containing bounding box coordinates.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the actual instance count doesn't match the expected count
|
||||
for a recognized dataset.
|
||||
|
||||
Note:
|
||||
For unrecognized datasets (those not in the predefined expected_counts),
|
||||
a warning is logged and verification is skipped.
|
||||
"""
|
||||
expected_counts = {
|
||||
"final_mixed_train_no_coco_segm": 3662412,
|
||||
"final_mixed_train_no_coco": 3681235,
|
||||
"final_flickr_separateGT_train_segm": 638214,
|
||||
"final_flickr_separateGT_train": 640704,
|
||||
}
|
||||
|
||||
instance_count = sum(label["bboxes"].shape[0] for label in labels)
|
||||
for data_name, count in expected_counts.items():
|
||||
if data_name in self.json_file:
|
||||
assert instance_count == count, f"'{self.json_file}' has {instance_count} instances, expected {count}."
|
||||
return
|
||||
LOGGER.warning(f"Skipping instance count verification for unrecognized dataset '{self.json_file}'")
|
||||
|
||||
def cache_labels(self, path: Path = Path("./labels.cache")) -> dict[str, Any]:
|
||||
"""
|
||||
Load annotations from a JSON file, filter, and normalize bounding boxes for each image.
|
||||
|
||||
Args:
|
||||
path (Path): Path where to save the cache file.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Dictionary containing cached labels and related information.
|
||||
"""
|
||||
x = {"labels": []}
|
||||
LOGGER.info("Loading annotation file...")
|
||||
with open(self.json_file) as f:
|
||||
annotations = json.load(f)
|
||||
images = {f"{x['id']:d}": x for x in annotations["images"]}
|
||||
img_to_anns = defaultdict(list)
|
||||
for ann in annotations["annotations"]:
|
||||
img_to_anns[ann["image_id"]].append(ann)
|
||||
for img_id, anns in TQDM(img_to_anns.items(), desc=f"Reading annotations {self.json_file}"):
|
||||
img = images[f"{img_id:d}"]
|
||||
h, w, f = img["height"], img["width"], img["file_name"]
|
||||
im_file = Path(self.img_path) / f
|
||||
if not im_file.exists():
|
||||
continue
|
||||
self.im_files.append(str(im_file))
|
||||
bboxes = []
|
||||
segments = []
|
||||
cat2id = {}
|
||||
texts = []
|
||||
for ann in anns:
|
||||
if ann["iscrowd"]:
|
||||
continue
|
||||
box = np.array(ann["bbox"], dtype=np.float32)
|
||||
box[:2] += box[2:] / 2
|
||||
box[[0, 2]] /= float(w)
|
||||
box[[1, 3]] /= float(h)
|
||||
if box[2] <= 0 or box[3] <= 0:
|
||||
continue
|
||||
|
||||
caption = img["caption"]
|
||||
cat_name = " ".join([caption[t[0] : t[1]] for t in ann["tokens_positive"]]).lower().strip()
|
||||
if not cat_name:
|
||||
continue
|
||||
|
||||
if cat_name not in cat2id:
|
||||
cat2id[cat_name] = len(cat2id)
|
||||
texts.append([cat_name])
|
||||
cls = cat2id[cat_name] # class
|
||||
box = [cls] + box.tolist()
|
||||
if box not in bboxes:
|
||||
bboxes.append(box)
|
||||
if ann.get("segmentation") is not None:
|
||||
if len(ann["segmentation"]) == 0:
|
||||
segments.append(box)
|
||||
continue
|
||||
elif len(ann["segmentation"]) > 1:
|
||||
s = merge_multi_segment(ann["segmentation"])
|
||||
s = (np.concatenate(s, axis=0) / np.array([w, h], dtype=np.float32)).reshape(-1).tolist()
|
||||
else:
|
||||
s = [j for i in ann["segmentation"] for j in i] # all segments concatenated
|
||||
s = (
|
||||
(np.array(s, dtype=np.float32).reshape(-1, 2) / np.array([w, h], dtype=np.float32))
|
||||
.reshape(-1)
|
||||
.tolist()
|
||||
)
|
||||
s = [cls] + s
|
||||
segments.append(s)
|
||||
lb = np.array(bboxes, dtype=np.float32) if len(bboxes) else np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
if segments:
|
||||
classes = np.array([x[0] for x in segments], dtype=np.float32)
|
||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in segments] # (cls, xy1...)
|
||||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||||
lb = np.array(lb, dtype=np.float32)
|
||||
|
||||
x["labels"].append(
|
||||
{
|
||||
"im_file": im_file,
|
||||
"shape": (h, w),
|
||||
"cls": lb[:, 0:1], # n, 1
|
||||
"bboxes": lb[:, 1:], # n, 4
|
||||
"segments": segments,
|
||||
"normalized": True,
|
||||
"bbox_format": "xywh",
|
||||
"texts": texts,
|
||||
}
|
||||
)
|
||||
x["hash"] = get_hash(self.json_file)
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return x
|
||||
|
||||
def get_labels(self) -> list[dict]:
|
||||
"""
|
||||
Load labels from cache or generate them from JSON file.
|
||||
|
||||
Returns:
|
||||
(list[dict]): List of label dictionaries, each containing information about an image and its annotations.
|
||||
"""
|
||||
cache_path = Path(self.json_file).with_suffix(".cache")
|
||||
try:
|
||||
cache, _ = load_dataset_cache_file(cache_path), True # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash(self.json_file) # identical hash
|
||||
except (FileNotFoundError, AssertionError, AttributeError, ModuleNotFoundError):
|
||||
cache, _ = self.cache_labels(cache_path), False # run cache ops
|
||||
[cache.pop(k) for k in ("hash", "version")] # remove items
|
||||
labels = cache["labels"]
|
||||
self.verify_labels(labels)
|
||||
self.im_files = [str(label["im_file"]) for label in labels]
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
LOGGER.info(f"Load {self.json_file} from cache file {cache_path}")
|
||||
return labels
|
||||
|
||||
def build_transforms(self, hyp: dict | None = None) -> Compose:
|
||||
"""
|
||||
Configure augmentations for training with optional text loading.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transforms.
|
||||
|
||||
Returns:
|
||||
(Compose): Composed transforms including text augmentation if applicable.
|
||||
"""
|
||||
transforms = super().build_transforms(hyp)
|
||||
if self.augment:
|
||||
# NOTE: hard-coded the args for now.
|
||||
# NOTE: this implementation is different from official yoloe,
|
||||
# the strategy of selecting negative is restricted in one dataset,
|
||||
# while official pre-saved neg embeddings from all datasets at once.
|
||||
transform = RandomLoadText(
|
||||
max_samples=min(self.max_samples, 80),
|
||||
padding=True,
|
||||
padding_value=self._get_neg_texts(self.category_freq),
|
||||
)
|
||||
transforms.insert(-1, transform)
|
||||
return transforms
|
||||
|
||||
@property
|
||||
def category_names(self):
|
||||
"""Return unique category names from the dataset."""
|
||||
return {t.strip() for label in self.labels for text in label["texts"] for t in text}
|
||||
|
||||
@property
|
||||
def category_freq(self):
|
||||
"""Return frequency of each category in the dataset."""
|
||||
category_freq = defaultdict(int)
|
||||
for label in self.labels:
|
||||
for text in label["texts"]:
|
||||
for t in text:
|
||||
t = t.strip()
|
||||
category_freq[t] += 1
|
||||
return category_freq
|
||||
|
||||
@staticmethod
|
||||
def _get_neg_texts(category_freq: dict, threshold: int = 100) -> list[str]:
|
||||
"""Get negative text samples based on frequency threshold."""
|
||||
threshold = min(max(category_freq.values()), 100)
|
||||
return [k for k, v in category_freq.items() if v >= threshold]
|
||||
|
||||
|
||||
class YOLOConcatDataset(ConcatDataset):
|
||||
"""
|
||||
Dataset as a concatenation of multiple datasets.
|
||||
|
||||
This class is useful to assemble different existing datasets for YOLO training, ensuring they use the same
|
||||
collation function.
|
||||
|
||||
Methods:
|
||||
collate_fn: Static method that collates data samples into batches using YOLODataset's collation function.
|
||||
|
||||
Examples:
|
||||
>>> dataset1 = YOLODataset(...)
|
||||
>>> dataset2 = YOLODataset(...)
|
||||
>>> combined_dataset = YOLOConcatDataset([dataset1, dataset2])
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch: list[dict]) -> dict:
|
||||
"""
|
||||
Collate data samples into batches.
|
||||
|
||||
Args:
|
||||
batch (list[dict]): List of dictionaries containing sample data.
|
||||
|
||||
Returns:
|
||||
(dict): Collated batch with stacked tensors.
|
||||
"""
|
||||
return YOLODataset.collate_fn(batch)
|
||||
|
||||
def close_mosaic(self, hyp: dict) -> None:
|
||||
"""
|
||||
Set mosaic, copy_paste and mixup options to 0.0 and build transformations.
|
||||
|
||||
Args:
|
||||
hyp (dict): Hyperparameters for transforms.
|
||||
"""
|
||||
for dataset in self.datasets:
|
||||
if not hasattr(dataset, "close_mosaic"):
|
||||
continue
|
||||
dataset.close_mosaic(hyp)
|
||||
|
||||
|
||||
# TODO: support semantic segmentation
|
||||
class SemanticDataset(BaseDataset):
|
||||
"""Semantic Segmentation Dataset."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize a SemanticDataset object."""
|
||||
super().__init__()
|
||||
|
||||
|
||||
class ClassificationDataset:
|
||||
"""
|
||||
Dataset class for image classification tasks extending torchvision ImageFolder functionality.
|
||||
|
||||
This class offers functionalities like image augmentation, caching, and verification. It's designed to efficiently
|
||||
handle large datasets for training deep learning models, with optional image transformations and caching mechanisms
|
||||
to speed up training.
|
||||
|
||||
Attributes:
|
||||
cache_ram (bool): Indicates if caching in RAM is enabled.
|
||||
cache_disk (bool): Indicates if caching on disk is enabled.
|
||||
samples (list): A list of tuples, each containing the path to an image, its class index, path to its .npy cache
|
||||
file (if caching on disk), and optionally the loaded image array (if caching in RAM).
|
||||
torch_transforms (callable): PyTorch transforms to be applied to the images.
|
||||
root (str): Root directory of the dataset.
|
||||
prefix (str): Prefix for logging and cache filenames.
|
||||
|
||||
Methods:
|
||||
__getitem__: Return subset of data and targets corresponding to given indices.
|
||||
__len__: Return the total number of samples in the dataset.
|
||||
verify_images: Verify all images in dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, root: str, args, augment: bool = False, prefix: str = ""):
|
||||
"""
|
||||
Initialize YOLO classification dataset with root directory, arguments, augmentations, and cache settings.
|
||||
|
||||
Args:
|
||||
root (str): Path to the dataset directory where images are stored in a class-specific folder structure.
|
||||
args (Namespace): Configuration containing dataset-related settings such as image size, augmentation
|
||||
parameters, and cache settings.
|
||||
augment (bool, optional): Whether to apply augmentations to the dataset.
|
||||
prefix (str, optional): Prefix for logging and cache filenames, aiding in dataset identification.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
# Base class assigned as attribute rather than used as base class to allow for scoping slow torchvision import
|
||||
if TORCHVISION_0_18: # 'allow_empty' argument first introduced in torchvision 0.18
|
||||
self.base = torchvision.datasets.ImageFolder(root=root, allow_empty=True)
|
||||
else:
|
||||
self.base = torchvision.datasets.ImageFolder(root=root)
|
||||
self.samples = self.base.samples
|
||||
self.root = self.base.root
|
||||
|
||||
# Initialize attributes
|
||||
if augment and args.fraction < 1.0: # reduce training fraction
|
||||
self.samples = self.samples[: round(len(self.samples) * args.fraction)]
|
||||
self.prefix = colorstr(f"{prefix}: ") if prefix else ""
|
||||
self.cache_ram = args.cache is True or str(args.cache).lower() == "ram" # cache images into RAM
|
||||
if self.cache_ram:
|
||||
LOGGER.warning(
|
||||
"Classification `cache_ram` training has known memory leak in "
|
||||
"https://github.com/ultralytics/ultralytics/issues/9824, setting `cache_ram=False`."
|
||||
)
|
||||
self.cache_ram = False
|
||||
self.cache_disk = str(args.cache).lower() == "disk" # cache images on hard drive as uncompressed *.npy files
|
||||
self.samples = self.verify_images() # filter out bad images
|
||||
self.samples = [list(x) + [Path(x[0]).with_suffix(".npy"), None] for x in self.samples] # file, index, npy, im
|
||||
scale = (1.0 - args.scale, 1.0) # (0.08, 1.0)
|
||||
self.torch_transforms = (
|
||||
classify_augmentations(
|
||||
size=args.imgsz,
|
||||
scale=scale,
|
||||
hflip=args.fliplr,
|
||||
vflip=args.flipud,
|
||||
erasing=args.erasing,
|
||||
auto_augment=args.auto_augment,
|
||||
hsv_h=args.hsv_h,
|
||||
hsv_s=args.hsv_s,
|
||||
hsv_v=args.hsv_v,
|
||||
)
|
||||
if augment
|
||||
else classify_transforms(size=args.imgsz)
|
||||
)
|
||||
|
||||
def __getitem__(self, i: int) -> dict:
|
||||
"""
|
||||
Return subset of data and targets corresponding to given indices.
|
||||
|
||||
Args:
|
||||
i (int): Index of the sample to retrieve.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing the image and its class index.
|
||||
"""
|
||||
f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
|
||||
if self.cache_ram:
|
||||
if im is None: # Warning: two separate if statements required here, do not combine this with previous line
|
||||
im = self.samples[i][3] = cv2.imread(f)
|
||||
elif self.cache_disk:
|
||||
if not fn.exists(): # load npy
|
||||
np.save(fn.as_posix(), cv2.imread(f), allow_pickle=False)
|
||||
im = np.load(fn)
|
||||
else: # read image
|
||||
im = cv2.imread(f) # BGR
|
||||
# Convert NumPy array to PIL image
|
||||
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
||||
sample = self.torch_transforms(im)
|
||||
return {"img": sample, "cls": j}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the total number of samples in the dataset."""
|
||||
return len(self.samples)
|
||||
|
||||
def verify_images(self) -> list[tuple]:
|
||||
"""
|
||||
Verify all images in dataset.
|
||||
|
||||
Returns:
|
||||
(list): List of valid samples after verification.
|
||||
"""
|
||||
desc = f"{self.prefix}Scanning {self.root}..."
|
||||
path = Path(self.root).with_suffix(".cache") # *.cache file path
|
||||
|
||||
try:
|
||||
check_file_speeds([file for (file, _) in self.samples[:5]], prefix=self.prefix) # check image read speeds
|
||||
cache = load_dataset_cache_file(path) # attempt to load a *.cache file
|
||||
assert cache["version"] == DATASET_CACHE_VERSION # matches current version
|
||||
assert cache["hash"] == get_hash([x[0] for x in self.samples]) # identical hash
|
||||
nf, nc, n, samples = cache.pop("results") # found, missing, empty, corrupt, total
|
||||
if LOCAL_RANK in {-1, 0}:
|
||||
d = f"{desc} {nf} images, {nc} corrupt"
|
||||
TQDM(None, desc=d, total=n, initial=n)
|
||||
if cache["msgs"]:
|
||||
LOGGER.info("\n".join(cache["msgs"])) # display warnings
|
||||
return samples
|
||||
|
||||
except (FileNotFoundError, AssertionError, AttributeError):
|
||||
# Run scan if *.cache retrieval failed
|
||||
nf, nc, msgs, samples, x = 0, 0, [], [], {}
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
results = pool.imap(func=verify_image, iterable=zip(self.samples, repeat(self.prefix)))
|
||||
pbar = TQDM(results, desc=desc, total=len(self.samples))
|
||||
for sample, nf_f, nc_f, msg in pbar:
|
||||
if nf_f:
|
||||
samples.append(sample)
|
||||
if msg:
|
||||
msgs.append(msg)
|
||||
nf += nf_f
|
||||
nc += nc_f
|
||||
pbar.desc = f"{desc} {nf} images, {nc} corrupt"
|
||||
pbar.close()
|
||||
if msgs:
|
||||
LOGGER.info("\n".join(msgs))
|
||||
x["hash"] = get_hash([x[0] for x in self.samples])
|
||||
x["results"] = nf, nc, len(samples), samples
|
||||
x["msgs"] = msgs # warnings
|
||||
save_dataset_cache_file(self.prefix, path, x, DATASET_CACHE_VERSION)
|
||||
return samples
|
||||
711
ultralytics/data/loaders.py
Normal file
711
ultralytics/data/loaders.py
Normal file
@@ -0,0 +1,711 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import time
|
||||
import urllib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.utils import FORMATS_HELP_MSG, IMG_FORMATS, VID_FORMATS
|
||||
from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.patches import imread
|
||||
|
||||
|
||||
@dataclass
|
||||
class SourceTypes:
|
||||
"""
|
||||
Class to represent various types of input sources for predictions.
|
||||
|
||||
This class uses dataclass to define boolean flags for different types of input sources that can be used for
|
||||
making predictions with YOLO models.
|
||||
|
||||
Attributes:
|
||||
stream (bool): Flag indicating if the input source is a video stream.
|
||||
screenshot (bool): Flag indicating if the input source is a screenshot.
|
||||
from_img (bool): Flag indicating if the input source is an image file.
|
||||
tensor (bool): Flag indicating if the input source is a tensor.
|
||||
|
||||
Examples:
|
||||
>>> source_types = SourceTypes(stream=True, screenshot=False, from_img=False)
|
||||
>>> print(source_types.stream)
|
||||
True
|
||||
>>> print(source_types.from_img)
|
||||
False
|
||||
"""
|
||||
|
||||
stream: bool = False
|
||||
screenshot: bool = False
|
||||
from_img: bool = False
|
||||
tensor: bool = False
|
||||
|
||||
|
||||
class LoadStreams:
|
||||
"""
|
||||
Stream Loader for various types of video streams.
|
||||
|
||||
Supports RTSP, RTMP, HTTP, and TCP streams. This class handles the loading and processing of multiple video
|
||||
streams simultaneously, making it suitable for real-time video analysis tasks.
|
||||
|
||||
Attributes:
|
||||
sources (list[str]): The source input paths or URLs for the video streams.
|
||||
vid_stride (int): Video frame-rate stride.
|
||||
buffer (bool): Whether to buffer input streams.
|
||||
running (bool): Flag to indicate if the streaming thread is running.
|
||||
mode (str): Set to 'stream' indicating real-time capture.
|
||||
imgs (list[list[np.ndarray]]): List of image frames for each stream.
|
||||
fps (list[float]): List of FPS for each stream.
|
||||
frames (list[int]): List of total frames for each stream.
|
||||
threads (list[Thread]): List of threads for each stream.
|
||||
shape (list[tuple[int, int, int]]): List of shapes for each stream.
|
||||
caps (list[cv2.VideoCapture]): List of cv2.VideoCapture objects for each stream.
|
||||
bs (int): Batch size for processing.
|
||||
cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).
|
||||
|
||||
Methods:
|
||||
update: Read stream frames in daemon thread.
|
||||
close: Close stream loader and release resources.
|
||||
__iter__: Returns an iterator object for the class.
|
||||
__next__: Returns source paths, transformed, and original images for processing.
|
||||
__len__: Return the length of the sources object.
|
||||
|
||||
Examples:
|
||||
>>> stream_loader = LoadStreams("rtsp://example.com/stream1.mp4")
|
||||
>>> for sources, imgs, _ in stream_loader:
|
||||
... # Process the images
|
||||
... pass
|
||||
>>> stream_loader.close()
|
||||
|
||||
Notes:
|
||||
- The class uses threading to efficiently load frames from multiple streams simultaneously.
|
||||
- It automatically handles YouTube links, converting them to the best available stream URL.
|
||||
- The class implements a buffer system to manage frame storage and retrieval.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: str = "file.streams", vid_stride: int = 1, buffer: bool = False, channels: int = 3):
|
||||
"""
|
||||
Initialize stream loader for multiple video sources, supporting various stream types.
|
||||
|
||||
Args:
|
||||
sources (str): Path to streams file or single stream URL.
|
||||
vid_stride (int): Video frame-rate stride.
|
||||
buffer (bool): Whether to buffer input streams.
|
||||
channels (int): Number of image channels (1 for grayscale, 3 for RGB).
|
||||
"""
|
||||
torch.backends.cudnn.benchmark = True # faster for fixed-size inference
|
||||
self.buffer = buffer # buffer input streams
|
||||
self.running = True # running flag for Thread
|
||||
self.mode = "stream"
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB
|
||||
|
||||
sources = Path(sources).read_text().rsplit() if os.path.isfile(sources) else [sources]
|
||||
n = len(sources)
|
||||
self.bs = n
|
||||
self.fps = [0] * n # frames per second
|
||||
self.frames = [0] * n
|
||||
self.threads = [None] * n
|
||||
self.caps = [None] * n # video capture objects
|
||||
self.imgs = [[] for _ in range(n)] # images
|
||||
self.shape = [[] for _ in range(n)] # image shapes
|
||||
self.sources = [ops.clean_str(x).replace(os.sep, "_") for x in sources] # clean source names for later
|
||||
for i, s in enumerate(sources): # index, source
|
||||
# Start thread to read frames from video stream
|
||||
st = f"{i + 1}/{n}: {s}... "
|
||||
if urllib.parse.urlparse(s).hostname in {"www.youtube.com", "youtube.com", "youtu.be"}: # YouTube video
|
||||
# YouTube format i.e. 'https://www.youtube.com/watch?v=Jsn8D3aC840' or 'https://youtu.be/Jsn8D3aC840'
|
||||
s = get_best_youtube_url(s)
|
||||
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
|
||||
if s == 0 and (IS_COLAB or IS_KAGGLE):
|
||||
raise NotImplementedError(
|
||||
"'source=0' webcam not supported in Colab and Kaggle notebooks. "
|
||||
"Try running 'source=0' in a local environment."
|
||||
)
|
||||
self.caps[i] = cv2.VideoCapture(s) # store video capture object
|
||||
if not self.caps[i].isOpened():
|
||||
raise ConnectionError(f"{st}Failed to open {s}")
|
||||
w = int(self.caps[i].get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
h = int(self.caps[i].get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = self.caps[i].get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
|
||||
self.frames[i] = max(int(self.caps[i].get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float(
|
||||
"inf"
|
||||
) # infinite stream fallback
|
||||
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
|
||||
|
||||
success, im = self.caps[i].read() # guarantee first frame
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im
|
||||
if not success or im is None:
|
||||
raise ConnectionError(f"{st}Failed to read images from {s}")
|
||||
self.imgs[i].append(im)
|
||||
self.shape[i] = im.shape
|
||||
self.threads[i] = Thread(target=self.update, args=([i, self.caps[i], s]), daemon=True)
|
||||
LOGGER.info(f"{st}Success ✅ ({self.frames[i]} frames of shape {w}x{h} at {self.fps[i]:.2f} FPS)")
|
||||
self.threads[i].start()
|
||||
LOGGER.info("") # newline
|
||||
|
||||
def update(self, i: int, cap: cv2.VideoCapture, stream: str):
|
||||
"""Read stream frames in daemon thread and update image buffer."""
|
||||
n, f = 0, self.frames[i] # frame number, frame array
|
||||
while self.running and cap.isOpened() and n < (f - 1):
|
||||
if len(self.imgs[i]) < 30: # keep a <=30-image buffer
|
||||
n += 1
|
||||
cap.grab() # .read() = .grab() followed by .retrieve()
|
||||
if n % self.vid_stride == 0:
|
||||
success, im = cap.retrieve()
|
||||
im = (
|
||||
cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im
|
||||
)
|
||||
if not success:
|
||||
im = np.zeros(self.shape[i], dtype=np.uint8)
|
||||
LOGGER.warning("Video stream unresponsive, please check your IP camera connection.")
|
||||
cap.open(stream) # re-open stream if signal was lost
|
||||
if self.buffer:
|
||||
self.imgs[i].append(im)
|
||||
else:
|
||||
self.imgs[i] = [im]
|
||||
else:
|
||||
time.sleep(0.01) # wait until the buffer is empty
|
||||
|
||||
def close(self):
|
||||
"""Terminate stream loader, stop threads, and release video capture resources."""
|
||||
self.running = False # stop flag for Thread
|
||||
for thread in self.threads:
|
||||
if thread.is_alive():
|
||||
thread.join(timeout=5) # Add timeout
|
||||
for cap in self.caps: # Iterate through the stored VideoCapture objects
|
||||
try:
|
||||
cap.release() # release video capture
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Could not release VideoCapture object: {e}")
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through YOLO image feed and re-open unresponsive streams."""
|
||||
self.count = -1
|
||||
return self
|
||||
|
||||
def __next__(self) -> tuple[list[str], list[np.ndarray], list[str]]:
|
||||
"""Return the next batch of frames from multiple video streams for processing."""
|
||||
self.count += 1
|
||||
|
||||
images = []
|
||||
for i, x in enumerate(self.imgs):
|
||||
# Wait until a frame is available in each buffer
|
||||
while not x:
|
||||
if not self.threads[i].is_alive():
|
||||
self.close()
|
||||
raise StopIteration
|
||||
time.sleep(1 / min(self.fps))
|
||||
x = self.imgs[i]
|
||||
if not x:
|
||||
LOGGER.warning(f"Waiting for stream {i}")
|
||||
|
||||
# Get and remove the first frame from imgs buffer
|
||||
if self.buffer:
|
||||
images.append(x.pop(0))
|
||||
|
||||
# Get the last frame, and clear the rest from the imgs buffer
|
||||
else:
|
||||
images.append(x.pop(-1) if x else np.zeros(self.shape[i], dtype=np.uint8))
|
||||
x.clear()
|
||||
|
||||
return self.sources, images, [""] * self.bs
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of video streams in the LoadStreams object."""
|
||||
return self.bs # 1E12 frames = 32 streams at 30 FPS for 30 years
|
||||
|
||||
|
||||
class LoadScreenshots:
|
||||
"""
|
||||
Ultralytics screenshot dataloader for capturing and processing screen images.
|
||||
|
||||
This class manages the loading of screenshot images for processing with YOLO. It is suitable for use with
|
||||
`yolo predict source=screen`.
|
||||
|
||||
Attributes:
|
||||
source (str): The source input indicating which screen to capture.
|
||||
screen (int): The screen number to capture.
|
||||
left (int): The left coordinate for screen capture area.
|
||||
top (int): The top coordinate for screen capture area.
|
||||
width (int): The width of the screen capture area.
|
||||
height (int): The height of the screen capture area.
|
||||
mode (str): Set to 'stream' indicating real-time capture.
|
||||
frame (int): Counter for captured frames.
|
||||
sct (mss.mss): Screen capture object from `mss` library.
|
||||
bs (int): Batch size, set to 1.
|
||||
fps (int): Frames per second, set to 30.
|
||||
monitor (dict[str, int]): Monitor configuration details.
|
||||
cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).
|
||||
|
||||
Methods:
|
||||
__iter__: Returns an iterator object.
|
||||
__next__: Captures the next screenshot and returns it.
|
||||
|
||||
Examples:
|
||||
>>> loader = LoadScreenshots("0 100 100 640 480") # screen 0, top-left (100,100), 640x480
|
||||
>>> for source, im, im0s, vid_cap, s in loader:
|
||||
... print(f"Captured frame: {im.shape}")
|
||||
"""
|
||||
|
||||
def __init__(self, source: str, channels: int = 3):
|
||||
"""
|
||||
Initialize screenshot capture with specified screen and region parameters.
|
||||
|
||||
Args:
|
||||
source (str): Screen capture source string in format "screen_num left top width height".
|
||||
channels (int): Number of image channels (1 for grayscale, 3 for RGB).
|
||||
"""
|
||||
check_requirements("mss")
|
||||
import mss # noqa
|
||||
|
||||
source, *params = source.split()
|
||||
self.screen, left, top, width, height = 0, None, None, None, None # default to full screen 0
|
||||
if len(params) == 1:
|
||||
self.screen = int(params[0])
|
||||
elif len(params) == 4:
|
||||
left, top, width, height = (int(x) for x in params)
|
||||
elif len(params) == 5:
|
||||
self.screen, left, top, width, height = (int(x) for x in params)
|
||||
self.mode = "stream"
|
||||
self.frame = 0
|
||||
self.sct = mss.mss()
|
||||
self.bs = 1
|
||||
self.fps = 30
|
||||
self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB
|
||||
|
||||
# Parse monitor shape
|
||||
monitor = self.sct.monitors[self.screen]
|
||||
self.top = monitor["top"] if top is None else (monitor["top"] + top)
|
||||
self.left = monitor["left"] if left is None else (monitor["left"] + left)
|
||||
self.width = width or monitor["width"]
|
||||
self.height = height or monitor["height"]
|
||||
self.monitor = {"left": self.left, "top": self.top, "width": self.width, "height": self.height}
|
||||
|
||||
def __iter__(self):
|
||||
"""Yield the next screenshot image from the specified screen or region for processing."""
|
||||
return self
|
||||
|
||||
def __next__(self) -> tuple[list[str], list[np.ndarray], list[str]]:
|
||||
"""Capture and return the next screenshot as a numpy array using the mss library."""
|
||||
im0 = np.asarray(self.sct.grab(self.monitor))[:, :, :3] # BGRA to BGR
|
||||
im0 = cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)[..., None] if self.cv2_flag == cv2.IMREAD_GRAYSCALE else im0
|
||||
s = f"screen {self.screen} (LTWH): {self.left},{self.top},{self.width},{self.height}: "
|
||||
|
||||
self.frame += 1
|
||||
return [str(self.screen)], [im0], [s] # screen, img, string
|
||||
|
||||
|
||||
class LoadImagesAndVideos:
|
||||
"""
|
||||
A class for loading and processing images and videos for YOLO object detection.
|
||||
|
||||
This class manages the loading and pre-processing of image and video data from various sources, including
|
||||
single image files, video files, and lists of image and video paths.
|
||||
|
||||
Attributes:
|
||||
files (list[str]): List of image and video file paths.
|
||||
nf (int): Total number of files (images and videos).
|
||||
video_flag (list[bool]): Flags indicating whether a file is a video (True) or an image (False).
|
||||
mode (str): Current mode, 'image' or 'video'.
|
||||
vid_stride (int): Stride for video frame-rate.
|
||||
bs (int): Batch size.
|
||||
cap (cv2.VideoCapture): Video capture object for OpenCV.
|
||||
frame (int): Frame counter for video.
|
||||
frames (int): Total number of frames in the video.
|
||||
count (int): Counter for iteration, initialized at 0 during __iter__().
|
||||
ni (int): Number of images.
|
||||
cv2_flag (int): OpenCV flag for image reading (grayscale or RGB).
|
||||
|
||||
Methods:
|
||||
__init__: Initialize the LoadImagesAndVideos object.
|
||||
__iter__: Returns an iterator object for VideoStream or ImageFolder.
|
||||
__next__: Returns the next batch of images or video frames along with their paths and metadata.
|
||||
_new_video: Creates a new video capture object for the given path.
|
||||
__len__: Returns the number of batches in the object.
|
||||
|
||||
Examples:
|
||||
>>> loader = LoadImagesAndVideos("path/to/data", batch=32, vid_stride=1)
|
||||
>>> for paths, imgs, info in loader:
|
||||
... # Process batch of images or video frames
|
||||
... pass
|
||||
|
||||
Notes:
|
||||
- Supports various image formats including HEIC.
|
||||
- Handles both local files and directories.
|
||||
- Can read from a text file containing paths to images and videos.
|
||||
"""
|
||||
|
||||
def __init__(self, path: str | Path | list, batch: int = 1, vid_stride: int = 1, channels: int = 3):
|
||||
"""
|
||||
Initialize dataloader for images and videos, supporting various input formats.
|
||||
|
||||
Args:
|
||||
path (str | Path | list): Path to images/videos, directory, or list of paths.
|
||||
batch (int): Batch size for processing.
|
||||
vid_stride (int): Video frame-rate stride.
|
||||
channels (int): Number of image channels (1 for grayscale, 3 for RGB).
|
||||
"""
|
||||
parent = None
|
||||
if isinstance(path, str) and Path(path).suffix in {".txt", ".csv"}: # txt/csv file with source paths
|
||||
parent, content = Path(path).parent, Path(path).read_text()
|
||||
path = content.splitlines() if Path(path).suffix == ".txt" else content.split(",") # list of sources
|
||||
path = [p.strip() for p in path]
|
||||
files = []
|
||||
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
|
||||
a = str(Path(p).absolute()) # do not use .resolve() https://github.com/ultralytics/ultralytics/issues/2912
|
||||
if "*" in a:
|
||||
files.extend(sorted(glob.glob(a, recursive=True))) # glob
|
||||
elif os.path.isdir(a):
|
||||
files.extend(sorted(glob.glob(os.path.join(a, "*.*")))) # dir
|
||||
elif os.path.isfile(a):
|
||||
files.append(a) # files (absolute or relative to CWD)
|
||||
elif parent and (parent / p).is_file():
|
||||
files.append(str((parent / p).absolute())) # files (relative to *.txt file parent)
|
||||
else:
|
||||
raise FileNotFoundError(f"{p} does not exist")
|
||||
|
||||
# Define files as images or videos
|
||||
images, videos = [], []
|
||||
for f in files:
|
||||
suffix = f.rpartition(".")[-1].lower() # Get file extension without the dot and lowercase
|
||||
if suffix in IMG_FORMATS:
|
||||
images.append(f)
|
||||
elif suffix in VID_FORMATS:
|
||||
videos.append(f)
|
||||
ni, nv = len(images), len(videos)
|
||||
|
||||
self.files = images + videos
|
||||
self.nf = ni + nv # number of files
|
||||
self.ni = ni # number of images
|
||||
self.video_flag = [False] * ni + [True] * nv
|
||||
self.mode = "video" if ni == 0 else "image" # default to video if no images
|
||||
self.vid_stride = vid_stride # video frame-rate stride
|
||||
self.bs = batch
|
||||
self.cv2_flag = cv2.IMREAD_GRAYSCALE if channels == 1 else cv2.IMREAD_COLOR # grayscale or RGB
|
||||
if any(videos):
|
||||
self._new_video(videos[0]) # new video
|
||||
else:
|
||||
self.cap = None
|
||||
if self.nf == 0:
|
||||
raise FileNotFoundError(f"No images or videos found in {p}. {FORMATS_HELP_MSG}")
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through image/video files, yielding source paths, images, and metadata."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> tuple[list[str], list[np.ndarray], list[str]]:
|
||||
"""Return the next batch of images or video frames with their paths and metadata."""
|
||||
paths, imgs, info = [], [], []
|
||||
while len(imgs) < self.bs:
|
||||
if self.count >= self.nf: # end of file list
|
||||
if imgs:
|
||||
return paths, imgs, info # return last partial batch
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
path = self.files[self.count]
|
||||
if self.video_flag[self.count]:
|
||||
self.mode = "video"
|
||||
if not self.cap or not self.cap.isOpened():
|
||||
self._new_video(path)
|
||||
|
||||
success = False
|
||||
for _ in range(self.vid_stride):
|
||||
success = self.cap.grab()
|
||||
if not success:
|
||||
break # end of video or failure
|
||||
|
||||
if success:
|
||||
success, im0 = self.cap.retrieve()
|
||||
im0 = (
|
||||
cv2.cvtColor(im0, cv2.COLOR_BGR2GRAY)[..., None]
|
||||
if self.cv2_flag == cv2.IMREAD_GRAYSCALE
|
||||
else im0
|
||||
)
|
||||
if success:
|
||||
self.frame += 1
|
||||
paths.append(path)
|
||||
imgs.append(im0)
|
||||
info.append(f"video {self.count + 1}/{self.nf} (frame {self.frame}/{self.frames}) {path}: ")
|
||||
if self.frame == self.frames: # end of video
|
||||
self.count += 1
|
||||
self.cap.release()
|
||||
else:
|
||||
# Move to the next file if the current video ended or failed to open
|
||||
self.count += 1
|
||||
if self.cap:
|
||||
self.cap.release()
|
||||
if self.count < self.nf:
|
||||
self._new_video(self.files[self.count])
|
||||
else:
|
||||
# Handle image files (including HEIC)
|
||||
self.mode = "image"
|
||||
if path.rpartition(".")[-1].lower() == "heic":
|
||||
# Load HEIC image using Pillow with pillow-heif
|
||||
check_requirements("pi-heif")
|
||||
|
||||
from pi_heif import register_heif_opener
|
||||
|
||||
register_heif_opener() # Register HEIF opener with Pillow
|
||||
with Image.open(path) as img:
|
||||
im0 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # convert image to BGR nparray
|
||||
else:
|
||||
im0 = imread(path, flags=self.cv2_flag) # BGR
|
||||
if im0 is None:
|
||||
LOGGER.warning(f"Image Read Error {path}")
|
||||
else:
|
||||
paths.append(path)
|
||||
imgs.append(im0)
|
||||
info.append(f"image {self.count + 1}/{self.nf} {path}: ")
|
||||
self.count += 1 # move to the next file
|
||||
if self.count >= self.ni: # end of image list
|
||||
break
|
||||
|
||||
return paths, imgs, info
|
||||
|
||||
def _new_video(self, path: str):
|
||||
"""Create a new video capture object for the given path and initialize video-related attributes."""
|
||||
self.frame = 0
|
||||
self.cap = cv2.VideoCapture(path)
|
||||
self.fps = int(self.cap.get(cv2.CAP_PROP_FPS))
|
||||
if not self.cap.isOpened():
|
||||
raise FileNotFoundError(f"Failed to open video {path}")
|
||||
self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of files (images and videos) in the dataset."""
|
||||
return math.ceil(self.nf / self.bs) # number of batches
|
||||
|
||||
|
||||
class LoadPilAndNumpy:
|
||||
"""
|
||||
Load images from PIL and Numpy arrays for batch processing.
|
||||
|
||||
This class manages loading and pre-processing of image data from both PIL and Numpy formats. It performs basic
|
||||
validation and format conversion to ensure that the images are in the required format for downstream processing.
|
||||
|
||||
Attributes:
|
||||
paths (list[str]): List of image paths or autogenerated filenames.
|
||||
im0 (list[np.ndarray]): List of images stored as Numpy arrays.
|
||||
mode (str): Type of data being processed, set to 'image'.
|
||||
bs (int): Batch size, equivalent to the length of `im0`.
|
||||
|
||||
Methods:
|
||||
_single_check: Validate and format a single image to a Numpy array.
|
||||
|
||||
Examples:
|
||||
>>> from PIL import Image
|
||||
>>> import numpy as np
|
||||
>>> pil_img = Image.new("RGB", (100, 100))
|
||||
>>> np_img = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
>>> loader = LoadPilAndNumpy([pil_img, np_img])
|
||||
>>> paths, images, _ = next(iter(loader))
|
||||
>>> print(f"Loaded {len(images)} images")
|
||||
Loaded 2 images
|
||||
"""
|
||||
|
||||
def __init__(self, im0: Image.Image | np.ndarray | list, channels: int = 3):
|
||||
"""
|
||||
Initialize a loader for PIL and Numpy images, converting inputs to a standardized format.
|
||||
|
||||
Args:
|
||||
im0 (PIL.Image.Image | np.ndarray | list): Single image or list of images in PIL or numpy format.
|
||||
channels (int): Number of image channels (1 for grayscale, 3 for RGB).
|
||||
"""
|
||||
if not isinstance(im0, list):
|
||||
im0 = [im0]
|
||||
# use `image{i}.jpg` when Image.filename returns an empty path.
|
||||
self.paths = [getattr(im, "filename", "") or f"image{i}.jpg" for i, im in enumerate(im0)]
|
||||
pil_flag = "L" if channels == 1 else "RGB" # grayscale or RGB
|
||||
self.im0 = [self._single_check(im, pil_flag) for im in im0]
|
||||
self.mode = "image"
|
||||
self.bs = len(self.im0)
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im: Image.Image | np.ndarray, flag: str = "RGB") -> np.ndarray:
|
||||
"""Validate and format an image to numpy array, ensuring RGB order and contiguous memory."""
|
||||
assert isinstance(im, (Image.Image, np.ndarray)), f"Expected PIL/np.ndarray image type, but got {type(im)}"
|
||||
if isinstance(im, Image.Image):
|
||||
im = np.asarray(im.convert(flag))
|
||||
# adding new axis if it's grayscale, and converting to BGR if it's RGB
|
||||
im = im[..., None] if flag == "L" else im[..., ::-1]
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
elif im.ndim == 2: # grayscale in numpy form
|
||||
im = im[..., None]
|
||||
return im
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the length of the 'im0' attribute, representing the number of loaded images."""
|
||||
return len(self.im0)
|
||||
|
||||
def __next__(self) -> tuple[list[str], list[np.ndarray], list[str]]:
|
||||
"""Return the next batch of images, paths, and metadata for processing."""
|
||||
if self.count == 1: # loop only once as it's batch inference
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, [""] * self.bs
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate through PIL/numpy images, yielding paths, raw images, and metadata for processing."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
|
||||
class LoadTensor:
|
||||
"""
|
||||
A class for loading and processing tensor data for object detection tasks.
|
||||
|
||||
This class handles the loading and pre-processing of image data from PyTorch tensors, preparing them for
|
||||
further processing in object detection pipelines.
|
||||
|
||||
Attributes:
|
||||
im0 (torch.Tensor): The input tensor containing the image(s) with shape (B, C, H, W).
|
||||
bs (int): Batch size, inferred from the shape of `im0`.
|
||||
mode (str): Current processing mode, set to 'image'.
|
||||
paths (list[str]): List of image paths or auto-generated filenames.
|
||||
|
||||
Methods:
|
||||
_single_check: Validates and formats an input tensor.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> tensor = torch.rand(1, 3, 640, 640)
|
||||
>>> loader = LoadTensor(tensor)
|
||||
>>> paths, images, info = next(iter(loader))
|
||||
>>> print(f"Processed {len(images)} images")
|
||||
"""
|
||||
|
||||
def __init__(self, im0: torch.Tensor) -> None:
|
||||
"""
|
||||
Initialize LoadTensor object for processing torch.Tensor image data.
|
||||
|
||||
Args:
|
||||
im0 (torch.Tensor): Input tensor with shape (B, C, H, W).
|
||||
"""
|
||||
self.im0 = self._single_check(im0)
|
||||
self.bs = self.im0.shape[0]
|
||||
self.mode = "image"
|
||||
self.paths = [getattr(im, "filename", f"image{i}.jpg") for i, im in enumerate(im0)]
|
||||
|
||||
@staticmethod
|
||||
def _single_check(im: torch.Tensor, stride: int = 32) -> torch.Tensor:
|
||||
"""Validate and format a single image tensor, ensuring correct shape and normalization."""
|
||||
s = (
|
||||
f"torch.Tensor inputs should be BCHW i.e. shape(1, 3, 640, 640) "
|
||||
f"divisible by stride {stride}. Input shape{tuple(im.shape)} is incompatible."
|
||||
)
|
||||
if len(im.shape) != 4:
|
||||
if len(im.shape) != 3:
|
||||
raise ValueError(s)
|
||||
LOGGER.warning(s)
|
||||
im = im.unsqueeze(0)
|
||||
if im.shape[2] % stride or im.shape[3] % stride:
|
||||
raise ValueError(s)
|
||||
if im.max() > 1.0 + torch.finfo(im.dtype).eps: # torch.float32 eps is 1.2e-07
|
||||
LOGGER.warning(
|
||||
f"torch.Tensor inputs should be normalized 0.0-1.0 but max value is {im.max()}. Dividing input by 255."
|
||||
)
|
||||
im = im.float() / 255.0
|
||||
|
||||
return im
|
||||
|
||||
def __iter__(self):
|
||||
"""Yield an iterator object for iterating through tensor image data."""
|
||||
self.count = 0
|
||||
return self
|
||||
|
||||
def __next__(self) -> tuple[list[str], torch.Tensor, list[str]]:
|
||||
"""Yield the next batch of tensor images and metadata for processing."""
|
||||
if self.count == 1:
|
||||
raise StopIteration
|
||||
self.count += 1
|
||||
return self.paths, self.im0, [""] * self.bs
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the batch size of the tensor input."""
|
||||
return self.bs
|
||||
|
||||
|
||||
def autocast_list(source: list[Any]) -> list[Image.Image | np.ndarray]:
|
||||
"""Merge a list of sources into a list of numpy arrays or PIL images for Ultralytics prediction."""
|
||||
files = []
|
||||
for im in source:
|
||||
if isinstance(im, (str, Path)): # filename or uri
|
||||
files.append(Image.open(urllib.request.urlopen(im) if str(im).startswith("http") else im))
|
||||
elif isinstance(im, (Image.Image, np.ndarray)): # PIL or np Image
|
||||
files.append(im)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"type {type(im).__name__} is not a supported Ultralytics prediction source type. \n"
|
||||
f"See https://docs.ultralytics.com/modes/predict for supported source types."
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def get_best_youtube_url(url: str, method: str = "pytube") -> str | None:
|
||||
"""
|
||||
Retrieve the URL of the best quality MP4 video stream from a given YouTube video.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the YouTube video.
|
||||
method (str): The method to use for extracting video info. Options are "pytube", "pafy", and "yt-dlp".
|
||||
|
||||
Returns:
|
||||
(str | None): The URL of the best quality MP4 video stream, or None if no suitable stream is found.
|
||||
|
||||
Examples:
|
||||
>>> url = "https://www.youtube.com/watch?v=dQw4w9WgXcQ"
|
||||
>>> best_url = get_best_youtube_url(url)
|
||||
>>> print(best_url)
|
||||
https://rr4---sn-q4flrnek.googlevideo.com/videoplayback?expire=...
|
||||
|
||||
Notes:
|
||||
- Requires additional libraries based on the chosen method: pytubefix, pafy, or yt-dlp.
|
||||
- The function prioritizes streams with at least 1080p resolution when available.
|
||||
- For the "yt-dlp" method, it looks for formats with video codec, no audio, and *.mp4 extension.
|
||||
"""
|
||||
if method == "pytube":
|
||||
# Switched from pytube to pytubefix to resolve https://github.com/pytube/pytube/issues/1954
|
||||
check_requirements("pytubefix>=6.5.2")
|
||||
from pytubefix import YouTube
|
||||
|
||||
streams = YouTube(url).streams.filter(file_extension="mp4", only_video=True)
|
||||
streams = sorted(streams, key=lambda s: s.resolution, reverse=True) # sort streams by resolution
|
||||
for stream in streams:
|
||||
if stream.resolution and int(stream.resolution[:-1]) >= 1080: # check if resolution is at least 1080p
|
||||
return stream.url
|
||||
|
||||
elif method == "pafy":
|
||||
check_requirements(("pafy", "youtube_dl==2020.12.2"))
|
||||
import pafy # noqa
|
||||
|
||||
return pafy.new(url).getbestvideo(preftype="mp4").url
|
||||
|
||||
elif method == "yt-dlp":
|
||||
check_requirements("yt-dlp")
|
||||
import yt_dlp
|
||||
|
||||
with yt_dlp.YoutubeDL({"quiet": True}) as ydl:
|
||||
info_dict = ydl.extract_info(url, download=False) # extract info
|
||||
for f in reversed(info_dict.get("formats", [])): # reversed because best is usually last
|
||||
# Find a format with video codec, no audio, *.mp4 extension at least 1920x1080 size
|
||||
good_size = (f.get("width") or 0) >= 1920 or (f.get("height") or 0) >= 1080
|
||||
if good_size and f["vcodec"] != "none" and f["acodec"] == "none" and f["ext"] == "mp4":
|
||||
return f.get("url")
|
||||
|
||||
|
||||
# Define constants
|
||||
LOADERS = (LoadStreams, LoadPilAndNumpy, LoadImagesAndVideos, LoadScreenshots)
|
||||
18
ultralytics/data/scripts/download_weights.sh
Executable file
18
ultralytics/data/scripts/download_weights.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Download latest models from https://github.com/ultralytics/assets/releases
|
||||
# Example usage: bash ultralytics/data/scripts/download_weights.sh
|
||||
# parent
|
||||
# └── weights
|
||||
# ├── yolov8n.pt ← downloads here
|
||||
# ├── yolov8s.pt
|
||||
# └── ...
|
||||
|
||||
python << EOF
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
assets = [f"yolov8{size}{suffix}.pt" for size in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose")]
|
||||
for x in assets:
|
||||
attempt_download_asset(f"weights/{x}")
|
||||
EOF
|
||||
61
ultralytics/data/scripts/get_coco.sh
Executable file
61
ultralytics/data/scripts/get_coco.sh
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Download COCO 2017 dataset https://cocodataset.org
|
||||
# Example usage: bash data/scripts/get_coco.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── coco ← downloads here
|
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_coco.sh --train --val --test --segments
|
||||
if [ "$#" -gt 0 ]; then
|
||||
for opt in "$@"; do
|
||||
case "${opt}" in
|
||||
--train) train=true ;;
|
||||
--val) val=true ;;
|
||||
--test) test=true ;;
|
||||
--segments) segments=true ;;
|
||||
--sama) sama=true ;;
|
||||
esac
|
||||
done
|
||||
else
|
||||
train=true
|
||||
val=true
|
||||
test=false
|
||||
segments=false
|
||||
sama=false
|
||||
fi
|
||||
|
||||
# Download/unzip labels
|
||||
d='../datasets' # unzip directory
|
||||
url=https://github.com/ultralytics/assets/releases/download/v0.0.0/
|
||||
if [ "$segments" == "true" ]; then
|
||||
f='coco2017labels-segments.zip' # 169 MB
|
||||
elif [ "$sama" == "true" ]; then
|
||||
f='coco2017labels-segments-sama.zip' # 199 MB https://www.sama.com/sama-coco-dataset/
|
||||
else
|
||||
f='coco2017labels.zip' # 46 MB
|
||||
fi
|
||||
echo 'Downloading' $url$f ' ...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
|
||||
# Download/unzip images
|
||||
d='../datasets/coco/images' # unzip directory
|
||||
url=http://images.cocodataset.org/zips/
|
||||
if [ "$train" == "true" ]; then
|
||||
f='train2017.zip' # 19G, 118k images
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
if [ "$val" == "true" ]; then
|
||||
f='val2017.zip' # 1G, 5k images
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
if [ "$test" == "true" ]; then
|
||||
f='test2017.zip' # 7G, 41k images (optional)
|
||||
echo 'Downloading' $url$f '...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
fi
|
||||
wait # finish background tasks
|
||||
18
ultralytics/data/scripts/get_coco128.sh
Executable file
18
ultralytics/data/scripts/get_coco128.sh
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Download COCO128 dataset https://www.kaggle.com/ultralytics/coco128 (first 128 images from COCO train2017)
|
||||
# Example usage: bash data/scripts/get_coco128.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── coco128 ← downloads here
|
||||
|
||||
# Download/unzip images and labels
|
||||
d='../datasets' # unzip directory
|
||||
url=https://github.com/ultralytics/assets/releases/download/v0.0.0/
|
||||
f='coco128.zip' # or 'coco128-segments.zip', 68 MB
|
||||
echo 'Downloading' $url$f ' ...'
|
||||
curl -L $url$f -o $f -# && unzip -q $f -d $d && rm $f &
|
||||
|
||||
wait # finish background tasks
|
||||
52
ultralytics/data/scripts/get_imagenet.sh
Executable file
52
ultralytics/data/scripts/get_imagenet.sh
Executable file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Download ILSVRC2012 ImageNet dataset https://image-net.org
|
||||
# Example usage: bash data/scripts/get_imagenet.sh
|
||||
# parent
|
||||
# ├── ultralytics
|
||||
# └── datasets
|
||||
# └── imagenet ← downloads here
|
||||
|
||||
# Arguments (optional) Usage: bash data/scripts/get_imagenet.sh --train --val
|
||||
if [ "$#" -gt 0 ]; then
|
||||
for opt in "$@"; do
|
||||
case "${opt}" in
|
||||
--train) train=true ;;
|
||||
--val) val=true ;;
|
||||
esac
|
||||
done
|
||||
else
|
||||
train=true
|
||||
val=true
|
||||
fi
|
||||
|
||||
# Make dir
|
||||
d='../datasets/imagenet' # unzip directory
|
||||
mkdir -p $d && cd $d
|
||||
|
||||
# Download/unzip train
|
||||
if [ "$train" == "true" ]; then
|
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar # download 138G, 1281167 images
|
||||
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
||||
tar -xf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
||||
find . -name "*.tar" | while read NAME; do
|
||||
mkdir -p "${NAME%.tar}"
|
||||
tar -xf "${NAME}" -C "${NAME%.tar}"
|
||||
rm -f "${NAME}"
|
||||
done
|
||||
cd ..
|
||||
fi
|
||||
|
||||
# Download/unzip val
|
||||
if [ "$val" == "true" ]; then
|
||||
wget https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar # download 6.3G, 50000 images
|
||||
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xf ILSVRC2012_img_val.tar
|
||||
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash # move into subdirs
|
||||
fi
|
||||
|
||||
# Delete corrupted image (optional: PNG under JPEG name that may cause dataloaders to fail)
|
||||
# rm train/n04266014/n04266014_10835.JPEG
|
||||
|
||||
# TFRecords (optional)
|
||||
# wget https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_lsvrc_2015_synsets.txt
|
||||
139
ultralytics/data/split.py
Normal file
139
ultralytics/data/split.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.data.utils import IMG_FORMATS, img2label_paths
|
||||
from ultralytics.utils import DATASETS_DIR, LOGGER, TQDM
|
||||
|
||||
|
||||
def split_classify_dataset(source_dir: str | Path, train_ratio: float = 0.8) -> Path:
|
||||
"""
|
||||
Split classification dataset into train and val directories in a new directory.
|
||||
|
||||
Creates a new directory '{source_dir}_split' with train/val subdirectories, preserving the original class
|
||||
structure with an 80/20 split by default.
|
||||
|
||||
Directory structure:
|
||||
Before:
|
||||
caltech/
|
||||
├── class1/
|
||||
│ ├── img1.jpg
|
||||
│ ├── img2.jpg
|
||||
│ └── ...
|
||||
├── class2/
|
||||
│ ├── img1.jpg
|
||||
│ └── ...
|
||||
└── ...
|
||||
|
||||
After:
|
||||
caltech_split/
|
||||
├── train/
|
||||
│ ├── class1/
|
||||
│ │ ├── img1.jpg
|
||||
│ │ └── ...
|
||||
│ ├── class2/
|
||||
│ │ ├── img1.jpg
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
└── val/
|
||||
├── class1/
|
||||
│ ├── img2.jpg
|
||||
│ └── ...
|
||||
├── class2/
|
||||
│ └── ...
|
||||
└── ...
|
||||
|
||||
Args:
|
||||
source_dir (str | Path): Path to classification dataset root directory.
|
||||
train_ratio (float): Ratio for train split, between 0 and 1.
|
||||
|
||||
Returns:
|
||||
(Path): Path to the created split directory.
|
||||
|
||||
Examples:
|
||||
Split dataset with default 80/20 ratio
|
||||
>>> split_classify_dataset("path/to/caltech")
|
||||
|
||||
Split with custom ratio
|
||||
>>> split_classify_dataset("path/to/caltech", 0.75)
|
||||
"""
|
||||
source_path = Path(source_dir)
|
||||
split_path = Path(f"{source_path}_split")
|
||||
train_path, val_path = split_path / "train", split_path / "val"
|
||||
|
||||
# Create directory structure
|
||||
split_path.mkdir(exist_ok=True)
|
||||
train_path.mkdir(exist_ok=True)
|
||||
val_path.mkdir(exist_ok=True)
|
||||
|
||||
# Process class directories
|
||||
class_dirs = [d for d in source_path.iterdir() if d.is_dir()]
|
||||
total_images = sum(len(list(d.glob("*.*"))) for d in class_dirs)
|
||||
stats = f"{len(class_dirs)} classes, {total_images} images"
|
||||
LOGGER.info(f"Splitting {source_path} ({stats}) into {train_ratio:.0%} train, {1 - train_ratio:.0%} val...")
|
||||
|
||||
for class_dir in class_dirs:
|
||||
# Create class directories
|
||||
(train_path / class_dir.name).mkdir(exist_ok=True)
|
||||
(val_path / class_dir.name).mkdir(exist_ok=True)
|
||||
|
||||
# Split and copy files
|
||||
image_files = list(class_dir.glob("*.*"))
|
||||
random.shuffle(image_files)
|
||||
split_idx = int(len(image_files) * train_ratio)
|
||||
|
||||
for img in image_files[:split_idx]:
|
||||
shutil.copy2(img, train_path / class_dir.name / img.name)
|
||||
|
||||
for img in image_files[split_idx:]:
|
||||
shutil.copy2(img, val_path / class_dir.name / img.name)
|
||||
|
||||
LOGGER.info(f"Split complete in {split_path} ✅")
|
||||
return split_path
|
||||
|
||||
|
||||
def autosplit(
|
||||
path: Path = DATASETS_DIR / "coco8/images",
|
||||
weights: tuple[float, float, float] = (0.9, 0.1, 0.0),
|
||||
annotated_only: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Automatically split a dataset into train/val/test splits and save the resulting splits into autosplit_*.txt files.
|
||||
|
||||
Args:
|
||||
path (Path): Path to images directory.
|
||||
weights (tuple): Train, validation, and test split fractions.
|
||||
annotated_only (bool): If True, only images with an associated txt file are used.
|
||||
|
||||
Examples:
|
||||
Split images with default weights
|
||||
>>> from ultralytics.data.split import autosplit
|
||||
>>> autosplit()
|
||||
|
||||
Split with custom weights and annotated images only
|
||||
>>> autosplit(path="path/to/images", weights=(0.8, 0.15, 0.05), annotated_only=True)
|
||||
"""
|
||||
path = Path(path) # images dir
|
||||
files = sorted(x for x in path.rglob("*.*") if x.suffix[1:].lower() in IMG_FORMATS) # image files only
|
||||
n = len(files) # number of files
|
||||
random.seed(0) # for reproducibility
|
||||
indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
|
||||
|
||||
txt = ["autosplit_train.txt", "autosplit_val.txt", "autosplit_test.txt"] # 3 txt files
|
||||
for x in txt:
|
||||
if (path.parent / x).exists():
|
||||
(path.parent / x).unlink() # remove existing
|
||||
|
||||
LOGGER.info(f"Autosplitting images from {path}" + ", using *.txt labeled images only" * annotated_only)
|
||||
for i, img in TQDM(zip(indices, files), total=n):
|
||||
if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
|
||||
with open(path.parent / txt[i], "a", encoding="utf-8") as f:
|
||||
f.write(f"./{img.relative_to(path.parent).as_posix()}" + "\n") # add image to txt file
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
split_classify_dataset("caltech101")
|
||||
351
ultralytics/data/split_dota.py
Normal file
351
ultralytics/data/split_dota.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from glob import glob
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.utils import exif_size, img2label_paths
|
||||
from ultralytics.utils import TQDM
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
||||
def bbox_iof(polygon1: np.ndarray, bbox2: np.ndarray, eps: float = 1e-6) -> np.ndarray:
|
||||
"""
|
||||
Calculate Intersection over Foreground (IoF) between polygons and bounding boxes.
|
||||
|
||||
Args:
|
||||
polygon1 (np.ndarray): Polygon coordinates with shape (N, 8).
|
||||
bbox2 (np.ndarray): Bounding boxes with shape (N, 4).
|
||||
eps (float, optional): Small value to prevent division by zero.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): IoF scores with shape (N, 1) or (N, M) if bbox2 is (M, 4).
|
||||
|
||||
Notes:
|
||||
Polygon format: [x1, y1, x2, y2, x3, y3, x4, y4].
|
||||
Bounding box format: [x_min, y_min, x_max, y_max].
|
||||
"""
|
||||
check_requirements("shapely>=2.0.0")
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
polygon1 = polygon1.reshape(-1, 4, 2)
|
||||
lt_point = np.min(polygon1, axis=-2) # left-top
|
||||
rb_point = np.max(polygon1, axis=-2) # right-bottom
|
||||
bbox1 = np.concatenate([lt_point, rb_point], axis=-1)
|
||||
|
||||
lt = np.maximum(bbox1[:, None, :2], bbox2[..., :2])
|
||||
rb = np.minimum(bbox1[:, None, 2:], bbox2[..., 2:])
|
||||
wh = np.clip(rb - lt, 0, np.inf)
|
||||
h_overlaps = wh[..., 0] * wh[..., 1]
|
||||
|
||||
left, top, right, bottom = (bbox2[..., i] for i in range(4))
|
||||
polygon2 = np.stack([left, top, right, top, right, bottom, left, bottom], axis=-1).reshape(-1, 4, 2)
|
||||
|
||||
sg_polys1 = [Polygon(p) for p in polygon1]
|
||||
sg_polys2 = [Polygon(p) for p in polygon2]
|
||||
overlaps = np.zeros(h_overlaps.shape)
|
||||
for p in zip(*np.nonzero(h_overlaps)):
|
||||
overlaps[p] = sg_polys1[p[0]].intersection(sg_polys2[p[-1]]).area
|
||||
unions = np.array([p.area for p in sg_polys1], dtype=np.float32)
|
||||
unions = unions[..., None]
|
||||
|
||||
unions = np.clip(unions, eps, np.inf)
|
||||
outputs = overlaps / unions
|
||||
if outputs.ndim == 1:
|
||||
outputs = outputs[..., None]
|
||||
return outputs
|
||||
|
||||
|
||||
def load_yolo_dota(data_root: str, split: str = "train") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load DOTA dataset annotations and image information.
|
||||
|
||||
Args:
|
||||
data_root (str): Data root directory.
|
||||
split (str, optional): The split data set, could be 'train' or 'val'.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, Any]]): List of annotation dictionaries containing image information.
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- data_root
|
||||
- images
|
||||
- train
|
||||
- val
|
||||
- labels
|
||||
- train
|
||||
- val
|
||||
"""
|
||||
assert split in {"train", "val"}, f"Split must be 'train' or 'val', not {split}."
|
||||
im_dir = Path(data_root) / "images" / split
|
||||
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
||||
im_files = glob(str(Path(data_root) / "images" / split / "*"))
|
||||
lb_files = img2label_paths(im_files)
|
||||
annos = []
|
||||
for im_file, lb_file in zip(im_files, lb_files):
|
||||
w, h = exif_size(Image.open(im_file))
|
||||
with open(lb_file, encoding="utf-8") as f:
|
||||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
lb = np.array(lb, dtype=np.float32)
|
||||
annos.append(dict(ori_size=(h, w), label=lb, filepath=im_file))
|
||||
return annos
|
||||
|
||||
|
||||
def get_windows(
|
||||
im_size: tuple[int, int],
|
||||
crop_sizes: tuple[int, ...] = (1024,),
|
||||
gaps: tuple[int, ...] = (200,),
|
||||
im_rate_thr: float = 0.6,
|
||||
eps: float = 0.01,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Get the coordinates of sliding windows for image cropping.
|
||||
|
||||
Args:
|
||||
im_size (tuple[int, int]): Original image size, (H, W).
|
||||
crop_sizes (tuple[int, ...], optional): Crop size of windows.
|
||||
gaps (tuple[int, ...], optional): Gap between crops.
|
||||
im_rate_thr (float, optional): Threshold of windows areas divided by image areas.
|
||||
eps (float, optional): Epsilon value for math operations.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Array of window coordinates with shape (N, 4) where each row is [x_start, y_start, x_stop, y_stop].
|
||||
"""
|
||||
h, w = im_size
|
||||
windows = []
|
||||
for crop_size, gap in zip(crop_sizes, gaps):
|
||||
assert crop_size > gap, f"invalid crop_size gap pair [{crop_size} {gap}]"
|
||||
step = crop_size - gap
|
||||
|
||||
xn = 1 if w <= crop_size else ceil((w - crop_size) / step + 1)
|
||||
xs = [step * i for i in range(xn)]
|
||||
if len(xs) > 1 and xs[-1] + crop_size > w:
|
||||
xs[-1] = w - crop_size
|
||||
|
||||
yn = 1 if h <= crop_size else ceil((h - crop_size) / step + 1)
|
||||
ys = [step * i for i in range(yn)]
|
||||
if len(ys) > 1 and ys[-1] + crop_size > h:
|
||||
ys[-1] = h - crop_size
|
||||
|
||||
start = np.array(list(itertools.product(xs, ys)), dtype=np.int64)
|
||||
stop = start + crop_size
|
||||
windows.append(np.concatenate([start, stop], axis=1))
|
||||
windows = np.concatenate(windows, axis=0)
|
||||
|
||||
im_in_wins = windows.copy()
|
||||
im_in_wins[:, 0::2] = np.clip(im_in_wins[:, 0::2], 0, w)
|
||||
im_in_wins[:, 1::2] = np.clip(im_in_wins[:, 1::2], 0, h)
|
||||
im_areas = (im_in_wins[:, 2] - im_in_wins[:, 0]) * (im_in_wins[:, 3] - im_in_wins[:, 1])
|
||||
win_areas = (windows[:, 2] - windows[:, 0]) * (windows[:, 3] - windows[:, 1])
|
||||
im_rates = im_areas / win_areas
|
||||
if not (im_rates > im_rate_thr).any():
|
||||
max_rate = im_rates.max()
|
||||
im_rates[abs(im_rates - max_rate) < eps] = 1
|
||||
return windows[im_rates > im_rate_thr]
|
||||
|
||||
|
||||
def get_window_obj(anno: dict[str, Any], windows: np.ndarray, iof_thr: float = 0.7) -> list[np.ndarray]:
|
||||
"""Get objects for each window based on IoF threshold."""
|
||||
h, w = anno["ori_size"]
|
||||
label = anno["label"]
|
||||
if len(label):
|
||||
label[:, 1::2] *= w
|
||||
label[:, 2::2] *= h
|
||||
iofs = bbox_iof(label[:, 1:], windows)
|
||||
# Unnormalized and misaligned coordinates
|
||||
return [(label[iofs[:, i] >= iof_thr]) for i in range(len(windows))] # window_anns
|
||||
else:
|
||||
return [np.zeros((0, 9), dtype=np.float32) for _ in range(len(windows))] # window_anns
|
||||
|
||||
|
||||
def crop_and_save(
|
||||
anno: dict[str, Any],
|
||||
windows: np.ndarray,
|
||||
window_objs: list[np.ndarray],
|
||||
im_dir: str,
|
||||
lb_dir: str,
|
||||
allow_background_images: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Crop images and save new labels for each window.
|
||||
|
||||
Args:
|
||||
anno (dict[str, Any]): Annotation dict, including 'filepath', 'label', 'ori_size' as its keys.
|
||||
windows (np.ndarray): Array of windows coordinates with shape (N, 4).
|
||||
window_objs (list[np.ndarray]): A list of labels inside each window.
|
||||
im_dir (str): The output directory path of images.
|
||||
lb_dir (str): The output directory path of labels.
|
||||
allow_background_images (bool, optional): Whether to include background images without labels.
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- data_root
|
||||
- images
|
||||
- train
|
||||
- val
|
||||
- labels
|
||||
- train
|
||||
- val
|
||||
"""
|
||||
im = cv2.imread(anno["filepath"])
|
||||
name = Path(anno["filepath"]).stem
|
||||
for i, window in enumerate(windows):
|
||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
|
||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||
ph, pw = patch_im.shape[:2]
|
||||
|
||||
label = window_objs[i]
|
||||
if len(label) or allow_background_images:
|
||||
cv2.imwrite(str(Path(im_dir) / f"{new_name}.jpg"), patch_im)
|
||||
if len(label):
|
||||
label[:, 1::2] -= x_start
|
||||
label[:, 2::2] -= y_start
|
||||
label[:, 1::2] /= pw
|
||||
label[:, 2::2] /= ph
|
||||
|
||||
with open(Path(lb_dir) / f"{new_name}.txt", "w", encoding="utf-8") as f:
|
||||
for lb in label:
|
||||
formatted_coords = [f"{coord:.6g}" for coord in lb[1:]]
|
||||
f.write(f"{int(lb[0])} {' '.join(formatted_coords)}\n")
|
||||
|
||||
|
||||
def split_images_and_labels(
|
||||
data_root: str,
|
||||
save_dir: str,
|
||||
split: str = "train",
|
||||
crop_sizes: tuple[int, ...] = (1024,),
|
||||
gaps: tuple[int, ...] = (200,),
|
||||
) -> None:
|
||||
"""
|
||||
Split both images and labels for a given dataset split.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
save_dir (str): Directory to save the split dataset.
|
||||
split (str, optional): The split data set, could be 'train' or 'val'.
|
||||
crop_sizes (tuple[int, ...], optional): Tuple of crop sizes.
|
||||
gaps (tuple[int, ...], optional): Tuple of gaps between crops.
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- data_root
|
||||
- images
|
||||
- split
|
||||
- labels
|
||||
- split
|
||||
and the output directory structure is:
|
||||
- save_dir
|
||||
- images
|
||||
- split
|
||||
- labels
|
||||
- split
|
||||
"""
|
||||
im_dir = Path(save_dir) / "images" / split
|
||||
im_dir.mkdir(parents=True, exist_ok=True)
|
||||
lb_dir = Path(save_dir) / "labels" / split
|
||||
lb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
annos = load_yolo_dota(data_root, split=split)
|
||||
for anno in TQDM(annos, total=len(annos), desc=split):
|
||||
windows = get_windows(anno["ori_size"], crop_sizes, gaps)
|
||||
window_objs = get_window_obj(anno, windows)
|
||||
crop_and_save(anno, windows, window_objs, str(im_dir), str(lb_dir))
|
||||
|
||||
|
||||
def split_trainval(
|
||||
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
|
||||
) -> None:
|
||||
"""
|
||||
Split train and val sets of DOTA dataset with multiple scaling rates.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
save_dir (str): Directory to save the split dataset.
|
||||
crop_size (int, optional): Base crop size.
|
||||
gap (int, optional): Base gap between crops.
|
||||
rates (tuple[float, ...], optional): Scaling rates for crop_size and gap.
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- data_root
|
||||
- images
|
||||
- train
|
||||
- val
|
||||
- labels
|
||||
- train
|
||||
- val
|
||||
and the output directory structure is:
|
||||
- save_dir
|
||||
- images
|
||||
- train
|
||||
- val
|
||||
- labels
|
||||
- train
|
||||
- val
|
||||
"""
|
||||
crop_sizes, gaps = [], []
|
||||
for r in rates:
|
||||
crop_sizes.append(int(crop_size / r))
|
||||
gaps.append(int(gap / r))
|
||||
for split in {"train", "val"}:
|
||||
split_images_and_labels(data_root, save_dir, split, crop_sizes, gaps)
|
||||
|
||||
|
||||
def split_test(
|
||||
data_root: str, save_dir: str, crop_size: int = 1024, gap: int = 200, rates: tuple[float, ...] = (1.0,)
|
||||
) -> None:
|
||||
"""
|
||||
Split test set of DOTA dataset, labels are not included within this set.
|
||||
|
||||
Args:
|
||||
data_root (str): Root directory of the dataset.
|
||||
save_dir (str): Directory to save the split dataset.
|
||||
crop_size (int, optional): Base crop size.
|
||||
gap (int, optional): Base gap between crops.
|
||||
rates (tuple[float, ...], optional): Scaling rates for crop_size and gap.
|
||||
|
||||
Notes:
|
||||
The directory structure assumed for the DOTA dataset:
|
||||
- data_root
|
||||
- images
|
||||
- test
|
||||
and the output directory structure is:
|
||||
- save_dir
|
||||
- images
|
||||
- test
|
||||
"""
|
||||
crop_sizes, gaps = [], []
|
||||
for r in rates:
|
||||
crop_sizes.append(int(crop_size / r))
|
||||
gaps.append(int(gap / r))
|
||||
save_dir = Path(save_dir) / "images" / "test"
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
im_dir = Path(data_root) / "images" / "test"
|
||||
assert im_dir.exists(), f"Can't find {im_dir}, please check your data root."
|
||||
im_files = glob(str(im_dir / "*"))
|
||||
for im_file in TQDM(im_files, total=len(im_files), desc="test"):
|
||||
w, h = exif_size(Image.open(im_file))
|
||||
windows = get_windows((h, w), crop_sizes=crop_sizes, gaps=gaps)
|
||||
im = cv2.imread(im_file)
|
||||
name = Path(im_file).stem
|
||||
for window in windows:
|
||||
x_start, y_start, x_stop, y_stop = window.tolist()
|
||||
new_name = f"{name}__{x_stop - x_start}__{x_start}___{y_start}"
|
||||
patch_im = im[y_start:y_stop, x_start:x_stop]
|
||||
cv2.imwrite(str(save_dir / f"{new_name}.jpg"), patch_im)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
split_trainval(data_root="DOTAv2", save_dir="DOTAv2-split")
|
||||
split_test(data_root="DOTAv2", save_dir="DOTAv2-split")
|
||||
807
ultralytics/data/utils.py
Normal file
807
ultralytics/data/utils.py
Normal file
@@ -0,0 +1,807 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
import zipfile
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from tarfile import is_tarfile
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
from ultralytics.nn.autobackend import check_class_names
|
||||
from ultralytics.utils import (
|
||||
DATASETS_DIR,
|
||||
LOGGER,
|
||||
NUM_THREADS,
|
||||
ROOT,
|
||||
SETTINGS_FILE,
|
||||
TQDM,
|
||||
YAML,
|
||||
clean_url,
|
||||
colorstr,
|
||||
emojis,
|
||||
is_dir_writeable,
|
||||
)
|
||||
from ultralytics.utils.checks import check_file, check_font, is_ascii
|
||||
from ultralytics.utils.downloads import download, safe_download, unzip_file
|
||||
from ultralytics.utils.ops import segments2boxes
|
||||
|
||||
HELP_URL = "See https://docs.ultralytics.com/datasets for dataset formatting guidance."
|
||||
IMG_FORMATS = {"bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "pfm", "heic"} # image suffixes
|
||||
VID_FORMATS = {"asf", "avi", "gif", "m4v", "mkv", "mov", "mp4", "mpeg", "mpg", "ts", "wmv", "webm"} # video suffixes
|
||||
FORMATS_HELP_MSG = f"Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
|
||||
|
||||
|
||||
def img2label_paths(img_paths: list[str]) -> list[str]:
|
||||
"""Convert image paths to label paths by replacing 'images' with 'labels' and extension with '.txt'."""
|
||||
sa, sb = f"{os.sep}images{os.sep}", f"{os.sep}labels{os.sep}" # /images/, /labels/ substrings
|
||||
return [sb.join(x.rsplit(sa, 1)).rsplit(".", 1)[0] + ".txt" for x in img_paths]
|
||||
|
||||
|
||||
def check_file_speeds(
|
||||
files: list[str], threshold_ms: float = 10, threshold_mb: float = 50, max_files: int = 5, prefix: str = ""
|
||||
):
|
||||
"""
|
||||
Check dataset file access speed and provide performance feedback.
|
||||
|
||||
This function tests the access speed of dataset files by measuring ping (stat call) time and read speed.
|
||||
It samples up to 5 files from the provided list and warns if access times exceed the threshold.
|
||||
|
||||
Args:
|
||||
files (list[str]): List of file paths to check for access speed.
|
||||
threshold_ms (float, optional): Threshold in milliseconds for ping time warnings.
|
||||
threshold_mb (float, optional): Threshold in megabytes per second for read speed warnings.
|
||||
max_files (int, optional): The maximum number of files to check.
|
||||
prefix (str, optional): Prefix string to add to log messages.
|
||||
|
||||
Examples:
|
||||
>>> from pathlib import Path
|
||||
>>> image_files = list(Path("dataset/images").glob("*.jpg"))
|
||||
>>> check_file_speeds(image_files, threshold_ms=15)
|
||||
"""
|
||||
if not files:
|
||||
LOGGER.warning(f"{prefix}Image speed checks: No files to check")
|
||||
return
|
||||
|
||||
# Sample files (max 5)
|
||||
files = random.sample(files, min(max_files, len(files)))
|
||||
|
||||
# Test ping (stat time)
|
||||
ping_times = []
|
||||
file_sizes = []
|
||||
read_speeds = []
|
||||
|
||||
for f in files:
|
||||
try:
|
||||
# Measure ping (stat call)
|
||||
start = time.perf_counter()
|
||||
file_size = os.stat(f).st_size
|
||||
ping_times.append((time.perf_counter() - start) * 1000) # ms
|
||||
file_sizes.append(file_size)
|
||||
|
||||
# Measure read speed
|
||||
start = time.perf_counter()
|
||||
with open(f, "rb") as file_obj:
|
||||
_ = file_obj.read()
|
||||
read_time = time.perf_counter() - start
|
||||
if read_time > 0: # Avoid division by zero
|
||||
read_speeds.append(file_size / (1 << 20) / read_time) # MB/s
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not ping_times:
|
||||
LOGGER.warning(f"{prefix}Image speed checks: failed to access files")
|
||||
return
|
||||
|
||||
# Calculate stats with uncertainties
|
||||
avg_ping = np.mean(ping_times)
|
||||
std_ping = np.std(ping_times, ddof=1) if len(ping_times) > 1 else 0
|
||||
size_msg = f", size: {np.mean(file_sizes) / (1 << 10):.1f} KB"
|
||||
ping_msg = f"ping: {avg_ping:.1f}±{std_ping:.1f} ms"
|
||||
|
||||
if read_speeds:
|
||||
avg_speed = np.mean(read_speeds)
|
||||
std_speed = np.std(read_speeds, ddof=1) if len(read_speeds) > 1 else 0
|
||||
speed_msg = f", read: {avg_speed:.1f}±{std_speed:.1f} MB/s"
|
||||
else:
|
||||
speed_msg = ""
|
||||
|
||||
if avg_ping < threshold_ms or avg_speed < threshold_mb:
|
||||
LOGGER.info(f"{prefix}Fast image access ✅ ({ping_msg}{speed_msg}{size_msg})")
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f"{prefix}Slow image access detected ({ping_msg}{speed_msg}{size_msg}). "
|
||||
f"Use local storage instead of remote/mounted storage for better performance. "
|
||||
f"See https://docs.ultralytics.com/guides/model-training-tips/"
|
||||
)
|
||||
|
||||
|
||||
def get_hash(paths: list[str]) -> str:
|
||||
"""Return a single hash value of a list of paths (files or dirs)."""
|
||||
size = 0
|
||||
for p in paths:
|
||||
try:
|
||||
size += os.stat(p).st_size
|
||||
except OSError:
|
||||
continue
|
||||
h = __import__("hashlib").sha256(str(size).encode()) # hash sizes
|
||||
h.update("".join(paths).encode()) # hash paths
|
||||
return h.hexdigest() # return hash
|
||||
|
||||
|
||||
def exif_size(img: Image.Image) -> tuple[int, int]:
|
||||
"""Return exif-corrected PIL size."""
|
||||
s = img.size # (width, height)
|
||||
if img.format == "JPEG": # only support JPEG images
|
||||
try:
|
||||
if exif := img.getexif():
|
||||
rotation = exif.get(274, None) # the EXIF key for the orientation tag is 274
|
||||
if rotation in {6, 8}: # rotation 270 or 90
|
||||
s = s[1], s[0]
|
||||
except Exception:
|
||||
pass
|
||||
return s
|
||||
|
||||
|
||||
def verify_image(args: tuple) -> tuple:
|
||||
"""Verify one image."""
|
||||
(im_file, cls), prefix = args
|
||||
# Number (found, corrupt), message
|
||||
nf, nc, msg = 0, 0, ""
|
||||
try:
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"Invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
|
||||
nf = 1
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
|
||||
return (im_file, cls), nf, nc, msg
|
||||
|
||||
|
||||
def verify_image_label(args: tuple) -> list:
|
||||
"""Verify one image-label pair."""
|
||||
im_file, lb_file, prefix, keypoint, num_cls, nkpt, ndim, single_cls = args
|
||||
# Number (missing, found, empty, corrupt), message, segments, keypoints
|
||||
nm, nf, ne, nc, msg, segments, keypoints = 0, 0, 0, 0, "", [], None
|
||||
try:
|
||||
# Verify images
|
||||
im = Image.open(im_file)
|
||||
im.verify() # PIL verify
|
||||
shape = exif_size(im) # image size
|
||||
shape = (shape[1], shape[0]) # hw
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}. {FORMATS_HELP_MSG}"
|
||||
if im.format.lower() in {"jpg", "jpeg"}:
|
||||
with open(im_file, "rb") as f:
|
||||
f.seek(-2, 2)
|
||||
if f.read() != b"\xff\xd9": # corrupt JPEG
|
||||
ImageOps.exif_transpose(Image.open(im_file)).save(im_file, "JPEG", subsampling=0, quality=100)
|
||||
msg = f"{prefix}{im_file}: corrupt JPEG restored and saved"
|
||||
|
||||
# Verify labels
|
||||
if os.path.isfile(lb_file):
|
||||
nf = 1 # label found
|
||||
with open(lb_file, encoding="utf-8") as f:
|
||||
lb = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
if any(len(x) > 6 for x in lb) and (not keypoint): # is segment
|
||||
classes = np.array([x[0] for x in lb], dtype=np.float32)
|
||||
segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in lb] # (cls, xy1...)
|
||||
lb = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
|
||||
lb = np.array(lb, dtype=np.float32)
|
||||
if nl := len(lb):
|
||||
if keypoint:
|
||||
assert lb.shape[1] == (5 + nkpt * ndim), f"labels require {(5 + nkpt * ndim)} columns each"
|
||||
points = lb[:, 5:].reshape(-1, ndim)[:, :2]
|
||||
else:
|
||||
assert lb.shape[1] == 5, f"labels require 5 columns, {lb.shape[1]} columns detected"
|
||||
points = lb[:, 1:]
|
||||
# Coordinate points check with 1% tolerance
|
||||
assert points.max() <= 1.01, f"non-normalized or out of bounds coordinates {points[points > 1.01]}"
|
||||
assert lb.min() >= -0.01, f"negative class labels or coordinate {lb[lb < -0.01]}"
|
||||
|
||||
# All labels
|
||||
max_cls = 0 if single_cls else lb[:, 0].max() # max label count
|
||||
assert max_cls < num_cls, (
|
||||
f"Label class {int(max_cls)} exceeds dataset class count {num_cls}. "
|
||||
f"Possible class labels are 0-{num_cls - 1}"
|
||||
)
|
||||
_, i = np.unique(lb, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
lb = lb[i] # remove duplicates
|
||||
if segments:
|
||||
segments = [segments[x] for x in i]
|
||||
msg = f"{prefix}{im_file}: {nl - len(i)} duplicate labels removed"
|
||||
else:
|
||||
ne = 1 # label empty
|
||||
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
||||
else:
|
||||
nm = 1 # label missing
|
||||
lb = np.zeros((0, (5 + nkpt * ndim) if keypoint else 5), dtype=np.float32)
|
||||
if keypoint:
|
||||
keypoints = lb[:, 5:].reshape(-1, nkpt, ndim)
|
||||
if ndim == 2:
|
||||
kpt_mask = np.where((keypoints[..., 0] < 0) | (keypoints[..., 1] < 0), 0.0, 1.0).astype(np.float32)
|
||||
keypoints = np.concatenate([keypoints, kpt_mask[..., None]], axis=-1) # (nl, nkpt, 3)
|
||||
lb = lb[:, :5]
|
||||
return im_file, lb, shape, segments, keypoints, nm, nf, ne, nc, msg
|
||||
except Exception as e:
|
||||
nc = 1
|
||||
msg = f"{prefix}{im_file}: ignoring corrupt image/label: {e}"
|
||||
return [None, None, None, None, None, nm, nf, ne, nc, msg]
|
||||
|
||||
|
||||
def visualize_image_annotations(image_path: str, txt_path: str, label_map: dict[int, str]):
|
||||
"""
|
||||
Visualize YOLO annotations (bounding boxes and class labels) on an image.
|
||||
|
||||
This function reads an image and its corresponding annotation file in YOLO format, then
|
||||
draws bounding boxes around detected objects and labels them with their respective class names.
|
||||
The bounding box colors are assigned based on the class ID, and the text color is dynamically
|
||||
adjusted for readability, depending on the background color's luminance.
|
||||
|
||||
Args:
|
||||
image_path (str): The path to the image file to annotate, and it can be in formats supported by PIL.
|
||||
txt_path (str): The path to the annotation file in YOLO format, that should contain one line per object.
|
||||
label_map (dict[int, str]): A dictionary that maps class IDs (integers) to class labels (strings).
|
||||
|
||||
Examples:
|
||||
>>> label_map = {0: "cat", 1: "dog", 2: "bird"} # It should include all annotated classes details
|
||||
>>> visualize_image_annotations("path/to/image.jpg", "path/to/annotations.txt", label_map)
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
img = np.array(Image.open(image_path))
|
||||
img_height, img_width = img.shape[:2]
|
||||
annotations = []
|
||||
with open(txt_path, encoding="utf-8") as file:
|
||||
for line in file:
|
||||
class_id, x_center, y_center, width, height = map(float, line.split())
|
||||
x = (x_center - width / 2) * img_width
|
||||
y = (y_center - height / 2) * img_height
|
||||
w = width * img_width
|
||||
h = height * img_height
|
||||
annotations.append((x, y, w, h, int(class_id)))
|
||||
_, ax = plt.subplots(1) # Plot the image and annotations
|
||||
for x, y, w, h, label in annotations:
|
||||
color = tuple(c / 255 for c in colors(label, True)) # Get and normalize the RGB color
|
||||
rect = plt.Rectangle((x, y), w, h, linewidth=2, edgecolor=color, facecolor="none") # Create a rectangle
|
||||
ax.add_patch(rect)
|
||||
luminance = 0.2126 * color[0] + 0.7152 * color[1] + 0.0722 * color[2] # Formula for luminance
|
||||
ax.text(x, y - 5, label_map[label], color="white" if luminance < 0.5 else "black", backgroundcolor=color)
|
||||
ax.imshow(img)
|
||||
plt.show()
|
||||
|
||||
|
||||
def polygon2mask(
|
||||
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int = 1, downsample_ratio: int = 1
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert a list of polygons to a binary mask of the specified image size.
|
||||
|
||||
Args:
|
||||
imgsz (tuple[int, int]): The size of the image as (height, width).
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
|
||||
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
color (int, optional): The color value to fill in the polygons on the mask.
|
||||
downsample_ratio (int, optional): Factor by which to downsample the mask.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): A binary mask of the specified image size with the polygons filled in.
|
||||
"""
|
||||
mask = np.zeros(imgsz, dtype=np.uint8)
|
||||
polygons = np.asarray(polygons, dtype=np.int32)
|
||||
polygons = polygons.reshape((polygons.shape[0], -1, 2))
|
||||
cv2.fillPoly(mask, polygons, color=color)
|
||||
nh, nw = (imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio)
|
||||
# Note: fillPoly first then resize is trying to keep the same loss calculation method when mask-ratio=1
|
||||
return cv2.resize(mask, (nw, nh))
|
||||
|
||||
|
||||
def polygons2masks(
|
||||
imgsz: tuple[int, int], polygons: list[np.ndarray], color: int, downsample_ratio: int = 1
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert a list of polygons to a set of binary masks of the specified image size.
|
||||
|
||||
Args:
|
||||
imgsz (tuple[int, int]): The size of the image as (height, width).
|
||||
polygons (list[np.ndarray]): A list of polygons. Each polygon is an array with shape (N, M), where
|
||||
N is the number of polygons, and M is the number of points such that M % 2 = 0.
|
||||
color (int): The color value to fill in the polygons on the masks.
|
||||
downsample_ratio (int, optional): Factor by which to downsample each mask.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): A set of binary masks of the specified image size with the polygons filled in.
|
||||
"""
|
||||
return np.array([polygon2mask(imgsz, [x.reshape(-1)], color, downsample_ratio) for x in polygons])
|
||||
|
||||
|
||||
def polygons2masks_overlap(
|
||||
imgsz: tuple[int, int], segments: list[np.ndarray], downsample_ratio: int = 1
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return a (640, 640) overlap mask."""
|
||||
masks = np.zeros(
|
||||
(imgsz[0] // downsample_ratio, imgsz[1] // downsample_ratio),
|
||||
dtype=np.int32 if len(segments) > 255 else np.uint8,
|
||||
)
|
||||
areas = []
|
||||
ms = []
|
||||
for segment in segments:
|
||||
mask = polygon2mask(
|
||||
imgsz,
|
||||
[segment.reshape(-1)],
|
||||
downsample_ratio=downsample_ratio,
|
||||
color=1,
|
||||
)
|
||||
ms.append(mask.astype(masks.dtype))
|
||||
areas.append(mask.sum())
|
||||
areas = np.asarray(areas)
|
||||
index = np.argsort(-areas)
|
||||
ms = np.array(ms)[index]
|
||||
for i in range(len(segments)):
|
||||
mask = ms[i] * (i + 1)
|
||||
masks = masks + mask
|
||||
masks = np.clip(masks, a_min=0, a_max=i + 1)
|
||||
return masks, index
|
||||
|
||||
|
||||
def find_dataset_yaml(path: Path) -> Path:
|
||||
"""
|
||||
Find and return the YAML file associated with a Detect, Segment or Pose dataset.
|
||||
|
||||
This function searches for a YAML file at the root level of the provided directory first, and if not found, it
|
||||
performs a recursive search. It prefers YAML files that have the same stem as the provided path.
|
||||
|
||||
Args:
|
||||
path (Path): The directory path to search for the YAML file.
|
||||
|
||||
Returns:
|
||||
(Path): The path of the found YAML file.
|
||||
"""
|
||||
files = list(path.glob("*.yaml")) or list(path.rglob("*.yaml")) # try root level first and then recursive
|
||||
assert files, f"No YAML file found in '{path.resolve()}'"
|
||||
if len(files) > 1:
|
||||
files = [f for f in files if f.stem == path.stem] # prefer *.yaml files that match
|
||||
assert len(files) == 1, f"Expected 1 YAML file in '{path.resolve()}', but found {len(files)}.\n{files}"
|
||||
return files[0]
|
||||
|
||||
|
||||
def check_det_dataset(dataset: str, autodownload: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Download, verify, and/or unzip a dataset if not found locally.
|
||||
|
||||
This function checks the availability of a specified dataset, and if not found, it has the option to download and
|
||||
unzip the dataset. It then reads and parses the accompanying YAML data, ensuring key requirements are met and also
|
||||
resolves paths related to the dataset.
|
||||
|
||||
Args:
|
||||
dataset (str): Path to the dataset or dataset descriptor (like a YAML file).
|
||||
autodownload (bool, optional): Whether to automatically download the dataset if not found.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Parsed dataset information and paths.
|
||||
"""
|
||||
file = check_file(dataset)
|
||||
|
||||
# Download (optional)
|
||||
extract_dir = ""
|
||||
if zipfile.is_zipfile(file) or is_tarfile(file):
|
||||
new_dir = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
file = find_dataset_yaml(DATASETS_DIR / new_dir)
|
||||
extract_dir, autodownload = file.parent, False
|
||||
|
||||
# Read YAML
|
||||
data = YAML.load(file, append_filename=True) # dictionary
|
||||
|
||||
# Checks
|
||||
for k in "train", "val":
|
||||
if k not in data:
|
||||
if k != "val" or "validation" not in data:
|
||||
raise SyntaxError(
|
||||
emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs.")
|
||||
)
|
||||
LOGGER.warning("renaming data YAML 'validation' key to 'val' to match YOLO format.")
|
||||
data["val"] = data.pop("validation") # replace 'validation' key with 'val' key
|
||||
if "names" not in data and "nc" not in data:
|
||||
raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))
|
||||
if "names" in data and "nc" in data and len(data["names"]) != data["nc"]:
|
||||
raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))
|
||||
if "names" not in data:
|
||||
data["names"] = [f"class_{i}" for i in range(data["nc"])]
|
||||
else:
|
||||
data["nc"] = len(data["names"])
|
||||
|
||||
data["names"] = check_class_names(data["names"])
|
||||
data["channels"] = data.get("channels", 3) # get image channels, default to 3
|
||||
|
||||
# Resolve paths
|
||||
path = Path(extract_dir or data.get("path") or Path(data.get("yaml_file", "")).parent) # dataset root
|
||||
if not path.exists() and not path.is_absolute():
|
||||
path = (DATASETS_DIR / path).resolve() # path relative to DATASETS_DIR
|
||||
|
||||
# Set paths
|
||||
data["path"] = path # download scripts
|
||||
for k in "train", "val", "test", "minival":
|
||||
if data.get(k): # prepend path
|
||||
if isinstance(data[k], str):
|
||||
x = (path / data[k]).resolve()
|
||||
if not x.exists() and data[k].startswith("../"):
|
||||
x = (path / data[k][3:]).resolve()
|
||||
data[k] = str(x)
|
||||
else:
|
||||
data[k] = [str((path / x).resolve()) for x in data[k]]
|
||||
|
||||
# Parse YAML
|
||||
val, s = (data.get(x) for x in ("val", "download"))
|
||||
if val:
|
||||
val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path
|
||||
if not all(x.exists() for x in val):
|
||||
name = clean_url(dataset) # dataset name with URL auth stripped
|
||||
LOGGER.info("")
|
||||
m = f"Dataset '{name}' images not found, missing path '{[x for x in val if not x.exists()][0]}'"
|
||||
if s and autodownload:
|
||||
LOGGER.warning(m)
|
||||
else:
|
||||
m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_FILE}'"
|
||||
raise FileNotFoundError(m)
|
||||
t = time.time()
|
||||
r = None # success
|
||||
if s.startswith("http") and s.endswith(".zip"): # URL
|
||||
safe_download(url=s, dir=DATASETS_DIR, delete=True)
|
||||
elif s.startswith("bash "): # bash script
|
||||
LOGGER.info(f"Running {s} ...")
|
||||
subprocess.run(s.split(), check=True)
|
||||
else: # python script
|
||||
exec(s, {"yaml": data})
|
||||
dt = f"({round(time.time() - t, 1)}s)"
|
||||
s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in {0, None} else f"failure {dt} ❌"
|
||||
LOGGER.info(f"Dataset download {s}\n")
|
||||
check_font("Arial.ttf" if is_ascii(data["names"]) else "Arial.Unicode.ttf") # download fonts
|
||||
|
||||
return data # dictionary
|
||||
|
||||
|
||||
def check_cls_dataset(dataset: str | Path, split: str = "") -> dict[str, Any]:
|
||||
"""
|
||||
Check a classification dataset such as Imagenet.
|
||||
|
||||
This function accepts a `dataset` name and attempts to retrieve the corresponding dataset information.
|
||||
If the dataset is not found locally, it attempts to download the dataset from the internet and save it locally.
|
||||
|
||||
Args:
|
||||
dataset (str | Path): The name of the dataset.
|
||||
split (str, optional): The split of the dataset. Either 'val', 'test', or ''.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): A dictionary containing the following keys:
|
||||
|
||||
- 'train' (Path): The directory path containing the training set of the dataset.
|
||||
- 'val' (Path): The directory path containing the validation set of the dataset.
|
||||
- 'test' (Path): The directory path containing the test set of the dataset.
|
||||
- 'nc' (int): The number of classes in the dataset.
|
||||
- 'names' (dict[int, str]): A dictionary of class names in the dataset.
|
||||
"""
|
||||
# Download (optional if dataset=https://file.zip is passed directly)
|
||||
if str(dataset).startswith(("http:/", "https:/")):
|
||||
dataset = safe_download(dataset, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
elif str(dataset).endswith((".zip", ".tar", ".gz")):
|
||||
file = check_file(dataset)
|
||||
dataset = safe_download(file, dir=DATASETS_DIR, unzip=True, delete=False)
|
||||
|
||||
dataset = Path(dataset)
|
||||
data_dir = (dataset if dataset.is_dir() else (DATASETS_DIR / dataset)).resolve()
|
||||
if not data_dir.is_dir():
|
||||
if data_dir.suffix != "":
|
||||
raise ValueError(
|
||||
f'Classification datasets must be a directory (data="path/to/dir") not a file (data="{dataset}"), '
|
||||
"See https://docs.ultralytics.com/datasets/classify/"
|
||||
)
|
||||
LOGGER.info("")
|
||||
LOGGER.warning(f"Dataset not found, missing path {data_dir}, attempting download...")
|
||||
t = time.time()
|
||||
if str(dataset) == "imagenet":
|
||||
subprocess.run(["bash", str(ROOT / "data/scripts/get_imagenet.sh")], check=True)
|
||||
else:
|
||||
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{dataset}.zip"
|
||||
download(url, dir=data_dir.parent)
|
||||
LOGGER.info(f"Dataset download success ✅ ({time.time() - t:.1f}s), saved to {colorstr('bold', data_dir)}\n")
|
||||
train_set = data_dir / "train"
|
||||
if not train_set.is_dir():
|
||||
LOGGER.warning(f"Dataset 'split=train' not found at {train_set}")
|
||||
if image_files := list(data_dir.rglob("*.jpg")) + list(data_dir.rglob("*.png")):
|
||||
from ultralytics.data.split import split_classify_dataset
|
||||
|
||||
LOGGER.info(f"Found {len(image_files)} images in subdirectories. Attempting to split...")
|
||||
data_dir = split_classify_dataset(data_dir, train_ratio=0.8)
|
||||
train_set = data_dir / "train"
|
||||
else:
|
||||
LOGGER.error(f"No images found in {data_dir} or its subdirectories.")
|
||||
val_set = (
|
||||
data_dir / "val"
|
||||
if (data_dir / "val").exists()
|
||||
else data_dir / "validation"
|
||||
if (data_dir / "validation").exists()
|
||||
else data_dir / "valid"
|
||||
if (data_dir / "valid").exists()
|
||||
else None
|
||||
) # data/test or data/val
|
||||
test_set = data_dir / "test" if (data_dir / "test").exists() else None # data/val or data/test
|
||||
if split == "val" and not val_set:
|
||||
LOGGER.warning("Dataset 'split=val' not found, using 'split=test' instead.")
|
||||
val_set = test_set
|
||||
elif split == "test" and not test_set:
|
||||
LOGGER.warning("Dataset 'split=test' not found, using 'split=val' instead.")
|
||||
test_set = val_set
|
||||
|
||||
nc = len([x for x in (data_dir / "train").glob("*") if x.is_dir()]) # number of classes
|
||||
names = [x.name for x in (data_dir / "train").iterdir() if x.is_dir()] # class names list
|
||||
names = dict(enumerate(sorted(names)))
|
||||
|
||||
# Print to console
|
||||
for k, v in {"train": train_set, "val": val_set, "test": test_set}.items():
|
||||
prefix = f"{colorstr(f'{k}:')} {v}..."
|
||||
if v is None:
|
||||
LOGGER.info(prefix)
|
||||
else:
|
||||
files = [path for path in v.rglob("*.*") if path.suffix[1:].lower() in IMG_FORMATS]
|
||||
nf = len(files) # number of files
|
||||
nd = len({file.parent for file in files}) # number of directories
|
||||
if nf == 0:
|
||||
if k == "train":
|
||||
raise FileNotFoundError(f"{dataset} '{k}:' no training images found")
|
||||
else:
|
||||
LOGGER.warning(f"{prefix} found {nf} images in {nd} classes (no images found)")
|
||||
elif nd != nc:
|
||||
LOGGER.error(f"{prefix} found {nf} images in {nd} classes (requires {nc} classes, not {nd})")
|
||||
else:
|
||||
LOGGER.info(f"{prefix} found {nf} images in {nd} classes ✅ ")
|
||||
|
||||
return {"train": train_set, "val": val_set, "test": test_set, "nc": nc, "names": names, "channels": 3}
|
||||
|
||||
|
||||
class HUBDatasetStats:
|
||||
"""
|
||||
A class for generating HUB dataset JSON and `-hub` dataset directory.
|
||||
|
||||
Args:
|
||||
path (str): Path to data.yaml or data.zip (with data.yaml inside data.zip).
|
||||
task (str): Dataset task. Options are 'detect', 'segment', 'pose', 'classify'.
|
||||
autodownload (bool): Attempt to download dataset if not found locally.
|
||||
|
||||
Attributes:
|
||||
task (str): Dataset task type.
|
||||
hub_dir (Path): Directory path for HUB dataset files.
|
||||
im_dir (Path): Directory path for compressed images.
|
||||
stats (dict): Statistics dictionary containing dataset information.
|
||||
data (dict): Dataset configuration data.
|
||||
|
||||
Methods:
|
||||
get_json: Return dataset JSON for Ultralytics HUB.
|
||||
process_images: Compress images for Ultralytics HUB.
|
||||
|
||||
Note:
|
||||
Download *.zip files from https://github.com/ultralytics/hub/tree/main/example_datasets
|
||||
i.e. https://github.com/ultralytics/hub/raw/main/example_datasets/coco8.zip for coco8.zip.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.data.utils import HUBDatasetStats
|
||||
>>> stats = HUBDatasetStats("path/to/coco8.zip", task="detect") # detect dataset
|
||||
>>> stats = HUBDatasetStats("path/to/coco8-seg.zip", task="segment") # segment dataset
|
||||
>>> stats = HUBDatasetStats("path/to/coco8-pose.zip", task="pose") # pose dataset
|
||||
>>> stats = HUBDatasetStats("path/to/dota8.zip", task="obb") # OBB dataset
|
||||
>>> stats = HUBDatasetStats("path/to/imagenet10.zip", task="classify") # classification dataset
|
||||
>>> stats.get_json(save=True)
|
||||
>>> stats.process_images()
|
||||
"""
|
||||
|
||||
def __init__(self, path: str = "coco8.yaml", task: str = "detect", autodownload: bool = False):
|
||||
"""Initialize class."""
|
||||
path = Path(path).resolve()
|
||||
LOGGER.info(f"Starting HUB dataset checks for {path}....")
|
||||
|
||||
self.task = task # detect, segment, pose, classify, obb
|
||||
if self.task == "classify":
|
||||
unzip_dir = unzip_file(path)
|
||||
data = check_cls_dataset(unzip_dir)
|
||||
data["path"] = unzip_dir
|
||||
else: # detect, segment, pose, obb
|
||||
_, data_dir, yaml_path = self._unzip(Path(path))
|
||||
try:
|
||||
# Load YAML with checks
|
||||
data = YAML.load(yaml_path)
|
||||
data["path"] = "" # strip path since YAML should be in dataset root for all HUB datasets
|
||||
YAML.save(yaml_path, data)
|
||||
data = check_det_dataset(yaml_path, autodownload) # dict
|
||||
data["path"] = data_dir # YAML path should be set to '' (relative) or parent (absolute)
|
||||
except Exception as e:
|
||||
raise Exception("error/HUB/dataset_stats/init") from e
|
||||
|
||||
self.hub_dir = Path(f"{data['path']}-hub")
|
||||
self.im_dir = self.hub_dir / "images"
|
||||
self.stats = {"nc": len(data["names"]), "names": list(data["names"].values())} # statistics dictionary
|
||||
self.data = data
|
||||
|
||||
@staticmethod
|
||||
def _unzip(path: Path) -> tuple[bool, str, Path]:
|
||||
"""Unzip data.zip."""
|
||||
if not str(path).endswith(".zip"): # path is data.yaml
|
||||
return False, None, path
|
||||
unzip_dir = unzip_file(path, path=path.parent)
|
||||
assert unzip_dir.is_dir(), (
|
||||
f"Error unzipping {path}, {unzip_dir} not found. path/to/abc.zip MUST unzip to path/to/abc/"
|
||||
)
|
||||
return True, str(unzip_dir), find_dataset_yaml(unzip_dir) # zipped, data_dir, yaml_path
|
||||
|
||||
def _hub_ops(self, f: str):
|
||||
"""Save a compressed image for HUB previews."""
|
||||
compress_one_image(f, self.im_dir / Path(f).name) # save to dataset-hub
|
||||
|
||||
def get_json(self, save: bool = False, verbose: bool = False) -> dict:
|
||||
"""Return dataset JSON for Ultralytics HUB."""
|
||||
|
||||
def _round(labels):
|
||||
"""Update labels to integer class and 4 decimal place floats."""
|
||||
if self.task == "detect":
|
||||
coordinates = labels["bboxes"]
|
||||
elif self.task in {"segment", "obb"}: # Segment and OBB use segments. OBB segments are normalized xyxyxyxy
|
||||
coordinates = [x.flatten() for x in labels["segments"]]
|
||||
elif self.task == "pose":
|
||||
n, nk, nd = labels["keypoints"].shape
|
||||
coordinates = np.concatenate((labels["bboxes"], labels["keypoints"].reshape(n, nk * nd)), 1)
|
||||
else:
|
||||
raise ValueError(f"Undefined dataset task={self.task}.")
|
||||
zipped = zip(labels["cls"], coordinates)
|
||||
return [[int(c[0]), *(round(float(x), 4) for x in points)] for c, points in zipped]
|
||||
|
||||
for split in "train", "val", "test":
|
||||
self.stats[split] = None # predefine
|
||||
path = self.data.get(split)
|
||||
|
||||
# Check split
|
||||
if path is None: # no split
|
||||
continue
|
||||
files = [f for f in Path(path).rglob("*.*") if f.suffix[1:].lower() in IMG_FORMATS] # image files in split
|
||||
if not files: # no images
|
||||
continue
|
||||
|
||||
# Get dataset statistics
|
||||
if self.task == "classify":
|
||||
from torchvision.datasets import ImageFolder # scope for faster 'import ultralytics'
|
||||
|
||||
dataset = ImageFolder(self.data[split])
|
||||
|
||||
x = np.zeros(len(dataset.classes)).astype(int)
|
||||
for im in dataset.imgs:
|
||||
x[im[1]] += 1
|
||||
|
||||
self.stats[split] = {
|
||||
"instance_stats": {"total": len(dataset), "per_class": x.tolist()},
|
||||
"image_stats": {"total": len(dataset), "unlabelled": 0, "per_class": x.tolist()},
|
||||
"labels": [{Path(k).name: v} for k, v in dataset.imgs],
|
||||
}
|
||||
else:
|
||||
from ultralytics.data import YOLODataset
|
||||
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data, task=self.task)
|
||||
x = np.array(
|
||||
[
|
||||
np.bincount(label["cls"].astype(int).flatten(), minlength=self.data["nc"])
|
||||
for label in TQDM(dataset.labels, total=len(dataset), desc="Statistics")
|
||||
]
|
||||
) # shape(128x80)
|
||||
self.stats[split] = {
|
||||
"instance_stats": {"total": int(x.sum()), "per_class": x.sum(0).tolist()},
|
||||
"image_stats": {
|
||||
"total": len(dataset),
|
||||
"unlabelled": int(np.all(x == 0, 1).sum()),
|
||||
"per_class": (x > 0).sum(0).tolist(),
|
||||
},
|
||||
"labels": [{Path(k).name: _round(v)} for k, v in zip(dataset.im_files, dataset.labels)],
|
||||
}
|
||||
|
||||
# Save, print and return
|
||||
if save:
|
||||
self.hub_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/
|
||||
stats_path = self.hub_dir / "stats.json"
|
||||
LOGGER.info(f"Saving {stats_path.resolve()}...")
|
||||
with open(stats_path, "w", encoding="utf-8") as f:
|
||||
json.dump(self.stats, f) # save stats.json
|
||||
if verbose:
|
||||
LOGGER.info(json.dumps(self.stats, indent=2, sort_keys=False))
|
||||
return self.stats
|
||||
|
||||
def process_images(self) -> Path:
|
||||
"""Compress images for Ultralytics HUB."""
|
||||
from ultralytics.data import YOLODataset # ClassificationDataset
|
||||
|
||||
self.im_dir.mkdir(parents=True, exist_ok=True) # makes dataset-hub/images/
|
||||
for split in "train", "val", "test":
|
||||
if self.data.get(split) is None:
|
||||
continue
|
||||
dataset = YOLODataset(img_path=self.data[split], data=self.data)
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
for _ in TQDM(pool.imap(self._hub_ops, dataset.im_files), total=len(dataset), desc=f"{split} images"):
|
||||
pass
|
||||
LOGGER.info(f"Done. All images saved to {self.im_dir}")
|
||||
return self.im_dir
|
||||
|
||||
|
||||
def compress_one_image(f: str, f_new: str = None, max_dim: int = 1920, quality: int = 50):
|
||||
"""
|
||||
Compress a single image file to reduced size while preserving its aspect ratio and quality using either the Python
|
||||
Imaging Library (PIL) or OpenCV library. If the input image is smaller than the maximum dimension, it will not be
|
||||
resized.
|
||||
|
||||
Args:
|
||||
f (str): The path to the input image file.
|
||||
f_new (str, optional): The path to the output image file. If not specified, the input file will be overwritten.
|
||||
max_dim (int, optional): The maximum dimension (width or height) of the output image.
|
||||
quality (int, optional): The image compression quality as a percentage.
|
||||
|
||||
Examples:
|
||||
>>> from pathlib import Path
|
||||
>>> from ultralytics.data.utils import compress_one_image
|
||||
>>> for f in Path("path/to/dataset").rglob("*.jpg"):
|
||||
>>> compress_one_image(f)
|
||||
"""
|
||||
try: # use PIL
|
||||
Image.MAX_IMAGE_PIXELS = None # Fix DecompressionBombError, allow optimization of image > ~178.9 million pixels
|
||||
im = Image.open(f)
|
||||
if im.mode in {"RGBA", "LA"}: # Convert to RGB if needed (for JPEG)
|
||||
im = im.convert("RGB")
|
||||
r = max_dim / max(im.height, im.width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = im.resize((int(im.width * r), int(im.height * r)))
|
||||
im.save(f_new or f, "JPEG", quality=quality, optimize=True) # save
|
||||
except Exception as e: # use OpenCV
|
||||
LOGGER.warning(f"HUB ops PIL failure {f}: {e}")
|
||||
im = cv2.imread(f)
|
||||
im_height, im_width = im.shape[:2]
|
||||
r = max_dim / max(im_height, im_width) # ratio
|
||||
if r < 1.0: # image too large
|
||||
im = cv2.resize(im, (int(im_width * r), int(im_height * r)), interpolation=cv2.INTER_AREA)
|
||||
cv2.imwrite(str(f_new or f), im)
|
||||
|
||||
|
||||
def load_dataset_cache_file(path: Path) -> dict:
|
||||
"""Load an Ultralytics *.cache dictionary from path."""
|
||||
import gc
|
||||
|
||||
gc.disable() # reduce pickle load time https://github.com/ultralytics/ultralytics/pull/1585
|
||||
cache = np.load(str(path), allow_pickle=True).item() # load dict
|
||||
gc.enable()
|
||||
return cache
|
||||
|
||||
|
||||
def save_dataset_cache_file(prefix: str, path: Path, x: dict, version: str):
|
||||
"""Save an Ultralytics dataset *.cache dictionary x to path."""
|
||||
x["version"] = version # add cache version
|
||||
if is_dir_writeable(path.parent):
|
||||
if path.exists():
|
||||
path.unlink() # remove *.cache file if exists
|
||||
with open(str(path), "wb") as file: # context manager here fixes windows async np.save bug
|
||||
np.save(file, x)
|
||||
LOGGER.info(f"{prefix}New cache created: {path}")
|
||||
else:
|
||||
LOGGER.warning(f"{prefix}Cache directory {path.parent} is not writeable, cache not saved.")
|
||||
Reference in New Issue
Block a user