init commit
This commit is contained in:
7
ultralytics/models/yolo/__init__.py
Normal file
7
ultralytics/models/yolo/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
||||
|
||||
from .model import YOLO, YOLOE, YOLOWorld
|
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
|
||||
BIN
ultralytics/models/yolo/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
7
ultralytics/models/yolo/classify/__init__.py
Normal file
7
ultralytics/models/yolo/classify/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||
|
||||
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
93
ultralytics/models/yolo/classify/predict.py
Normal file
93
ultralytics/models/yolo/classify/predict.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.augment import classify_transforms
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a classification model.
|
||||
|
||||
This predictor handles the specific requirements of classification models, including preprocessing images
|
||||
and postprocessing predictions to generate classification results.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the predictor.
|
||||
|
||||
Methods:
|
||||
preprocess: Convert input images to model-compatible format.
|
||||
postprocess: Process model predictions into Results objects.
|
||||
|
||||
Notes:
|
||||
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
||||
>>> predictor = ClassificationPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
||||
|
||||
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
||||
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
||||
|
||||
Args:
|
||||
cfg (dict): Default configuration dictionary containing prediction settings.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "classify"
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Set up source and inference mode and classify transforms."""
|
||||
super().setup_source(source)
|
||||
updated = (
|
||||
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
||||
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
|
||||
else False
|
||||
)
|
||||
self.transforms = (
|
||||
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
|
||||
)
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack(
|
||||
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||
)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Process predictions to return Results objects with classification probabilities.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Input images after preprocessing.
|
||||
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing classification results for each image.
|
||||
"""
|
||||
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
return [
|
||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
223
ultralytics/models/yolo/classify/train.py
Normal file
223
ultralytics/models/yolo/classify/train.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
A trainer class extending BaseTrainer for training image classification models.
|
||||
|
||||
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
||||
and torchvision models with comprehensive dataset handling and validation.
|
||||
|
||||
Attributes:
|
||||
model (ClassificationModel): The classification model to be trained.
|
||||
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
||||
loss_names (list[str]): Names of the loss functions used during training.
|
||||
validator (ClassificationValidator): Validator instance for model evaluation.
|
||||
|
||||
Methods:
|
||||
set_model_attributes: Set the model's class names from the loaded dataset.
|
||||
get_model: Return a modified PyTorch model configured for training.
|
||||
setup_model: Load, create or download model for classification.
|
||||
build_dataset: Create a ClassificationDataset instance.
|
||||
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
||||
preprocess_batch: Preprocess a batch of images and classes.
|
||||
progress_string: Return a formatted string showing training progress.
|
||||
get_validator: Return an instance of ClassificationValidator.
|
||||
label_loss_items: Return a loss dict with labelled training loss items.
|
||||
final_eval: Evaluate trained model and save validation results.
|
||||
plot_training_samples: Plot training samples with their annotations.
|
||||
|
||||
Examples:
|
||||
Initialize and train a classification model
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
||||
>>> trainer = ClassificationTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a ClassificationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list[Any], optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
if overrides.get("imgsz") is None:
|
||||
overrides["imgsz"] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a modified PyTorch model configured for training YOLO classification.
|
||||
|
||||
Args:
|
||||
cfg (Any, optional): Model configuration.
|
||||
weights (Any, optional): Pre-trained model weights.
|
||||
verbose (bool, optional): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(ClassificationModel): Configured PyTorch model for classification.
|
||||
"""
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create or download model for classification tasks.
|
||||
|
||||
Returns:
|
||||
(Any): Model checkpoint if applicable, otherwise None.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if str(self.model) in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[self.model](
|
||||
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
||||
)
|
||||
ckpt = None
|
||||
else:
|
||||
ckpt = super().setup_model()
|
||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
||||
"""
|
||||
Create a ClassificationDataset instance given an image path and mode.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the dataset images.
|
||||
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
||||
batch (Any, optional): Batch information (unused in this implementation).
|
||||
|
||||
Returns:
|
||||
(ClassificationDataset): Dataset for the specified mode.
|
||||
"""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Return PyTorch DataLoader with transforms to preprocess images.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int, optional): Number of images per batch.
|
||||
rank (int, optional): Process rank for distributed training.
|
||||
mode (str, optional): 'train', 'val', or 'test' mode.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
||||
"""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
|
||||
# Attach inference transforms
|
||||
if mode != "train":
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Preprocess a batch of images and classes."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def progress_string(self) -> str:
|
||||
"""Return a formatted string showing training progress."""
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ["loss"]
|
||||
return yolo.classify.ClassificationValidator(
|
||||
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labelled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (torch.Tensor, optional): Loss tensor items.
|
||||
prefix (str, optional): Prefix to prepend to loss names.
|
||||
|
||||
Returns:
|
||||
keys (list[str]): List of loss keys if loss_items is None.
|
||||
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
|
||||
"""
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.data = self.args.data
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
||||
ni (int): Number of iterations.
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
214
ultralytics/models/yolo/classify/val.py
Normal file
214
ultralytics/models/yolo/classify/val.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a classification model.
|
||||
|
||||
This validator handles the validation process for classification models, including metrics calculation,
|
||||
confusion matrix generation, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
targets (list[torch.Tensor]): Ground truth class labels.
|
||||
pred (list[torch.Tensor]): Model predictions.
|
||||
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
||||
names (dict): Mapping of class indices to class names.
|
||||
nc (int): Number of classes.
|
||||
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
||||
|
||||
Methods:
|
||||
get_desc: Return a formatted string summarizing classification metrics.
|
||||
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
||||
preprocess: Preprocess input batch by moving data to device.
|
||||
update_metrics: Update running metrics with model predictions and batch targets.
|
||||
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
||||
postprocess: Extract the primary prediction from model output.
|
||||
get_stats: Calculate and return a dictionary of metrics.
|
||||
build_dataset: Create a ClassificationDataset instance for validation.
|
||||
get_dataloader: Build and return a data loader for classification validation.
|
||||
print_results: Print evaluation metrics for the classification model.
|
||||
plot_val_samples: Plot validation image samples with their ground truth labels.
|
||||
plot_predictions: Plot images with their predicted class labels.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict, optional): Arguments containing model and validation configuration.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = "classify"
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing classification metrics."""
|
||||
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update running metrics with model predictions and batch targets.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
||||
batch (dict): Batch data containing images and class labels.
|
||||
|
||||
Notes:
|
||||
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
||||
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
||||
"""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""
|
||||
Finalize metrics including confusion matrix and processing speed.
|
||||
|
||||
Notes:
|
||||
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
||||
optionally plots it, and updates the metrics object with speed information.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
||||
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
||||
>>> validator.finalize_metrics()
|
||||
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
||||
"""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.save_dir = self.save_dir
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
||||
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
|
||||
def get_stats(self) -> dict[str, float]:
|
||||
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
||||
"""Create a ClassificationDataset instance for validation."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Build and return a data loader for classification validation.
|
||||
|
||||
Args:
|
||||
dataset_path (str | Path): Path to the dataset directory.
|
||||
batch_size (int): Number of samples per batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path)
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print evaluation metrics for the classification model."""
|
||||
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples with their ground truth labels.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
||||
>>> validator.plot_val_samples(batch, 0)
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
||||
"""
|
||||
Plot images with their predicted class labels and save the visualization.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images and other information.
|
||||
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
||||
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
batched_preds = dict(
|
||||
img=batch["img"],
|
||||
batch_idx=torch.arange(batch["img"].shape[0]),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
)
|
||||
plot_images(
|
||||
batched_preds,
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
7
ultralytics/models/yolo/detect/__init__.py
Normal file
7
ultralytics/models/yolo/detect/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import DetectionPredictor
|
||||
from .train import DetectionTrainer
|
||||
from .val import DetectionValidator
|
||||
|
||||
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
|
||||
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/detect/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/detect/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/detect/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/detect/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
125
ultralytics/models/yolo/detect/predict.py
Normal file
125
ultralytics/models/yolo/detect/predict.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import nms, ops
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a detection model.
|
||||
|
||||
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
||||
with bounding boxes and class predictions.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (nn.Module): The detection model used for inference.
|
||||
batch (list): Batch of images and metadata for processing.
|
||||
|
||||
Methods:
|
||||
postprocess: Process raw model predictions into detection results.
|
||||
construct_results: Build Results objects from processed predictions.
|
||||
construct_result: Create a single Result object from a prediction.
|
||||
get_obj_feats: Extract object features from the feature maps.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.detect import DetectionPredictor
|
||||
>>> args = dict(model="yolo11n.pt", source=ASSETS)
|
||||
>>> predictor = DetectionPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
||||
"""
|
||||
Post-process predictions and return a list of Results objects.
|
||||
|
||||
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
||||
further analysis.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Processed input image tensor in model input format.
|
||||
orig_imgs (torch.Tensor | list): Original input images before preprocessing.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the post-processed predictions.
|
||||
|
||||
Examples:
|
||||
>>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
|
||||
>>> results = predictor.predict("path/to/image.jpg")
|
||||
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
||||
"""
|
||||
save_feats = getattr(self, "_feats", None) is not None
|
||||
preds = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
self.args.classes,
|
||||
self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=0 if self.args.task == "detect" else len(self.model.names),
|
||||
end2end=getattr(self.model, "end2end", False),
|
||||
rotated=self.args.task == "obb",
|
||||
return_idxs=save_feats,
|
||||
)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
if save_feats:
|
||||
obj_feats = self.get_obj_feats(self._feats, preds[1])
|
||||
preds = preds[0]
|
||||
|
||||
results = self.construct_results(preds, img, orig_imgs, **kwargs)
|
||||
|
||||
if save_feats:
|
||||
for r, f in zip(results, obj_feats):
|
||||
r.feats = f # add object features to results
|
||||
|
||||
return results
|
||||
|
||||
def get_obj_feats(self, feat_maps, idxs):
|
||||
"""Extract object features from the feature maps."""
|
||||
import torch
|
||||
|
||||
s = min(x.shape[1] for x in feat_maps) # find shortest vector length
|
||||
obj_feats = torch.cat(
|
||||
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
|
||||
) # mean reduce all vectors to same length
|
||||
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Construct a list of Results objects from model predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
||||
img (torch.Tensor): Batch of preprocessed images used for inference.
|
||||
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing detection information for each image.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct a single Results object from one image prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
||||
img (torch.Tensor): Preprocessed image tensor used for inference.
|
||||
orig_img (np.ndarray): Original image before preprocessing.
|
||||
img_path (str): Path to the original image file.
|
||||
|
||||
Returns:
|
||||
(Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
|
||||
"""
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
|
||||
236
ultralytics/models/yolo/detect/train.py
Normal file
236
ultralytics/models/yolo/detect/train.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.patches import override_configs
|
||||
from ultralytics.utils.plotting import plot_images, plot_labels
|
||||
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
|
||||
|
||||
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
"""
|
||||
A class extending the BaseTrainer class for training based on a detection model.
|
||||
|
||||
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
||||
for object detection including dataset building, data loading, preprocessing, and model configuration.
|
||||
|
||||
Attributes:
|
||||
model (DetectionModel): The YOLO detection model being trained.
|
||||
data (dict): Dictionary containing dataset information including class names and number of classes.
|
||||
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
||||
|
||||
Methods:
|
||||
build_dataset: Build YOLO dataset for training or validation.
|
||||
get_dataloader: Construct and return dataloader for the specified mode.
|
||||
preprocess_batch: Preprocess a batch of images by scaling and converting to float.
|
||||
set_model_attributes: Set model attributes based on dataset information.
|
||||
get_model: Return a YOLO detection model.
|
||||
get_validator: Return a validator for model evaluation.
|
||||
label_loss_items: Return a loss dictionary with labeled training loss items.
|
||||
progress_string: Return a formatted string of training progress.
|
||||
plot_training_samples: Plot training samples with their annotations.
|
||||
plot_training_labels: Create a labeled training plot of the YOLO model.
|
||||
auto_batch: Calculate optimal batch size based on model memory requirements.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
|
||||
>>> trainer = DetectionTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for 'rect' mode.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset object configured for the specified mode.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Construct and return dataloader for the specified mode.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int): Number of images per batch.
|
||||
rank (int): Process rank for distributed training.
|
||||
mode (str): 'train' for training dataloader, 'val' for validation dataloader.
|
||||
|
||||
Returns:
|
||||
(DataLoader): PyTorch dataloader object.
|
||||
"""
|
||||
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||
shuffle = mode == "train"
|
||||
if getattr(dataset, "rect", False) and shuffle:
|
||||
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
return build_dataloader(
|
||||
dataset,
|
||||
batch=batch_size,
|
||||
workers=self.args.workers if mode == "train" else self.args.workers * 2,
|
||||
shuffle=shuffle,
|
||||
rank=rank,
|
||||
drop_last=self.args.compile and mode == "train",
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch: dict) -> dict:
|
||||
"""
|
||||
Preprocess a batch of images by scaling and converting to float.
|
||||
|
||||
Args:
|
||||
batch (dict): Dictionary containing batch data with 'img' tensor.
|
||||
|
||||
Returns:
|
||||
(dict): Preprocessed batch with normalized images.
|
||||
"""
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].float() / 255
|
||||
if self.args.multi_scale:
|
||||
imgs = batch["img"]
|
||||
sz = (
|
||||
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
||||
// self.stride
|
||||
* self.stride
|
||||
) # size
|
||||
sf = sz / max(imgs.shape[2:]) # scale factor
|
||||
if sf != 1:
|
||||
ns = [
|
||||
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
|
||||
] # new shape (stretched to gs-multiple)
|
||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
batch["img"] = imgs
|
||||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set model attributes based on dataset information."""
|
||||
# Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
# self.args.box *= 3 / nl # scale to layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||
self.model.names = self.data["names"] # attach class names to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
|
||||
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLO detection model.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to model configuration file.
|
||||
weights (str, optional): Path to model weights.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(DetectionModel): YOLO detection model.
|
||||
"""
|
||||
model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.detect.DetectionValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labeled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (list[float], optional): List of loss values.
|
||||
prefix (str): Prefix for keys in the returned dictionary.
|
||||
|
||||
Returns:
|
||||
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
|
||||
"""
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
return dict(zip(keys, loss_items))
|
||||
else:
|
||||
return keys
|
||||
|
||||
def progress_string(self):
|
||||
"""Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data.
|
||||
ni (int): Number of iterations.
|
||||
"""
|
||||
plot_images(
|
||||
labels=batch,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Create a labeled training plot of the YOLO model."""
|
||||
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
||||
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
||||
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def auto_batch(self):
|
||||
"""
|
||||
Get optimal batch size by calculating memory occupation of model.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size.
|
||||
"""
|
||||
with override_configs(self.args, overrides={"cache": False}) as self.args:
|
||||
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
|
||||
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
|
||||
del train_dataset # free memory
|
||||
return super().auto_batch(max_num_obj)
|
||||
495
ultralytics/models/yolo/detect/val.py
Normal file
495
ultralytics/models/yolo/detect/val.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER, nms, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class DetectionValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a detection model.
|
||||
|
||||
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
||||
prediction processing, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
is_coco (bool): Whether the dataset is COCO.
|
||||
is_lvis (bool): Whether the dataset is LVIS.
|
||||
class_map (list[int]): Mapping from model class indices to dataset class indices.
|
||||
metrics (DetMetrics): Object detection metrics calculator.
|
||||
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
||||
niou (int): Number of IoU thresholds.
|
||||
lb (list[Any]): List for storing ground truth labels for hybrid saving.
|
||||
jdict (list[dict[str, Any]]): List for storing JSON detection results.
|
||||
stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
||||
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
|
||||
>>> validator = DetectionValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize detection validator with necessary variables and settings.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (dict[str, Any], optional): Arguments for the validator.
|
||||
_callbacks (list[Any], optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.is_coco = False
|
||||
self.is_lvis = False
|
||||
self.class_map = None
|
||||
self.args.task = "detect"
|
||||
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
||||
self.niou = self.iouv.numel()
|
||||
self.metrics = DetMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Preprocessed batch.
|
||||
"""
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO detection validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_coco = (
|
||||
isinstance(val, str)
|
||||
and "coco" in val
|
||||
and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
|
||||
) # is COCO
|
||||
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
|
||||
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.end2end = getattr(model, "end2end", False)
|
||||
self.seen = 0
|
||||
self.jdict = []
|
||||
self.metrics.names = model.names
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
||||
'bboxes', 'conf', 'cls', and 'extra' tensors.
|
||||
"""
|
||||
outputs = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
nc=0 if self.args.task == "detect" else self.nc,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
end2end=self.end2end,
|
||||
rotated=self.args.task == "obb",
|
||||
)
|
||||
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch of images and annotations for validation.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Batch data containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with processed annotations.
|
||||
"""
|
||||
idx = batch["batch_idx"] == si
|
||||
cls = batch["cls"][idx].squeeze(-1)
|
||||
bbox = batch["bboxes"][idx]
|
||||
ori_shape = batch["ori_shape"][si]
|
||||
imgsz = batch["img"].shape[2:]
|
||||
ratio_pad = batch["ratio_pad"][si]
|
||||
if cls.shape[0]:
|
||||
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
||||
return {
|
||||
"cls": cls,
|
||||
"bboxes": bbox,
|
||||
"ori_shape": ori_shape,
|
||||
"imgsz": imgsz,
|
||||
"ratio_pad": ratio_pad,
|
||||
"im_file": batch["im_file"][si],
|
||||
}
|
||||
|
||||
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare predictions for evaluation against ground truth.
|
||||
|
||||
Args:
|
||||
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Prepared predictions in native space.
|
||||
"""
|
||||
if self.args.single_cls:
|
||||
pred["cls"] *= 0
|
||||
return pred
|
||||
|
||||
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update metrics with new predictions and ground truth.
|
||||
|
||||
Args:
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
batch (dict[str, Any]): Batch data containing ground truth.
|
||||
"""
|
||||
for si, pred in enumerate(preds):
|
||||
self.seen += 1
|
||||
pbatch = self._prepare_batch(si, batch)
|
||||
predn = self._prepare_pred(pred)
|
||||
|
||||
cls = pbatch["cls"].cpu().numpy()
|
||||
no_pred = predn["cls"].shape[0] == 0
|
||||
self.metrics.update_stats(
|
||||
{
|
||||
**self._process_batch(predn, pbatch),
|
||||
"target_cls": cls,
|
||||
"target_img": np.unique(cls),
|
||||
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
|
||||
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
|
||||
}
|
||||
)
|
||||
# Evaluate
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
|
||||
if self.args.visualize:
|
||||
self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
|
||||
|
||||
if no_pred:
|
||||
continue
|
||||
|
||||
# Save
|
||||
if self.args.save_json or self.args.save_txt:
|
||||
predn_scaled = self.scale_preds(predn, pbatch)
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(predn_scaled, pbatch)
|
||||
if self.args.save_txt:
|
||||
self.save_one_txt(
|
||||
predn_scaled,
|
||||
self.args.save_conf,
|
||||
pbatch["ori_shape"],
|
||||
self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
|
||||
)
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""Set final values for metrics speed and confusion matrix."""
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate and return metrics statistics.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Dictionary containing metrics results.
|
||||
"""
|
||||
self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
|
||||
self.metrics.clear_stats()
|
||||
return self.metrics.results_dict
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print training/validation set metrics per class."""
|
||||
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
if self.metrics.nt_per_class.sum() == 0:
|
||||
LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
|
||||
|
||||
# Print results per class
|
||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
LOGGER.info(
|
||||
pf
|
||||
% (
|
||||
self.names[c],
|
||||
self.metrics.nt_per_image[c],
|
||||
self.metrics.nt_per_class[c],
|
||||
*self.metrics.class_result(i),
|
||||
)
|
||||
)
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
||||
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
||||
"""
|
||||
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
||||
iou = box_iou(batch["bboxes"], preds["bboxes"])
|
||||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset.
|
||||
"""
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Construct and return dataloader.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int): Size of each batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): Dataloader for validation.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
||||
return build_dataloader(
|
||||
dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
|
||||
)
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
ni (int): Batch index.
|
||||
"""
|
||||
plot_images(
|
||||
labels=batch,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(
|
||||
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
ni (int): Batch index.
|
||||
max_det (Optional[int]): Maximum number of detections to plot.
|
||||
"""
|
||||
# TODO: optimize this
|
||||
for i, pred in enumerate(preds):
|
||||
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
|
||||
keys = preds[0].keys()
|
||||
max_det = max_det or self.args.max_det
|
||||
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
|
||||
# TODO: fix this
|
||||
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
|
||||
plot_images(
|
||||
images=batch["img"],
|
||||
labels=batched_preds,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image (height, width).
|
||||
file (Path): File path to save the detections.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Serialize YOLO predictions to COCO json format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Examples:
|
||||
>>> result = {
|
||||
... "image_id": 42,
|
||||
... "file_name": "42.jpg",
|
||||
... "category_id": 18,
|
||||
... "bbox": [258.15, 41.29, 348.26, 243.78],
|
||||
... "score": 0.236,
|
||||
... }
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"bbox": [round(x, 3) for x in b],
|
||||
"score": round(s, 5),
|
||||
}
|
||||
)
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**predn,
|
||||
"bboxes": ops.scale_boxes(
|
||||
pbatch["imgsz"],
|
||||
predn["bboxes"].clone(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Current statistics dictionary.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
|
||||
"""
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
return self.coco_evaluate(stats, pred_json, anno_json)
|
||||
|
||||
def coco_evaluate(
|
||||
self,
|
||||
stats: dict[str, Any],
|
||||
pred_json: str,
|
||||
anno_json: str,
|
||||
iou_types: str | list[str] = "bbox",
|
||||
suffix: str | list[str] = "Box",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
||||
|
||||
Performs evaluation using the faster-coco-eval library to compute mAP metrics
|
||||
for object detection. Updates the provided stats dictionary with computed metrics
|
||||
including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
||||
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
|
||||
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
|
||||
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
|
||||
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
||||
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
|
||||
to iou_types if multiple types provided. Defaults to "Box".
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
||||
"""
|
||||
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
||||
LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
|
||||
try:
|
||||
for x in pred_json, anno_json:
|
||||
assert x.is_file(), f"{x} file not found"
|
||||
iou_types = [iou_types] if isinstance(iou_types, str) else iou_types
|
||||
suffix = [suffix] if isinstance(suffix, str) else suffix
|
||||
check_requirements("faster-coco-eval>=1.6.7")
|
||||
from faster_coco_eval import COCO, COCOeval_faster
|
||||
|
||||
anno = COCO(anno_json)
|
||||
pred = anno.loadRes(pred_json)
|
||||
for i, iou_type in enumerate(iou_types):
|
||||
val = COCOeval_faster(
|
||||
anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info
|
||||
)
|
||||
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
val.evaluate()
|
||||
val.accumulate()
|
||||
val.summarize()
|
||||
|
||||
# update mAP50-95 and mAP50
|
||||
stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
|
||||
stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
|
||||
|
||||
if self.is_lvis:
|
||||
stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]
|
||||
stats[f"metrics/APc({suffix[i][0]})"] = val.stats_as_dict["APc"]
|
||||
stats[f"metrics/APf({suffix[i][0]})"] = val.stats_as_dict["APf"]
|
||||
|
||||
if self.is_lvis:
|
||||
stats["fitness"] = stats["metrics/mAP50-95(B)"] # always use box mAP50-95 for fitness
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"faster-coco-eval unable to run: {e}")
|
||||
return stats
|
||||
447
ultralytics/models/yolo/model.py
Normal file
447
ultralytics/models/yolo/model.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data.build import load_inference_source
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import (
|
||||
ClassificationModel,
|
||||
DetectionModel,
|
||||
OBBModel,
|
||||
PoseModel,
|
||||
SegmentationModel,
|
||||
WorldModel,
|
||||
YOLOEModel,
|
||||
YOLOESegModel,
|
||||
)
|
||||
from ultralytics.utils import ROOT, YAML
|
||||
|
||||
|
||||
class YOLO(Model):
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
||||
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
||||
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLO model instance.
|
||||
task: The task type (detect, segment, classify, pose, obb).
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize a YOLO model with automatic type detection.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
|
||||
Examples:
|
||||
Load a pretrained YOLOv11n detection model
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
|
||||
Load a pretrained YOLO11n segmentation model
|
||||
>>> model = YOLO("yolo11n-seg.pt")
|
||||
|
||||
Initialize from a YAML configuration
|
||||
>>> model = YOLO("yolo11n.yaml")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
||||
"""
|
||||
Initialize a YOLO model.
|
||||
|
||||
This constructor initializes a YOLO model, automatically switching to specialized model types
|
||||
(YOLOWorld or YOLOE) based on the model filename.
|
||||
|
||||
Args:
|
||||
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
||||
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
||||
Defaults to auto-detection based on model.
|
||||
verbose (bool): Display model info on load.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
||||
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
|
||||
"""
|
||||
path = Path(model if isinstance(model, (str, Path)) else "")
|
||||
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
||||
new_instance = YOLOWorld(path, verbose=verbose)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
|
||||
new_instance = YOLOE(path, task=task, verbose=verbose)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
else:
|
||||
# Continue with default YOLO initialization
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
|
||||
from ultralytics import RTDETR
|
||||
|
||||
new_instance = RTDETR(self)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, trainer, validator, and predictor classes."""
|
||||
return {
|
||||
"classify": {
|
||||
"model": ClassificationModel,
|
||||
"trainer": yolo.classify.ClassificationTrainer,
|
||||
"validator": yolo.classify.ClassificationValidator,
|
||||
"predictor": yolo.classify.ClassificationPredictor,
|
||||
},
|
||||
"detect": {
|
||||
"model": DetectionModel,
|
||||
"trainer": yolo.detect.DetectionTrainer,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
},
|
||||
"segment": {
|
||||
"model": SegmentationModel,
|
||||
"trainer": yolo.segment.SegmentationTrainer,
|
||||
"validator": yolo.segment.SegmentationValidator,
|
||||
"predictor": yolo.segment.SegmentationPredictor,
|
||||
},
|
||||
"pose": {
|
||||
"model": PoseModel,
|
||||
"trainer": yolo.pose.PoseTrainer,
|
||||
"validator": yolo.pose.PoseValidator,
|
||||
"predictor": yolo.pose.PosePredictor,
|
||||
},
|
||||
"obb": {
|
||||
"model": OBBModel,
|
||||
"trainer": yolo.obb.OBBTrainer,
|
||||
"validator": yolo.obb.OBBValidator,
|
||||
"predictor": yolo.obb.OBBPredictor,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class YOLOWorld(Model):
|
||||
"""
|
||||
YOLO-World object detection model.
|
||||
|
||||
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
||||
without requiring training on specific classes. It extends the YOLO architecture to support real-time
|
||||
open-vocabulary detection.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLO-World model instance.
|
||||
task: Always set to 'detect' for object detection.
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize YOLOv8-World model with a pre-trained model file.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
set_classes: Set the model's class names for detection.
|
||||
|
||||
Examples:
|
||||
Load a YOLOv8-World model
|
||||
>>> model = YOLOWorld("yolov8s-world.pt")
|
||||
|
||||
Set custom classes for detection
|
||||
>>> model.set_classes(["person", "car", "bicycle"])
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOv8-World model with a pre-trained model file.
|
||||
|
||||
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
||||
COCO class names.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
verbose (bool): If True, prints additional information during initialization.
|
||||
"""
|
||||
super().__init__(model=model, task="detect", verbose=verbose)
|
||||
|
||||
# Assign default COCO class names when there are no custom names
|
||||
if not hasattr(self.model, "names"):
|
||||
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": WorldModel,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.world.WorldTrainer,
|
||||
}
|
||||
}
|
||||
|
||||
def set_classes(self, classes: list[str]) -> None:
|
||||
"""
|
||||
Set the model's class names for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
"""
|
||||
self.model.set_classes(classes)
|
||||
# Remove background if it's given
|
||||
background = " "
|
||||
if background in classes:
|
||||
classes.remove(background)
|
||||
self.model.names = classes
|
||||
|
||||
# Reset method class names
|
||||
if self.predictor:
|
||||
self.predictor.model.names = classes
|
||||
|
||||
|
||||
class YOLOE(Model):
|
||||
"""
|
||||
YOLOE object detection and segmentation model.
|
||||
|
||||
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
||||
improved performance and additional features like visual and text positional embeddings.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLOE model instance.
|
||||
task: The task type (detect or segment).
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize YOLOE model with a pre-trained model file.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
get_text_pe: Get text positional embeddings for the given texts.
|
||||
get_visual_pe: Get visual positional embeddings for the given image and visual features.
|
||||
set_vocab: Set vocabulary and class names for the YOLOE model.
|
||||
get_vocab: Get vocabulary for the given class names.
|
||||
set_classes: Set the model's class names and embeddings for detection.
|
||||
val: Validate the model using text or visual prompts.
|
||||
predict: Run prediction on images, videos, directories, streams, etc.
|
||||
|
||||
Examples:
|
||||
Load a YOLOE detection model
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
|
||||
Set vocabulary and class names
|
||||
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
||||
|
||||
Predict with visual prompts
|
||||
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
||||
>>> results = model.predict("image.jpg", visual_prompts=prompts)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOE model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
task (str, optional): Task type for the model. Auto-detected if None.
|
||||
verbose (bool): If True, prints additional information during initialization.
|
||||
"""
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": YOLOEModel,
|
||||
"validator": yolo.yoloe.YOLOEDetectValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.yoloe.YOLOETrainer,
|
||||
},
|
||||
"segment": {
|
||||
"model": YOLOESegModel,
|
||||
"validator": yolo.yoloe.YOLOESegValidator,
|
||||
"predictor": yolo.segment.SegmentationPredictor,
|
||||
"trainer": yolo.yoloe.YOLOESegTrainer,
|
||||
},
|
||||
}
|
||||
|
||||
def get_text_pe(self, texts):
|
||||
"""Get text positional embeddings for the given texts."""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_text_pe(texts)
|
||||
|
||||
def get_visual_pe(self, img, visual):
|
||||
"""
|
||||
Get visual positional embeddings for the given image and visual features.
|
||||
|
||||
This method extracts positional embeddings from visual features based on the input image. It requires
|
||||
that the model is an instance of YOLOEModel.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input image tensor.
|
||||
visual (torch.Tensor): Visual features extracted from the image.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Visual positional embeddings.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> img = torch.rand(1, 3, 640, 640)
|
||||
>>> visual_features = torch.rand(1, 1, 80, 80)
|
||||
>>> pe = model.get_visual_pe(img, visual_features)
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_visual_pe(img, visual)
|
||||
|
||||
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
||||
"""
|
||||
Set vocabulary and class names for the YOLOE model.
|
||||
|
||||
This method configures the vocabulary and class names used by the model for text processing and
|
||||
classification tasks. The model must be an instance of YOLOEModel.
|
||||
|
||||
Args:
|
||||
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
||||
names (list[str]): List of class names that the model can detect or classify.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not an instance of YOLOEModel.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
self.model.set_vocab(vocab, names=names)
|
||||
|
||||
def get_vocab(self, names):
|
||||
"""Get vocabulary for the given class names."""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_vocab(names)
|
||||
|
||||
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
||||
"""
|
||||
Set the model's class names and embeddings for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
if embeddings is None:
|
||||
embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
|
||||
self.model.set_classes(classes, embeddings)
|
||||
# Verify no background class is present
|
||||
assert " " not in classes
|
||||
self.model.names = classes
|
||||
|
||||
# Reset method class names
|
||||
if self.predictor:
|
||||
self.predictor.model.names = classes
|
||||
|
||||
def val(
|
||||
self,
|
||||
validator=None,
|
||||
load_vp: bool = False,
|
||||
refer_data: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Validate the model using text or visual prompts.
|
||||
|
||||
Args:
|
||||
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
||||
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
||||
refer_data (str, optional): Path to the reference data for visual prompts.
|
||||
**kwargs (Any): Additional keyword arguments to override default settings.
|
||||
|
||||
Returns:
|
||||
(dict): Validation statistics containing metrics computed during validation.
|
||||
"""
|
||||
custom = {"rect": not load_vp} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
||||
|
||||
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source=None,
|
||||
stream: bool = False,
|
||||
visual_prompts: dict[str, list] = {},
|
||||
refer_image=None,
|
||||
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Run prediction on images, videos, directories, streams, etc.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
||||
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
||||
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
||||
generator as they are computed.
|
||||
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
||||
'bboxes' and 'cls' keys when non-empty.
|
||||
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
||||
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
||||
loaded based on the task.
|
||||
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
||||
|
||||
Returns:
|
||||
(list | generator): List of Results objects or generator of Results objects if stream=True.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> results = model.predict("path/to/image.jpg")
|
||||
>>> # With visual prompts
|
||||
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
||||
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
|
||||
"""
|
||||
if len(visual_prompts):
|
||||
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
|
||||
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
|
||||
)
|
||||
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
|
||||
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
|
||||
f"{len(visual_prompts['cls'])} respectively"
|
||||
)
|
||||
if type(self.predictor) is not predictor:
|
||||
self.predictor = predictor(
|
||||
overrides={
|
||||
"task": self.model.task,
|
||||
"mode": "predict",
|
||||
"save": False,
|
||||
"verbose": refer_image is None,
|
||||
"batch": 1,
|
||||
"device": kwargs.get("device", None),
|
||||
"half": kwargs.get("half", False),
|
||||
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
|
||||
},
|
||||
_callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
num_cls = (
|
||||
max(len(set(c)) for c in visual_prompts["cls"])
|
||||
if isinstance(source, list) and refer_image is None # means multiple images
|
||||
else len(set(visual_prompts["cls"]))
|
||||
)
|
||||
self.model.model[-1].nc = num_cls
|
||||
self.model.names = [f"object{i}" for i in range(num_cls)]
|
||||
self.predictor.set_prompts(visual_prompts.copy())
|
||||
self.predictor.setup_model(model=self.model)
|
||||
|
||||
if refer_image is None and source is not None:
|
||||
dataset = load_inference_source(source)
|
||||
if dataset.mode in {"video", "stream"}:
|
||||
# NOTE: set the first frame as refer image for videos/streams inference
|
||||
refer_image = next(iter(dataset))[1][0]
|
||||
if refer_image is not None:
|
||||
vpe = self.predictor.get_vpe(refer_image)
|
||||
self.model.set_classes(self.model.names, vpe)
|
||||
self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
|
||||
self.predictor = None # reset predictor
|
||||
elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
|
||||
self.predictor = None # reset predictor if no visual prompts
|
||||
|
||||
return super().predict(source, stream, **kwargs)
|
||||
7
ultralytics/models/yolo/obb/__init__.py
Normal file
7
ultralytics/models/yolo/obb/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import OBBPredictor
|
||||
from .train import OBBTrainer
|
||||
from .val import OBBValidator
|
||||
|
||||
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
|
||||
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
65
ultralytics/models/yolo/obb/predict.py
Normal file
65
ultralytics/models/yolo/obb/predict.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class OBBPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
||||
bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO OBB model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize OBBPredictor with optional model and data configuration overrides.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "obb"
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
||||
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
||||
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
||||
boxes.
|
||||
"""
|
||||
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
||||
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
||||
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
||||
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
|
||||
82
ultralytics/models/yolo/obb/train.py
Normal file
82
ultralytics/models/yolo/obb/train.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import OBBModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
|
||||
|
||||
class OBBTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
||||
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
||||
and dfl_loss.
|
||||
|
||||
Methods:
|
||||
get_model: Return OBBModel initialized with specified config and weights.
|
||||
get_validator: Return an instance of OBBValidator for validation of YOLO model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
||||
>>> trainer = OBBTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
||||
"""
|
||||
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
||||
model configuration.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
||||
will take precedence over those in cfg.
|
||||
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "obb"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(
|
||||
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
||||
) -> OBBModel:
|
||||
"""
|
||||
Return OBBModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
||||
containing configuration parameters, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(OBBModel): Initialized OBBModel with the specified configuration and weights.
|
||||
|
||||
Examples:
|
||||
>>> trainer = OBBTrainer()
|
||||
>>> model = trainer.get_model(cfg="yolo11n-obb.yaml", weights="yolo11n-obb.pt")
|
||||
"""
|
||||
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of OBBValidator for validation of YOLO model."""
|
||||
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.obb.OBBValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
299
ultralytics/models/yolo/obb/val.py
Normal file
299
ultralytics/models/yolo/obb/val.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
||||
from ultralytics.utils.nms import TorchNMS
|
||||
|
||||
|
||||
class OBBValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
||||
satellite imagery where objects can appear at various orientations.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the validator.
|
||||
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
|
||||
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
|
||||
|
||||
Methods:
|
||||
init_metrics: Initialize evaluation metrics for YOLO.
|
||||
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
|
||||
_prepare_batch: Prepare batch data for OBB validation.
|
||||
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
|
||||
plot_predictions: Plot predicted bounding boxes on input images.
|
||||
pred_to_json: Serialize YOLO predictions to COCO json format.
|
||||
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
|
||||
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBValidator
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
|
||||
>>> validator = OBBValidator(args=args)
|
||||
>>> validator(model=args["model"])
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
||||
|
||||
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
||||
It extends the DetectionValidator class and configures it specifically for the OBB task.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.args.task = "obb"
|
||||
self.metrics = OBBMetrics()
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO obb validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
||||
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
||||
class labels and bounding boxes.
|
||||
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
||||
class labels and bounding boxes.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
||||
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
|
||||
of predictions compared to the ground truth.
|
||||
|
||||
Examples:
|
||||
>>> detections = torch.rand(100, 7) # 100 sample detections
|
||||
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
||||
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
||||
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
||||
"""
|
||||
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
||||
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
|
||||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare batch data for OBB validation with proper scaling and formatting.
|
||||
|
||||
Args:
|
||||
si (int): Batch index to process.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys:
|
||||
- batch_idx: Tensor of batch indices
|
||||
- cls: Tensor of class labels
|
||||
- bboxes: Tensor of bounding boxes
|
||||
- ori_shape: Original image shapes
|
||||
- img: Batch of images
|
||||
- ratio_pad: Ratio and padding information
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
||||
"""
|
||||
idx = batch["batch_idx"] == si
|
||||
cls = batch["cls"][idx].squeeze(-1)
|
||||
bbox = batch["bboxes"][idx]
|
||||
ori_shape = batch["ori_shape"][si]
|
||||
imgsz = batch["img"].shape[2:]
|
||||
ratio_pad = batch["ratio_pad"][si]
|
||||
if cls.shape[0]:
|
||||
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
||||
return {
|
||||
"cls": cls,
|
||||
"bboxes": bbox,
|
||||
"ori_shape": ori_shape,
|
||||
"imgsz": imgsz,
|
||||
"ratio_pad": ratio_pad,
|
||||
"im_file": batch["im_file"][si],
|
||||
}
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
||||
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> batch = {"img": images, "im_file": paths}
|
||||
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
for p in preds:
|
||||
# TODO: fix this duplicated `xywh2xyxy`
|
||||
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
|
||||
super().plot_predictions(batch, preds, ni) # plot bboxes
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
This method processes rotated bounding box predictions and converts them to both rbox format
|
||||
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
||||
to the JSON dictionary.
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
rbox = predn["bboxes"]
|
||||
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
||||
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"score": round(s, 5),
|
||||
"rbox": [round(x, 3) for x in r],
|
||||
"poly": [round(x, 3) for x in b],
|
||||
}
|
||||
)
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO OBB detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
||||
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
||||
save_conf (bool): Whether to save confidence scores in the text file.
|
||||
shape (tuple[int, int]): Original image shape in format (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
|
||||
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**predn,
|
||||
"bboxes": ops.scale_boxes(
|
||||
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Performance statistics dictionary.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated performance statistics.
|
||||
"""
|
||||
if self.args.save_json and self.is_dota and len(self.jdict):
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
pred_txt = self.save_dir / "predictions_txt" # predictions
|
||||
pred_txt.mkdir(parents=True, exist_ok=True)
|
||||
data = json.load(open(pred_json))
|
||||
# Save split results
|
||||
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"]
|
||||
score = d["score"]
|
||||
classname = self.names[d["category_id"] - 1].replace(" ", "-")
|
||||
p = d["poly"]
|
||||
|
||||
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
# Save merged results, this could result slightly lower map than using official merging script,
|
||||
# because of the probiou calculation.
|
||||
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
||||
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
||||
merged_results = defaultdict(list)
|
||||
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"].split("__", 1)[0]
|
||||
pattern = re.compile(r"\d+___\d+")
|
||||
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
||||
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
|
||||
bbox[0] += x
|
||||
bbox[1] += y
|
||||
bbox.extend([score, cls])
|
||||
merged_results[image_id].append(bbox)
|
||||
for image_id, bbox in merged_results.items():
|
||||
bbox = torch.tensor(bbox)
|
||||
max_wh = torch.max(bbox[:, :2]).item() * 2
|
||||
c = bbox[:, 6:7] * max_wh # classes
|
||||
scores = bbox[:, 5] # scores
|
||||
b = bbox[:, :5].clone()
|
||||
b[:, :2] += c
|
||||
# 0.3 could get results close to the ones from official merging script, even slightly better.
|
||||
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
|
||||
bbox = bbox[i]
|
||||
|
||||
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
||||
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
||||
classname = self.names[int(x[-1])].replace(" ", "-")
|
||||
p = [round(i, 3) for i in x[:-2]] # poly
|
||||
score = round(x[-2], 3)
|
||||
|
||||
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
|
||||
return stats
|
||||
7
ultralytics/models/yolo/pose/__init__.py
Normal file
7
ultralytics/models/yolo/pose/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import PosePredictor
|
||||
from .train import PoseTrainer
|
||||
from .val import PoseValidator
|
||||
|
||||
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
|
||||
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
80
ultralytics/models/yolo/pose/predict.py
Normal file
80
ultralytics/models/yolo/pose/predict.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
||||
|
||||
|
||||
class PosePredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a pose model.
|
||||
|
||||
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
||||
capabilities inherited from DetectionPredictor.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
||||
|
||||
Methods:
|
||||
construct_result: Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize PosePredictor for pose estimation tasks.
|
||||
|
||||
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
||||
warnings for Apple MPS.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "pose"
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
||||
result object.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
||||
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
|
||||
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original unprocessed image as a numpy array.
|
||||
img_path (str): The path to the original image file.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
||||
keypoints.
|
||||
"""
|
||||
result = super().construct_result(pred, img, orig_img, img_path)
|
||||
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
||||
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
|
||||
# Scale keypoints coordinates to match the original image dimensions
|
||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||
result.update(keypoints=pred_kpts)
|
||||
return result
|
||||
115
ultralytics/models/yolo/pose/train.py
Normal file
115
ultralytics/models/yolo/pose/train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import PoseModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
|
||||
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
|
||||
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
||||
of pose keypoints alongside bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for training.
|
||||
model (PoseModel): The pose estimation model being trained.
|
||||
data (dict): Dataset configuration including keypoint shape information.
|
||||
loss_names (tuple): Names of the loss components used in training.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieve a pose estimation model with specified configuration.
|
||||
set_model_attributes: Set keypoints shape attribute on the model.
|
||||
get_validator: Create a validator instance for model evaluation.
|
||||
plot_training_samples: Visualize training samples with keypoints.
|
||||
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
||||
>>> trainer = PoseTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
|
||||
Notes:
|
||||
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
||||
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "pose"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
cfg: str | Path | dict[str, Any] | None = None,
|
||||
weights: str | Path | None = None,
|
||||
verbose: bool = True,
|
||||
) -> PoseModel:
|
||||
"""
|
||||
Get pose estimation model with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
||||
weights (str | Path, optional): Path to the model weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(PoseModel): Initialized pose estimation model.
|
||||
"""
|
||||
model = PoseModel(
|
||||
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data["kpt_shape"]
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.pose.PoseValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def get_dataset(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `kpt_shape` key is not present in the dataset.
|
||||
"""
|
||||
data = super().get_dataset()
|
||||
if "kpt_shape" not in data:
|
||||
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
|
||||
return data
|
||||
267
ultralytics/models/yolo/pose/val.py
Normal file
267
ultralytics/models/yolo/pose/val.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
||||
|
||||
|
||||
class PoseValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a pose model.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Attributes:
|
||||
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
||||
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
||||
args (dict): Arguments for the validator including task set to "pose".
|
||||
metrics (PoseMetrics): Metrics object for pose evaluation.
|
||||
|
||||
Methods:
|
||||
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
|
||||
get_desc: Return description of evaluation metrics in string format.
|
||||
init_metrics: Initialize pose estimation metrics for YOLO model.
|
||||
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
||||
dimensions.
|
||||
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
||||
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
||||
detections and ground truth.
|
||||
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
||||
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
||||
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
||||
pred_to_json: Convert YOLO predictions to COCO JSON format.
|
||||
eval_json: Evaluate object detection model using COCO JSON format.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize a PoseValidator object for pose estimation validation.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path | str, optional): Directory to save results.
|
||||
args (dict, optional): Arguments for the validator including task set to "pose".
|
||||
_callbacks (list, optional): List of callback functions to be executed during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
||||
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
||||
due to a known bug with pose models.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.sigma = None
|
||||
self.kpt_shape = None
|
||||
self.args.task = "pose"
|
||||
self.metrics = PoseMetrics()
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch["keypoints"] = batch["keypoints"].float()
|
||||
return batch
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return description of evaluation metrics in string format."""
|
||||
return ("%22s" + "%11s" * 10) % (
|
||||
"Class",
|
||||
"Images",
|
||||
"Instances",
|
||||
"Box(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
"Pose(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO pose validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data["kpt_shape"]
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0]
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
||||
|
||||
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
||||
field of predictions and reshaping them according to the keypoint shape configuration.
|
||||
The keypoints are reshaped from a flattened format to the proper dimensional structure
|
||||
(typically [N, 17, 3] for COCO pose format).
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
||||
bounding boxes, confidence scores, class predictions, and keypoint data.
|
||||
|
||||
Returns:
|
||||
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
||||
- 'bboxes': Bounding box coordinates
|
||||
- 'conf': Confidence scores
|
||||
- 'cls': Class predictions
|
||||
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
||||
|
||||
Note:
|
||||
If no keypoints are present in a prediction (empty keypoints), that prediction
|
||||
is skipped and continues to the next one. The keypoints are extracted from the
|
||||
'extra' field which contains additional task-specific data beyond basic detection.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
||||
|
||||
Notes:
|
||||
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
||||
Keypoints are scaled from normalized coordinates to original image dimensions.
|
||||
"""
|
||||
pbatch = super()._prepare_batch(si, batch)
|
||||
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
||||
h, w = pbatch["imgsz"]
|
||||
kpts = kpts.clone()
|
||||
kpts[..., 0] *= w
|
||||
kpts[..., 1] *= h
|
||||
pbatch["keypoints"] = kpts
|
||||
return pbatch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
||||
and 'keypoints' for keypoint predictions.
|
||||
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
||||
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
||||
true positives across 10 IoU levels.
|
||||
|
||||
Notes:
|
||||
`0.53` scale factor used in area computation is referenced from
|
||||
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
||||
"""
|
||||
tp = super()._process_batch(preds, batch)
|
||||
gt_cls = batch["cls"]
|
||||
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
||||
else:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
||||
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
||||
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
||||
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
||||
return tp
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Notes:
|
||||
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
||||
normalized (x, y, visibility) values for each point.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
keypoints=predn["keypoints"],
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format.
|
||||
|
||||
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
||||
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
||||
and 'keypoints' tensors.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
||||
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
||||
before saving to the JSON dictionary.
|
||||
"""
|
||||
super().pred_to_json(predn, pbatch)
|
||||
kpts = predn["kpts"]
|
||||
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
||||
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**super().scale_preds(predn, pbatch),
|
||||
"kpts": ops.scale_coords(
|
||||
pbatch["imgsz"],
|
||||
predn["keypoints"].clone(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Evaluate object detection model using COCO JSON format."""
|
||||
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
|
||||
7
ultralytics/models/yolo/segment/__init__.py
Normal file
7
ultralytics/models/yolo/segment/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import SegmentationPredictor
|
||||
from .train import SegmentationTrainer
|
||||
from .val import SegmentationValidator
|
||||
|
||||
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/segment/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/segment/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
113
ultralytics/models/yolo/segment/predict.py
Normal file
113
ultralytics/models/yolo/segment/predict.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO segmentation model.
|
||||
batch (list): Current batch of images being processed.
|
||||
|
||||
Methods:
|
||||
postprocess: Apply non-max suppression and process segmentation detections.
|
||||
construct_results: Construct a list of result objects from predictions.
|
||||
construct_result: Construct a single result object from a prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
|
||||
>>> predictor = SegmentationPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "segment"
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
||||
|
||||
Args:
|
||||
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
||||
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
||||
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
||||
Each Results object includes both bounding boxes and segmentation masks.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
||||
>>> results = predictor.postprocess(preds, img, orig_img)
|
||||
"""
|
||||
# Extract protos - tuple if PyTorch model or array if exported
|
||||
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
||||
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs, protos):
|
||||
"""
|
||||
Construct a list of result objects from the predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
||||
protos (list[torch.Tensor]): List of prototype masks.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of result objects containing the original images, image paths, class names,
|
||||
bounding boxes, and masks.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path, proto)
|
||||
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path, proto):
|
||||
"""
|
||||
Construct a single result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
proto (torch.Tensor): The prototype masks.
|
||||
|
||||
Returns:
|
||||
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
||||
"""
|
||||
if pred.shape[0] == 0: # save empty boxes
|
||||
masks = None
|
||||
elif self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
if masks is not None:
|
||||
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
|
||||
pred, masks = pred[keep], masks[keep]
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
||||
72
ultralytics/models/yolo/segment/train.py
Normal file
72
ultralytics/models/yolo/segment/train.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import SegmentationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
|
||||
|
||||
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||
|
||||
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
||||
functionality including model initialization, validation, and visualization.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple[str]): Names of the loss components used during training.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
||||
>>> trainer = SegmentationTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a SegmentationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "segment"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
||||
"""
|
||||
Initialize and return a SegmentationModel with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
||||
weights (str | Path, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
|
||||
|
||||
Examples:
|
||||
>>> trainer = SegmentationTrainer()
|
||||
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
|
||||
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
|
||||
"""
|
||||
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.segment.SegmentationValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
259
ultralytics/models/yolo/segment/val.py
Normal file
259
ultralytics/models/yolo/segment/val.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
||||
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a segmentation model.
|
||||
|
||||
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
||||
to compute metrics such as mAP for both detection and segmentation tasks.
|
||||
|
||||
Attributes:
|
||||
plot_masks (list): List to store masks for plotting.
|
||||
process (callable): Function to process masks based on save_json and save_txt flags.
|
||||
args (namespace): Arguments for the validator.
|
||||
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
|
||||
stats (dict): Dictionary to store statistics during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
|
||||
>>> validator = SegmentationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (namespace, optional): Arguments for the validator.
|
||||
_callbacks (list, optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.process = None
|
||||
self.args.task = "segment"
|
||||
self.metrics = SegmentMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO segmentation validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Preprocessed batch.
|
||||
"""
|
||||
batch = super().preprocess(batch)
|
||||
batch["masks"] = batch["masks"].float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize metrics and select mask processing function based on save_json flag.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
if self.args.save_json:
|
||||
check_requirements("faster-coco-eval>=1.6.7")
|
||||
# More accurate vs faster
|
||||
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ("%22s" + "%11s" * 10) % (
|
||||
"Class",
|
||||
"Images",
|
||||
"Instances",
|
||||
"Box(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
"Mask(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Post-process YOLO predictions and return output detections with proto.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
list[dict[str, torch.Tensor]]: Processed detection predictions with masks.
|
||||
"""
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
preds = super().postprocess(preds[0])
|
||||
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
|
||||
for i, pred in enumerate(preds):
|
||||
coefficient = pred.pop("extra")
|
||||
pred["masks"] = (
|
||||
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
|
||||
if coefficient.shape[0]
|
||||
else torch.zeros(
|
||||
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
|
||||
dtype=torch.uint8,
|
||||
device=pred["bboxes"].device,
|
||||
)
|
||||
)
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for training or inference by processing images and targets.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Batch data containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with processed annotations.
|
||||
"""
|
||||
prepared_batch = super()._prepare_batch(si, batch)
|
||||
nl = prepared_batch["cls"].shape[0]
|
||||
if self.args.overlap_mask:
|
||||
masks = batch["masks"][si]
|
||||
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
|
||||
masks = (masks == index).float()
|
||||
else:
|
||||
masks = batch["masks"][batch["batch_idx"] == si]
|
||||
if nl:
|
||||
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
|
||||
if masks.shape[1:] != mask_size:
|
||||
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
|
||||
masks = masks.gt_(0.5)
|
||||
prepared_batch["masks"] = masks
|
||||
return prepared_batch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
||||
|
||||
Notes:
|
||||
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
||||
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
||||
|
||||
Examples:
|
||||
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
||||
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
||||
>>> correct_preds = validator._process_batch(preds, batch)
|
||||
"""
|
||||
tp = super()._process_batch(preds, batch)
|
||||
gt_cls = batch["cls"]
|
||||
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
||||
else:
|
||||
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
|
||||
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
||||
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
||||
return tp
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
||||
"""
|
||||
Plot batch predictions with masks and bounding boxes.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
ni (int): Batch index.
|
||||
"""
|
||||
for p in preds:
|
||||
masks = p["masks"]
|
||||
if masks.shape[0] > self.args.max_det:
|
||||
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
|
||||
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
|
||||
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
||||
|
||||
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image.
|
||||
file (Path): File path to save the detections.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Save one JSON result for COCO evaluation.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
"""
|
||||
from faster_coco_eval.core.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
return rle
|
||||
|
||||
pred_masks = np.transpose(predn["masks"], (2, 0, 1))
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
rles = pool.map(single_encode, pred_masks)
|
||||
super().pred_to_json(predn, pbatch)
|
||||
for i, r in enumerate(rles):
|
||||
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**super().scale_preds(predn, pbatch),
|
||||
"masks": ops.scale_image(
|
||||
torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return COCO-style instance segmentation evaluation metrics."""
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "segm"], suffix=["Box", "Mask"])
|
||||
5
ultralytics/models/yolo/world/__init__.py
Normal file
5
ultralytics/models/yolo/world/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .train import WorldTrainer
|
||||
|
||||
__all__ = ["WorldTrainer"]
|
||||
Binary file not shown.
BIN
ultralytics/models/yolo/world/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/world/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
179
ultralytics/models/yolo/world/train.py
Normal file
179
ultralytics/models/yolo/world/train.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import build_yolo_dataset
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import WorldModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer) -> None:
|
||||
"""Set up model classes and text encoder at the end of the pretrain routine."""
|
||||
if RANK in {-1, 0}:
|
||||
# Set class names for evaluation
|
||||
names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||
unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||
|
||||
|
||||
class WorldTrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
||||
|
||||
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
||||
features for improved object detection and understanding. It handles text embedding generation and caching to
|
||||
accelerate training with multi-modal data.
|
||||
|
||||
Attributes:
|
||||
text_embeddings (dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
|
||||
training.
|
||||
model (WorldModel): The YOLO World model being trained.
|
||||
data (dict[str, Any]): Dataset configuration containing class information.
|
||||
args (Any): Training arguments and configuration.
|
||||
|
||||
Methods:
|
||||
get_model: Return WorldModel initialized with specified config and weights.
|
||||
build_dataset: Build YOLO Dataset for training or validation.
|
||||
set_text_embeddings: Set text embeddings for datasets to accelerate training.
|
||||
generate_text_embeddings: Generate text embeddings for a list of text samples.
|
||||
preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
|
||||
|
||||
Examples:
|
||||
Initialize and train a YOLO World model
|
||||
>>> from ultralytics.models.yolo.world import WorldTrainer
|
||||
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
|
||||
>>> trainer = WorldTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a WorldTrainer object with given arguments.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any]): Configuration for the trainer.
|
||||
overrides (dict[str, Any], optional): Configuration overrides.
|
||||
_callbacks (list[Any], optional): List of callback functions.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.text_embeddings = None
|
||||
|
||||
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
||||
"""
|
||||
Return WorldModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any] | str, optional): Model configuration.
|
||||
weights (str, optional): Path to pretrained weights.
|
||||
verbose (bool): Whether to display model info.
|
||||
|
||||
Returns:
|
||||
(WorldModel): Initialized WorldModel.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = WorldModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
||||
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`.
|
||||
|
||||
Returns:
|
||||
(Any): YOLO dataset configured for training or validation.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
dataset = build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
if mode == "train":
|
||||
self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
|
||||
return dataset
|
||||
|
||||
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, then generates and caches text embeddings
|
||||
for these categories to improve training efficiency.
|
||||
|
||||
Args:
|
||||
datasets (list[Any]): List of datasets from which to extract category names.
|
||||
batch (int | None): Batch size used for processing.
|
||||
|
||||
Notes:
|
||||
This method collects category names from datasets that have the 'category_names' attribute,
|
||||
then uses the first dataset's image path to determine where to cache the generated text embeddings.
|
||||
"""
|
||||
text_embeddings = {}
|
||||
for dataset in datasets:
|
||||
if not hasattr(dataset, "category_names"):
|
||||
continue
|
||||
text_embeddings.update(
|
||||
self.generate_text_embeddings(
|
||||
list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
|
||||
)
|
||||
)
|
||||
self.text_embeddings = text_embeddings
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
batch (int): Batch size for processing.
|
||||
cache_dir (Path): Directory to save/load cached embeddings.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
|
||||
"""
|
||||
model = "clip:ViT-B/32"
|
||||
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
||||
if cache_path.exists():
|
||||
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
||||
txt_map = torch.load(cache_path, map_location=self.device)
|
||||
if sorted(txt_map.keys()) == sorted(texts):
|
||||
return txt_map
|
||||
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
||||
assert self.model is not None
|
||||
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
|
||||
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
||||
torch.save(txt_map, cache_path)
|
||||
return txt_map
|
||||
|
||||
def preprocess_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess a batch of images and text for YOLOWorld training."""
|
||||
batch = DetectionTrainer.preprocess_batch(self, batch)
|
||||
|
||||
# Add text features
|
||||
texts = list(itertools.chain(*batch["texts"]))
|
||||
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
|
||||
self.device, non_blocking=self.device.type == "cuda"
|
||||
)
|
||||
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
||||
return batch
|
||||
201
ultralytics/models/yolo/world/train_world.py
Normal file
201
ultralytics/models/yolo/world/train_world.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.data import YOLOConcatDataset, build_grounding, build_yolo_dataset
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.world import WorldTrainer
|
||||
from ultralytics.utils import DATASETS_DIR, DEFAULT_CFG, LOGGER
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
|
||||
class WorldTrainerFromScratch(WorldTrainer):
|
||||
"""
|
||||
A class extending the WorldTrainer for training a world model from scratch on open-set datasets.
|
||||
|
||||
This trainer specializes in handling mixed datasets including both object detection and grounding datasets,
|
||||
supporting training YOLO-World models with combined vision-language capabilities.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary with default parameters for model training.
|
||||
overrides (dict): Dictionary of parameter overrides to customize the configuration.
|
||||
_callbacks (list): List of callback functions to be executed during different stages of training.
|
||||
data (dict): Final processed data configuration containing train/val paths and metadata.
|
||||
training_data (dict): Dictionary mapping training dataset paths to their configurations.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build YOLO Dataset for training or validation with mixed dataset support.
|
||||
get_dataset: Get train and validation paths from data dictionary.
|
||||
plot_training_labels: Skip label plotting for YOLO-World training.
|
||||
final_eval: Perform final evaluation and validation for the YOLO-World model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||
>>> from ultralytics import YOLOWorld
|
||||
>>> data = dict(
|
||||
... train=dict(
|
||||
... yolo_data=["Objects365.yaml"],
|
||||
... grounding_data=[
|
||||
... dict(
|
||||
... img_path="flickr30k/images",
|
||||
... json_file="flickr30k/final_flickr_separateGT_train.json",
|
||||
... ),
|
||||
... dict(
|
||||
... img_path="GQA/images",
|
||||
... json_file="GQA/final_mixed_train_no_coco.json",
|
||||
... ),
|
||||
... ],
|
||||
... ),
|
||||
... val=dict(yolo_data=["lvis.yaml"]),
|
||||
... )
|
||||
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize a WorldTrainerFromScratch object.
|
||||
|
||||
This initializes a trainer for YOLO-World models from scratch, supporting mixed datasets including both
|
||||
object detection and grounding datasets for vision-language capabilities.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default parameters for model training.
|
||||
overrides (dict, optional): Dictionary of parameter overrides to customize the configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during different stages of training.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.world.train_world import WorldTrainerFromScratch
|
||||
>>> from ultralytics import YOLOWorld
|
||||
>>> data = dict(
|
||||
... train=dict(
|
||||
... yolo_data=["Objects365.yaml"],
|
||||
... grounding_data=[
|
||||
... dict(
|
||||
... img_path="flickr30k/images",
|
||||
... json_file="flickr30k/final_flickr_separateGT_train.json",
|
||||
... ),
|
||||
... ],
|
||||
... ),
|
||||
... val=dict(yolo_data=["lvis.yaml"]),
|
||||
... )
|
||||
>>> model = YOLOWorld("yolov8s-worldv2.yaml")
|
||||
>>> model.train(data=data, trainer=WorldTrainerFromScratch)
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
This method constructs appropriate datasets based on the mode and input paths, handling both
|
||||
standard YOLO datasets and grounding datasets with different formats.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
||||
batch (int, optional): Size of batches, used for rectangular training/validation.
|
||||
|
||||
Returns:
|
||||
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
if mode != "train":
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=False, stride=gs)
|
||||
datasets = [
|
||||
build_yolo_dataset(self.args, im_path, batch, self.training_data[im_path], stride=gs, multi_modal=True)
|
||||
if isinstance(im_path, str)
|
||||
else build_grounding(
|
||||
# assign `nc` from validation set to max number of text samples for training consistency
|
||||
self.args,
|
||||
im_path["img_path"],
|
||||
im_path["json_file"],
|
||||
batch,
|
||||
stride=gs,
|
||||
max_samples=self.data["nc"],
|
||||
)
|
||||
for im_path in img_path
|
||||
]
|
||||
self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
|
||||
return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train and validation paths from data dictionary.
|
||||
|
||||
Processes the data configuration to extract paths for training and validation datasets,
|
||||
handling both YOLO detection datasets and grounding datasets.
|
||||
|
||||
Returns:
|
||||
train_path (str): Train dataset path.
|
||||
val_path (str): Validation dataset path.
|
||||
|
||||
Raises:
|
||||
AssertionError: If train or validation datasets are not found, or if validation has multiple datasets.
|
||||
"""
|
||||
final_data = {}
|
||||
data_yaml = self.args.data
|
||||
assert data_yaml.get("train", False), "train dataset not found" # object365.yaml
|
||||
assert data_yaml.get("val", False), "validation dataset not found" # lvis.yaml
|
||||
data = {k: [check_det_dataset(d) for d in v.get("yolo_data", [])] for k, v in data_yaml.items()}
|
||||
assert len(data["val"]) == 1, f"Only support validating on 1 dataset for now, but got {len(data['val'])}."
|
||||
val_split = "minival" if "lvis" in data["val"][0]["val"] else "val"
|
||||
for d in data["val"]:
|
||||
if d.get("minival") is None: # for lvis dataset
|
||||
continue
|
||||
d["minival"] = str(d["path"] / d["minival"])
|
||||
for s in {"train", "val"}:
|
||||
final_data[s] = [d["train" if s == "train" else val_split] for d in data[s]]
|
||||
# save grounding data if there's one
|
||||
grounding_data = data_yaml[s].get("grounding_data")
|
||||
if grounding_data is None:
|
||||
continue
|
||||
grounding_data = grounding_data if isinstance(grounding_data, list) else [grounding_data]
|
||||
for g in grounding_data:
|
||||
assert isinstance(g, dict), f"Grounding data should be provided in dict format, but got {type(g)}"
|
||||
for k in {"img_path", "json_file"}:
|
||||
path = Path(g[k])
|
||||
if not path.exists() and not path.is_absolute():
|
||||
g[k] = str((DATASETS_DIR / g[k]).resolve()) # path relative to DATASETS_DIR
|
||||
final_data[s] += grounding_data
|
||||
# assign the first val dataset as currently only one validation set is supported
|
||||
data["val"] = data["val"][0]
|
||||
final_data["val"] = final_data["val"][0]
|
||||
# NOTE: to make training work properly, set `nc` and `names`
|
||||
final_data["nc"] = data["val"]["nc"]
|
||||
final_data["names"] = data["val"]["names"]
|
||||
# NOTE: add path with lvis path
|
||||
final_data["path"] = data["val"]["path"]
|
||||
final_data["channels"] = data["val"]["channels"]
|
||||
self.data = final_data
|
||||
if self.args.single_cls: # consistent with base trainer
|
||||
LOGGER.info("Overriding class names with single class.")
|
||||
self.data["names"] = {0: "object"}
|
||||
self.data["nc"] = 1
|
||||
self.training_data = {}
|
||||
for d in data["train"]:
|
||||
if self.args.single_cls:
|
||||
d["names"] = {0: "object"}
|
||||
d["nc"] = 1
|
||||
self.training_data[d["train"]] = d
|
||||
return final_data
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Skip label plotting for YOLO-World training."""
|
||||
pass
|
||||
|
||||
def final_eval(self):
|
||||
"""
|
||||
Perform final evaluation and validation for the YOLO-World model.
|
||||
|
||||
Configures the validator with appropriate dataset and split information before running evaluation.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing evaluation metrics and results.
|
||||
"""
|
||||
val = self.args.data["val"]["yolo_data"][0]
|
||||
self.validator.args.data = val
|
||||
self.validator.args.split = "minival" if isinstance(val, str) and "lvis" in val else "val"
|
||||
return super().final_eval()
|
||||
22
ultralytics/models/yolo/yoloe/__init__.py
Normal file
22
ultralytics/models/yolo/yoloe/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
|
||||
from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
||||
from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
|
||||
from .val import YOLOEDetectValidator, YOLOESegValidator
|
||||
|
||||
__all__ = [
|
||||
"YOLOETrainer",
|
||||
"YOLOEPETrainer",
|
||||
"YOLOESegTrainer",
|
||||
"YOLOEDetectValidator",
|
||||
"YOLOESegValidator",
|
||||
"YOLOEPESegTrainer",
|
||||
"YOLOESegTrainerFromScratch",
|
||||
"YOLOESegVPTrainer",
|
||||
"YOLOEVPTrainer",
|
||||
"YOLOEPEFreeTrainer",
|
||||
"YOLOEVPDetectPredictor",
|
||||
"YOLOEVPSegPredictor",
|
||||
"YOLOETrainerFromScratch",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/yoloe/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/yoloe/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/yoloe/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/yoloe/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
169
ultralytics/models/yolo/yoloe/predict.py
Normal file
169
ultralytics/models/yolo/yoloe/predict.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.models.yolo.detect import DetectionPredictor
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
|
||||
|
||||
class YOLOEVPDetectPredictor(DetectionPredictor):
|
||||
"""
|
||||
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
||||
|
||||
This mixin provides common functionality for YOLO models that use visual prompting, including
|
||||
model setup, prompt handling, and preprocessing transformations.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The YOLO model for inference.
|
||||
device (torch.device): Device to run the model on (CPU or CUDA).
|
||||
prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
|
||||
|
||||
Methods:
|
||||
setup_model: Initialize the YOLO model and set it to evaluation mode.
|
||||
set_prompts: Set the visual prompts for the model.
|
||||
pre_transform: Preprocess images and prompts before inference.
|
||||
inference: Run inference with visual prompts.
|
||||
get_vpe: Process source to get visual prompt embeddings.
|
||||
"""
|
||||
|
||||
def setup_model(self, model, verbose: bool = True):
|
||||
"""
|
||||
Set up the model for prediction.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to load or use.
|
||||
verbose (bool, optional): If True, provides detailed logging.
|
||||
"""
|
||||
super().setup_model(model, verbose=verbose)
|
||||
self.done_warmup = True
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""
|
||||
Set the visual prompts for the model.
|
||||
|
||||
Args:
|
||||
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
||||
Must include a 'cls' key with class indices.
|
||||
"""
|
||||
self.prompts = prompts
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Preprocess images and prompts before inference.
|
||||
|
||||
This method applies letterboxing to the input image and transforms the visual prompts
|
||||
(bounding boxes or masks) accordingly.
|
||||
|
||||
Args:
|
||||
im (list): List containing a single input image.
|
||||
|
||||
Returns:
|
||||
(list): Preprocessed image ready for model inference.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
|
||||
"""
|
||||
img = super().pre_transform(im)
|
||||
bboxes = self.prompts.pop("bboxes", None)
|
||||
masks = self.prompts.pop("masks", None)
|
||||
category = self.prompts["cls"]
|
||||
if len(img) == 1:
|
||||
visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
|
||||
prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
|
||||
else:
|
||||
# NOTE: only supports bboxes as prompts for now
|
||||
assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
|
||||
# NOTE: needs list[np.ndarray]
|
||||
assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
|
||||
f"Expected list[np.ndarray], but got {bboxes}!"
|
||||
)
|
||||
assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
|
||||
f"Expected list[np.ndarray], but got {category}!"
|
||||
)
|
||||
assert len(im) == len(category) == len(bboxes), (
|
||||
f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
|
||||
)
|
||||
visuals = [
|
||||
self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
|
||||
for i in range(len(img))
|
||||
]
|
||||
prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
|
||||
self.prompts = prompts.half() if self.model.fp16 else prompts.float()
|
||||
return img
|
||||
|
||||
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
||||
"""
|
||||
Process a single image by resizing bounding boxes or masks and generating visuals.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple): The target shape (height, width) of the image.
|
||||
src_shape (tuple): The original shape (height, width) of the image.
|
||||
category (str): The category of the image for visual prompts.
|
||||
bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
|
||||
masks (np.ndarray, optional): A list of masks corresponding to the image.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The processed visuals for the image.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `bboxes` nor `masks` are provided.
|
||||
"""
|
||||
if bboxes is not None and len(bboxes):
|
||||
bboxes = np.array(bboxes, dtype=np.float32)
|
||||
if bboxes.ndim == 1:
|
||||
bboxes = bboxes[None, :]
|
||||
# Calculate scaling factor and adjust bounding boxes
|
||||
gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
|
||||
bboxes *= gain
|
||||
bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
|
||||
bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
|
||||
elif masks is not None:
|
||||
# Resize and process masks
|
||||
resized_masks = super().pre_transform(masks)
|
||||
masks = np.stack(resized_masks) # (N, H, W)
|
||||
masks[masks == 114] = 0 # Reset padding values to 0
|
||||
else:
|
||||
raise ValueError("Please provide valid bboxes or masks")
|
||||
|
||||
# Generate visuals using the visual prompt loader
|
||||
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
"""
|
||||
Run inference with visual prompts.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input image tensor.
|
||||
*args (Any): Variable length argument list.
|
||||
**kwargs (Any): Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model prediction results.
|
||||
"""
|
||||
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
||||
|
||||
def get_vpe(self, source):
|
||||
"""
|
||||
Process the source to get the visual prompt embeddings (VPE).
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
||||
of the image to make predictions on. Accepts various types including file paths, URLs, PIL
|
||||
images, numpy arrays, and torch tensors.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|
||||
"""
|
||||
self.setup_source(source)
|
||||
assert len(self.dataset) == 1, "get_vpe only supports one image!"
|
||||
for _, im0s, _ in self.dataset:
|
||||
im = self.preprocess(im0s)
|
||||
return self.model(im, vpe=self.prompts, return_vpe=True)
|
||||
|
||||
|
||||
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
|
||||
"""Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
|
||||
|
||||
pass
|
||||
300
ultralytics/models/yolo/yoloe/train.py
Normal file
300
ultralytics/models/yolo/yoloe/train.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import YOLOConcatDataset, build_yolo_dataset
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
|
||||
from ultralytics.nn.tasks import YOLOEModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
from ..world.train_world import WorldTrainerFromScratch
|
||||
from .val import YOLOEDetectValidator
|
||||
|
||||
|
||||
class YOLOETrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for YOLOE object detection models.
|
||||
|
||||
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
|
||||
including custom model initialization, validation, and dataset building with multi-modal support.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of loss components used during training.
|
||||
|
||||
Methods:
|
||||
get_model: Initialize and return a YOLOEModel with specified configuration.
|
||||
get_validator: Return a YOLOEDetectValidator for model validation.
|
||||
build_dataset: Build YOLO dataset with multi-modal support for training.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize the YOLOE Trainer with specified configurations.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be applied during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
||||
overrides["overlap_mask"] = False
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLOEModel initialized with the specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
|
||||
a direct path to a YAML file, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(YOLOEModel): The initialized YOLOE model.
|
||||
|
||||
Notes:
|
||||
- The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
|
||||
- The nc parameter here represents the maximum number of different text samples in one image,
|
||||
rather than the actual number of classes.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOEModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a YOLOEDetectValidator for YOLOE model validation."""
|
||||
self.loss_names = "box", "cls", "dfl"
|
||||
return YOLOEDetectValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for rectangular training.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset configured for training or validation.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
|
||||
|
||||
class YOLOEPETrainer(DetectionTrainer):
|
||||
"""
|
||||
Fine-tune YOLOE model using linear probing approach.
|
||||
|
||||
This trainer freezes most model layers and only trains specific projection layers for efficient
|
||||
fine-tuning on new datasets while preserving pretrained features.
|
||||
|
||||
Methods:
|
||||
get_model: Initialize YOLOEModel with frozen layers except projection layers.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return YOLOEModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration.
|
||||
weights (str, optional): Path to pretrained weights.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOEModel): Initialized model with frozen layers except for specific projection layers.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOEModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=self.data["nc"],
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
|
||||
del model.model[-1].savpe
|
||||
|
||||
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
model.eval()
|
||||
names = list(self.data["names"].values())
|
||||
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
||||
# it'd get correct results as long as loading proper pretrained weights.
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
|
||||
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
||||
del model.pe
|
||||
model.train()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE models from scratch with text embedding support.
|
||||
|
||||
This trainer combines YOLOE training capabilities with world training features, enabling
|
||||
training from scratch with text embeddings and grounding datasets.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build datasets for training with grounding support.
|
||||
generate_text_embeddings: Generate and cache text embeddings for training.
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
This method constructs appropriate datasets based on the mode and input paths, handling both
|
||||
standard YOLO datasets and grounding datasets with different formats.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
||||
batch (int, optional): Size of batches, used for rectangular training/validation.
|
||||
|
||||
Returns:
|
||||
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
||||
"""
|
||||
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
batch (int): Batch size for processing.
|
||||
cache_dir (Path): Directory to save/load cached embeddings.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary mapping text samples to their embeddings.
|
||||
"""
|
||||
model = "mobileclip:blt"
|
||||
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
||||
if cache_path.exists():
|
||||
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
||||
txt_map = torch.load(cache_path, map_location=self.device)
|
||||
if sorted(txt_map.keys()) == sorted(texts):
|
||||
return txt_map
|
||||
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
||||
assert self.model is not None
|
||||
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
||||
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
||||
torch.save(txt_map, cache_path)
|
||||
return txt_map
|
||||
|
||||
|
||||
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train prompt-free YOLOE model.
|
||||
|
||||
This trainer combines linear probing capabilities with from-scratch training for prompt-free
|
||||
YOLOE models that don't require text prompts during inference.
|
||||
|
||||
Methods:
|
||||
get_validator: Return standard DetectionValidator for validation.
|
||||
preprocess_batch: Preprocess batches without text features.
|
||||
set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
|
||||
"""
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = "box", "cls", "dfl"
|
||||
return DetectionValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
|
||||
return DetectionTrainer.preprocess_batch(self, batch)
|
||||
|
||||
def set_text_embeddings(self, datasets, batch: int):
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, generates text embeddings for them,
|
||||
and caches these embeddings to improve training efficiency. The embeddings are stored in a file
|
||||
in the parent directory of the first dataset's image path.
|
||||
|
||||
Args:
|
||||
datasets (list[Dataset]): List of datasets containing category names to process.
|
||||
batch (int): Batch size for processing text embeddings.
|
||||
|
||||
Notes:
|
||||
The method creates a dictionary mapping text samples to their embeddings and stores it
|
||||
at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
|
||||
instead of regenerating the embeddings.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE model with visual prompts.
|
||||
|
||||
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
|
||||
where visual cues are provided alongside images to guide the detection process.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build dataset with visual prompt loading transforms.
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation with visual prompts.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
||||
batch (int, optional): Size of batches, used for rectangular training/validation.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
|
||||
"""
|
||||
dataset = super().build_dataset(img_path, mode, batch)
|
||||
if isinstance(dataset, YOLOConcatDataset):
|
||||
for d in dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
dataset.transforms.append(LoadVisualPrompt())
|
||||
return dataset
|
||||
|
||||
def _close_dataloader_mosaic(self):
|
||||
"""Close mosaic augmentation and add visual prompt loading to the training dataset."""
|
||||
super()._close_dataloader_mosaic()
|
||||
if isinstance(self.train_loader.dataset, YOLOConcatDataset):
|
||||
for d in self.train_loader.dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
self.train_loader.dataset.transforms.append(LoadVisualPrompt())
|
||||
127
ultralytics/models/yolo/yoloe/train_seg.py
Normal file
127
ultralytics/models/yolo/yoloe/train_seg.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||
from ultralytics.nn.tasks import YOLOESegModel
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
||||
from .val import YOLOESegValidator
|
||||
|
||||
|
||||
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
||||
"""
|
||||
Trainer class for YOLOE segmentation models.
|
||||
|
||||
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
||||
segmentation models, enabling both object detection and instance segmentation capabilities.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary with training parameters.
|
||||
overrides (dict): Dictionary with parameter overrides.
|
||||
_callbacks (list): List of callback functions for training events.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
weights (str, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOESegModel): Initialized YOLOE segmentation model.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOESegModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""
|
||||
Create and return a validator for YOLOE segmentation model evaluation.
|
||||
|
||||
Returns:
|
||||
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
||||
"""
|
||||
self.loss_names = "box", "seg", "cls", "dfl"
|
||||
return YOLOESegValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
class YOLOEPESegTrainer(SegmentationTrainer):
|
||||
"""
|
||||
Fine-tune YOLOESeg model in linear probing way.
|
||||
|
||||
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
||||
most of the model and only training specific layers for efficient adaptation to new tasks.
|
||||
|
||||
Attributes:
|
||||
data (dict): Dataset configuration containing channels, class names, and number of classes.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
weights (str, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOESegModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=self.data["nc"],
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
|
||||
del model.model[-1].savpe
|
||||
|
||||
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
model.eval()
|
||||
names = list(self.data["names"].values())
|
||||
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
||||
# it'd get correct results as long as loading proper pretrained weights.
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
model.model[-1].fuse(model.pe)
|
||||
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
||||
del model.pe
|
||||
model.train()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
|
||||
"""Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
|
||||
"""Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
|
||||
|
||||
pass
|
||||
211
ultralytics/models/yolo/yoloe/val.py
Normal file
211
ultralytics/models/yolo/yoloe/val.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
from ultralytics.nn.modules.head import YOLOEDetect
|
||||
from ultralytics.nn.tasks import YOLOEModel
|
||||
from ultralytics.utils import LOGGER, TQDM
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
class YOLOEDetectValidator(DetectionValidator):
|
||||
"""
|
||||
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
||||
|
||||
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
|
||||
It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
|
||||
enabling flexible evaluation strategies for prompt-based object detection.
|
||||
|
||||
Attributes:
|
||||
device (torch.device): The device on which validation is performed.
|
||||
args (namespace): Configuration arguments for validation.
|
||||
dataloader (DataLoader): DataLoader for validation data.
|
||||
|
||||
Methods:
|
||||
get_visual_pe: Extract visual prompt embeddings from training samples.
|
||||
preprocess: Preprocess batch data ensuring visuals are on the same device as images.
|
||||
get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
|
||||
__call__: Run validation using either text or visual prompt embeddings.
|
||||
|
||||
Examples:
|
||||
Validate with text prompts
|
||||
>>> validator = YOLOEDetectValidator()
|
||||
>>> stats = validator(model=model, load_vp=False)
|
||||
|
||||
Validate with visual prompts
|
||||
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
|
||||
"""
|
||||
|
||||
@smart_inference_mode()
|
||||
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
||||
"""
|
||||
Extract visual prompt embeddings from training samples.
|
||||
|
||||
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
|
||||
It normalizes the embeddings and handles cases where no samples exist for a class by setting their
|
||||
embeddings to zero.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
||||
model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
||||
"""
|
||||
assert isinstance(model, YOLOEModel)
|
||||
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
|
||||
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
||||
cls_visual_num = torch.zeros(len(names))
|
||||
|
||||
desc = "Get visual prompt embeddings from samples"
|
||||
|
||||
# Count samples per class
|
||||
for batch in dataloader:
|
||||
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
|
||||
count = torch.bincount(cls, minlength=len(names))
|
||||
cls_visual_num += count
|
||||
|
||||
cls_visual_num = cls_visual_num.to(self.device)
|
||||
|
||||
# Extract visual prompt embeddings
|
||||
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
|
||||
for batch in pbar:
|
||||
batch = self.preprocess(batch)
|
||||
preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
|
||||
|
||||
batch_idx = batch["batch_idx"]
|
||||
for i in range(preds.shape[0]):
|
||||
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
||||
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
||||
pad_cls[: cls.shape[0]] = cls
|
||||
for c in cls:
|
||||
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
||||
|
||||
# Normalize embeddings for classes with samples, set others to zero
|
||||
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
|
||||
visual_pe[cls_visual_num == 0] = 0
|
||||
return visual_pe.unsqueeze(0)
|
||||
|
||||
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Create a dataloader for LVIS training visual prompt samples.
|
||||
|
||||
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
|
||||
It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
|
||||
for validation purposes.
|
||||
|
||||
Args:
|
||||
data (dict): Dataset configuration dictionary containing paths and settings.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): The dataloader for visual prompt samples.
|
||||
"""
|
||||
dataset = build_yolo_dataset(
|
||||
self.args,
|
||||
data.get(self.args.split, data.get("val")),
|
||||
self.args.batch,
|
||||
data,
|
||||
mode="val",
|
||||
rect=False,
|
||||
)
|
||||
if isinstance(dataset, YOLOConcatDataset):
|
||||
for d in dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
dataset.transforms.append(LoadVisualPrompt())
|
||||
return build_dataloader(
|
||||
dataset,
|
||||
self.args.batch,
|
||||
self.args.workers,
|
||||
shuffle=False,
|
||||
rank=-1,
|
||||
)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
trainer: Any | None = None,
|
||||
model: YOLOEModel | str | None = None,
|
||||
refer_data: str | None = None,
|
||||
load_vp: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run validation on the model using either text or visual prompt embeddings.
|
||||
|
||||
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
|
||||
It supports validation during training (using a trainer object) or standalone validation with a provided
|
||||
model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
|
||||
|
||||
Args:
|
||||
trainer (object, optional): Trainer object containing the model and device.
|
||||
model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
|
||||
refer_data (str, optional): Path to reference data for visual prompts.
|
||||
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
||||
|
||||
Returns:
|
||||
(dict): Validation statistics containing metrics computed during validation.
|
||||
"""
|
||||
if trainer is not None:
|
||||
self.device = trainer.device
|
||||
model = trainer.ema.ema
|
||||
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
||||
|
||||
if load_vp:
|
||||
LOGGER.info("Validate using the visual prompt.")
|
||||
self.args.half = False
|
||||
# Directly use the same dataloader for visual embeddings extracted during training
|
||||
vpe = self.get_visual_pe(self.dataloader, model)
|
||||
model.set_classes(names, vpe)
|
||||
else:
|
||||
LOGGER.info("Validate using the text prompt.")
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
stats = super().__call__(trainer, model)
|
||||
else:
|
||||
if refer_data is not None:
|
||||
assert load_vp, "Refer data is only used for visual prompt validation."
|
||||
self.device = select_device(self.args.device, verbose=False)
|
||||
|
||||
if isinstance(model, (str, Path)):
|
||||
from ultralytics.nn.tasks import load_checkpoint
|
||||
|
||||
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
|
||||
model.eval().to(self.device)
|
||||
data = check_det_dataset(refer_data or self.args.data)
|
||||
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
||||
|
||||
if load_vp:
|
||||
LOGGER.info("Validate using the visual prompt.")
|
||||
self.args.half = False
|
||||
# TODO: need to check if the names from refer data is consistent with the evaluated dataset
|
||||
# could use same dataset or refer to extract visual prompt embeddings
|
||||
dataloader = self.get_vpe_dataloader(data)
|
||||
vpe = self.get_visual_pe(dataloader, model)
|
||||
model.set_classes(names, vpe)
|
||||
stats = super().__call__(model=deepcopy(model))
|
||||
elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
|
||||
return super().__call__(trainer, model)
|
||||
else:
|
||||
LOGGER.info("Validate using the text prompt.")
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
stats = super().__call__(model=deepcopy(model))
|
||||
return stats
|
||||
|
||||
|
||||
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
|
||||
"""YOLOE segmentation validator that supports both text and visual prompt embeddings."""
|
||||
|
||||
pass
|
||||
Reference in New Issue
Block a user