init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because it is too large Load Diff

1164
ultralytics/engine/model.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,517 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
Usage - sources:
$ yolo mode=predict model=yolo11n.pt source=0 # webcam
img.jpg # image
vid.mp4 # video
screen # screenshot
path/ # directory
list.txt # list of images
list.streams # list of streams
'path/*.jpg' # glob
'https://youtu.be/LNwODJXcvt4' # YouTube
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
Usage - formats:
$ yolo mode=predict model=yolo11n.pt # PyTorch
yolo11n.torchscript # TorchScript
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolo11n_openvino_model # OpenVINO
yolo11n.engine # TensorRT
yolo11n.mlpackage # CoreML (macOS-only)
yolo11n_saved_model # TensorFlow SavedModel
yolo11n.pb # TensorFlow GraphDef
yolo11n.tflite # TensorFlow Lite
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
yolo11n_paddle_model # PaddlePaddle
yolo11n.mnn # MNN
yolo11n_ncnn_model # NCNN
yolo11n_imx_model # Sony IMX
yolo11n_rknn_model # Rockchip RKNN
"""
from __future__ import annotations
import platform
import re
import threading
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data import load_inference_source
from ultralytics.data.augment import LetterBox
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
from ultralytics.utils.checks import check_imgsz, check_imshow
from ultralytics.utils.files import increment_path
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
STREAM_WARNING = """
inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
Example:
results = model(source=..., stream=True) # generator of Results objects
for r in results:
boxes = r.boxes # Boxes object for bbox outputs
masks = r.masks # Masks object for segment masks outputs
probs = r.probs # Class probabilities for classification outputs
"""
class BasePredictor:
"""
A base class for creating predictors.
This class provides the foundation for prediction functionality, handling model setup, inference,
and result processing across various input sources.
Attributes:
args (SimpleNamespace): Configuration for the predictor.
save_dir (Path): Directory to save results.
done_warmup (bool): Whether the predictor has finished setup.
model (torch.nn.Module): Model used for prediction.
data (dict): Data configuration.
device (torch.device): Device used for prediction.
dataset (Dataset): Dataset used for prediction.
vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
plotted_img (np.ndarray): Last plotted image.
source_type (SimpleNamespace): Type of input source.
seen (int): Number of images processed.
windows (list[str]): List of window names for visualization.
batch (tuple): Current batch data.
results (list[Any]): Current batch results.
transforms (callable): Image transforms for classification.
callbacks (dict[str, list[callable]]): Callback functions for different events.
txt_path (Path): Path to save text results.
_lock (threading.Lock): Lock for thread-safe inference.
Methods:
preprocess: Prepare input image before inference.
inference: Run inference on a given image.
postprocess: Process raw predictions into structured results.
predict_cli: Run prediction for command line interface.
setup_source: Set up input source and inference mode.
stream_inference: Stream inference on input source.
setup_model: Initialize and configure the model.
write_results: Write inference results to files.
save_predicted_images: Save prediction visualizations.
show: Display results in a window.
run_callbacks: Execute registered callbacks for an event.
add_callback: Register a new callback function.
"""
def __init__(
self,
cfg=DEFAULT_CFG,
overrides: dict[str, Any] | None = None,
_callbacks: dict[str, list[callable]] | None = None,
):
"""
Initialize the BasePredictor class.
Args:
cfg (str | dict): Path to a configuration file or a configuration dictionary.
overrides (dict, optional): Configuration overrides.
_callbacks (dict, optional): Dictionary of callback functions.
"""
self.args = get_cfg(cfg, overrides)
self.save_dir = get_save_dir(self.args)
if self.args.conf is None:
self.args.conf = 0.25 # default conf=0.25
self.done_warmup = False
if self.args.show:
self.args.show = check_imshow(warn=True)
# Usable if setup is done
self.model = None
self.data = self.args.data # data_dict
self.imgsz = None
self.device = None
self.dataset = None
self.vid_writer = {} # dict of {save_path: video_writer, ...}
self.plotted_img = None
self.source_type = None
self.seen = 0
self.windows = []
self.batch = None
self.results = None
self.transforms = None
self.callbacks = _callbacks or callbacks.get_default_callbacks()
self.txt_path = None
self._lock = threading.Lock() # for automatic thread-safe inference
callbacks.add_integration_callbacks(self)
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
"""
Prepare input image before inference.
Args:
im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
Returns:
(torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
"""
not_tensor = not isinstance(im, torch.Tensor)
if not_tensor:
im = np.stack(self.pre_transform(im))
if im.shape[-1] == 3:
im = im[..., ::-1] # BGR to RGB
im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im) # contiguous
im = torch.from_numpy(im)
im = im.to(self.device)
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
if not_tensor:
im /= 255 # 0 - 255 to 0.0 - 1.0
return im
def inference(self, im: torch.Tensor, *args, **kwargs):
"""Run inference on a given image using the specified model and arguments."""
visualize = (
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
if self.args.visualize and (not self.source_type.tensor)
else False
)
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
"""
Pre-transform input image before inference.
Args:
im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
Returns:
(list[np.ndarray]): List of transformed images.
"""
same_shapes = len({x.shape for x in im}) == 1
letterbox = LetterBox(
self.imgsz,
auto=same_shapes
and self.args.rect
and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
stride=self.model.stride,
)
return [letterbox(image=x) for x in im]
def postprocess(self, preds, img, orig_imgs):
"""Post-process predictions for an image and return them."""
return preds
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
"""
Perform inference on an image or stream.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
Source for inference.
model (str | Path | torch.nn.Module, optional): Model for inference.
stream (bool): Whether to stream the inference results. If True, returns a generator.
*args (Any): Additional arguments for the inference method.
**kwargs (Any): Additional keyword arguments for the inference method.
Returns:
(list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
"""
self.stream = stream
if stream:
return self.stream_inference(source, model, *args, **kwargs)
else:
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
def predict_cli(self, source=None, model=None):
"""
Method used for Command Line Interface (CLI) prediction.
This function is designed to run predictions using the CLI. It sets up the source and model, then processes
the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
generator without storing results.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
Source for inference.
model (str | Path | torch.nn.Module, optional): Model for inference.
Note:
Do not modify this function or remove the generator. The generator ensures that no outputs are
accumulated in memory, which is critical for preventing memory issues during long-running predictions.
"""
gen = self.stream_inference(source, model)
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
pass
def setup_source(self, source):
"""
Set up source and inference mode.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor):
Source for inference.
"""
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
self.dataset = load_inference_source(
source=source,
batch=self.args.batch,
vid_stride=self.args.vid_stride,
buffer=self.args.stream_buffer,
channels=getattr(self.model, "ch", 3),
)
self.source_type = self.dataset.source_type
long_sequence = (
self.source_type.stream
or self.source_type.screenshot
or len(self.dataset) > 1000 # many images
or any(getattr(self.dataset, "video_flag", [False]))
)
if long_sequence:
import torchvision # noqa (import here triggers torchvision NMS use in nms.py)
if not getattr(self, "stream", True): # videos
LOGGER.warning(STREAM_WARNING)
self.vid_writer = {}
@smart_inference_mode()
def stream_inference(self, source=None, model=None, *args, **kwargs):
"""
Stream real-time inference on camera feed and save results to file.
Args:
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
Source for inference.
model (str | Path | torch.nn.Module, optional): Model for inference.
*args (Any): Additional arguments for the inference method.
**kwargs (Any): Additional keyword arguments for the inference method.
Yields:
(ultralytics.engine.results.Results): Results objects.
"""
if self.args.verbose:
LOGGER.info("")
# Setup model
if not self.model:
self.setup_model(model)
with self._lock: # for thread-safe inference
# Setup source every time predict is called
self.setup_source(source if source is not None else self.args.source)
# Check if save_dir/ label file exists
if self.args.save or self.args.save_txt:
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# Warmup model
if not self.done_warmup:
self.model.warmup(
imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz)
)
self.done_warmup = True
self.seen, self.windows, self.batch = 0, [], None
profilers = (
ops.Profile(device=self.device),
ops.Profile(device=self.device),
ops.Profile(device=self.device),
)
self.run_callbacks("on_predict_start")
for self.batch in self.dataset:
self.run_callbacks("on_predict_batch_start")
paths, im0s, s = self.batch
# Preprocess
with profilers[0]:
im = self.preprocess(im0s)
# Inference
with profilers[1]:
preds = self.inference(im, *args, **kwargs)
if self.args.embed:
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
continue
# Postprocess
with profilers[2]:
self.results = self.postprocess(preds, im, im0s)
self.run_callbacks("on_predict_postprocess_end")
# Visualize, save, write results
n = len(im0s)
try:
for i in range(n):
self.seen += 1
self.results[i].speed = {
"preprocess": profilers[0].dt * 1e3 / n,
"inference": profilers[1].dt * 1e3 / n,
"postprocess": profilers[2].dt * 1e3 / n,
}
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
s[i] += self.write_results(i, Path(paths[i]), im, s)
except StopIteration:
break
# Print batch results
if self.args.verbose:
LOGGER.info("\n".join(s))
self.run_callbacks("on_predict_batch_end")
yield from self.results
# Release assets
for v in self.vid_writer.values():
if isinstance(v, cv2.VideoWriter):
v.release()
if self.args.show:
cv2.destroyAllWindows() # close any open windows
# Print final results
if self.args.verbose and self.seen:
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
LOGGER.info(
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
f"{(min(self.args.batch, self.seen), getattr(self.model, 'ch', 3), *im.shape[2:])}" % t
)
if self.args.save or self.args.save_txt or self.args.save_crop:
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
self.run_callbacks("on_predict_end")
def setup_model(self, model, verbose: bool = True):
"""
Initialize YOLO model with given parameters and set it to evaluation mode.
Args:
model (str | Path | torch.nn.Module, optional): Model to load or use.
verbose (bool): Whether to print verbose output.
"""
self.model = AutoBackend(
model=model or self.args.model,
device=select_device(self.args.device, verbose=verbose),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
fuse=True,
verbose=verbose,
)
self.device = self.model.device # update device
self.args.half = self.model.fp16 # update half
if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
self.model.eval()
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
"""
Write inference results to a file or directory.
Args:
i (int): Index of the current image in the batch.
p (Path): Path to the current image.
im (torch.Tensor): Preprocessed image tensor.
s (list[str]): List of result strings.
Returns:
(str): String with result information.
"""
string = "" # print string
if len(im.shape) == 3:
im = im[None] # expand for batch dim
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
string += f"{i}: "
frame = self.dataset.count
else:
match = re.search(r"frame (\d+)/", s[i])
frame = int(match[1]) if match else None # 0 if frame undetermined
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
string += "{:g}x{:g} ".format(*im.shape[2:])
result = self.results[i]
result.save_dir = self.save_dir.__str__() # used in other locations
string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
# Add predictions to image
if self.args.save or self.args.show:
self.plotted_img = result.plot(
line_width=self.args.line_width,
boxes=self.args.show_boxes,
conf=self.args.show_conf,
labels=self.args.show_labels,
im_gpu=None if self.args.retina_masks else im[i],
)
# Save results
if self.args.save_txt:
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
if self.args.save_crop:
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
if self.args.show:
self.show(str(p))
if self.args.save:
self.save_predicted_images(self.save_dir / p.name, frame)
return string
def save_predicted_images(self, save_path: Path, frame: int = 0):
"""
Save video predictions as mp4 or images as jpg at specified path.
Args:
save_path (Path): Path to save the results.
frame (int): Frame number for video mode.
"""
im = self.plotted_img
# Save videos and streams
if self.dataset.mode in {"stream", "video"}:
fps = self.dataset.fps if self.dataset.mode == "video" else 30
frames_path = self.save_dir / f"{save_path.stem}_frames" # save frames to a separate directory
if save_path not in self.vid_writer: # new video
if self.args.save_frames:
Path(frames_path).mkdir(parents=True, exist_ok=True)
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
self.vid_writer[save_path] = cv2.VideoWriter(
filename=str(Path(save_path).with_suffix(suffix)),
fourcc=cv2.VideoWriter_fourcc(*fourcc),
fps=fps, # integer required, floats produce error in MP4 codec
frameSize=(im.shape[1], im.shape[0]), # (width, height)
)
# Save video
self.vid_writer[save_path].write(im)
if self.args.save_frames:
cv2.imwrite(f"{frames_path}/{save_path.stem}_{frame}.jpg", im)
# Save images
else:
cv2.imwrite(str(save_path.with_suffix(".jpg")), im) # save to JPG for best support
def show(self, p: str = ""):
"""Display an image in a window."""
im = self.plotted_img
if platform.system() == "Linux" and p not in self.windows:
self.windows.append(p)
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
cv2.imshow(p, im)
if cv2.waitKey(300 if self.dataset.mode == "image" else 1) & 0xFF == ord("q"): # 300ms if image; else 1ms
raise StopIteration
def run_callbacks(self, event: str):
"""Run all registered callbacks for a specific event."""
for callback in self.callbacks.get(event, []):
callback(self)
def add_callback(self, event: str, func: callable):
"""Add a callback function for a specific event."""
self.callbacks[event].append(func)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,904 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Train a model on a dataset.
Usage:
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
"""
import gc
import math
import os
import subprocess
import time
import warnings
from copy import copy, deepcopy
from datetime import datetime, timedelta
from pathlib import Path
import numpy as np
import torch
from torch import distributed as dist
from torch import nn, optim
from ultralytics import __version__
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.tasks import load_checkpoint
from ultralytics.utils import (
DEFAULT_CFG,
GIT,
LOCAL_RANK,
LOGGER,
RANK,
TQDM,
YAML,
callbacks,
clean_url,
colorstr,
emojis,
)
from ultralytics.utils.autobatch import check_train_batch_size
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
from ultralytics.utils.files import get_latest_run
from ultralytics.utils.plotting import plot_results
from ultralytics.utils.torch_utils import (
TORCH_2_4,
EarlyStopping,
ModelEMA,
attempt_compile,
autocast,
convert_optimizer_state_dict_to_fp16,
init_seeds,
one_cycle,
select_device,
strip_optimizer,
torch_distributed_zero_first,
unset_deterministic,
unwrap_model,
)
class BaseTrainer:
"""
A base class for creating trainers.
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
Attributes:
args (SimpleNamespace): Configuration for the trainer.
validator (BaseValidator): Validator instance.
model (nn.Module): Model instance.
callbacks (defaultdict): Dictionary of callbacks.
save_dir (Path): Directory to save results.
wdir (Path): Directory to save weights.
last (Path): Path to the last checkpoint.
best (Path): Path to the best checkpoint.
save_period (int): Save checkpoint every x epochs (disabled if < 1).
batch_size (int): Batch size for training.
epochs (int): Number of epochs to train for.
start_epoch (int): Starting epoch for training.
device (torch.device): Device to use for training.
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
scaler (amp.GradScaler): Gradient scaler for AMP.
data (str): Path to data.
ema (nn.Module): EMA (Exponential Moving Average) of the model.
resume (bool): Resume training from a checkpoint.
lf (nn.Module): Loss function.
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
best_fitness (float): The best fitness value achieved.
fitness (float): Current fitness value.
loss (float): Current loss value.
tloss (float): Total loss value.
loss_names (list): List of loss names.
csv (Path): Path to results CSV file.
metrics (dict): Dictionary of metrics.
plots (dict): Dictionary of plots.
Methods:
train: Execute the training process.
validate: Run validation on the test set.
save_model: Save model training checkpoints.
get_dataset: Get train and validation datasets.
setup_model: Load, create, or download model.
build_optimizer: Construct an optimizer for the model.
Examples:
Initialize a trainer and start training
>>> trainer = BaseTrainer(cfg="config.yaml")
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the BaseTrainer class.
Args:
cfg (str, optional): Path to a configuration file.
overrides (dict, optional): Configuration overrides.
_callbacks (list, optional): List of callback functions.
"""
self.hub_session = overrides.pop("session", None) # HUB
self.args = get_cfg(cfg, overrides)
self.check_resume(overrides)
self.device = select_device(self.args.device)
# Update "-1" devices so post-training val does not repeat search
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
self.validator = None
self.metrics = None
self.plots = {}
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
# Dirs
self.save_dir = get_save_dir(self.args)
self.args.name = self.save_dir.name # update name for loggers
self.wdir = self.save_dir / "weights" # weights dir
if RANK in {-1, 0}:
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
self.args.save_dir = str(self.save_dir)
YAML.save(self.save_dir / "args.yaml", vars(self.args)) # save run args
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
self.save_period = self.args.save_period
self.batch_size = self.args.batch
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
self.start_epoch = 0
if RANK == -1:
print_args(vars(self.args))
# Device
if self.device.type in {"cpu", "mps"}:
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
# Model and Dataset
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
self.data = self.get_dataset()
self.ema = None
# Optimization utils init
self.lf = None
self.scheduler = None
# Epoch level metrics
self.best_fitness = None
self.fitness = None
self.loss = None
self.tloss = None
self.loss_names = ["Loss"]
self.csv = self.save_dir / "results.csv"
self.plot_idx = [0, 1, 2]
# Callbacks
self.callbacks = _callbacks or callbacks.get_default_callbacks()
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
world_size = len(self.args.device.split(","))
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
world_size = len(self.args.device)
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
world_size = 0
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
world_size = 1 # default to device 0
else: # i.e. device=None or device=''
world_size = 0
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
self.world_size = world_size
# Run subprocess if DDP training, else train normally
if RANK in {-1, 0} and not self.ddp:
callbacks.add_integration_callbacks(self)
# Start console logging immediately at trainer initialization
self.run_callbacks("on_pretrain_routine_start")
def add_callback(self, event: str, callback):
"""Append the given callback to the event's callback list."""
self.callbacks[event].append(callback)
def set_callback(self, event: str, callback):
"""Override the existing callbacks with the given callback for the specified event."""
self.callbacks[event] = [callback]
def run_callbacks(self, event: str):
"""Run all existing callbacks associated with a particular event."""
for callback in self.callbacks.get(event, []):
callback(self)
def train(self):
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
# Run subprocess if DDP training, else train normally
if self.ddp:
# Argument checks
if self.args.rect:
LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
self.args.rect = False
if self.args.batch < 1.0:
raise ValueError(
"AutoBatch with batch<1 not supported for Multi-GPU training, "
f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
)
# Command
cmd, file = generate_ddp_command(self)
try:
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
subprocess.run(cmd, check=True)
except Exception as e:
raise e
finally:
ddp_cleanup(self, str(file))
else:
self._do_train()
def _setup_scheduler(self):
"""Initialize training learning rate scheduler."""
if self.args.cos_lr:
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
else:
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
def _setup_ddp(self):
"""Initialize and set the DistributedDataParallel parameters for training."""
torch.cuda.set_device(RANK)
self.device = torch.device("cuda", RANK)
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
dist.init_process_group(
backend="nccl" if dist.is_nccl_available() else "gloo",
timeout=timedelta(seconds=10800), # 3 hours
rank=RANK,
world_size=self.world_size,
)
def _setup_train(self):
"""Build dataloaders and optimizer on correct rank process."""
ckpt = self.setup_model()
self.model = self.model.to(self.device)
self.set_model_attributes()
# Compile model
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
# Freeze layers
freeze_list = (
self.args.freeze
if isinstance(self.args.freeze, list)
else range(self.args.freeze)
if isinstance(self.args.freeze, int)
else []
)
always_freeze_names = [".dfl"] # always freeze these layers
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
self.freeze_layer_names = freeze_layer_names
for k, v in self.model.named_parameters():
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
if any(x in k for x in freeze_layer_names):
LOGGER.info(f"Freezing layer '{k}'")
v.requires_grad = False
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
LOGGER.warning(
f"setting 'requires_grad=True' for frozen layer '{k}'. "
"See ultralytics.engine.trainer for customization of frozen layers."
)
v.requires_grad = True
# Check AMP
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
self.amp = torch.tensor(check_amp(self.model), device=self.device)
callbacks.default_callbacks = callbacks_backup # restore callbacks
if RANK > -1 and self.world_size > 1: # DDP
dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
self.amp = bool(self.amp) # as boolean
self.scaler = (
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
)
if self.world_size > 1:
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
# Check imgsz
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
self.stride = gs # for multiscale training
# Batch size
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
self.args.batch = self.batch_size = self.auto_batch()
# Dataloaders
batch_size = self.batch_size // max(self.world_size, 1)
self.train_loader = self.get_dataloader(
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
)
if RANK in {-1, 0}:
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
self.test_loader = self.get_dataloader(
self.data.get("val") or self.data.get("test"),
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
rank=-1,
mode="val",
)
self.validator = self.get_validator()
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
self.ema = ModelEMA(self.model)
if self.args.plots:
self.plot_training_labels()
# Optimizer
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
self.optimizer = self.build_optimizer(
model=self.model,
name=self.args.optimizer,
lr=self.args.lr0,
momentum=self.args.momentum,
decay=weight_decay,
iterations=iterations,
)
# Scheduler
self._setup_scheduler()
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
self.resume_training(ckpt)
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
self.run_callbacks("on_pretrain_routine_end")
def _do_train(self):
"""Train the model with the specified world size."""
if self.world_size > 1:
self._setup_ddp()
self._setup_train()
nb = len(self.train_loader) # number of batches
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
last_opt_step = -1
self.epoch_time = None
self.epoch_time_start = time.time()
self.train_time_start = time.time()
self.run_callbacks("on_train_start")
LOGGER.info(
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
f"Logging results to {colorstr('bold', self.save_dir)}\n"
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
)
if self.args.close_mosaic:
base_idx = (self.epochs - self.args.close_mosaic) * nb
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
epoch = self.start_epoch
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
while True:
self.epoch = epoch
self.run_callbacks("on_train_epoch_start")
with warnings.catch_warnings():
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
self.scheduler.step()
self._model_train()
if RANK != -1:
self.train_loader.sampler.set_epoch(epoch)
pbar = enumerate(self.train_loader)
# Update dataloader attributes (optional)
if epoch == (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
self.train_loader.reset()
if RANK in {-1, 0}:
LOGGER.info(self.progress_string())
pbar = TQDM(enumerate(self.train_loader), total=nb)
self.tloss = None
for i, batch in pbar:
self.run_callbacks("on_train_batch_start")
# Warmup
ni = i + nb * epoch
if ni <= nw:
xi = [0, nw] # x interp
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
for j, x in enumerate(self.optimizer.param_groups):
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
x["lr"] = np.interp(
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
)
if "momentum" in x:
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
# Forward
with autocast(self.amp):
batch = self.preprocess_batch(batch)
if self.args.compile:
# Decouple inference and loss calculations for improved compile performance
preds = self.model(batch["img"])
loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
else:
loss, self.loss_items = self.model(batch)
self.loss = loss.sum()
if RANK != -1:
self.loss *= self.world_size
self.tloss = (
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
)
# Backward
self.scaler.scale(self.loss).backward()
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
if ni - last_opt_step >= self.accumulate:
self.optimizer_step()
last_opt_step = ni
# Timed stopping
if self.args.time:
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
self.stop = broadcast_list[0]
if self.stop: # training time exceeded
break
# Log
if RANK in {-1, 0}:
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
pbar.set_description(
("%11s" * 2 + "%11.4g" * (2 + loss_length))
% (
f"{epoch + 1}/{self.epochs}",
f"{self._get_memory():.3g}G", # (GB) GPU memory util
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
batch["cls"].shape[0], # batch size, i.e. 8
batch["img"].shape[-1], # imgsz, i.e 640
)
)
self.run_callbacks("on_batch_end")
if self.args.plots and ni in self.plot_idx:
self.plot_training_samples(batch, ni)
self.run_callbacks("on_train_batch_end")
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
self.run_callbacks("on_train_epoch_end")
if RANK in {-1, 0}:
final_epoch = epoch + 1 >= self.epochs
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
# Validation
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
self._clear_memory(threshold=0.5) # prevent VRAM spike
self.metrics, self.fitness = self.validate()
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
if self.args.time:
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
# Save model
if self.args.save or final_epoch:
self.save_model()
self.run_callbacks("on_model_save")
# Scheduler
t = time.time()
self.epoch_time = t - self.epoch_time_start
self.epoch_time_start = t
if self.args.time:
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
self._setup_scheduler()
self.scheduler.last_epoch = self.epoch # do not move
self.stop |= epoch >= self.epochs # stop if exceeded epochs
self.run_callbacks("on_fit_epoch_end")
self._clear_memory(0.5) # clear if memory utilization > 50%
# Early Stopping
if RANK != -1: # if DDP training
broadcast_list = [self.stop if RANK == 0 else None]
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
self.stop = broadcast_list[0]
if self.stop:
break # must break all DDP ranks
epoch += 1
if RANK in {-1, 0}:
# Do final val with best.pt
seconds = time.time() - self.train_time_start
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
self.final_eval()
if self.args.plots:
self.plot_metrics()
self.run_callbacks("on_train_end")
self._clear_memory()
unset_deterministic()
self.run_callbacks("teardown")
def auto_batch(self, max_num_obj=0):
"""Calculate optimal batch size based on model and device memory constraints."""
return check_train_batch_size(
model=self.model,
imgsz=self.args.imgsz,
amp=self.amp,
batch=self.batch_size,
max_num_obj=max_num_obj,
) # returns batch size
def _get_memory(self, fraction=False):
"""Get accelerator memory utilization in GB or as a fraction of total memory."""
memory, total = 0, 0
if self.device.type == "mps":
memory = torch.mps.driver_allocated_memory()
if fraction:
return __import__("psutil").virtual_memory().percent / 100
elif self.device.type != "cpu":
memory = torch.cuda.memory_reserved()
if fraction:
total = torch.cuda.get_device_properties(self.device).total_memory
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
def _clear_memory(self, threshold: float = None):
"""Clear accelerator memory by calling garbage collector and emptying cache."""
if threshold:
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
if self._get_memory(fraction=True) <= threshold:
return
gc.collect()
if self.device.type == "mps":
torch.mps.empty_cache()
elif self.device.type == "cpu":
return
else:
torch.cuda.empty_cache()
def read_results_csv(self):
"""Read results.csv into a dictionary using polars."""
import polars as pl # scope for faster 'import ultralytics'
return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
def _model_train(self):
"""Set model in training mode."""
self.model.train()
# Freeze BN stat
for n, m in self.model.named_modules():
if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
m.eval()
def save_model(self):
"""Save model training checkpoints with additional metadata."""
import io
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
buffer = io.BytesIO()
torch.save(
{
"epoch": self.epoch,
"best_fitness": self.best_fitness,
"model": None, # resume and final checkpoints derive from EMA
"ema": deepcopy(unwrap_model(self.ema.ema)).half(),
"updates": self.ema.updates,
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
"scaler": self.scaler.state_dict(),
"train_args": vars(self.args), # save as dict
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
"train_results": self.read_results_csv(),
"date": datetime.now().isoformat(),
"version": __version__,
"git": {
"root": str(GIT.root),
"branch": GIT.branch,
"commit": GIT.commit,
"origin": GIT.origin,
},
"license": "AGPL-3.0 (https://ultralytics.com/license)",
"docs": "https://docs.ultralytics.com",
},
buffer,
)
serialized_ckpt = buffer.getvalue() # get the serialized content to save
# Save checkpoints
self.last.write_bytes(serialized_ckpt) # save last.pt
if self.best_fitness == self.fitness:
self.best.write_bytes(serialized_ckpt) # save best.pt
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
def get_dataset(self):
"""
Get train and validation datasets from data dictionary.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
"""
try:
if self.args.task == "classify":
data = check_cls_dataset(self.args.data)
elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
# Convert NDJSON to YOLO format
import asyncio
from ultralytics.data.converter import convert_ndjson_to_yolo
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
self.args.data = str(yaml_path)
data = check_det_dataset(self.args.data)
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
"detect",
"segment",
"pose",
"obb",
}:
data = check_det_dataset(self.args.data)
if "yaml_file" in data:
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
except Exception as e:
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
if self.args.single_cls:
LOGGER.info("Overriding class names with single class.")
data["names"] = {0: "item"}
data["nc"] = 1
return data
def setup_model(self):
"""
Load, create, or download model for any task.
Returns:
(dict): Optional checkpoint to resume training from.
"""
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
return
cfg, weights = self.model, None
ckpt = None
if str(self.model).endswith(".pt"):
weights, ckpt = load_checkpoint(self.model)
cfg = weights.yaml
elif isinstance(self.args.pretrained, (str, Path)):
weights, _ = load_checkpoint(self.args.pretrained)
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
return ckpt
def optimizer_step(self):
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
self.scaler.unscale_(self.optimizer) # unscale gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.ema:
self.ema.update(self.model)
def preprocess_batch(self, batch):
"""Allow custom preprocessing model inputs and ground truths depending on task type."""
return batch
def validate(self):
"""
Run validation on val set using self.validator.
Returns:
metrics (dict): Dictionary of validation metrics.
fitness (float): Fitness score for the validation.
"""
metrics = self.validator(self)
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
if not self.best_fitness or self.best_fitness < fitness:
self.best_fitness = fitness
return metrics, fitness
def get_model(self, cfg=None, weights=None, verbose=True):
"""Get model and raise NotImplementedError for loading cfg files."""
raise NotImplementedError("This task trainer doesn't support loading cfg files")
def get_validator(self):
"""Return a NotImplementedError when the get_validator function is called."""
raise NotImplementedError("get_validator function not implemented in trainer")
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
"""Return dataloader derived from torch.data.Dataloader."""
raise NotImplementedError("get_dataloader function not implemented in trainer")
def build_dataset(self, img_path, mode="train", batch=None):
"""Build dataset."""
raise NotImplementedError("build_dataset function not implemented in trainer")
def label_loss_items(self, loss_items=None, prefix="train"):
"""
Return a loss dict with labelled training loss items tensor.
Note:
This is not needed for classification but necessary for segmentation & detection
"""
return {"loss": loss_items} if loss_items is not None else ["loss"]
def set_model_attributes(self):
"""Set or update model parameters before training."""
self.model.names = self.data["names"]
def build_targets(self, preds, targets):
"""Build target tensors for training YOLO model."""
pass
def progress_string(self):
"""Return a string describing training progress."""
return ""
# TODO: may need to put these following functions into callback
def plot_training_samples(self, batch, ni):
"""Plot training samples during YOLO training."""
pass
def plot_training_labels(self):
"""Plot training labels for YOLO model."""
pass
def save_metrics(self, metrics):
"""Save training metrics to a CSV file."""
keys, vals = list(metrics.keys()), list(metrics.values())
n = len(metrics) + 2 # number of cols
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
t = time.time() - self.train_time_start
with open(self.csv, "a", encoding="utf-8") as f:
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
def plot_metrics(self):
"""Plot metrics from a CSV file."""
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
def on_plot(self, name, data=None):
"""Register plots (e.g. to be consumed in callbacks)."""
path = Path(name)
self.plots[path] = {"data": data, "timestamp": time.time()}
def final_eval(self):
"""Perform final evaluation and validation for object detection YOLO model."""
ckpt = {}
for f in self.last, self.best:
if f.exists():
if f is self.last:
ckpt = strip_optimizer(f)
elif f is self.best:
k = "train_results" # update best.pt train_metrics from last.pt
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
LOGGER.info(f"\nValidating {f}...")
self.validator.args.plots = self.args.plots
self.validator.args.compile = False # disable final val compile as too slow
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def check_resume(self, overrides):
"""Check if resume checkpoint exists and update arguments accordingly."""
resume = self.args.resume
if resume:
try:
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
last = Path(check_file(resume) if exists else get_latest_run())
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
ckpt_args = load_checkpoint(last)[0].args
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
ckpt_args["data"] = self.args.data
resume = True
self.args = get_cfg(ckpt_args)
self.args.model = self.args.resume = str(last) # reinstate model
for k in (
"imgsz",
"batch",
"device",
"close_mosaic",
): # allow arg updates to reduce memory or update device on resume
if k in overrides:
setattr(self.args, k, overrides[k])
except Exception as e:
raise FileNotFoundError(
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
"i.e. 'yolo train resume model=path/to/last.pt'"
) from e
self.resume = resume
def resume_training(self, ckpt):
"""Resume YOLO training from given epoch and best fitness."""
if ckpt is None or not self.resume:
return
best_fitness = 0.0
start_epoch = ckpt.get("epoch", -1) + 1
if ckpt.get("optimizer") is not None:
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
best_fitness = ckpt["best_fitness"]
if ckpt.get("scaler") is not None:
self.scaler.load_state_dict(ckpt["scaler"])
if self.ema and ckpt.get("ema"):
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
self.ema.updates = ckpt["updates"]
assert start_epoch > 0, (
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
)
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
if self.epochs < start_epoch:
LOGGER.info(
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
)
self.epochs += ckpt["epoch"] # finetune additional epochs
self.best_fitness = best_fitness
self.start_epoch = start_epoch
if start_epoch > (self.epochs - self.args.close_mosaic):
self._close_dataloader_mosaic()
def _close_dataloader_mosaic(self):
"""Update dataloaders to stop using mosaic augmentation."""
if hasattr(self.train_loader.dataset, "mosaic"):
self.train_loader.dataset.mosaic = False
if hasattr(self.train_loader.dataset, "close_mosaic"):
LOGGER.info("Closing dataloader mosaic")
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
"""
Construct an optimizer for the given model.
Args:
model (torch.nn.Module): The model for which to build an optimizer.
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
based on the number of iterations.
lr (float, optional): The learning rate for the optimizer.
momentum (float, optional): The momentum factor for the optimizer.
decay (float, optional): The weight decay for the optimizer.
iterations (float, optional): The number of iterations, which determines the optimizer if
name is 'auto'.
Returns:
(torch.optim.Optimizer): The constructed optimizer.
"""
g = [], [], [] # optimizer parameter groups
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
if name == "auto":
LOGGER.info(
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
)
nc = self.data.get("nc", 10) # number of classes
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
for module_name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
fullname = f"{module_name}.{param_name}" if module_name else param_name
if "bias" in fullname: # bias (no decay)
g[2].append(param)
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
g[1].append(param)
else: # weight (with decay)
g[0].append(param)
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
name = {x.lower(): x for x in optimizers}.get(name.lower())
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
elif name == "RMSProp":
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
elif name == "SGD":
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
else:
raise NotImplementedError(
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
)
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
LOGGER.info(
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
)
return optimizer

459
ultralytics/engine/tuner.py Normal file
View File

@@ -0,0 +1,459 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
segmentation, image classification, pose estimation, and multi-object tracking.
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
Examples:
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt")
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
"""
from __future__ import annotations
import gc
import random
import shutil
import subprocess
import time
from datetime import datetime
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.patches import torch_load
from ultralytics.utils.plotting import plot_tune_results
class Tuner:
"""
A class for hyperparameter tuning of YOLO models.
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
search space and retraining the model to evaluate their performance. Supports both local CSV storage and
distributed MongoDB Atlas coordination for multi-machine hyperparameter optimization.
Attributes:
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
tune_dir (Path): Directory where evolution logs and results will be saved.
tune_csv (Path): Path to the CSV file where evolution logs are saved.
args (dict): Configuration arguments for the tuning process.
callbacks (list): Callback functions to be executed during tuning.
prefix (str): Prefix string for logging messages.
mongodb (MongoClient): Optional MongoDB client for distributed tuning.
collection (Collection): MongoDB collection for storing tuning results.
Methods:
_mutate: Mutate hyperparameters based on bounds and scaling factors.
__call__: Execute the hyperparameter evolution across multiple iterations.
Examples:
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt")
>>> model.tune(
>>> data="coco8.yaml",
>>> epochs=10,
>>> iterations=300,
>>> plots=False,
>>> save=False,
>>> val=False
>>> )
Tune with distributed MongoDB Atlas coordination across multiple machines:
>>> model.tune(
>>> data="coco8.yaml",
>>> epochs=10,
>>> iterations=300,
>>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
>>> mongodb_db="ultralytics",
>>> mongodb_collection="tune_results"
>>> )
Tune with custom search space:
>>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
"""
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
"""
Initialize the Tuner with configurations.
Args:
args (dict): Configuration for hyperparameter evolution.
_callbacks (list | None, optional): Callback functions to be executed during tuning.
"""
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
"box": (1.0, 20.0), # box loss gain
"cls": (0.1, 4.0), # cls loss gain (scale with pixels)
"dfl": (0.4, 6.0), # dfl loss gain
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": (0.0, 45.0), # image rotation (+/- deg)
"translate": (0.0, 0.9), # image translation (+/- fraction)
"scale": (0.0, 0.95), # image scale (+/- gain)
"shear": (0.0, 10.0), # image shear (+/- deg)
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": (0.0, 1.0), # image flip up-down (probability)
"fliplr": (0.0, 1.0), # image flip left-right (probability)
"bgr": (0.0, 1.0), # image channel bgr (probability)
"mosaic": (0.0, 1.0), # image mosaic (probability)
"mixup": (0.0, 1.0), # image mixup (probability)
"cutmix": (0.0, 1.0), # image cutmix (probability)
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
"close_mosaic": (0.0, 10.0), # close dataloader mosaic (epochs)
}
mongodb_uri = args.pop("mongodb_uri", None)
mongodb_db = args.pop("mongodb_db", "ultralytics")
mongodb_collection = args.pop("mongodb_collection", "tuner_results")
self.args = get_cfg(overrides=args)
self.args.exist_ok = self.args.resume # resume w/ same tune_dir
self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
self.args.name, self.args.exist_ok, self.args.resume = (None, False, False) # reset to not affect training
self.tune_csv = self.tune_dir / "tune_results.csv"
self.callbacks = _callbacks or callbacks.get_default_callbacks()
self.prefix = colorstr("Tuner: ")
callbacks.add_integration_callbacks(self)
# MongoDB Atlas support (optional)
self.mongodb = None
if mongodb_uri:
self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
LOGGER.info(
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
)
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
"""
Create MongoDB client with exponential backoff retry on connection failures.
Args:
uri (str): MongoDB connection string with credentials and cluster information.
max_retries (int): Maximum number of connection attempts before giving up.
Returns:
(MongoClient): Connected MongoDB client instance.
"""
check_requirements("pymongo")
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
for attempt in range(max_retries):
try:
client = MongoClient(
uri,
serverSelectionTimeoutMS=30000,
connectTimeoutMS=20000,
socketTimeoutMS=40000,
retryWrites=True,
retryReads=True,
maxPoolSize=30,
minPoolSize=3,
maxIdleTimeMS=60000,
)
client.admin.command("ping") # Test connection
LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
return client
except (ConnectionFailure, ServerSelectionTimeoutError):
if attempt == max_retries - 1:
raise
wait_time = 2**attempt
LOGGER.warning(
f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
)
time.sleep(wait_time)
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
"""
Initialize MongoDB connection for distributed tuning.
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
Each worker saves results to a shared collection and reads the latest best hyperparameters
from all workers for evolution.
Args:
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
mongodb_db (str, optional): Database name.
mongodb_collection (str, optional): Collection name.
Notes:
- Creates a fitness index for fast queries of top results
- Falls back to CSV-only mode if connection fails
- Uses connection pooling and retry logic for production reliability
"""
self.mongodb = self._connect(mongodb_uri)
self.collection = self.mongodb[mongodb_db][mongodb_collection]
self.collection.create_index([("fitness", -1)], background=True)
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
def _get_mongodb_results(self, n: int = 5) -> list:
"""
Get top N results from MongoDB sorted by fitness.
Args:
n (int): Number of top results to retrieve.
Returns:
(list[dict]): List of result documents with fitness scores and hyperparameters.
"""
try:
return list(self.collection.find().sort("fitness", -1).limit(n))
except Exception:
return []
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
"""
Save results to MongoDB with proper type conversion.
Args:
fitness (float): Fitness score achieved with these hyperparameters.
hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
iteration (int): Current iteration number.
"""
try:
self.collection.insert_one(
{
"fitness": float(fitness),
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
"metrics": metrics,
"timestamp": datetime.now(),
"iteration": iteration,
}
)
except Exception as e:
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
def _sync_mongodb_to_csv(self):
"""
Sync MongoDB results to CSV for plotting compatibility.
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
the existing plotting functions to work seamlessly with distributed MongoDB data.
"""
try:
# Get all results from MongoDB
all_results = list(self.collection.find().sort("iteration", 1))
if not all_results:
return
# Write to CSV
headers = ",".join(["fitness"] + list(self.space.keys())) + "\n"
with open(self.tune_csv, "w", encoding="utf-8") as f:
f.write(headers)
for result in all_results:
fitness = result["fitness"]
hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
log_row = [round(fitness, 5)] + hyp_values
f.write(",".join(map(str, log_row)) + "\n")
except Exception as e:
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
k = min(k, len(x))
# fitness weights (shifted to >0); fallback to uniform if degenerate
weights = x[:, 0] - x[:, 0].min() + 1e-6
if not np.isfinite(weights).all() or weights.sum() == 0:
weights = np.ones_like(weights)
idxs = random.choices(range(len(x)), weights=weights, k=k)
parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
lo, hi = parents_mat.min(0), parents_mat.max(0)
span = hi - lo
return np.random.uniform(lo - alpha * span, hi + alpha * span)
def _mutate(
self,
n: int = 9,
mutation: float = 0.5,
sigma: float = 0.2,
) -> dict[str, float]:
"""
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
Args:
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
n (int): Number of top parents to consider.
mutation (float): Probability of a parameter mutation in any given iteration.
sigma (float): Standard deviation for Gaussian random number generator.
Returns:
(dict[str, float]): A dictionary containing mutated hyperparameters.
"""
x = None
# Try MongoDB first if available
if self.mongodb:
results = self._get_mongodb_results(n)
if results:
# MongoDB already sorted by fitness DESC, so results[0] is best
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
x = np.array([[0.0] + [getattr(self.args, k) for k in self.space.keys()]])
# Fall back to CSV if MongoDB unavailable or empty
if x is None and self.tune_csv.exists():
csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
if len(csv_data) > 0:
fitness = csv_data[:, 0] # first column
order = np.argsort(-fitness)
x = csv_data[order][:n] # top-n sorted by fitness DESC
# Mutate if we have data, otherwise use defaults
if x is not None:
np.random.seed(int(time.time()))
ng = len(self.space)
# Crossover
genes = self._crossover(x)
# Mutation
gains = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
factors = np.ones(ng)
while np.all(factors == 1): # mutate until a change occurs (prevent duplicates)
mask = np.random.random(ng) < mutation
step = np.random.randn(ng) * (sigma * gains)
factors = np.where(mask, np.exp(step), 1.0).clip(0.25, 4.0)
hyp = {k: float(genes[i] * factors[i]) for i, k in enumerate(self.space.keys())}
else:
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
# Constrain to limits
for k, bounds in self.space.items():
hyp[k] = round(min(max(hyp[k], bounds[0]), bounds[1]), 5)
# Update types
if "close_mosaic" in hyp:
hyp["close_mosaic"] = int(round(hyp["close_mosaic"]))
return hyp
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
"""
Execute the hyperparameter evolution process when the Tuner instance is called.
This method iterates through the specified number of iterations, performing the following steps:
1. Sync MongoDB results to CSV (if using distributed mode)
2. Mutate hyperparameters using the best previous results or defaults
3. Train a YOLO model with the mutated hyperparameters
4. Log fitness scores and hyperparameters to MongoDB and/or CSV
5. Track the best performing configuration across all iterations
Args:
model (Model | None, optional): A pre-initialized YOLO model to be used for training.
iterations (int): The number of generations to run the evolution for.
cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
"""
t0 = time.time()
best_save_dir, best_metrics = None, None
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
# Sync MongoDB to CSV at startup for proper resume logic
if self.mongodb:
self._sync_mongodb_to_csv()
start = 0
if self.tune_csv.exists():
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
start = x.shape[0]
LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
for i in range(start, iterations):
# Linearly decay sigma from 0.2 → 0.1 over first 300 iterations
frac = min(i / 300.0, 1.0)
sigma_i = 0.2 - 0.1 * frac
# Mutate hyperparameters
mutated_hyp = self._mutate(sigma=sigma_i)
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
metrics = {}
train_args = {**vars(self.args), **mutated_hyp}
save_dir = get_save_dir(get_cfg(train_args))
weights_dir = save_dir / "weights"
try:
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
launch = [__import__("sys").executable, "-m", "ultralytics.cfg.__init__"] # workaround yolo not found
cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())]
return_code = subprocess.run(cmd, check=True).returncode
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
metrics = torch_load(ckpt_file)["train_metrics"]
assert return_code == 0, "training failed"
# Cleanup
time.sleep(1)
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
# Save results - MongoDB takes precedence
fitness = metrics.get("fitness", 0.0)
if self.mongodb:
self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
self._sync_mongodb_to_csv()
total_mongo_iterations = self.collection.count_documents({})
if total_mongo_iterations >= iterations:
LOGGER.info(
f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
)
break
else:
# Save to CSV only if no MongoDB
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
with open(self.tune_csv, "a", encoding="utf-8") as f:
f.write(headers + ",".join(map(str, log_row)) + "\n")
# Get best results
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
fitness = x[:, 0] # first column
best_idx = fitness.argmax()
best_is_current = best_idx == (i - start)
if best_is_current:
best_save_dir = str(save_dir)
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
for ckpt in weights_dir.glob("*.pt"):
shutil.copy2(ckpt, self.tune_dir / "weights")
elif cleanup and best_save_dir:
shutil.rmtree(best_save_dir, ignore_errors=True) # remove iteration dirs to reduce storage space
# Plot tune results
plot_tune_results(str(self.tune_csv))
# Save and print tune results
header = (
f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n"
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
f"{self.prefix}Best fitness model is {best_save_dir}"
)
LOGGER.info("\n" + header)
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
YAML.save(
self.tune_dir / "best_hyperparameters.yaml",
data=data,
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
)
YAML.print(self.tune_dir / "best_hyperparameters.yaml")

View File

@@ -0,0 +1,370 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Check a model's accuracy on a test or val split of a dataset.
Usage:
$ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
Usage - formats:
$ yolo mode=val model=yolo11n.pt # PyTorch
yolo11n.torchscript # TorchScript
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
yolo11n_openvino_model # OpenVINO
yolo11n.engine # TensorRT
yolo11n.mlpackage # CoreML (macOS-only)
yolo11n_saved_model # TensorFlow SavedModel
yolo11n.pb # TensorFlow GraphDef
yolo11n.tflite # TensorFlow Lite
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
yolo11n_paddle_model # PaddlePaddle
yolo11n.mnn # MNN
yolo11n_ncnn_model # NCNN
yolo11n_imx_model # Sony IMX
yolo11n_rknn_model # Rockchip RKNN
"""
import json
import time
from pathlib import Path
import numpy as np
import torch
from ultralytics.cfg import get_cfg, get_save_dir
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.ops import Profile
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
class BaseValidator:
"""
A base class for creating validators.
This class provides the foundation for validation processes, including model evaluation, metric computation, and
result visualization.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
model (nn.Module): Model to validate.
data (dict): Data dictionary containing dataset information.
device (torch.device): Device to use for validation.
batch_i (int): Current batch index.
training (bool): Whether the model is in training mode.
names (dict): Class names mapping.
seen (int): Number of images seen so far during validation.
stats (dict): Statistics collected during validation.
confusion_matrix: Confusion matrix for classification evaluation.
nc (int): Number of classes.
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
jdict (list): List to store JSON validation results.
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
batch processing times in milliseconds.
save_dir (Path): Directory to save results.
plots (dict): Dictionary to store plots for visualization.
callbacks (dict): Dictionary to store various callback functions.
stride (int): Model stride for padding calculations.
loss (torch.Tensor): Accumulated loss during training validation.
Methods:
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
match_predictions: Match predictions to ground truth objects using IoU.
add_callback: Append the given callback to the specified event.
run_callbacks: Run all callbacks associated with a specified event.
get_dataloader: Get data loader from dataset path and batch size.
build_dataset: Build dataset from image path.
preprocess: Preprocess an input batch.
postprocess: Postprocess the predictions.
init_metrics: Initialize performance metrics for the YOLO model.
update_metrics: Update metrics based on predictions and batch.
finalize_metrics: Finalize and return all metrics.
get_stats: Return statistics about the model's performance.
print_results: Print the results of the model's predictions.
get_desc: Get description of the YOLO model.
on_plot: Register plots for visualization.
plot_val_samples: Plot validation samples during training.
plot_predictions: Plot YOLO model predictions on batch images.
pred_to_json: Convert predictions to JSON format.
eval_json: Evaluate and return JSON format of prediction statistics.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
"""
Initialize a BaseValidator instance.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
args (SimpleNamespace, optional): Configuration for the validator.
_callbacks (dict, optional): Dictionary to store various callback functions.
"""
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.stride = None
self.data = None
self.device = None
self.batch_i = None
self.training = True
self.names = None
self.seen = None
self.stats = None
self.confusion_matrix = None
self.nc = None
self.iouv = None
self.jdict = None
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
self.save_dir = save_dir or get_save_dir(self.args)
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
if self.args.conf is None:
self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
self.plots = {}
self.callbacks = _callbacks or callbacks.get_default_callbacks()
@smart_inference_mode()
def __call__(self, trainer=None, model=None):
"""
Execute validation process, running inference on dataloader and computing performance metrics.
Args:
trainer (object, optional): Trainer object that contains the model to validate.
model (nn.Module, optional): Model to validate if not using a trainer.
Returns:
(dict): Dictionary containing validation statistics.
"""
self.training = trainer is not None
augment = self.args.augment and (not self.training)
if self.training:
self.device = trainer.device
self.data = trainer.data
# Force FP16 val during training
self.args.half = self.device.type != "cpu" and trainer.amp
model = trainer.ema.ema or trainer.model
if trainer.args.compile and hasattr(model, "_orig_mod"):
model = model._orig_mod # validate non-compiled original model to avoid issues
model = model.half() if self.args.half else model.float()
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
model.eval()
else:
if str(self.args.model).endswith(".yaml") and model is None:
LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
callbacks.add_integration_callbacks(self)
model = AutoBackend(
model=model or self.args.model,
device=select_device(self.args.device),
dnn=self.args.dnn,
data=self.args.data,
fp16=self.args.half,
)
self.device = model.device # update device
self.args.half = model.fp16 # update half
stride, pt, jit = model.stride, model.pt, model.jit
imgsz = check_imgsz(self.args.imgsz, stride=stride)
if not (pt or jit or getattr(model, "dynamic", False)):
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
self.data = check_det_dataset(self.args.data)
elif self.args.task == "classify":
self.data = check_cls_dataset(self.args.data, split=self.args.split)
else:
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
if self.device.type in {"cpu", "mps"}:
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
self.args.rect = False
self.stride = model.stride # used in get_dataloader() for padding
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
model.eval()
if self.args.compile:
model = attempt_compile(model, device=self.device)
model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
self.run_callbacks("on_val_start")
dt = (
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
Profile(device=self.device),
)
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
self.init_metrics(unwrap_model(model))
self.jdict = [] # empty before each val
for batch_i, batch in enumerate(bar):
self.run_callbacks("on_val_batch_start")
self.batch_i = batch_i
# Preprocess
with dt[0]:
batch = self.preprocess(batch)
# Inference
with dt[1]:
preds = model(batch["img"], augment=augment)
# Loss
with dt[2]:
if self.training:
self.loss += model.loss(batch, preds)[1]
# Postprocess
with dt[3]:
preds = self.postprocess(preds)
self.update_metrics(preds, batch)
if self.args.plots and batch_i < 3:
self.plot_val_samples(batch, batch_i)
self.plot_predictions(batch, preds, batch_i)
self.run_callbacks("on_val_batch_end")
stats = self.get_stats()
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
self.finalize_metrics()
self.print_results()
self.run_callbacks("on_val_end")
if self.training:
model.float()
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
else:
LOGGER.info(
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
*tuple(self.speed.values())
)
)
if self.args.save_json and self.jdict:
with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
LOGGER.info(f"Saving {f.name}...")
json.dump(self.jdict, f) # flatten and save
stats = self.eval_json(stats) # update stats
if self.args.plots or self.args.save_json:
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
return stats
def match_predictions(
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
) -> torch.Tensor:
"""
Match predictions to ground truth objects using IoU.
Args:
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
true_classes (torch.Tensor): Target class indices of shape (M,).
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
use_scipy (bool, optional): Whether to use scipy for matching (more precise).
Returns:
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
"""
# Dx10 matrix, where D - detections, 10 - IoU thresholds
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
# LxD matrix where L - labels (rows), D - detections (columns)
correct_class = true_classes[:, None] == pred_classes
iou = iou * correct_class # zero out the wrong classes
iou = iou.cpu().numpy()
for i, threshold in enumerate(self.iouv.cpu().tolist()):
if use_scipy:
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
import scipy # scope import to avoid importing for all commands
cost_matrix = iou * (iou >= threshold)
if cost_matrix.any():
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
valid = cost_matrix[labels_idx, detections_idx] > 0
if valid.any():
correct[detections_idx[valid], i] = True
else:
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
matches = np.array(matches).T
if matches.shape[0]:
if matches.shape[0] > 1:
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
def add_callback(self, event: str, callback):
"""Append the given callback to the specified event."""
self.callbacks[event].append(callback)
def run_callbacks(self, event: str):
"""Run all callbacks associated with a specified event."""
for callback in self.callbacks.get(event, []):
callback(self)
def get_dataloader(self, dataset_path, batch_size):
"""Get data loader from dataset path and batch size."""
raise NotImplementedError("get_dataloader function not implemented for this validator")
def build_dataset(self, img_path):
"""Build dataset from image path."""
raise NotImplementedError("build_dataset function not implemented in validator")
def preprocess(self, batch):
"""Preprocess an input batch."""
return batch
def postprocess(self, preds):
"""Postprocess the predictions."""
return preds
def init_metrics(self, model):
"""Initialize performance metrics for the YOLO model."""
pass
def update_metrics(self, preds, batch):
"""Update metrics based on predictions and batch."""
pass
def finalize_metrics(self):
"""Finalize and return all metrics."""
pass
def get_stats(self):
"""Return statistics about the model's performance."""
return {}
def print_results(self):
"""Print the results of the model's predictions."""
pass
def get_desc(self):
"""Get description of the YOLO model."""
pass
@property
def metric_keys(self):
"""Return the metric keys used in YOLO training/validation."""
return []
def on_plot(self, name, data=None):
"""Register plots for visualization."""
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
def plot_val_samples(self, batch, ni):
"""Plot validation samples during training."""
pass
def plot_predictions(self, batch, preds, ni):
"""Plot YOLO model predictions on batch images."""
pass
def pred_to_json(self, preds, batch):
"""Convert predictions to JSON format."""
pass
def eval_json(self, stats):
"""Evaluate and return JSON format of prediction statistics."""
pass