init commit
This commit is contained in:
214
ultralytics/models/yolo/classify/val.py
Normal file
214
ultralytics/models/yolo/classify/val.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a classification model.
|
||||
|
||||
This validator handles the validation process for classification models, including metrics calculation,
|
||||
confusion matrix generation, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
targets (list[torch.Tensor]): Ground truth class labels.
|
||||
pred (list[torch.Tensor]): Model predictions.
|
||||
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
||||
names (dict): Mapping of class indices to class names.
|
||||
nc (int): Number of classes.
|
||||
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
||||
|
||||
Methods:
|
||||
get_desc: Return a formatted string summarizing classification metrics.
|
||||
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
||||
preprocess: Preprocess input batch by moving data to device.
|
||||
update_metrics: Update running metrics with model predictions and batch targets.
|
||||
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
||||
postprocess: Extract the primary prediction from model output.
|
||||
get_stats: Calculate and return a dictionary of metrics.
|
||||
build_dataset: Create a ClassificationDataset instance for validation.
|
||||
get_dataloader: Build and return a data loader for classification validation.
|
||||
print_results: Print evaluation metrics for the classification model.
|
||||
plot_val_samples: Plot validation image samples with their ground truth labels.
|
||||
plot_predictions: Plot images with their predicted class labels.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict, optional): Arguments containing model and validation configuration.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = "classify"
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing classification metrics."""
|
||||
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update running metrics with model predictions and batch targets.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
||||
batch (dict): Batch data containing images and class labels.
|
||||
|
||||
Notes:
|
||||
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
||||
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
||||
"""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""
|
||||
Finalize metrics including confusion matrix and processing speed.
|
||||
|
||||
Notes:
|
||||
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
||||
optionally plots it, and updates the metrics object with speed information.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
||||
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
||||
>>> validator.finalize_metrics()
|
||||
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
||||
"""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.save_dir = self.save_dir
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
||||
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
|
||||
def get_stats(self) -> dict[str, float]:
|
||||
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
||||
"""Create a ClassificationDataset instance for validation."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Build and return a data loader for classification validation.
|
||||
|
||||
Args:
|
||||
dataset_path (str | Path): Path to the dataset directory.
|
||||
batch_size (int): Number of samples per batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path)
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print evaluation metrics for the classification model."""
|
||||
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples with their ground truth labels.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
||||
>>> validator.plot_val_samples(batch, 0)
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
||||
"""
|
||||
Plot images with their predicted class labels and save the visualization.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images and other information.
|
||||
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
||||
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
batched_preds = dict(
|
||||
img=batch["img"],
|
||||
batch_idx=torch.arange(batch["img"].shape[0]),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
)
|
||||
plot_images(
|
||||
batched_preds,
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
Reference in New Issue
Block a user