init commit
This commit is contained in:
7
ultralytics/models/yolo/classify/__init__.py
Normal file
7
ultralytics/models/yolo/classify/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||
|
||||
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
93
ultralytics/models/yolo/classify/predict.py
Normal file
93
ultralytics/models/yolo/classify/predict.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.augment import classify_transforms
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a classification model.
|
||||
|
||||
This predictor handles the specific requirements of classification models, including preprocessing images
|
||||
and postprocessing predictions to generate classification results.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the predictor.
|
||||
|
||||
Methods:
|
||||
preprocess: Convert input images to model-compatible format.
|
||||
postprocess: Process model predictions into Results objects.
|
||||
|
||||
Notes:
|
||||
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
||||
>>> predictor = ClassificationPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
||||
|
||||
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
||||
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
||||
|
||||
Args:
|
||||
cfg (dict): Default configuration dictionary containing prediction settings.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "classify"
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Set up source and inference mode and classify transforms."""
|
||||
super().setup_source(source)
|
||||
updated = (
|
||||
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
||||
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
|
||||
else False
|
||||
)
|
||||
self.transforms = (
|
||||
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
|
||||
)
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack(
|
||||
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||
)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Process predictions to return Results objects with classification probabilities.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Input images after preprocessing.
|
||||
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing classification results for each image.
|
||||
"""
|
||||
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
return [
|
||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
223
ultralytics/models/yolo/classify/train.py
Normal file
223
ultralytics/models/yolo/classify/train.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
A trainer class extending BaseTrainer for training image classification models.
|
||||
|
||||
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
||||
and torchvision models with comprehensive dataset handling and validation.
|
||||
|
||||
Attributes:
|
||||
model (ClassificationModel): The classification model to be trained.
|
||||
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
||||
loss_names (list[str]): Names of the loss functions used during training.
|
||||
validator (ClassificationValidator): Validator instance for model evaluation.
|
||||
|
||||
Methods:
|
||||
set_model_attributes: Set the model's class names from the loaded dataset.
|
||||
get_model: Return a modified PyTorch model configured for training.
|
||||
setup_model: Load, create or download model for classification.
|
||||
build_dataset: Create a ClassificationDataset instance.
|
||||
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
||||
preprocess_batch: Preprocess a batch of images and classes.
|
||||
progress_string: Return a formatted string showing training progress.
|
||||
get_validator: Return an instance of ClassificationValidator.
|
||||
label_loss_items: Return a loss dict with labelled training loss items.
|
||||
final_eval: Evaluate trained model and save validation results.
|
||||
plot_training_samples: Plot training samples with their annotations.
|
||||
|
||||
Examples:
|
||||
Initialize and train a classification model
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
||||
>>> trainer = ClassificationTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a ClassificationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list[Any], optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
if overrides.get("imgsz") is None:
|
||||
overrides["imgsz"] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a modified PyTorch model configured for training YOLO classification.
|
||||
|
||||
Args:
|
||||
cfg (Any, optional): Model configuration.
|
||||
weights (Any, optional): Pre-trained model weights.
|
||||
verbose (bool, optional): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(ClassificationModel): Configured PyTorch model for classification.
|
||||
"""
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create or download model for classification tasks.
|
||||
|
||||
Returns:
|
||||
(Any): Model checkpoint if applicable, otherwise None.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if str(self.model) in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[self.model](
|
||||
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
||||
)
|
||||
ckpt = None
|
||||
else:
|
||||
ckpt = super().setup_model()
|
||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
||||
"""
|
||||
Create a ClassificationDataset instance given an image path and mode.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the dataset images.
|
||||
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
||||
batch (Any, optional): Batch information (unused in this implementation).
|
||||
|
||||
Returns:
|
||||
(ClassificationDataset): Dataset for the specified mode.
|
||||
"""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Return PyTorch DataLoader with transforms to preprocess images.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int, optional): Number of images per batch.
|
||||
rank (int, optional): Process rank for distributed training.
|
||||
mode (str, optional): 'train', 'val', or 'test' mode.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
||||
"""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
|
||||
# Attach inference transforms
|
||||
if mode != "train":
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Preprocess a batch of images and classes."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def progress_string(self) -> str:
|
||||
"""Return a formatted string showing training progress."""
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ["loss"]
|
||||
return yolo.classify.ClassificationValidator(
|
||||
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labelled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (torch.Tensor, optional): Loss tensor items.
|
||||
prefix (str, optional): Prefix to prepend to loss names.
|
||||
|
||||
Returns:
|
||||
keys (list[str]): List of loss keys if loss_items is None.
|
||||
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
|
||||
"""
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.data = self.args.data
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
||||
ni (int): Number of iterations.
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
214
ultralytics/models/yolo/classify/val.py
Normal file
214
ultralytics/models/yolo/classify/val.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a classification model.
|
||||
|
||||
This validator handles the validation process for classification models, including metrics calculation,
|
||||
confusion matrix generation, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
targets (list[torch.Tensor]): Ground truth class labels.
|
||||
pred (list[torch.Tensor]): Model predictions.
|
||||
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
||||
names (dict): Mapping of class indices to class names.
|
||||
nc (int): Number of classes.
|
||||
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
||||
|
||||
Methods:
|
||||
get_desc: Return a formatted string summarizing classification metrics.
|
||||
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
||||
preprocess: Preprocess input batch by moving data to device.
|
||||
update_metrics: Update running metrics with model predictions and batch targets.
|
||||
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
||||
postprocess: Extract the primary prediction from model output.
|
||||
get_stats: Calculate and return a dictionary of metrics.
|
||||
build_dataset: Create a ClassificationDataset instance for validation.
|
||||
get_dataloader: Build and return a data loader for classification validation.
|
||||
print_results: Print evaluation metrics for the classification model.
|
||||
plot_val_samples: Plot validation image samples with their ground truth labels.
|
||||
plot_predictions: Plot images with their predicted class labels.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict, optional): Arguments containing model and validation configuration.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = "classify"
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing classification metrics."""
|
||||
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update running metrics with model predictions and batch targets.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
||||
batch (dict): Batch data containing images and class labels.
|
||||
|
||||
Notes:
|
||||
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
||||
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
||||
"""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""
|
||||
Finalize metrics including confusion matrix and processing speed.
|
||||
|
||||
Notes:
|
||||
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
||||
optionally plots it, and updates the metrics object with speed information.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
||||
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
||||
>>> validator.finalize_metrics()
|
||||
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
||||
"""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.save_dir = self.save_dir
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
||||
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
|
||||
def get_stats(self) -> dict[str, float]:
|
||||
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
||||
"""Create a ClassificationDataset instance for validation."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Build and return a data loader for classification validation.
|
||||
|
||||
Args:
|
||||
dataset_path (str | Path): Path to the dataset directory.
|
||||
batch_size (int): Number of samples per batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path)
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print evaluation metrics for the classification model."""
|
||||
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples with their ground truth labels.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
||||
>>> validator.plot_val_samples(batch, 0)
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
||||
"""
|
||||
Plot images with their predicted class labels and save the visualization.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images and other information.
|
||||
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
||||
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
batched_preds = dict(
|
||||
img=batch["img"],
|
||||
batch_idx=torch.arange(batch["img"].shape[0]),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
)
|
||||
plot_images(
|
||||
batched_preds,
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
Reference in New Issue
Block a user