215 lines
9.9 KiB
Python
215 lines
9.9 KiB
Python
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
|
|
from ultralytics.data import ClassificationDataset, build_dataloader
|
|
from ultralytics.engine.validator import BaseValidator
|
|
from ultralytics.utils import LOGGER
|
|
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
|
from ultralytics.utils.plotting import plot_images
|
|
|
|
|
|
class ClassificationValidator(BaseValidator):
|
|
"""
|
|
A class extending the BaseValidator class for validation based on a classification model.
|
|
|
|
This validator handles the validation process for classification models, including metrics calculation,
|
|
confusion matrix generation, and visualization of results.
|
|
|
|
Attributes:
|
|
targets (list[torch.Tensor]): Ground truth class labels.
|
|
pred (list[torch.Tensor]): Model predictions.
|
|
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
|
names (dict): Mapping of class indices to class names.
|
|
nc (int): Number of classes.
|
|
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
|
|
|
Methods:
|
|
get_desc: Return a formatted string summarizing classification metrics.
|
|
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
|
preprocess: Preprocess input batch by moving data to device.
|
|
update_metrics: Update running metrics with model predictions and batch targets.
|
|
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
|
postprocess: Extract the primary prediction from model output.
|
|
get_stats: Calculate and return a dictionary of metrics.
|
|
build_dataset: Create a ClassificationDataset instance for validation.
|
|
get_dataloader: Build and return a data loader for classification validation.
|
|
print_results: Print evaluation metrics for the classification model.
|
|
plot_val_samples: Plot validation image samples with their ground truth labels.
|
|
plot_predictions: Plot images with their predicted class labels.
|
|
|
|
Examples:
|
|
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
|
>>> validator = ClassificationValidator(args=args)
|
|
>>> validator()
|
|
|
|
Notes:
|
|
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
|
"""
|
|
|
|
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
|
"""
|
|
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
|
|
|
Args:
|
|
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
|
save_dir (str | Path, optional): Directory to save results.
|
|
args (dict, optional): Arguments containing model and validation configuration.
|
|
_callbacks (list, optional): List of callback functions to be called during validation.
|
|
|
|
Examples:
|
|
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
|
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
|
>>> validator = ClassificationValidator(args=args)
|
|
>>> validator()
|
|
"""
|
|
super().__init__(dataloader, save_dir, args, _callbacks)
|
|
self.targets = None
|
|
self.pred = None
|
|
self.args.task = "classify"
|
|
self.metrics = ClassifyMetrics()
|
|
|
|
def get_desc(self) -> str:
|
|
"""Return a formatted string summarizing classification metrics."""
|
|
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
|
|
|
def init_metrics(self, model: torch.nn.Module) -> None:
|
|
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
|
self.names = model.names
|
|
self.nc = len(model.names)
|
|
self.pred = []
|
|
self.targets = []
|
|
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
|
|
|
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
|
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
|
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
|
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
|
return batch
|
|
|
|
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
|
"""
|
|
Update running metrics with model predictions and batch targets.
|
|
|
|
Args:
|
|
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
|
batch (dict): Batch data containing images and class labels.
|
|
|
|
Notes:
|
|
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
|
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
|
"""
|
|
n5 = min(len(self.names), 5)
|
|
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
|
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
|
|
|
def finalize_metrics(self) -> None:
|
|
"""
|
|
Finalize metrics including confusion matrix and processing speed.
|
|
|
|
Notes:
|
|
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
|
optionally plots it, and updates the metrics object with speed information.
|
|
|
|
Examples:
|
|
>>> validator = ClassificationValidator()
|
|
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
|
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
|
>>> validator.finalize_metrics()
|
|
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
|
"""
|
|
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
|
if self.args.plots:
|
|
for normalize in True, False:
|
|
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
|
self.metrics.speed = self.speed
|
|
self.metrics.save_dir = self.save_dir
|
|
self.metrics.confusion_matrix = self.confusion_matrix
|
|
|
|
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
|
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
|
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
|
|
|
def get_stats(self) -> dict[str, float]:
|
|
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
|
self.metrics.process(self.targets, self.pred)
|
|
return self.metrics.results_dict
|
|
|
|
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
|
"""Create a ClassificationDataset instance for validation."""
|
|
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
|
|
|
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
|
"""
|
|
Build and return a data loader for classification validation.
|
|
|
|
Args:
|
|
dataset_path (str | Path): Path to the dataset directory.
|
|
batch_size (int): Number of samples per batch.
|
|
|
|
Returns:
|
|
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
|
"""
|
|
dataset = self.build_dataset(dataset_path)
|
|
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
|
|
|
def print_results(self) -> None:
|
|
"""Print evaluation metrics for the classification model."""
|
|
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
|
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
|
|
|
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
|
"""
|
|
Plot validation image samples with their ground truth labels.
|
|
|
|
Args:
|
|
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
|
ni (int): Batch index used for naming the output file.
|
|
|
|
Examples:
|
|
>>> validator = ClassificationValidator()
|
|
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
|
>>> validator.plot_val_samples(batch, 0)
|
|
"""
|
|
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
|
plot_images(
|
|
labels=batch,
|
|
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
|
names=self.names,
|
|
on_plot=self.on_plot,
|
|
)
|
|
|
|
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
|
"""
|
|
Plot images with their predicted class labels and save the visualization.
|
|
|
|
Args:
|
|
batch (dict[str, Any]): Batch data containing images and other information.
|
|
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
|
ni (int): Batch index used for naming the output file.
|
|
|
|
Examples:
|
|
>>> validator = ClassificationValidator()
|
|
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
|
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
|
>>> validator.plot_predictions(batch, preds, 0)
|
|
"""
|
|
batched_preds = dict(
|
|
img=batch["img"],
|
|
batch_idx=torch.arange(batch["img"].shape[0]),
|
|
cls=torch.argmax(preds, dim=1),
|
|
)
|
|
plot_images(
|
|
batched_preds,
|
|
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
|
names=self.names,
|
|
on_plot=self.on_plot,
|
|
) # pred
|