init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
from .model import YOLO, YOLOE, YOLOWorld
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"

View File

@@ -0,0 +1,93 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import cv2
import torch
from PIL import Image
from ultralytics.data.augment import classify_transforms
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images
and postprocessing predictions to generate classification results.
Attributes:
args (dict): Configuration arguments for the predictor.
Methods:
preprocess: Convert input images to model-compatible format.
postprocess: Process model predictions into Results objects.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
>>> predictor = ClassificationPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
tasks. It ensures the task is set to 'classify' regardless of input configuration.
Args:
cfg (dict): Default configuration dictionary containing prediction settings.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be executed during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "classify"
def setup_source(self, source):
"""Set up source and inference mode and classify transforms."""
super().setup_source(source)
updated = (
self.model.model.transforms.transforms[0].size != max(self.imgsz)
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
else False
)
self.transforms = (
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
)
def preprocess(self, img):
"""Convert input images to model-compatible tensor format with appropriate normalization."""
if not isinstance(img, torch.Tensor):
img = torch.stack(
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""
Process predictions to return Results objects with classification probabilities.
Args:
preds (torch.Tensor): Raw predictions from the model.
img (torch.Tensor): Input images after preprocessing.
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
Returns:
(list[Results]): List of Results objects containing classification results for each image.
"""
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]

View File

@@ -0,0 +1,223 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
"""
A trainer class extending BaseTrainer for training image classification models.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
and torchvision models with comprehensive dataset handling and validation.
Attributes:
model (ClassificationModel): The classification model to be trained.
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
loss_names (list[str]): Names of the loss functions used during training.
validator (ClassificationValidator): Validator instance for model evaluation.
Methods:
set_model_attributes: Set the model's class names from the loaded dataset.
get_model: Return a modified PyTorch model configured for training.
setup_model: Load, create or download model for classification.
build_dataset: Create a ClassificationDataset instance.
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
preprocess_batch: Preprocess a batch of images and classes.
progress_string: Return a formatted string showing training progress.
get_validator: Return an instance of ClassificationValidator.
label_loss_items: Return a loss dict with labelled training loss items.
final_eval: Evaluate trained model and save validation results.
plot_training_samples: Plot training samples with their annotations.
Examples:
Initialize and train a classification model
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a ClassificationTrainer object.
Args:
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list[Any], optional): List of callback functions to be executed during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "classify"
if overrides.get("imgsz") is None:
overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a modified PyTorch model configured for training YOLO classification.
Args:
cfg (Any, optional): Model configuration.
weights (Any, optional): Pre-trained model weights.
verbose (bool, optional): Whether to display model information.
Returns:
(ClassificationModel): Configured PyTorch model for classification.
"""
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""
Load, create or download model for classification tasks.
Returns:
(Any): Model checkpoint if applicable, otherwise None.
"""
import torchvision # scope for faster 'import ultralytics'
if str(self.model) in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[self.model](
weights="IMAGENET1K_V1" if self.args.pretrained else None
)
ckpt = None
else:
ckpt = super().setup_model()
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
"""
Create a ClassificationDataset instance given an image path and mode.
Args:
img_path (str): Path to the dataset images.
mode (str, optional): Dataset mode ('train', 'val', or 'test').
batch (Any, optional): Batch information (unused in this implementation).
Returns:
(ClassificationDataset): Dataset for the specified mode.
"""
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Return PyTorch DataLoader with transforms to preprocess images.
Args:
dataset_path (str): Path to the dataset.
batch_size (int, optional): Number of images per batch.
rank (int, optional): Process rank for distributed training.
mode (str, optional): 'train', 'val', or 'test' mode.
Returns:
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
"""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
# Attach inference transforms
if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Preprocess a batch of images and classes."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def progress_string(self) -> str:
"""Return a formatted string showing training progress."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def get_validator(self):
"""Return an instance of ClassificationValidator for validation."""
self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
"""
Return a loss dict with labelled training loss items tensor.
Args:
loss_items (torch.Tensor, optional): Loss tensor items.
prefix (str, optional): Prefix to prepend to loss names.
Returns:
keys (list[str]): List of loss keys if loss_items is None.
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
"""
Plot training samples with their annotations.
Args:
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
ni (int): Number of iterations.
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)

View File

@@ -0,0 +1,214 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
This validator handles the validation process for classification models, including metrics calculation,
confusion matrix generation, and visualization of results.
Attributes:
targets (list[torch.Tensor]): Ground truth class labels.
pred (list[torch.Tensor]): Model predictions.
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
names (dict): Mapping of class indices to class names.
nc (int): Number of classes.
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
Methods:
get_desc: Return a formatted string summarizing classification metrics.
init_metrics: Initialize confusion matrix, class names, and tracking containers.
preprocess: Preprocess input batch by moving data to device.
update_metrics: Update running metrics with model predictions and batch targets.
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
postprocess: Extract the primary prediction from model output.
get_stats: Calculate and return a dictionary of metrics.
build_dataset: Create a ClassificationDataset instance for validation.
get_dataloader: Build and return a data loader for classification validation.
print_results: Print evaluation metrics for the classification model.
plot_val_samples: Plot validation image samples with their ground truth labels.
plot_predictions: Plot images with their predicted class labels.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
Notes:
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (str | Path, optional): Directory to save results.
args (dict, optional): Arguments containing model and validation configuration.
_callbacks (list, optional): List of callback functions to be called during validation.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = "classify"
self.metrics = ClassifyMetrics()
def get_desc(self) -> str:
"""Return a formatted string summarizing classification metrics."""
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
def init_metrics(self, model: torch.nn.Module) -> None:
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
self.names = model.names
self.nc = len(model.names)
self.pred = []
self.targets = []
self.confusion_matrix = ConfusionMatrix(names=model.names)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
"""
Update running metrics with model predictions and batch targets.
Args:
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
batch (dict): Batch data containing images and class labels.
Notes:
This method appends the top-N predictions (sorted by confidence in descending order) to the
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
"""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
self.targets.append(batch["cls"].type(torch.int32).cpu())
def finalize_metrics(self) -> None:
"""
Finalize metrics including confusion matrix and processing speed.
Notes:
This method processes the accumulated predictions and targets to generate the confusion matrix,
optionally plots it, and updates the metrics object with speed information.
Examples:
>>> validator = ClassificationValidator()
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
>>> validator.targets = [torch.tensor([0])] # Ground truth class
>>> validator.finalize_metrics()
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
"""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.save_dir = self.save_dir
self.metrics.confusion_matrix = self.confusion_matrix
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
"""Extract the primary prediction from model output if it's in a list or tuple format."""
return preds[0] if isinstance(preds, (list, tuple)) else preds
def get_stats(self) -> dict[str, float]:
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path: str) -> ClassificationDataset:
"""Create a ClassificationDataset instance for validation."""
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Build and return a data loader for classification validation.
Args:
dataset_path (str | Path): Path to the dataset directory.
batch_size (int): Number of samples per batch.
Returns:
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
"""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self) -> None:
"""Print evaluation metrics for the classification model."""
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples with their ground truth labels.
Args:
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
>>> validator.plot_val_samples(batch, 0)
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
"""
Plot images with their predicted class labels and save the visualization.
Args:
batch (dict[str, Any]): Batch data containing images and other information.
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
>>> validator.plot_predictions(batch, preds, 0)
"""
batched_preds = dict(
img=batch["img"],
batch_idx=torch.arange(batch["img"].shape[0]),
cls=torch.argmax(preds, dim=1),
)
plot_images(
batched_preds,
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"

View File

@@ -0,0 +1,125 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import nms, ops
class DetectionPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a detection model.
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
with bounding boxes and class predictions.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (nn.Module): The detection model used for inference.
batch (list): Batch of images and metadata for processing.
Methods:
postprocess: Process raw model predictions into detection results.
construct_results: Build Results objects from processed predictions.
construct_result: Create a single Result object from a prediction.
get_obj_feats: Extract object features from the feature maps.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.detect import DetectionPredictor
>>> args = dict(model="yolo11n.pt", source=ASSETS)
>>> predictor = DetectionPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def postprocess(self, preds, img, orig_imgs, **kwargs):
"""
Post-process predictions and return a list of Results objects.
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
further analysis.
Args:
preds (torch.Tensor): Raw predictions from the model.
img (torch.Tensor): Processed input image tensor in model input format.
orig_imgs (torch.Tensor | list): Original input images before preprocessing.
**kwargs (Any): Additional keyword arguments.
Returns:
(list): List of Results objects containing the post-processed predictions.
Examples:
>>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
>>> results = predictor.predict("path/to/image.jpg")
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
"""
save_feats = getattr(self, "_feats", None) is not None
preds = nms.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
self.args.classes,
self.args.agnostic_nms,
max_det=self.args.max_det,
nc=0 if self.args.task == "detect" else len(self.model.names),
end2end=getattr(self.model, "end2end", False),
rotated=self.args.task == "obb",
return_idxs=save_feats,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
if save_feats:
obj_feats = self.get_obj_feats(self._feats, preds[1])
preds = preds[0]
results = self.construct_results(preds, img, orig_imgs, **kwargs)
if save_feats:
for r, f in zip(results, obj_feats):
r.feats = f # add object features to results
return results
def get_obj_feats(self, feat_maps, idxs):
"""Extract object features from the feature maps."""
import torch
s = min(x.shape[1] for x in feat_maps) # find shortest vector length
obj_feats = torch.cat(
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
) # mean reduce all vectors to same length
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
def construct_results(self, preds, img, orig_imgs):
"""
Construct a list of Results objects from model predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
img (torch.Tensor): Batch of preprocessed images used for inference.
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
Returns:
(list[Results]): List of Results objects containing detection information for each image.
"""
return [
self.construct_result(pred, img, orig_img, img_path)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct a single Results object from one image prediction.
Args:
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
img (torch.Tensor): Preprocessed image tensor used for inference.
orig_img (np.ndarray): Original image before preprocessing.
img_path (str): Path to the original image file.
Returns:
(Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
"""
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])

View File

@@ -0,0 +1,236 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import math
import random
from copy import copy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.patches import override_configs
from ultralytics.utils.plotting import plot_images, plot_labels
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
for object detection including dataset building, data loading, preprocessing, and model configuration.
Attributes:
model (DetectionModel): The YOLO detection model being trained.
data (dict): Dictionary containing dataset information including class names and number of classes.
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
Methods:
build_dataset: Build YOLO dataset for training or validation.
get_dataloader: Construct and return dataloader for the specified mode.
preprocess_batch: Preprocess a batch of images by scaling and converting to float.
set_model_attributes: Set model attributes based on dataset information.
get_model: Return a YOLO detection model.
get_validator: Return a validator for model evaluation.
label_loss_items: Return a loss dictionary with labeled training loss items.
progress_string: Return a formatted string of training progress.
plot_training_samples: Plot training samples with their annotations.
plot_training_labels: Create a labeled training plot of the YOLO model.
auto_batch: Calculate optimal batch size based on model memory requirements.
Examples:
>>> from ultralytics.models.yolo.detect import DetectionTrainer
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
>>> trainer = DetectionTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a DetectionTrainer object for training YOLO object detection model training.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
"""
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for 'rect' mode.
Returns:
(Dataset): YOLO dataset object configured for the specified mode.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Construct and return dataloader for the specified mode.
Args:
dataset_path (str): Path to the dataset.
batch_size (int): Number of images per batch.
rank (int): Process rank for distributed training.
mode (str): 'train' for training dataloader, 'val' for validation dataloader.
Returns:
(DataLoader): PyTorch dataloader object.
"""
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == "train"
if getattr(dataset, "rect", False) and shuffle:
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
return build_dataloader(
dataset,
batch=batch_size,
workers=self.args.workers if mode == "train" else self.args.workers * 2,
shuffle=shuffle,
rank=rank,
drop_last=self.args.compile and mode == "train",
)
def preprocess_batch(self, batch: dict) -> dict:
"""
Preprocess a batch of images by scaling and converting to float.
Args:
batch (dict): Dictionary containing batch data with 'img' tensor.
Returns:
(dict): Preprocessed batch with normalized images.
"""
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = batch["img"].float() / 255
if self.args.multi_scale:
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
def set_model_attributes(self):
"""Set model attributes based on dataset information."""
# Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
"""
Return a YOLO detection model.
Args:
cfg (str, optional): Path to model configuration file.
weights (str, optional): Path to model weights.
verbose (bool): Whether to display model information.
Returns:
(DetectionModel): YOLO detection model.
"""
model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return a DetectionValidator for YOLO model validation."""
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.detect.DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
"""
Return a loss dict with labeled training loss items tensor.
Args:
loss_items (list[float], optional): List of loss values.
prefix (str): Prefix for keys in the returned dictionary.
Returns:
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys
def progress_string(self):
"""Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot training samples with their annotations.
Args:
batch (dict[str, Any]): Dictionary containing batch data.
ni (int): Number of iterations.
"""
plot_images(
labels=batch,
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
def auto_batch(self):
"""
Get optimal batch size by calculating memory occupation of model.
Returns:
(int): Optimal batch size.
"""
with override_configs(self.args, overrides={"cache": False}) as self.args:
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
del train_dataset # free memory
return super().auto_batch(max_num_obj)

View File

@@ -0,0 +1,495 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER, nms, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import plot_images
class DetectionValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
This class implements validation functionality specific to object detection tasks, including metrics calculation,
prediction processing, and visualization of results.
Attributes:
is_coco (bool): Whether the dataset is COCO.
is_lvis (bool): Whether the dataset is LVIS.
class_map (list[int]): Mapping from model class indices to dataset class indices.
metrics (DetMetrics): Object detection metrics calculator.
iouv (torch.Tensor): IoU thresholds for mAP calculation.
niou (int): Number of IoU thresholds.
lb (list[Any]): List for storing ground truth labels for hybrid saving.
jdict (list[dict[str, Any]]): List for storing JSON detection results.
stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
Examples:
>>> from ultralytics.models.yolo.detect import DetectionValidator
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
>>> validator = DetectionValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize detection validator with necessary variables and settings.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (Path, optional): Directory to save results.
args (dict[str, Any], optional): Arguments for the validator.
_callbacks (list[Any], optional): List of callback functions.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.is_coco = False
self.is_lvis = False
self.class_map = None
self.args.task = "detect"
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.metrics = DetMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
Returns:
(dict[str, Any]): Preprocessed batch.
"""
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO detection validation.
Args:
model (torch.nn.Module): Model to validate.
"""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = (
isinstance(val, str)
and "coco" in val
and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
) # is COCO
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
self.names = model.names
self.nc = len(model.names)
self.end2end = getattr(model, "end2end", False)
self.seen = 0
self.jdict = []
self.metrics.names = model.names
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
def get_desc(self) -> str:
"""Return a formatted string summarizing class metrics of YOLO model."""
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
Apply Non-maximum suppression to prediction outputs.
Args:
preds (torch.Tensor): Raw predictions from the model.
Returns:
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
'bboxes', 'conf', 'cls', and 'extra' tensors.
"""
outputs = nms.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
nc=0 if self.args.task == "detect" else self.nc,
multi_label=True,
agnostic=self.args.single_cls or self.args.agnostic_nms,
max_det=self.args.max_det,
end2end=self.end2end,
rotated=self.args.task == "obb",
)
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch of images and annotations for validation.
Args:
si (int): Batch index.
batch (dict[str, Any]): Batch data containing images and annotations.
Returns:
(dict[str, Any]): Prepared batch with processed annotations.
"""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if cls.shape[0]:
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
return {
"cls": cls,
"bboxes": bbox,
"ori_shape": ori_shape,
"imgsz": imgsz,
"ratio_pad": ratio_pad,
"im_file": batch["im_file"][si],
}
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Prepare predictions for evaluation against ground truth.
Args:
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
Returns:
(dict[str, torch.Tensor]): Prepared predictions in native space.
"""
if self.args.single_cls:
pred["cls"] *= 0
return pred
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
"""
Update metrics with new predictions and ground truth.
Args:
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
batch (dict[str, Any]): Batch data containing ground truth.
"""
for si, pred in enumerate(preds):
self.seen += 1
pbatch = self._prepare_batch(si, batch)
predn = self._prepare_pred(pred)
cls = pbatch["cls"].cpu().numpy()
no_pred = predn["cls"].shape[0] == 0
self.metrics.update_stats(
{
**self._process_batch(predn, pbatch),
"target_cls": cls,
"target_img": np.unique(cls),
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
}
)
# Evaluate
if self.args.plots:
self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
if self.args.visualize:
self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
if no_pred:
continue
# Save
if self.args.save_json or self.args.save_txt:
predn_scaled = self.scale_preds(predn, pbatch)
if self.args.save_json:
self.pred_to_json(predn_scaled, pbatch)
if self.args.save_txt:
self.save_one_txt(
predn_scaled,
self.args.save_conf,
pbatch["ori_shape"],
self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
)
def finalize_metrics(self) -> None:
"""Set final values for metrics speed and confusion matrix."""
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
def get_stats(self) -> dict[str, Any]:
"""
Calculate and return metrics statistics.
Returns:
(dict[str, Any]): Dictionary containing metrics results.
"""
self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
self.metrics.clear_stats()
return self.metrics.results_dict
def print_results(self) -> None:
"""Print training/validation set metrics per class."""
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
if self.metrics.nt_per_class.sum() == 0:
LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(
pf
% (
self.names[c],
self.metrics.nt_per_image[c],
self.metrics.nt_per_class[c],
*self.metrics.class_result(i),
)
)
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
Returns:
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
"""
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
iou = box_iou(batch["bboxes"], preds["bboxes"])
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`.
Returns:
(Dataset): YOLO dataset.
"""
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Construct and return dataloader.
Args:
dataset_path (str): Path to the dataset.
batch_size (int): Size of each batch.
Returns:
(torch.utils.data.DataLoader): Dataloader for validation.
"""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(
dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
)
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
ni (int): Batch index.
"""
plot_images(
labels=batch,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
ni (int): Batch index.
max_det (Optional[int]): Maximum number of detections to plot.
"""
# TODO: optimize this
for i, pred in enumerate(preds):
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
keys = preds[0].keys()
max_det = max_det or self.args.max_det
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
# TODO: fix this
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
plot_images(
images=batch["img"],
labels=batched_preds,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image (height, width).
file (Path): File path to save the detections.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Serialize YOLO predictions to COCO json format.
Args:
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Examples:
>>> result = {
... "image_id": 42,
... "file_name": "42.jpg",
... "category_id": 18,
... "bbox": [258.15, 41.29, 348.26, 243.78],
... "score": 0.236,
... }
"""
path = Path(pbatch["im_file"])
stem = path.stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
self.jdict.append(
{
"image_id": image_id,
"file_name": path.name,
"category_id": self.class_map[int(c)],
"bbox": [round(x, 3) for x in b],
"score": round(s, 5),
}
)
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**predn,
"bboxes": ops.scale_boxes(
pbatch["imgsz"],
predn["bboxes"].clone(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and return performance statistics.
Args:
stats (dict[str, Any]): Current statistics dictionary.
Returns:
(dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
"""
pred_json = self.save_dir / "predictions.json" # predictions
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
return self.coco_evaluate(stats, pred_json, anno_json)
def coco_evaluate(
self,
stats: dict[str, Any],
pred_json: str,
anno_json: str,
iou_types: str | list[str] = "bbox",
suffix: str | list[str] = "Box",
) -> dict[str, Any]:
"""
Evaluate COCO/LVIS metrics using faster-coco-eval library.
Performs evaluation using the faster-coco-eval library to compute mAP metrics
for object detection. Updates the provided stats dictionary with computed metrics
including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
Args:
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
to iou_types if multiple types provided. Defaults to "Box".
Returns:
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
"""
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
try:
for x in pred_json, anno_json:
assert x.is_file(), f"{x} file not found"
iou_types = [iou_types] if isinstance(iou_types, str) else iou_types
suffix = [suffix] if isinstance(suffix, str) else suffix
check_requirements("faster-coco-eval>=1.6.7")
from faster_coco_eval import COCO, COCOeval_faster
anno = COCO(anno_json)
pred = anno.loadRes(pred_json)
for i, iou_type in enumerate(iou_types):
val = COCOeval_faster(
anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info
)
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
val.evaluate()
val.accumulate()
val.summarize()
# update mAP50-95 and mAP50
stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
if self.is_lvis:
stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]
stats[f"metrics/APc({suffix[i][0]})"] = val.stats_as_dict["APc"]
stats[f"metrics/APf({suffix[i][0]})"] = val.stats_as_dict["APf"]
if self.is_lvis:
stats["fitness"] = stats["metrics/mAP50-95(B)"] # always use box mAP50-95 for fitness
except Exception as e:
LOGGER.warning(f"faster-coco-eval unable to run: {e}")
return stats

View File

@@ -0,0 +1,447 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data.build import load_inference_source
from ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import (
ClassificationModel,
DetectionModel,
OBBModel,
PoseModel,
SegmentationModel,
WorldModel,
YOLOEModel,
YOLOESegModel,
)
from ultralytics.utils import ROOT, YAML
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
This class provides a unified interface for YOLO models, automatically switching to specialized model types
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
Attributes:
model: The loaded YOLO model instance.
task: The task type (detect, segment, classify, pose, obb).
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize a YOLO model with automatic type detection.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
Examples:
Load a pretrained YOLOv11n detection model
>>> model = YOLO("yolo11n.pt")
Load a pretrained YOLO11n segmentation model
>>> model = YOLO("yolo11n-seg.pt")
Initialize from a YAML configuration
>>> model = YOLO("yolo11n.yaml")
"""
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
"""
Initialize a YOLO model.
This constructor initializes a YOLO model, automatically switching to specialized model types
(YOLOWorld or YOLOE) based on the model filename.
Args:
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
Defaults to auto-detection based on model.
verbose (bool): Display model info on load.
Examples:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
"""
path = Path(model if isinstance(model, (str, Path)) else "")
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(path, verbose=verbose)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
new_instance = YOLOE(path, task=task, verbose=verbose)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
super().__init__(model=model, task=task, verbose=verbose)
if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
from ultralytics import RTDETR
new_instance = RTDETR(self)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, trainer, validator, and predictor classes."""
return {
"classify": {
"model": ClassificationModel,
"trainer": yolo.classify.ClassificationTrainer,
"validator": yolo.classify.ClassificationValidator,
"predictor": yolo.classify.ClassificationPredictor,
},
"detect": {
"model": DetectionModel,
"trainer": yolo.detect.DetectionTrainer,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
},
"segment": {
"model": SegmentationModel,
"trainer": yolo.segment.SegmentationTrainer,
"validator": yolo.segment.SegmentationValidator,
"predictor": yolo.segment.SegmentationPredictor,
},
"pose": {
"model": PoseModel,
"trainer": yolo.pose.PoseTrainer,
"validator": yolo.pose.PoseValidator,
"predictor": yolo.pose.PosePredictor,
},
"obb": {
"model": OBBModel,
"trainer": yolo.obb.OBBTrainer,
"validator": yolo.obb.OBBValidator,
"predictor": yolo.obb.OBBPredictor,
},
}
class YOLOWorld(Model):
"""
YOLO-World object detection model.
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
without requiring training on specific classes. It extends the YOLO architecture to support real-time
open-vocabulary detection.
Attributes:
model: The loaded YOLO-World model instance.
task: Always set to 'detect' for object detection.
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize YOLOv8-World model with a pre-trained model file.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
set_classes: Set the model's class names for detection.
Examples:
Load a YOLOv8-World model
>>> model = YOLOWorld("yolov8s-world.pt")
Set custom classes for detection
>>> model.set_classes(["person", "car", "bicycle"])
"""
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
"""
Initialize YOLOv8-World model with a pre-trained model file.
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
COCO class names.
Args:
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
verbose (bool): If True, prints additional information during initialization.
"""
super().__init__(model=model, task="detect", verbose=verbose)
# Assign default COCO class names when there are no custom names
if not hasattr(self.model, "names"):
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.world.WorldTrainer,
}
}
def set_classes(self, classes: list[str]) -> None:
"""
Set the model's class names for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
"""
self.model.set_classes(classes)
# Remove background if it's given
background = " "
if background in classes:
classes.remove(background)
self.model.names = classes
# Reset method class names
if self.predictor:
self.predictor.model.names = classes
class YOLOE(Model):
"""
YOLOE object detection and segmentation model.
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
improved performance and additional features like visual and text positional embeddings.
Attributes:
model: The loaded YOLOE model instance.
task: The task type (detect or segment).
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize YOLOE model with a pre-trained model file.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
get_text_pe: Get text positional embeddings for the given texts.
get_visual_pe: Get visual positional embeddings for the given image and visual features.
set_vocab: Set vocabulary and class names for the YOLOE model.
get_vocab: Get vocabulary for the given class names.
set_classes: Set the model's class names and embeddings for detection.
val: Validate the model using text or visual prompts.
predict: Run prediction on images, videos, directories, streams, etc.
Examples:
Load a YOLOE detection model
>>> model = YOLOE("yoloe-11s-seg.pt")
Set vocabulary and class names
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
Predict with visual prompts
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
>>> results = model.predict("image.jpg", visual_prompts=prompts)
"""
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
"""
Initialize YOLOE model with a pre-trained model file.
Args:
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
task (str, optional): Task type for the model. Auto-detected if None.
verbose (bool): If True, prints additional information during initialization.
"""
super().__init__(model=model, task=task, verbose=verbose)
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": YOLOEModel,
"validator": yolo.yoloe.YOLOEDetectValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.yoloe.YOLOETrainer,
},
"segment": {
"model": YOLOESegModel,
"validator": yolo.yoloe.YOLOESegValidator,
"predictor": yolo.segment.SegmentationPredictor,
"trainer": yolo.yoloe.YOLOESegTrainer,
},
}
def get_text_pe(self, texts):
"""Get text positional embeddings for the given texts."""
assert isinstance(self.model, YOLOEModel)
return self.model.get_text_pe(texts)
def get_visual_pe(self, img, visual):
"""
Get visual positional embeddings for the given image and visual features.
This method extracts positional embeddings from visual features based on the input image. It requires
that the model is an instance of YOLOEModel.
Args:
img (torch.Tensor): Input image tensor.
visual (torch.Tensor): Visual features extracted from the image.
Returns:
(torch.Tensor): Visual positional embeddings.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> img = torch.rand(1, 3, 640, 640)
>>> visual_features = torch.rand(1, 1, 80, 80)
>>> pe = model.get_visual_pe(img, visual_features)
"""
assert isinstance(self.model, YOLOEModel)
return self.model.get_visual_pe(img, visual)
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
"""
Set vocabulary and class names for the YOLOE model.
This method configures the vocabulary and class names used by the model for text processing and
classification tasks. The model must be an instance of YOLOEModel.
Args:
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
names (list[str]): List of class names that the model can detect or classify.
Raises:
AssertionError: If the model is not an instance of YOLOEModel.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
"""
assert isinstance(self.model, YOLOEModel)
self.model.set_vocab(vocab, names=names)
def get_vocab(self, names):
"""Get vocabulary for the given class names."""
assert isinstance(self.model, YOLOEModel)
return self.model.get_vocab(names)
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
"""
Set the model's class names and embeddings for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
embeddings (torch.Tensor): Embeddings corresponding to the classes.
"""
assert isinstance(self.model, YOLOEModel)
if embeddings is None:
embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
self.model.set_classes(classes, embeddings)
# Verify no background class is present
assert " " not in classes
self.model.names = classes
# Reset method class names
if self.predictor:
self.predictor.model.names = classes
def val(
self,
validator=None,
load_vp: bool = False,
refer_data: str | None = None,
**kwargs,
):
"""
Validate the model using text or visual prompts.
Args:
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
refer_data (str, optional): Path to the reference data for visual prompts.
**kwargs (Any): Additional keyword arguments to override default settings.
Returns:
(dict): Validation statistics containing metrics computed during validation.
"""
custom = {"rect": not load_vp} # method defaults
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
self.metrics = validator.metrics
return validator.metrics
def predict(
self,
source=None,
stream: bool = False,
visual_prompts: dict[str, list] = {},
refer_image=None,
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
**kwargs,
):
"""
Run prediction on images, videos, directories, streams, etc.
Args:
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
generator as they are computed.
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
'bboxes' and 'cls' keys when non-empty.
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
loaded based on the task.
**kwargs (Any): Additional keyword arguments passed to the predictor.
Returns:
(list | generator): List of Results objects or generator of Results objects if stream=True.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> results = model.predict("path/to/image.jpg")
>>> # With visual prompts
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
"""
if len(visual_prompts):
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
)
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
f"{len(visual_prompts['cls'])} respectively"
)
if type(self.predictor) is not predictor:
self.predictor = predictor(
overrides={
"task": self.model.task,
"mode": "predict",
"save": False,
"verbose": refer_image is None,
"batch": 1,
"device": kwargs.get("device", None),
"half": kwargs.get("half", False),
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
},
_callbacks=self.callbacks,
)
num_cls = (
max(len(set(c)) for c in visual_prompts["cls"])
if isinstance(source, list) and refer_image is None # means multiple images
else len(set(visual_prompts["cls"]))
)
self.model.model[-1].nc = num_cls
self.model.names = [f"object{i}" for i in range(num_cls)]
self.predictor.set_prompts(visual_prompts.copy())
self.predictor.setup_model(model=self.model)
if refer_image is None and source is not None:
dataset = load_inference_source(source)
if dataset.mode in {"video", "stream"}:
# NOTE: set the first frame as refer image for videos/streams inference
refer_image = next(iter(dataset))[1][0]
if refer_image is not None:
vpe = self.predictor.get_vpe(refer_image)
self.model.set_classes(self.model.names, vpe)
self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
self.predictor = None # reset predictor
elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
self.predictor = None # reset predictor if no visual prompts
return super().predict(source, stream, **kwargs)

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import OBBPredictor
from .train import OBBTrainer
from .val import OBBValidator
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"

View File

@@ -0,0 +1,65 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class OBBPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
bounding boxes.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO OBB model.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.obb import OBBPredictor
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
>>> predictor = OBBPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize OBBPredictor with optional model and data configuration overrides.
Args:
cfg (dict, optional): Default configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over the default config.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.obb import OBBPredictor
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
>>> predictor = OBBPredictor(overrides=args)
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "obb"
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
the last dimension contains [x, y, w, h, confidence, class_id, angle].
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
orig_img (np.ndarray): The original image before preprocessing.
img_path (str): The path to the original image.
Returns:
(Results): The result object containing the original image, image path, class names, and oriented bounding
boxes.
"""
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)

View File

@@ -0,0 +1,82 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from typing import Any
from ultralytics.models import yolo
from ultralytics.nn.tasks import OBBModel
from ultralytics.utils import DEFAULT_CFG, RANK
class OBBTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
Attributes:
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
and dfl_loss.
Methods:
get_model: Return OBBModel initialized with specified config and weights.
get_validator: Return an instance of OBBValidator for validation of YOLO model.
Examples:
>>> from ultralytics.models.yolo.obb import OBBTrainer
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
>>> trainer = OBBTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
"""
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
Args:
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
model configuration.
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
will take precedence over those in cfg.
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "obb"
super().__init__(cfg, overrides, _callbacks)
def get_model(
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
) -> OBBModel:
"""
Return OBBModel initialized with specified config and weights.
Args:
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
containing configuration parameters, or None to use default configuration.
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
verbose (bool): Whether to display model information during initialization.
Returns:
(OBBModel): Initialized OBBModel with the specified configuration and weights.
Examples:
>>> trainer = OBBTrainer()
>>> model = trainer.get_model(cfg="yolo11n-obb.yaml", weights="yolo11n-obb.pt")
"""
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of OBBValidator for validation of YOLO model."""
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.obb.OBBValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)

View File

@@ -0,0 +1,299 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
from ultralytics.utils.nms import TorchNMS
class OBBValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
satellite imagery where objects can appear at various orientations.
Attributes:
args (dict): Configuration arguments for the validator.
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
Methods:
init_metrics: Initialize evaluation metrics for YOLO.
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
_prepare_batch: Prepare batch data for OBB validation.
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
plot_predictions: Plot predicted bounding boxes on input images.
pred_to_json: Serialize YOLO predictions to COCO json format.
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
Examples:
>>> from ultralytics.models.yolo.obb import OBBValidator
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
>>> validator = OBBValidator(args=args)
>>> validator(model=args["model"])
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
It extends the DetectionValidator class and configures it specifically for the OBB task.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (str | Path, optional): Directory to save results.
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
_callbacks (list, optional): List of callback functions to be called during validation.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.args.task = "obb"
self.metrics = OBBMetrics()
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO obb validation.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
val = self.data.get(self.args.split, "") # validation path
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
"""
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
Args:
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
class labels and bounding boxes.
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
class labels and bounding boxes.
Returns:
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
of predictions compared to the ground truth.
Examples:
>>> detections = torch.rand(100, 7) # 100 sample detections
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
"""
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
Args:
preds (torch.Tensor): Raw predictions from the model.
Returns:
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
"""
preds = super().postprocess(preds)
for pred in preds:
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare batch data for OBB validation with proper scaling and formatting.
Args:
si (int): Batch index to process.
batch (dict[str, Any]): Dictionary containing batch data with keys:
- batch_idx: Tensor of batch indices
- cls: Tensor of class labels
- bboxes: Tensor of bounding boxes
- ori_shape: Original image shapes
- img: Batch of images
- ratio_pad: Ratio and padding information
Returns:
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
"""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if cls.shape[0]:
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
return {
"cls": cls,
"bboxes": bbox,
"ori_shape": ori_shape,
"imgsz": imgsz,
"ratio_pad": ratio_pad,
"im_file": batch["im_file"][si],
}
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = OBBValidator()
>>> batch = {"img": images, "im_file": paths}
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
>>> validator.plot_predictions(batch, preds, 0)
"""
for p in preds:
# TODO: fix this duplicated `xywh2xyxy`
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
super().plot_predictions(batch, preds, ni) # plot bboxes
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
This method processes rotated bounding box predictions and converts them to both rbox format
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
to the JSON dictionary.
"""
path = Path(pbatch["im_file"])
stem = path.stem
image_id = int(stem) if stem.isnumeric() else stem
rbox = predn["bboxes"]
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
self.jdict.append(
{
"image_id": image_id,
"file_name": path.name,
"category_id": self.class_map[int(c)],
"score": round(s, 5),
"rbox": [round(x, 3) for x in r],
"poly": [round(x, 3) for x in b],
}
)
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO OBB detections to a text file in normalized coordinates.
Args:
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
save_conf (bool): Whether to save confidence scores in the text file.
shape (tuple[int, int]): Original image shape in format (height, width).
file (Path): Output file path to save detections.
Examples:
>>> validator = OBBValidator()
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
"""
import numpy as np
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
).save_txt(file, save_conf=save_conf)
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**predn,
"bboxes": ops.scale_boxes(
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and save predictions in DOTA format.
Args:
stats (dict[str, Any]): Performance statistics dictionary.
Returns:
(dict[str, Any]): Updated performance statistics.
"""
if self.args.save_json and self.is_dota and len(self.jdict):
import json
import re
from collections import defaultdict
pred_json = self.save_dir / "predictions.json" # predictions
pred_txt = self.save_dir / "predictions_txt" # predictions
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
for d in data:
image_id = d["image_id"]
score = d["score"]
classname = self.names[d["category_id"] - 1].replace(" ", "-")
p = d["poly"]
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
for d in data:
image_id = d["image_id"].split("__", 1)[0]
pattern = re.compile(r"\d+___\d+")
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
bbox[0] += x
bbox[1] += y
bbox.extend([score, cls])
merged_results[image_id].append(bbox)
for image_id, bbox in merged_results.items():
bbox = torch.tensor(bbox)
max_wh = torch.max(bbox[:, :2]).item() * 2
c = bbox[:, 6:7] * max_wh # classes
scores = bbox[:, 5] # scores
b = bbox[:, :5].clone()
b[:, :2] += c
# 0.3 could get results close to the ones from official merging script, even slightly better.
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
bbox = bbox[i]
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
classname = self.names[int(x[-1])].replace(" ", "-")
p = [round(i, 3) for i in x[:-2]] # poly
score = round(x[-2], 3)
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
return stats

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"

View File

@@ -0,0 +1,80 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
capabilities inherited from DetectionPredictor.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
Methods:
construct_result: Construct the result object from the prediction, including keypoints.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize PosePredictor for pose estimation tasks.
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
warnings for Apple MPS.
Args:
cfg (Any): Configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "pose"
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction, including keypoints.
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
result object.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
orig_img (np.ndarray): The original unprocessed image as a numpy array.
img_path (str): The path to the original image file.
Returns:
(Results): The result object containing the original image, image path, class names, bounding boxes, and
keypoints.
"""
result = super().construct_result(pred, img, orig_img, img_path)
# Extract keypoints from prediction and reshape according to model's keypoint shape
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
# Scale keypoints coordinates to match the original image dimensions
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
result.update(keypoints=pred_kpts)
return result

View File

@@ -0,0 +1,115 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from typing import Any
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training YOLO pose estimation models.
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
of pose keypoints alongside bounding boxes.
Attributes:
args (dict): Configuration arguments for training.
model (PoseModel): The pose estimation model being trained.
data (dict): Dataset configuration including keypoint shape information.
loss_names (tuple): Names of the loss components used in training.
Methods:
get_model: Retrieve a pose estimation model with specified configuration.
set_model_attributes: Set keypoints shape attribute on the model.
get_validator: Create a validator instance for model evaluation.
plot_training_samples: Visualize training samples with keypoints.
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
Examples:
>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a PoseTrainer object for training YOLO pose estimation models.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
Notes:
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
A warning is issued when using Apple MPS device due to known bugs with pose models.
"""
if overrides is None:
overrides = {}
overrides["task"] = "pose"
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def get_model(
self,
cfg: str | Path | dict[str, Any] | None = None,
weights: str | Path | None = None,
verbose: bool = True,
) -> PoseModel:
"""
Get pose estimation model with specified configuration and weights.
Args:
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
weights (str | Path, optional): Path to the model weights file.
verbose (bool): Whether to display model information.
Returns:
(PoseModel): Initialized pose estimation model.
"""
model = PoseModel(
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
)
if weights:
model.load(weights)
return model
def set_model_attributes(self):
"""Set keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data["kpt_shape"]
def get_validator(self):
"""Return an instance of the PoseValidator class for validation."""
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
return yolo.pose.PoseValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def get_dataset(self) -> dict[str, Any]:
"""
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
Raises:
KeyError: If the `kpt_shape` key is not present in the dataset.
"""
data = super().get_dataset()
if "kpt_shape" not in data:
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
return data

View File

@@ -0,0 +1,267 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Attributes:
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
args (dict): Arguments for the validator including task set to "pose".
metrics (PoseMetrics): Metrics object for pose evaluation.
Methods:
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
get_desc: Return description of evaluation metrics in string format.
init_metrics: Initialize pose estimation metrics for YOLO model.
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
dimensions.
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
detections and ground truth.
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
pred_to_json: Convert YOLO predictions to COCO JSON format.
eval_json: Evaluate object detection model using COCO JSON format.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize a PoseValidator object for pose estimation validation.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path | str, optional): Directory to save results.
args (dict, optional): Arguments for the validator including task set to "pose".
_callbacks (list, optional): List of callback functions to be executed during validation.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
Notes:
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
due to a known bug with pose models.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = "pose"
self.metrics = PoseMetrics()
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
batch = super().preprocess(batch)
batch["keypoints"] = batch["keypoints"].float()
return batch
def get_desc(self) -> str:
"""Return description of evaluation metrics in string format."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Pose(P",
"R",
"mAP50",
"mAP50-95)",
)
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO pose validation.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
field of predictions and reshaping them according to the keypoint shape configuration.
The keypoints are reshaped from a flattened format to the proper dimensional structure
(typically [N, 17, 3] for COCO pose format).
Args:
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
bounding boxes, confidence scores, class predictions, and keypoint data.
Returns:
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
- 'bboxes': Bounding box coordinates
- 'conf': Confidence scores
- 'cls': Class predictions
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
Note:
If no keypoints are present in a prediction (empty keypoints), that prediction
is skipped and continues to the next one. The keypoints are extracted from the
'extra' field which contains additional task-specific data beyond basic detection.
"""
preds = super().postprocess(preds)
for pred in preds:
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
Args:
si (int): Batch index.
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
Returns:
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
Notes:
This method extends the parent class's _prepare_batch method by adding keypoint processing.
Keypoints are scaled from normalized coordinates to original image dimensions.
"""
pbatch = super()._prepare_batch(si, batch)
kpts = batch["keypoints"][batch["batch_idx"] == si]
h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
pbatch["keypoints"] = kpts
return pbatch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
and 'keypoints' for keypoint predictions.
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
Returns:
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
true positives across 10 IoU levels.
Notes:
`0.53` scale factor used in area computation is referenced from
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
"""
tp = super()._process_batch(preds, batch)
gt_cls = batch["cls"]
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
else:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
return tp
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO pose detections to a text file in normalized coordinates.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image (height, width).
file (Path): Output file path to save detections.
Notes:
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
normalized (x, y, visibility) values for each point.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
keypoints=predn["keypoints"],
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format.
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
and 'keypoints' tensors.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
before saving to the JSON dictionary.
"""
super().pred_to_json(predn, pbatch)
kpts = predn["kpts"]
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**super().scale_preds(predn, pbatch),
"kpts": ops.scale_coords(
pbatch["imgsz"],
predn["keypoints"].clone(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""Evaluate object detection model using COCO JSON format."""
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"

View File

@@ -0,0 +1,113 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
Attributes:
args (dict): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO segmentation model.
batch (list): Current batch of images being processed.
Methods:
postprocess: Apply non-max suppression and process segmentation detections.
construct_results: Construct a list of result objects from predictions.
construct_result: Construct a single result object from a prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
>>> predictor = SegmentationPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
Args:
cfg (dict): Configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
Apply non-max suppression and process segmentation detections for each image in the input batch.
Args:
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
Returns:
(list): List of Results objects containing the segmentation predictions for each image in the batch.
Each Results object includes both bounding boxes and segmentation masks.
Examples:
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
>>> results = predictor.postprocess(preds, img, orig_img)
"""
# Extract protos - tuple if PyTorch model or array if exported
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
def construct_results(self, preds, img, orig_imgs, protos):
"""
Construct a list of result objects from the predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
img (torch.Tensor): The image after preprocessing.
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
protos (list[torch.Tensor]): List of prototype masks.
Returns:
(list[Results]): List of result objects containing the original images, image paths, class names,
bounding boxes, and masks.
"""
return [
self.construct_result(pred, img, orig_img, img_path, proto)
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
]
def construct_result(self, pred, img, orig_img, img_path, proto):
"""
Construct a single result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
img (torch.Tensor): The image after preprocessing.
orig_img (np.ndarray): The original image before preprocessing.
img_path (str): The path to the original image.
proto (torch.Tensor): The prototype masks.
Returns:
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
"""
if pred.shape[0] == 0: # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
if masks is not None:
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
pred, masks = pred[keep], masks[keep]
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)

View File

@@ -0,0 +1,72 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.utils import DEFAULT_CFG, RANK
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
functionality including model initialization, validation, and visualization.
Attributes:
loss_names (tuple[str]): Names of the loss components used during training.
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
>>> trainer = SegmentationTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
"""
Initialize a SegmentationTrainer object.
Args:
cfg (dict): Configuration dictionary with default training settings.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
"""
Initialize and return a SegmentationModel with specified configuration and weights.
Args:
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
weights (str | Path, optional): Path to pretrained weights file.
verbose (bool): Whether to display model information during initialization.
Returns:
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
Examples:
>>> trainer = SegmentationTrainer()
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
"""
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
return yolo.segment.SegmentationValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)

View File

@@ -0,0 +1,259 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
to compute metrics such as mAP for both detection and segmentation tasks.
Attributes:
plot_masks (list): List to store masks for plotting.
process (callable): Function to process masks based on save_json and save_txt flags.
args (namespace): Arguments for the validator.
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
stats (dict): Dictionary to store statistics during validation.
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationValidator
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
>>> validator = SegmentationValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (Path, optional): Directory to save results.
args (namespace, optional): Arguments for the validator.
_callbacks (list, optional): List of callback functions.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.process = None
self.args.task = "segment"
self.metrics = SegmentMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO segmentation validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
Returns:
(dict[str, Any]): Preprocessed batch.
"""
batch = super().preprocess(batch)
batch["masks"] = batch["masks"].float()
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize metrics and select mask processing function based on save_json flag.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
if self.args.save_json:
check_requirements("faster-coco-eval>=1.6.7")
# More accurate vs faster
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
def get_desc(self) -> str:
"""Return a formatted description of evaluation metrics."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Mask(P",
"R",
"mAP50",
"mAP50-95)",
)
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
"""
Post-process YOLO predictions and return output detections with proto.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
Returns:
list[dict[str, torch.Tensor]]: Processed detection predictions with masks.
"""
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
preds = super().postprocess(preds[0])
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
for i, pred in enumerate(preds):
coefficient = pred.pop("extra")
pred["masks"] = (
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
if coefficient.shape[0]
else torch.zeros(
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
dtype=torch.uint8,
device=pred["bboxes"].device,
)
)
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for training or inference by processing images and targets.
Args:
si (int): Batch index.
batch (dict[str, Any]): Batch data containing images and annotations.
Returns:
(dict[str, Any]): Prepared batch with processed annotations.
"""
prepared_batch = super()._prepare_batch(si, batch)
nl = prepared_batch["cls"].shape[0]
if self.args.overlap_mask:
masks = batch["masks"][si]
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
masks = (masks == index).float()
else:
masks = batch["masks"][batch["batch_idx"] == si]
if nl:
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
if masks.shape[1:] != mask_size:
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
masks = masks.gt_(0.5)
prepared_batch["masks"] = masks
return prepared_batch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
batch (dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
Returns:
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
Notes:
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
Examples:
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
>>> correct_preds = validator._process_batch(preds, batch)
"""
tp = super()._process_batch(preds, batch)
gt_cls = batch["cls"]
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
else:
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
tp.update({"tp_m": tp_m}) # update tp with mask IoU
return tp
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
"""
Plot batch predictions with masks and bounding boxes.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
ni (int): Batch index.
"""
for p in preds:
masks = p["masks"]
if masks.shape[0] > self.args.max_det:
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image.
file (Path): File path to save the detections.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Save one JSON result for COCO evaluation.
Args:
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
"""
from faster_coco_eval.core.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
pred_masks = np.transpose(predn["masks"], (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
super().pred_to_json(predn, pbatch)
for i, r in enumerate(rles):
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**super().scale_preds(predn, pbatch),
"masks": ops.scale_image(
torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""Return COCO-style instance segmentation evaluation metrics."""
pred_json = self.save_dir / "predictions.json" # predictions
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "segm"], suffix=["Box", "Mask"])

View File

@@ -0,0 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .train import WorldTrainer
__all__ = ["WorldTrainer"]

View File

@@ -0,0 +1,179 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import itertools
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import build_yolo_dataset
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import WorldModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.torch_utils import unwrap_model
def on_pretrain_routine_end(trainer) -> None:
"""Set up model classes and text encoder at the end of the pretrain routine."""
if RANK in {-1, 0}:
# Set class names for evaluation
names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
class WorldTrainer(DetectionTrainer):
"""
A trainer class for fine-tuning YOLO World models on close-set datasets.
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
features for improved object detection and understanding. It handles text embedding generation and caching to
accelerate training with multi-modal data.
Attributes:
text_embeddings (dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
training.
model (WorldModel): The YOLO World model being trained.
data (dict[str, Any]): Dataset configuration containing class information.
args (Any): Training arguments and configuration.
Methods:
get_model: Return WorldModel initialized with specified config and weights.
build_dataset: Build YOLO Dataset for training or validation.
set_text_embeddings: Set text embeddings for datasets to accelerate training.
generate_text_embeddings: Generate text embeddings for a list of text samples.
preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
Examples:
Initialize and train a YOLO World model
>>> from ultralytics.models.yolo.world import WorldTrainer
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
>>> trainer = WorldTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a WorldTrainer object with given arguments.
Args:
cfg (dict[str, Any]): Configuration for the trainer.
overrides (dict[str, Any], optional): Configuration overrides.
_callbacks (list[Any], optional): List of callback functions.
"""
if overrides is None:
overrides = {}
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
super().__init__(cfg, overrides, _callbacks)
self.text_embeddings = None
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
"""
Return WorldModel initialized with specified config and weights.
Args:
cfg (dict[str, Any] | str, optional): Model configuration.
weights (str, optional): Path to pretrained weights.
verbose (bool): Whether to display model info.
Returns:
(WorldModel): Initialized WorldModel.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = WorldModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
return model
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`.
Returns:
(Any): YOLO dataset configured for training or validation.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
dataset = build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
if mode == "train":
self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
return dataset
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
"""
Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, then generates and caches text embeddings
for these categories to improve training efficiency.
Args:
datasets (list[Any]): List of datasets from which to extract category names.
batch (int | None): Batch size used for processing.
Notes:
This method collects category names from datasets that have the 'category_names' attribute,
then uses the first dataset's image path to determine where to cache the generated text embeddings.
"""
text_embeddings = {}
for dataset in datasets:
if not hasattr(dataset, "category_names"):
continue
text_embeddings.update(
self.generate_text_embeddings(
list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
)
)
self.text_embeddings = text_embeddings
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
"""
Generate text embeddings for a list of text samples.
Args:
texts (list[str]): List of text samples to encode.
batch (int): Batch size for processing.
cache_dir (Path): Directory to save/load cached embeddings.
Returns:
(dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
"""
model = "clip:ViT-B/32"
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
if cache_path.exists():
LOGGER.info(f"Reading existed cache from '{cache_path}'")
txt_map = torch.load(cache_path, map_location=self.device)
if sorted(txt_map.keys()) == sorted(texts):
return txt_map
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
assert self.model is not None
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
torch.save(txt_map, cache_path)
return txt_map
def preprocess_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess a batch of images and text for YOLOWorld training."""
batch = DetectionTrainer.preprocess_batch(self, batch)
# Add text features
texts = list(itertools.chain(*batch["texts"]))
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
self.device, non_blocking=self.device.type == "cuda"
)
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
return batch

View File

@@ -0,0 +1,201 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from pathlib import Path
from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.world import WorldTrainer
from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
from ultralytics.utils.torch_utils import unwrap_model
class WorldTrainerFromScratch(WorldTrainer):
"""
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
supporting training YOLO-World models with combined vision-language capabilities.
Attributes:
cfg (dict): Configuration dictionary with default parameters for model training.
overrides (dict): Dictionary of parameter overrides to customize the configuration.
_callbacks (list): List of callback functions to be executed during different stages of training.
data (dict): Final processed data configuration containing train/val paths and metadata.
training_data (dict): Dictionary mapping training dataset paths to their configurations.
Methods:
build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.
get_dataset: Get train and validation paths from data dictionary.
plot_training_labels: Skip label plotting for YOLO-World training.
final_eval: Perform final evaluation and validation for the YOLO-World model.
Examples:
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
>>> from ultralytics import YOLOWorld
>>> data = dict(
... train=dict(
... yolo_data=["Objects365.yaml"],
... grounding_data=[
... dict(
... img_path="flickr30k/images",
... json_file="flickr30k/final_flickr_separateGT_train.json",
... ),
... dict(
... img_path="GQA/images",
... json_file="GQA/final_mixed_train_no_coco.json",
... ),
... ],
... ),
... val=dict(yolo_data=["lvis.yaml"]),
... )
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize a WorldTrainerFromScratch object.
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
object detection and grounding datasets for vision-language capabilities.
Args:
cfg (dict): Configuration dictionary with default parameters for model training.
overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
_callbacks (list, optional): List of callback functions to be executed during different stages of training.
Examples:
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
>>> from ultralytics import YOLOWorld
>>> data = dict(
... train=dict(
... yolo_data=["Objects365.yaml"],
... grounding_data=[
... dict(
... img_path="flickr30k/images",
... json_file="flickr30k/final_flickr_separateGT_train.json",
... ),
... ],
... ),
... val=dict(yolo_data=["lvis.yaml"]),
... )
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
"""
if overrides is None:
overrides = {}
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path, mode="train", batch=None):
"""
Build YOLO Dataset for training or validation.
This method constructs appropriate datasets based on the mode and input paths, handling both
standard YOLO datasets and grounding datasets with different formats.
Args:
img_path (list[str] | str): Path to the folder containing images or list of paths.
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
batch (int, optional): Size of batches, used for rectangular training/validation.
Returns:
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
if mode != "train":
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
datasets = [
build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)
if isinstance(im_path, str)
else build_grounding(
# assign `nc` from validation set to max number of text samples for training consistency
self.args,
im_path["img_path"],
im_path["json_file"],
batch,
stride=gs,
max_samples=self.data["nc"],
)
for im_path in img_path
]
self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
def get_dataset(self):
"""
Get train and validation paths from data dictionary.
Processes the data configuration to extract paths for training and validation datasets,
handling both YOLO detection datasets and grounding datasets.
Returns:
train_path (str): Train dataset path.
val_path (str): Validation dataset path.
Raises:
AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
"""
final_data = {}
data_yaml = self.args.data
assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
for d in data["val"]:
if d.get("minival") is None: # for lvis dataset
continue
d["minival"] = str(d["path"] / d["minival"])
for s in {"train", "val"}:
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
# save grounding data if there's one
grounding_data = data_yaml[s].get("grounding_data")
if grounding_data is None:
continue
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
for g in grounding_data:
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
for k in {"img_path", "json_file"}:
path = Path(g[k])
if not path.exists() and not path.is_absolute():
g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR
final_data[s] += grounding_data
# assign the first val dataset as currently only one validation set is supported
data["val"] = data["val"][0]
final_data["val"] = final_data["val"][0]
# NOTE: to make training work properly, set `nc` and `names`
final_data["nc"] = data["val"]["nc"]
final_data["names"] = data["val"]["names"]
# NOTE: add path with lvis path
final_data["path"] = data["val"]["path"]
final_data["channels"] = data["val"]["channels"]
self.data = final_data
if self.args.single_cls: # consistent with base trainer
LOGGER.info("Overriding class names with single class.")
self.data["names"] = {0: "object"}
self.data["nc"] = 1
self.training_data = {}
for d in data["train"]:
if self.args.single_cls:
d["names"] = {0: "object"}
d["nc"] = 1
self.training_data[d["train"]] = d
return final_data
def plot_training_labels(self):
"""Skip label plotting for YOLO-World training."""
pass
def final_eval(self):
"""
Perform final evaluation and validation for the YOLO-World model.
Configures the validator with appropriate dataset and split information before running evaluation.
Returns:
(dict): Dictionary containing evaluation metrics and results.
"""
val = self.args.data["val"]["yolo_data"][0]
self.validator.args.data = val
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
return super().final_eval()

View File

@@ -0,0 +1,22 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
from .val import YOLOEDetectValidator, YOLOESegValidator
__all__ = [
"YOLOETrainer",
"YOLOEPETrainer",
"YOLOESegTrainer",
"YOLOEDetectValidator",
"YOLOESegValidator",
"YOLOEPESegTrainer",
"YOLOESegTrainerFromScratch",
"YOLOESegVPTrainer",
"YOLOEVPTrainer",
"YOLOEPEFreeTrainer",
"YOLOEVPDetectPredictor",
"YOLOEVPSegPredictor",
"YOLOETrainerFromScratch",
]

View File

@@ -0,0 +1,169 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import numpy as np
import torch
from ultralytics.data.augment import LoadVisualPrompt
from ultralytics.models.yolo.detect import DetectionPredictor
from ultralytics.models.yolo.segment import SegmentationPredictor
class YOLOEVPDetectPredictor(DetectionPredictor):
"""
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
This mixin provides common functionality for YOLO models that use visual prompting, including
model setup, prompt handling, and preprocessing transformations.
Attributes:
model (torch.nn.Module): The YOLO model for inference.
device (torch.device): Device to run the model on (CPU or CUDA).
prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
Methods:
setup_model: Initialize the YOLO model and set it to evaluation mode.
set_prompts: Set the visual prompts for the model.
pre_transform: Preprocess images and prompts before inference.
inference: Run inference with visual prompts.
get_vpe: Process source to get visual prompt embeddings.
"""
def setup_model(self, model, verbose: bool = True):
"""
Set up the model for prediction.
Args:
model (torch.nn.Module): Model to load or use.
verbose (bool, optional): If True, provides detailed logging.
"""
super().setup_model(model, verbose=verbose)
self.done_warmup = True
def set_prompts(self, prompts):
"""
Set the visual prompts for the model.
Args:
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
Must include a 'cls' key with class indices.
"""
self.prompts = prompts
def pre_transform(self, im):
"""
Preprocess images and prompts before inference.
This method applies letterboxing to the input image and transforms the visual prompts
(bounding boxes or masks) accordingly.
Args:
im (list): List containing a single input image.
Returns:
(list): Preprocessed image ready for model inference.
Raises:
ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
"""
img = super().pre_transform(im)
bboxes = self.prompts.pop("bboxes", None)
masks = self.prompts.pop("masks", None)
category = self.prompts["cls"]
if len(img) == 1:
visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
else:
# NOTE: only supports bboxes as prompts for now
assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
# NOTE: needs list[np.ndarray]
assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
f"Expected list[np.ndarray], but got {bboxes}!"
)
assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
f"Expected list[np.ndarray], but got {category}!"
)
assert len(im) == len(category) == len(bboxes), (
f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
)
visuals = [
self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
for i in range(len(img))
]
prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
self.prompts = prompts.half() if self.model.fp16 else prompts.float()
return img
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
"""
Process a single image by resizing bounding boxes or masks and generating visuals.
Args:
dst_shape (tuple): The target shape (height, width) of the image.
src_shape (tuple): The original shape (height, width) of the image.
category (str): The category of the image for visual prompts.
bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
masks (np.ndarray, optional): A list of masks corresponding to the image.
Returns:
(torch.Tensor): The processed visuals for the image.
Raises:
ValueError: If neither `bboxes` nor `masks` are provided.
"""
if bboxes is not None and len(bboxes):
bboxes = np.array(bboxes, dtype=np.float32)
if bboxes.ndim == 1:
bboxes = bboxes[None, :]
# Calculate scaling factor and adjust bounding boxes
gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
bboxes *= gain
bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
elif masks is not None:
# Resize and process masks
resized_masks = super().pre_transform(masks)
masks = np.stack(resized_masks) # (N, H, W)
masks[masks == 114] = 0 # Reset padding values to 0
else:
raise ValueError("Please provide valid bboxes or masks")
# Generate visuals using the visual prompt loader
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
def inference(self, im, *args, **kwargs):
"""
Run inference with visual prompts.
Args:
im (torch.Tensor): Input image tensor.
*args (Any): Variable length argument list.
**kwargs (Any): Arbitrary keyword arguments.
Returns:
(torch.Tensor): Model prediction results.
"""
return super().inference(im, vpe=self.prompts, *args, **kwargs)
def get_vpe(self, source):
"""
Process the source to get the visual prompt embeddings (VPE).
Args:
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
of the image to make predictions on. Accepts various types including file paths, URLs, PIL
images, numpy arrays, and torch tensors.
Returns:
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
"""
self.setup_source(source)
assert len(self.dataset) == 1, "get_vpe only supports one image!"
for _, im0s, _ in self.dataset:
im = self.preprocess(im0s)
return self.model(im, vpe=self.prompts, return_vpe=True)
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
"""Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
pass

View File

@@ -0,0 +1,300 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy, deepcopy
from pathlib import Path
import torch
from ultralytics.data import YOLOConcatDataset, build_yolo_dataset
from ultralytics.data.augment import LoadVisualPrompt
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
from ultralytics.nn.tasks import YOLOEModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.torch_utils import unwrap_model
from ..world.train_world import WorldTrainerFromScratch
from .val import YOLOEDetectValidator
class YOLOETrainer(DetectionTrainer):
"""
A trainer class for YOLOE object detection models.
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
including custom model initialization, validation, and dataset building with multi-modal support.
Attributes:
loss_names (tuple): Names of loss components used during training.
Methods:
get_model: Initialize and return a YOLOEModel with specified configuration.
get_validator: Return a YOLOEDetectValidator for model validation.
build_dataset: Build YOLO dataset with multi-modal support for training.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
"""
Initialize the YOLOE Trainer with specified configurations.
Args:
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be applied during training.
"""
if overrides is None:
overrides = {}
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
overrides["overlap_mask"] = False
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a YOLOEModel initialized with the specified configuration and weights.
Args:
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
a direct path to a YAML file, or None to use default configuration.
weights (str | Path, optional): Path to pretrained weights file to load into the model.
verbose (bool): Whether to display model information during initialization.
Returns:
(YOLOEModel): The initialized YOLOE model.
Notes:
- The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
- The nc parameter here represents the maximum number of different text samples in one image,
rather than the actual number of classes.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = YOLOEModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return a YOLOEDetectValidator for YOLOE model validation."""
self.loss_names = "box", "cls", "dfl"
return YOLOEDetectValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for rectangular training.
Returns:
(Dataset): YOLO dataset configured for training or validation.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
class YOLOEPETrainer(DetectionTrainer):
"""
Fine-tune YOLOE model using linear probing approach.
This trainer freezes most model layers and only trains specific projection layers for efficient
fine-tuning on new datasets while preserving pretrained features.
Methods:
get_model: Initialize YOLOEModel with frozen layers except projection layers.
"""
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return YOLOEModel initialized with specified config and weights.
Args:
cfg (dict | str, optional): Model configuration.
weights (str, optional): Path to pretrained weights.
verbose (bool): Whether to display model information.
Returns:
(YOLOEModel): Initialized model with frozen layers except for specific projection layers.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = YOLOEModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=self.data["nc"],
verbose=verbose and RANK == -1,
)
del model.model[-1].savpe
assert weights is not None, "Pretrained weights must be provided for linear probing."
if weights:
model.load(weights)
model.eval()
names = list(self.data["names"].values())
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
# it'd get correct results as long as loading proper pretrained weights.
tpe = model.get_text_pe(names)
model.set_classes(names, tpe)
model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
del model.pe
model.train()
return model
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
"""
Train YOLOE models from scratch with text embedding support.
This trainer combines YOLOE training capabilities with world training features, enabling
training from scratch with text embeddings and grounding datasets.
Methods:
build_dataset: Build datasets for training with grounding support.
generate_text_embeddings: Generate and cache text embeddings for training.
"""
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
This method constructs appropriate datasets based on the mode and input paths, handling both
standard YOLO datasets and grounding datasets with different formats.
Args:
img_path (list[str] | str): Path to the folder containing images or list of paths.
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
batch (int, optional): Size of batches, used for rectangular training/validation.
Returns:
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
"""
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
"""
Generate text embeddings for a list of text samples.
Args:
texts (list[str]): List of text samples to encode.
batch (int): Batch size for processing.
cache_dir (Path): Directory to save/load cached embeddings.
Returns:
(dict): Dictionary mapping text samples to their embeddings.
"""
model = "mobileclip:blt"
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
if cache_path.exists():
LOGGER.info(f"Reading existed cache from '{cache_path}'")
txt_map = torch.load(cache_path, map_location=self.device)
if sorted(txt_map.keys()) == sorted(texts):
return txt_map
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
assert self.model is not None
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
torch.save(txt_map, cache_path)
return txt_map
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
"""
Train prompt-free YOLOE model.
This trainer combines linear probing capabilities with from-scratch training for prompt-free
YOLOE models that don't require text prompts during inference.
Methods:
get_validator: Return standard DetectionValidator for validation.
preprocess_batch: Preprocess batches without text features.
set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
"""
def get_validator(self):
"""Return a DetectionValidator for YOLO model validation."""
self.loss_names = "box", "cls", "dfl"
return DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def preprocess_batch(self, batch):
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
return DetectionTrainer.preprocess_batch(self, batch)
def set_text_embeddings(self, datasets, batch: int):
"""
Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, generates text embeddings for them,
and caches these embeddings to improve training efficiency. The embeddings are stored in a file
in the parent directory of the first dataset's image path.
Args:
datasets (list[Dataset]): List of datasets containing category names to process.
batch (int): Batch size for processing text embeddings.
Notes:
The method creates a dictionary mapping text samples to their embeddings and stores it
at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
instead of regenerating the embeddings.
"""
pass
class YOLOEVPTrainer(YOLOETrainerFromScratch):
"""
Train YOLOE model with visual prompts.
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
where visual cues are provided alongside images to guide the detection process.
Methods:
build_dataset: Build dataset with visual prompt loading transforms.
"""
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation with visual prompts.
Args:
img_path (list[str] | str): Path to the folder containing images or list of paths.
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
batch (int, optional): Size of batches, used for rectangular training/validation.
Returns:
(Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
"""
dataset = super().build_dataset(img_path, mode, batch)
if isinstance(dataset, YOLOConcatDataset):
for d in dataset.datasets:
d.transforms.append(LoadVisualPrompt())
else:
dataset.transforms.append(LoadVisualPrompt())
return dataset
def _close_dataloader_mosaic(self):
"""Close mosaic augmentation and add visual prompt loading to the training dataset."""
super()._close_dataloader_mosaic()
if isinstance(self.train_loader.dataset, YOLOConcatDataset):
for d in self.train_loader.dataset.datasets:
d.transforms.append(LoadVisualPrompt())
else:
self.train_loader.dataset.transforms.append(LoadVisualPrompt())

View File

@@ -0,0 +1,127 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from copy import copy, deepcopy
from ultralytics.models.yolo.segment import SegmentationTrainer
from ultralytics.nn.tasks import YOLOESegModel
from ultralytics.utils import RANK
from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
from .val import YOLOESegValidator
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
"""
Trainer class for YOLOE segmentation models.
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
segmentation models, enabling both object detection and instance segmentation capabilities.
Attributes:
cfg (dict): Configuration dictionary with training parameters.
overrides (dict): Dictionary with parameter overrides.
_callbacks (list): List of callback functions for training events.
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
Return YOLOESegModel initialized with specified config and weights.
Args:
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
weights (str, optional): Path to pretrained weights file.
verbose (bool): Whether to display model information.
Returns:
(YOLOESegModel): Initialized YOLOE segmentation model.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = YOLOESegModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
return model
def get_validator(self):
"""
Create and return a validator for YOLOE segmentation model evaluation.
Returns:
(YOLOESegValidator): Validator for YOLOE segmentation models.
"""
self.loss_names = "box", "seg", "cls", "dfl"
return YOLOESegValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
class YOLOEPESegTrainer(SegmentationTrainer):
"""
Fine-tune YOLOESeg model in linear probing way.
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
most of the model and only training specific layers for efficient adaptation to new tasks.
Attributes:
data (dict): Dataset configuration containing channels, class names, and number of classes.
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
Return YOLOESegModel initialized with specified config and weights for linear probing.
Args:
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
weights (str, optional): Path to pretrained weights file.
verbose (bool): Whether to display model information.
Returns:
(YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = YOLOESegModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=self.data["nc"],
verbose=verbose and RANK == -1,
)
del model.model[-1].savpe
assert weights is not None, "Pretrained weights must be provided for linear probing."
if weights:
model.load(weights)
model.eval()
names = list(self.data["names"].values())
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
# it'd get correct results as long as loading proper pretrained weights.
tpe = model.get_text_pe(names)
model.set_classes(names, tpe)
model.model[-1].fuse(model.pe)
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
del model.pe
model.train()
return model
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
"""Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
pass
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
"""Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
pass

View File

@@ -0,0 +1,211 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import deepcopy
from pathlib import Path
from typing import Any
import torch
from torch.nn import functional as F
from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
from ultralytics.data.augment import LoadVisualPrompt
from ultralytics.data.utils import check_det_dataset
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.models.yolo.segment import SegmentationValidator
from ultralytics.nn.modules.head import YOLOEDetect
from ultralytics.nn.tasks import YOLOEModel
from ultralytics.utils import LOGGER, TQDM
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
class YOLOEDetectValidator(DetectionValidator):
"""
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
enabling flexible evaluation strategies for prompt-based object detection.
Attributes:
device (torch.device): The device on which validation is performed.
args (namespace): Configuration arguments for validation.
dataloader (DataLoader): DataLoader for validation data.
Methods:
get_visual_pe: Extract visual prompt embeddings from training samples.
preprocess: Preprocess batch data ensuring visuals are on the same device as images.
get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
__call__: Run validation using either text or visual prompt embeddings.
Examples:
Validate with text prompts
>>> validator = YOLOEDetectValidator()
>>> stats = validator(model=model, load_vp=False)
Validate with visual prompts
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
"""
@smart_inference_mode()
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
"""
Extract visual prompt embeddings from training samples.
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
It normalizes the embeddings and handles cases where no samples exist for a class by setting their
embeddings to zero.
Args:
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
Returns:
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
"""
assert isinstance(model, YOLOEModel)
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
cls_visual_num = torch.zeros(len(names))
desc = "Get visual prompt embeddings from samples"
# Count samples per class
for batch in dataloader:
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
count = torch.bincount(cls, minlength=len(names))
cls_visual_num += count
cls_visual_num = cls_visual_num.to(self.device)
# Extract visual prompt embeddings
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
for batch in pbar:
batch = self.preprocess(batch)
preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
batch_idx = batch["batch_idx"]
for i in range(preds.shape[0]):
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
pad_cls[: cls.shape[0]] = cls
for c in cls:
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
# Normalize embeddings for classes with samples, set others to zero
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
visual_pe[cls_visual_num == 0] = 0
return visual_pe.unsqueeze(0)
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
"""
Create a dataloader for LVIS training visual prompt samples.
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
for validation purposes.
Args:
data (dict): Dataset configuration dictionary containing paths and settings.
Returns:
(torch.utils.data.DataLoader): The dataloader for visual prompt samples.
"""
dataset = build_yolo_dataset(
self.args,
data.get(self.args.split, data.get("val")),
self.args.batch,
data,
mode="val",
rect=False,
)
if isinstance(dataset, YOLOConcatDataset):
for d in dataset.datasets:
d.transforms.append(LoadVisualPrompt())
else:
dataset.transforms.append(LoadVisualPrompt())
return build_dataloader(
dataset,
self.args.batch,
self.args.workers,
shuffle=False,
rank=-1,
)
@smart_inference_mode()
def __call__(
self,
trainer: Any | None = None,
model: YOLOEModel | str | None = None,
refer_data: str | None = None,
load_vp: bool = False,
) -> dict[str, Any]:
"""
Run validation on the model using either text or visual prompt embeddings.
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
It supports validation during training (using a trainer object) or standalone validation with a provided
model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
Args:
trainer (object, optional): Trainer object containing the model and device.
model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
refer_data (str, optional): Path to reference data for visual prompts.
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
Returns:
(dict): Validation statistics containing metrics computed during validation.
"""
if trainer is not None:
self.device = trainer.device
model = trainer.ema.ema
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
if load_vp:
LOGGER.info("Validate using the visual prompt.")
self.args.half = False
# Directly use the same dataloader for visual embeddings extracted during training
vpe = self.get_visual_pe(self.dataloader, model)
model.set_classes(names, vpe)
else:
LOGGER.info("Validate using the text prompt.")
tpe = model.get_text_pe(names)
model.set_classes(names, tpe)
stats = super().__call__(trainer, model)
else:
if refer_data is not None:
assert load_vp, "Refer data is only used for visual prompt validation."
self.device = select_device(self.args.device, verbose=False)
if isinstance(model, (str, Path)):
from ultralytics.nn.tasks import load_checkpoint
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
model.eval().to(self.device)
data = check_det_dataset(refer_data or self.args.data)
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
if load_vp:
LOGGER.info("Validate using the visual prompt.")
self.args.half = False
# TODO: need to check if the names from refer data is consistent with the evaluated dataset
# could use same dataset or refer to extract visual prompt embeddings
dataloader = self.get_vpe_dataloader(data)
vpe = self.get_visual_pe(dataloader, model)
model.set_classes(names, vpe)
stats = super().__call__(model=deepcopy(model))
elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
return super().__call__(trainer, model)
else:
LOGGER.info("Validate using the text prompt.")
tpe = model.get_text_pe(names)
model.set_classes(names, tpe)
stats = super().__call__(model=deepcopy(model))
return stats
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
"""YOLOE segmentation validator that supports both text and visual prompt embeddings."""
pass