init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"

View File

@@ -0,0 +1,93 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import cv2
import torch
from PIL import Image
from ultralytics.data.augment import classify_transforms
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images
and postprocessing predictions to generate classification results.
Attributes:
args (dict): Configuration arguments for the predictor.
Methods:
preprocess: Convert input images to model-compatible format.
postprocess: Process model predictions into Results objects.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
>>> predictor = ClassificationPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
tasks. It ensures the task is set to 'classify' regardless of input configuration.
Args:
cfg (dict): Default configuration dictionary containing prediction settings.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be executed during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "classify"
def setup_source(self, source):
"""Set up source and inference mode and classify transforms."""
super().setup_source(source)
updated = (
self.model.model.transforms.transforms[0].size != max(self.imgsz)
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
else False
)
self.transforms = (
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
)
def preprocess(self, img):
"""Convert input images to model-compatible tensor format with appropriate normalization."""
if not isinstance(img, torch.Tensor):
img = torch.stack(
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""
Process predictions to return Results objects with classification probabilities.
Args:
preds (torch.Tensor): Raw predictions from the model.
img (torch.Tensor): Input images after preprocessing.
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
Returns:
(list[Results]): List of Results objects containing classification results for each image.
"""
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]

View File

@@ -0,0 +1,223 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
"""
A trainer class extending BaseTrainer for training image classification models.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
and torchvision models with comprehensive dataset handling and validation.
Attributes:
model (ClassificationModel): The classification model to be trained.
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
loss_names (list[str]): Names of the loss functions used during training.
validator (ClassificationValidator): Validator instance for model evaluation.
Methods:
set_model_attributes: Set the model's class names from the loaded dataset.
get_model: Return a modified PyTorch model configured for training.
setup_model: Load, create or download model for classification.
build_dataset: Create a ClassificationDataset instance.
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
preprocess_batch: Preprocess a batch of images and classes.
progress_string: Return a formatted string showing training progress.
get_validator: Return an instance of ClassificationValidator.
label_loss_items: Return a loss dict with labelled training loss items.
final_eval: Evaluate trained model and save validation results.
plot_training_samples: Plot training samples with their annotations.
Examples:
Initialize and train a classification model
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a ClassificationTrainer object.
Args:
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list[Any], optional): List of callback functions to be executed during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "classify"
if overrides.get("imgsz") is None:
overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a modified PyTorch model configured for training YOLO classification.
Args:
cfg (Any, optional): Model configuration.
weights (Any, optional): Pre-trained model weights.
verbose (bool, optional): Whether to display model information.
Returns:
(ClassificationModel): Configured PyTorch model for classification.
"""
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""
Load, create or download model for classification tasks.
Returns:
(Any): Model checkpoint if applicable, otherwise None.
"""
import torchvision # scope for faster 'import ultralytics'
if str(self.model) in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[self.model](
weights="IMAGENET1K_V1" if self.args.pretrained else None
)
ckpt = None
else:
ckpt = super().setup_model()
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
"""
Create a ClassificationDataset instance given an image path and mode.
Args:
img_path (str): Path to the dataset images.
mode (str, optional): Dataset mode ('train', 'val', or 'test').
batch (Any, optional): Batch information (unused in this implementation).
Returns:
(ClassificationDataset): Dataset for the specified mode.
"""
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Return PyTorch DataLoader with transforms to preprocess images.
Args:
dataset_path (str): Path to the dataset.
batch_size (int, optional): Number of images per batch.
rank (int, optional): Process rank for distributed training.
mode (str, optional): 'train', 'val', or 'test' mode.
Returns:
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
"""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
# Attach inference transforms
if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Preprocess a batch of images and classes."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def progress_string(self) -> str:
"""Return a formatted string showing training progress."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def get_validator(self):
"""Return an instance of ClassificationValidator for validation."""
self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
"""
Return a loss dict with labelled training loss items tensor.
Args:
loss_items (torch.Tensor, optional): Loss tensor items.
prefix (str, optional): Prefix to prepend to loss names.
Returns:
keys (list[str]): List of loss keys if loss_items is None.
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
"""
Plot training samples with their annotations.
Args:
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
ni (int): Number of iterations.
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)

View File

@@ -0,0 +1,214 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
This validator handles the validation process for classification models, including metrics calculation,
confusion matrix generation, and visualization of results.
Attributes:
targets (list[torch.Tensor]): Ground truth class labels.
pred (list[torch.Tensor]): Model predictions.
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
names (dict): Mapping of class indices to class names.
nc (int): Number of classes.
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
Methods:
get_desc: Return a formatted string summarizing classification metrics.
init_metrics: Initialize confusion matrix, class names, and tracking containers.
preprocess: Preprocess input batch by moving data to device.
update_metrics: Update running metrics with model predictions and batch targets.
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
postprocess: Extract the primary prediction from model output.
get_stats: Calculate and return a dictionary of metrics.
build_dataset: Create a ClassificationDataset instance for validation.
get_dataloader: Build and return a data loader for classification validation.
print_results: Print evaluation metrics for the classification model.
plot_val_samples: Plot validation image samples with their ground truth labels.
plot_predictions: Plot images with their predicted class labels.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
Notes:
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (str | Path, optional): Directory to save results.
args (dict, optional): Arguments containing model and validation configuration.
_callbacks (list, optional): List of callback functions to be called during validation.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = "classify"
self.metrics = ClassifyMetrics()
def get_desc(self) -> str:
"""Return a formatted string summarizing classification metrics."""
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
def init_metrics(self, model: torch.nn.Module) -> None:
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
self.names = model.names
self.nc = len(model.names)
self.pred = []
self.targets = []
self.confusion_matrix = ConfusionMatrix(names=model.names)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
"""
Update running metrics with model predictions and batch targets.
Args:
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
batch (dict): Batch data containing images and class labels.
Notes:
This method appends the top-N predictions (sorted by confidence in descending order) to the
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
"""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
self.targets.append(batch["cls"].type(torch.int32).cpu())
def finalize_metrics(self) -> None:
"""
Finalize metrics including confusion matrix and processing speed.
Notes:
This method processes the accumulated predictions and targets to generate the confusion matrix,
optionally plots it, and updates the metrics object with speed information.
Examples:
>>> validator = ClassificationValidator()
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
>>> validator.targets = [torch.tensor([0])] # Ground truth class
>>> validator.finalize_metrics()
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
"""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.save_dir = self.save_dir
self.metrics.confusion_matrix = self.confusion_matrix
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
"""Extract the primary prediction from model output if it's in a list or tuple format."""
return preds[0] if isinstance(preds, (list, tuple)) else preds
def get_stats(self) -> dict[str, float]:
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path: str) -> ClassificationDataset:
"""Create a ClassificationDataset instance for validation."""
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Build and return a data loader for classification validation.
Args:
dataset_path (str | Path): Path to the dataset directory.
batch_size (int): Number of samples per batch.
Returns:
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
"""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self) -> None:
"""Print evaluation metrics for the classification model."""
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples with their ground truth labels.
Args:
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
>>> validator.plot_val_samples(batch, 0)
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
"""
Plot images with their predicted class labels and save the visualization.
Args:
batch (dict[str, Any]): Batch data containing images and other information.
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
>>> validator.plot_predictions(batch, preds, 0)
"""
batched_preds = dict(
img=batch["img"],
batch_idx=torch.arange(batch["img"].shape[0]),
cls=torch.argmax(preds, dim=1),
)
plot_images(
batched_preds,
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred