init commit
This commit is contained in:
1
ultralytics/engine/__init__.py
Normal file
1
ultralytics/engine/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
BIN
ultralytics/engine/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/exporter.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/exporter.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/predictor.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/predictor.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/results.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/results.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/trainer.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/trainer.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/engine/__pycache__/validator.cpython-310.pyc
Normal file
BIN
ultralytics/engine/__pycache__/validator.cpython-310.pyc
Normal file
Binary file not shown.
1472
ultralytics/engine/exporter.py
Normal file
1472
ultralytics/engine/exporter.py
Normal file
File diff suppressed because it is too large
Load Diff
1164
ultralytics/engine/model.py
Normal file
1164
ultralytics/engine/model.py
Normal file
File diff suppressed because it is too large
Load Diff
517
ultralytics/engine/predictor.py
Normal file
517
ultralytics/engine/predictor.py
Normal file
@@ -0,0 +1,517 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Run prediction on images, videos, directories, globs, YouTube, webcam, streams, etc.
|
||||
|
||||
Usage - sources:
|
||||
$ yolo mode=predict model=yolo11n.pt source=0 # webcam
|
||||
img.jpg # image
|
||||
vid.mp4 # video
|
||||
screen # screenshot
|
||||
path/ # directory
|
||||
list.txt # list of images
|
||||
list.streams # list of streams
|
||||
'path/*.jpg' # glob
|
||||
'https://youtu.be/LNwODJXcvt4' # YouTube
|
||||
'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP, TCP stream
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=predict model=yolo11n.pt # PyTorch
|
||||
yolo11n.torchscript # TorchScript
|
||||
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolo11n_openvino_model # OpenVINO
|
||||
yolo11n.engine # TensorRT
|
||||
yolo11n.mlpackage # CoreML (macOS-only)
|
||||
yolo11n_saved_model # TensorFlow SavedModel
|
||||
yolo11n.pb # TensorFlow GraphDef
|
||||
yolo11n.tflite # TensorFlow Lite
|
||||
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolo11n_paddle_model # PaddlePaddle
|
||||
yolo11n.mnn # MNN
|
||||
yolo11n_ncnn_model # NCNN
|
||||
yolo11n_imx_model # Sony IMX
|
||||
yolo11n_rknn_model # Rockchip RKNN
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data import load_inference_source
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, MACOS, WINDOWS, callbacks, colorstr, ops
|
||||
from ultralytics.utils.checks import check_imgsz, check_imshow
|
||||
from ultralytics.utils.files import increment_path
|
||||
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode
|
||||
|
||||
STREAM_WARNING = """
|
||||
inference results will accumulate in RAM unless `stream=True` is passed, causing potential out-of-memory
|
||||
errors for large sources or long-running streams and videos. See https://docs.ultralytics.com/modes/predict/ for help.
|
||||
|
||||
Example:
|
||||
results = model(source=..., stream=True) # generator of Results objects
|
||||
for r in results:
|
||||
boxes = r.boxes # Boxes object for bbox outputs
|
||||
masks = r.masks # Masks object for segment masks outputs
|
||||
probs = r.probs # Class probabilities for classification outputs
|
||||
"""
|
||||
|
||||
|
||||
class BasePredictor:
|
||||
"""
|
||||
A base class for creating predictors.
|
||||
|
||||
This class provides the foundation for prediction functionality, handling model setup, inference,
|
||||
and result processing across various input sources.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the predictor.
|
||||
save_dir (Path): Directory to save results.
|
||||
done_warmup (bool): Whether the predictor has finished setup.
|
||||
model (torch.nn.Module): Model used for prediction.
|
||||
data (dict): Data configuration.
|
||||
device (torch.device): Device used for prediction.
|
||||
dataset (Dataset): Dataset used for prediction.
|
||||
vid_writer (dict[str, cv2.VideoWriter]): Dictionary of {save_path: video_writer} for saving video output.
|
||||
plotted_img (np.ndarray): Last plotted image.
|
||||
source_type (SimpleNamespace): Type of input source.
|
||||
seen (int): Number of images processed.
|
||||
windows (list[str]): List of window names for visualization.
|
||||
batch (tuple): Current batch data.
|
||||
results (list[Any]): Current batch results.
|
||||
transforms (callable): Image transforms for classification.
|
||||
callbacks (dict[str, list[callable]]): Callback functions for different events.
|
||||
txt_path (Path): Path to save text results.
|
||||
_lock (threading.Lock): Lock for thread-safe inference.
|
||||
|
||||
Methods:
|
||||
preprocess: Prepare input image before inference.
|
||||
inference: Run inference on a given image.
|
||||
postprocess: Process raw predictions into structured results.
|
||||
predict_cli: Run prediction for command line interface.
|
||||
setup_source: Set up input source and inference mode.
|
||||
stream_inference: Stream inference on input source.
|
||||
setup_model: Initialize and configure the model.
|
||||
write_results: Write inference results to files.
|
||||
save_predicted_images: Save prediction visualizations.
|
||||
show: Display results in a window.
|
||||
run_callbacks: Execute registered callbacks for an event.
|
||||
add_callback: Register a new callback function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg=DEFAULT_CFG,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
_callbacks: dict[str, list[callable]] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the BasePredictor class.
|
||||
|
||||
Args:
|
||||
cfg (str | dict): Path to a configuration file or a configuration dictionary.
|
||||
overrides (dict, optional): Configuration overrides.
|
||||
_callbacks (dict, optional): Dictionary of callback functions.
|
||||
"""
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.save_dir = get_save_dir(self.args)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.25 # default conf=0.25
|
||||
self.done_warmup = False
|
||||
if self.args.show:
|
||||
self.args.show = check_imshow(warn=True)
|
||||
|
||||
# Usable if setup is done
|
||||
self.model = None
|
||||
self.data = self.args.data # data_dict
|
||||
self.imgsz = None
|
||||
self.device = None
|
||||
self.dataset = None
|
||||
self.vid_writer = {} # dict of {save_path: video_writer, ...}
|
||||
self.plotted_img = None
|
||||
self.source_type = None
|
||||
self.seen = 0
|
||||
self.windows = []
|
||||
self.batch = None
|
||||
self.results = None
|
||||
self.transforms = None
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
self.txt_path = None
|
||||
self._lock = threading.Lock() # for automatic thread-safe inference
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
def preprocess(self, im: torch.Tensor | list[np.ndarray]) -> torch.Tensor:
|
||||
"""
|
||||
Prepare input image before inference.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor | list[np.ndarray]): Images of shape (N, 3, H, W) for tensor, [(H, W, 3) x N] for list.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Preprocessed image tensor of shape (N, 3, H, W).
|
||||
"""
|
||||
not_tensor = not isinstance(im, torch.Tensor)
|
||||
if not_tensor:
|
||||
im = np.stack(self.pre_transform(im))
|
||||
if im.shape[-1] == 3:
|
||||
im = im[..., ::-1] # BGR to RGB
|
||||
im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
|
||||
im = np.ascontiguousarray(im) # contiguous
|
||||
im = torch.from_numpy(im)
|
||||
|
||||
im = im.to(self.device)
|
||||
im = im.half() if self.model.fp16 else im.float() # uint8 to fp16/32
|
||||
if not_tensor:
|
||||
im /= 255 # 0 - 255 to 0.0 - 1.0
|
||||
return im
|
||||
|
||||
def inference(self, im: torch.Tensor, *args, **kwargs):
|
||||
"""Run inference on a given image using the specified model and arguments."""
|
||||
visualize = (
|
||||
increment_path(self.save_dir / Path(self.batch[0][0]).stem, mkdir=True)
|
||||
if self.args.visualize and (not self.source_type.tensor)
|
||||
else False
|
||||
)
|
||||
return self.model(im, augment=self.args.augment, visualize=visualize, embed=self.args.embed, *args, **kwargs)
|
||||
|
||||
def pre_transform(self, im: list[np.ndarray]) -> list[np.ndarray]:
|
||||
"""
|
||||
Pre-transform input image before inference.
|
||||
|
||||
Args:
|
||||
im (list[np.ndarray]): List of images with shape [(H, W, 3) x N].
|
||||
|
||||
Returns:
|
||||
(list[np.ndarray]): List of transformed images.
|
||||
"""
|
||||
same_shapes = len({x.shape for x in im}) == 1
|
||||
letterbox = LetterBox(
|
||||
self.imgsz,
|
||||
auto=same_shapes
|
||||
and self.args.rect
|
||||
and (self.model.pt or (getattr(self.model, "dynamic", False) and not self.model.imx)),
|
||||
stride=self.model.stride,
|
||||
)
|
||||
return [letterbox(image=x) for x in im]
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""Post-process predictions for an image and return them."""
|
||||
return preds
|
||||
|
||||
def __call__(self, source=None, model=None, stream: bool = False, *args, **kwargs):
|
||||
"""
|
||||
Perform inference on an image or stream.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
||||
Source for inference.
|
||||
model (str | Path | torch.nn.Module, optional): Model for inference.
|
||||
stream (bool): Whether to stream the inference results. If True, returns a generator.
|
||||
*args (Any): Additional arguments for the inference method.
|
||||
**kwargs (Any): Additional keyword arguments for the inference method.
|
||||
|
||||
Returns:
|
||||
(list[ultralytics.engine.results.Results] | generator): Results objects or generator of Results objects.
|
||||
"""
|
||||
self.stream = stream
|
||||
if stream:
|
||||
return self.stream_inference(source, model, *args, **kwargs)
|
||||
else:
|
||||
return list(self.stream_inference(source, model, *args, **kwargs)) # merge list of Result into one
|
||||
|
||||
def predict_cli(self, source=None, model=None):
|
||||
"""
|
||||
Method used for Command Line Interface (CLI) prediction.
|
||||
|
||||
This function is designed to run predictions using the CLI. It sets up the source and model, then processes
|
||||
the inputs in a streaming manner. This method ensures that no outputs accumulate in memory by consuming the
|
||||
generator without storing results.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
||||
Source for inference.
|
||||
model (str | Path | torch.nn.Module, optional): Model for inference.
|
||||
|
||||
Note:
|
||||
Do not modify this function or remove the generator. The generator ensures that no outputs are
|
||||
accumulated in memory, which is critical for preventing memory issues during long-running predictions.
|
||||
"""
|
||||
gen = self.stream_inference(source, model)
|
||||
for _ in gen: # sourcery skip: remove-empty-nested-block, noqa
|
||||
pass
|
||||
|
||||
def setup_source(self, source):
|
||||
"""
|
||||
Set up source and inference mode.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor):
|
||||
Source for inference.
|
||||
"""
|
||||
self.imgsz = check_imgsz(self.args.imgsz, stride=self.model.stride, min_dim=2) # check image size
|
||||
self.dataset = load_inference_source(
|
||||
source=source,
|
||||
batch=self.args.batch,
|
||||
vid_stride=self.args.vid_stride,
|
||||
buffer=self.args.stream_buffer,
|
||||
channels=getattr(self.model, "ch", 3),
|
||||
)
|
||||
self.source_type = self.dataset.source_type
|
||||
long_sequence = (
|
||||
self.source_type.stream
|
||||
or self.source_type.screenshot
|
||||
or len(self.dataset) > 1000 # many images
|
||||
or any(getattr(self.dataset, "video_flag", [False]))
|
||||
)
|
||||
if long_sequence:
|
||||
import torchvision # noqa (import here triggers torchvision NMS use in nms.py)
|
||||
|
||||
if not getattr(self, "stream", True): # videos
|
||||
LOGGER.warning(STREAM_WARNING)
|
||||
self.vid_writer = {}
|
||||
|
||||
@smart_inference_mode()
|
||||
def stream_inference(self, source=None, model=None, *args, **kwargs):
|
||||
"""
|
||||
Stream real-time inference on camera feed and save results to file.
|
||||
|
||||
Args:
|
||||
source (str | Path | list[str] | list[Path] | list[np.ndarray] | np.ndarray | torch.Tensor, optional):
|
||||
Source for inference.
|
||||
model (str | Path | torch.nn.Module, optional): Model for inference.
|
||||
*args (Any): Additional arguments for the inference method.
|
||||
**kwargs (Any): Additional keyword arguments for the inference method.
|
||||
|
||||
Yields:
|
||||
(ultralytics.engine.results.Results): Results objects.
|
||||
"""
|
||||
if self.args.verbose:
|
||||
LOGGER.info("")
|
||||
|
||||
# Setup model
|
||||
if not self.model:
|
||||
self.setup_model(model)
|
||||
|
||||
with self._lock: # for thread-safe inference
|
||||
# Setup source every time predict is called
|
||||
self.setup_source(source if source is not None else self.args.source)
|
||||
|
||||
# Check if save_dir/ label file exists
|
||||
if self.args.save or self.args.save_txt:
|
||||
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Warmup model
|
||||
if not self.done_warmup:
|
||||
self.model.warmup(
|
||||
imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, self.model.ch, *self.imgsz)
|
||||
)
|
||||
self.done_warmup = True
|
||||
|
||||
self.seen, self.windows, self.batch = 0, [], None
|
||||
profilers = (
|
||||
ops.Profile(device=self.device),
|
||||
ops.Profile(device=self.device),
|
||||
ops.Profile(device=self.device),
|
||||
)
|
||||
self.run_callbacks("on_predict_start")
|
||||
for self.batch in self.dataset:
|
||||
self.run_callbacks("on_predict_batch_start")
|
||||
paths, im0s, s = self.batch
|
||||
|
||||
# Preprocess
|
||||
with profilers[0]:
|
||||
im = self.preprocess(im0s)
|
||||
|
||||
# Inference
|
||||
with profilers[1]:
|
||||
preds = self.inference(im, *args, **kwargs)
|
||||
if self.args.embed:
|
||||
yield from [preds] if isinstance(preds, torch.Tensor) else preds # yield embedding tensors
|
||||
continue
|
||||
|
||||
# Postprocess
|
||||
with profilers[2]:
|
||||
self.results = self.postprocess(preds, im, im0s)
|
||||
self.run_callbacks("on_predict_postprocess_end")
|
||||
|
||||
# Visualize, save, write results
|
||||
n = len(im0s)
|
||||
try:
|
||||
for i in range(n):
|
||||
self.seen += 1
|
||||
self.results[i].speed = {
|
||||
"preprocess": profilers[0].dt * 1e3 / n,
|
||||
"inference": profilers[1].dt * 1e3 / n,
|
||||
"postprocess": profilers[2].dt * 1e3 / n,
|
||||
}
|
||||
if self.args.verbose or self.args.save or self.args.save_txt or self.args.show:
|
||||
s[i] += self.write_results(i, Path(paths[i]), im, s)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Print batch results
|
||||
if self.args.verbose:
|
||||
LOGGER.info("\n".join(s))
|
||||
|
||||
self.run_callbacks("on_predict_batch_end")
|
||||
yield from self.results
|
||||
|
||||
# Release assets
|
||||
for v in self.vid_writer.values():
|
||||
if isinstance(v, cv2.VideoWriter):
|
||||
v.release()
|
||||
|
||||
if self.args.show:
|
||||
cv2.destroyAllWindows() # close any open windows
|
||||
|
||||
# Print final results
|
||||
if self.args.verbose and self.seen:
|
||||
t = tuple(x.t / self.seen * 1e3 for x in profilers) # speeds per image
|
||||
LOGGER.info(
|
||||
f"Speed: %.1fms preprocess, %.1fms inference, %.1fms postprocess per image at shape "
|
||||
f"{(min(self.args.batch, self.seen), getattr(self.model, 'ch', 3), *im.shape[2:])}" % t
|
||||
)
|
||||
if self.args.save or self.args.save_txt or self.args.save_crop:
|
||||
nl = len(list(self.save_dir.glob("labels/*.txt"))) # number of labels
|
||||
s = f"\n{nl} label{'s' * (nl > 1)} saved to {self.save_dir / 'labels'}" if self.args.save_txt else ""
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}{s}")
|
||||
self.run_callbacks("on_predict_end")
|
||||
|
||||
def setup_model(self, model, verbose: bool = True):
|
||||
"""
|
||||
Initialize YOLO model with given parameters and set it to evaluation mode.
|
||||
|
||||
Args:
|
||||
model (str | Path | torch.nn.Module, optional): Model to load or use.
|
||||
verbose (bool): Whether to print verbose output.
|
||||
"""
|
||||
self.model = AutoBackend(
|
||||
model=model or self.args.model,
|
||||
device=select_device(self.args.device, verbose=verbose),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
fuse=True,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
self.device = self.model.device # update device
|
||||
self.args.half = self.model.fp16 # update half
|
||||
if hasattr(self.model, "imgsz") and not getattr(self.model, "dynamic", False):
|
||||
self.args.imgsz = self.model.imgsz # reuse imgsz from export metadata
|
||||
self.model.eval()
|
||||
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
||||
|
||||
def write_results(self, i: int, p: Path, im: torch.Tensor, s: list[str]) -> str:
|
||||
"""
|
||||
Write inference results to a file or directory.
|
||||
|
||||
Args:
|
||||
i (int): Index of the current image in the batch.
|
||||
p (Path): Path to the current image.
|
||||
im (torch.Tensor): Preprocessed image tensor.
|
||||
s (list[str]): List of result strings.
|
||||
|
||||
Returns:
|
||||
(str): String with result information.
|
||||
"""
|
||||
string = "" # print string
|
||||
if len(im.shape) == 3:
|
||||
im = im[None] # expand for batch dim
|
||||
if self.source_type.stream or self.source_type.from_img or self.source_type.tensor: # batch_size >= 1
|
||||
string += f"{i}: "
|
||||
frame = self.dataset.count
|
||||
else:
|
||||
match = re.search(r"frame (\d+)/", s[i])
|
||||
frame = int(match[1]) if match else None # 0 if frame undetermined
|
||||
|
||||
self.txt_path = self.save_dir / "labels" / (p.stem + ("" if self.dataset.mode == "image" else f"_{frame}"))
|
||||
string += "{:g}x{:g} ".format(*im.shape[2:])
|
||||
result = self.results[i]
|
||||
result.save_dir = self.save_dir.__str__() # used in other locations
|
||||
string += f"{result.verbose()}{result.speed['inference']:.1f}ms"
|
||||
|
||||
# Add predictions to image
|
||||
if self.args.save or self.args.show:
|
||||
self.plotted_img = result.plot(
|
||||
line_width=self.args.line_width,
|
||||
boxes=self.args.show_boxes,
|
||||
conf=self.args.show_conf,
|
||||
labels=self.args.show_labels,
|
||||
im_gpu=None if self.args.retina_masks else im[i],
|
||||
)
|
||||
|
||||
# Save results
|
||||
if self.args.save_txt:
|
||||
result.save_txt(f"{self.txt_path}.txt", save_conf=self.args.save_conf)
|
||||
if self.args.save_crop:
|
||||
result.save_crop(save_dir=self.save_dir / "crops", file_name=self.txt_path.stem)
|
||||
if self.args.show:
|
||||
self.show(str(p))
|
||||
if self.args.save:
|
||||
self.save_predicted_images(self.save_dir / p.name, frame)
|
||||
|
||||
return string
|
||||
|
||||
def save_predicted_images(self, save_path: Path, frame: int = 0):
|
||||
"""
|
||||
Save video predictions as mp4 or images as jpg at specified path.
|
||||
|
||||
Args:
|
||||
save_path (Path): Path to save the results.
|
||||
frame (int): Frame number for video mode.
|
||||
"""
|
||||
im = self.plotted_img
|
||||
|
||||
# Save videos and streams
|
||||
if self.dataset.mode in {"stream", "video"}:
|
||||
fps = self.dataset.fps if self.dataset.mode == "video" else 30
|
||||
frames_path = self.save_dir / f"{save_path.stem}_frames" # save frames to a separate directory
|
||||
if save_path not in self.vid_writer: # new video
|
||||
if self.args.save_frames:
|
||||
Path(frames_path).mkdir(parents=True, exist_ok=True)
|
||||
suffix, fourcc = (".mp4", "avc1") if MACOS else (".avi", "WMV2") if WINDOWS else (".avi", "MJPG")
|
||||
self.vid_writer[save_path] = cv2.VideoWriter(
|
||||
filename=str(Path(save_path).with_suffix(suffix)),
|
||||
fourcc=cv2.VideoWriter_fourcc(*fourcc),
|
||||
fps=fps, # integer required, floats produce error in MP4 codec
|
||||
frameSize=(im.shape[1], im.shape[0]), # (width, height)
|
||||
)
|
||||
|
||||
# Save video
|
||||
self.vid_writer[save_path].write(im)
|
||||
if self.args.save_frames:
|
||||
cv2.imwrite(f"{frames_path}/{save_path.stem}_{frame}.jpg", im)
|
||||
|
||||
# Save images
|
||||
else:
|
||||
cv2.imwrite(str(save_path.with_suffix(".jpg")), im) # save to JPG for best support
|
||||
|
||||
def show(self, p: str = ""):
|
||||
"""Display an image in a window."""
|
||||
im = self.plotted_img
|
||||
if platform.system() == "Linux" and p not in self.windows:
|
||||
self.windows.append(p)
|
||||
cv2.namedWindow(p, cv2.WINDOW_NORMAL | cv2.WINDOW_KEEPRATIO) # allow window resize (Linux)
|
||||
cv2.resizeWindow(p, im.shape[1], im.shape[0]) # (width, height)
|
||||
cv2.imshow(p, im)
|
||||
if cv2.waitKey(300 if self.dataset.mode == "image" else 1) & 0xFF == ord("q"): # 300ms if image; else 1ms
|
||||
raise StopIteration
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all registered callbacks for a specific event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def add_callback(self, event: str, func: callable):
|
||||
"""Add a callback function for a specific event."""
|
||||
self.callbacks[event].append(func)
|
||||
1656
ultralytics/engine/results.py
Normal file
1656
ultralytics/engine/results.py
Normal file
File diff suppressed because it is too large
Load Diff
904
ultralytics/engine/trainer.py
Normal file
904
ultralytics/engine/trainer.py
Normal file
@@ -0,0 +1,904 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Train a model on a dataset.
|
||||
|
||||
Usage:
|
||||
$ yolo mode=train model=yolo11n.pt data=coco8.yaml imgsz=640 epochs=100 batch=16
|
||||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import warnings
|
||||
from copy import copy, deepcopy
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch import nn, optim
|
||||
|
||||
from ultralytics import __version__
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.tasks import load_checkpoint
|
||||
from ultralytics.utils import (
|
||||
DEFAULT_CFG,
|
||||
GIT,
|
||||
LOCAL_RANK,
|
||||
LOGGER,
|
||||
RANK,
|
||||
TQDM,
|
||||
YAML,
|
||||
callbacks,
|
||||
clean_url,
|
||||
colorstr,
|
||||
emojis,
|
||||
)
|
||||
from ultralytics.utils.autobatch import check_train_batch_size
|
||||
from ultralytics.utils.checks import check_amp, check_file, check_imgsz, check_model_file_from_stem, print_args
|
||||
from ultralytics.utils.dist import ddp_cleanup, generate_ddp_command
|
||||
from ultralytics.utils.files import get_latest_run
|
||||
from ultralytics.utils.plotting import plot_results
|
||||
from ultralytics.utils.torch_utils import (
|
||||
TORCH_2_4,
|
||||
EarlyStopping,
|
||||
ModelEMA,
|
||||
attempt_compile,
|
||||
autocast,
|
||||
convert_optimizer_state_dict_to_fp16,
|
||||
init_seeds,
|
||||
one_cycle,
|
||||
select_device,
|
||||
strip_optimizer,
|
||||
torch_distributed_zero_first,
|
||||
unset_deterministic,
|
||||
unwrap_model,
|
||||
)
|
||||
|
||||
|
||||
class BaseTrainer:
|
||||
"""
|
||||
A base class for creating trainers.
|
||||
|
||||
This class provides the foundation for training YOLO models, handling the training loop, validation, checkpointing,
|
||||
and various training utilities. It supports both single-GPU and multi-GPU distributed training.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the trainer.
|
||||
validator (BaseValidator): Validator instance.
|
||||
model (nn.Module): Model instance.
|
||||
callbacks (defaultdict): Dictionary of callbacks.
|
||||
save_dir (Path): Directory to save results.
|
||||
wdir (Path): Directory to save weights.
|
||||
last (Path): Path to the last checkpoint.
|
||||
best (Path): Path to the best checkpoint.
|
||||
save_period (int): Save checkpoint every x epochs (disabled if < 1).
|
||||
batch_size (int): Batch size for training.
|
||||
epochs (int): Number of epochs to train for.
|
||||
start_epoch (int): Starting epoch for training.
|
||||
device (torch.device): Device to use for training.
|
||||
amp (bool): Flag to enable AMP (Automatic Mixed Precision).
|
||||
scaler (amp.GradScaler): Gradient scaler for AMP.
|
||||
data (str): Path to data.
|
||||
ema (nn.Module): EMA (Exponential Moving Average) of the model.
|
||||
resume (bool): Resume training from a checkpoint.
|
||||
lf (nn.Module): Loss function.
|
||||
scheduler (torch.optim.lr_scheduler._LRScheduler): Learning rate scheduler.
|
||||
best_fitness (float): The best fitness value achieved.
|
||||
fitness (float): Current fitness value.
|
||||
loss (float): Current loss value.
|
||||
tloss (float): Total loss value.
|
||||
loss_names (list): List of loss names.
|
||||
csv (Path): Path to results CSV file.
|
||||
metrics (dict): Dictionary of metrics.
|
||||
plots (dict): Dictionary of plots.
|
||||
|
||||
Methods:
|
||||
train: Execute the training process.
|
||||
validate: Run validation on the test set.
|
||||
save_model: Save model training checkpoints.
|
||||
get_dataset: Get train and validation datasets.
|
||||
setup_model: Load, create, or download model.
|
||||
build_optimizer: Construct an optimizer for the model.
|
||||
|
||||
Examples:
|
||||
Initialize a trainer and start training
|
||||
>>> trainer = BaseTrainer(cfg="config.yaml")
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the BaseTrainer class.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to a configuration file.
|
||||
overrides (dict, optional): Configuration overrides.
|
||||
_callbacks (list, optional): List of callback functions.
|
||||
"""
|
||||
self.hub_session = overrides.pop("session", None) # HUB
|
||||
self.args = get_cfg(cfg, overrides)
|
||||
self.check_resume(overrides)
|
||||
self.device = select_device(self.args.device)
|
||||
# Update "-1" devices so post-training val does not repeat search
|
||||
self.args.device = os.getenv("CUDA_VISIBLE_DEVICES") if "cuda" in str(self.device) else str(self.device)
|
||||
self.validator = None
|
||||
self.metrics = None
|
||||
self.plots = {}
|
||||
init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic)
|
||||
|
||||
# Dirs
|
||||
self.save_dir = get_save_dir(self.args)
|
||||
self.args.name = self.save_dir.name # update name for loggers
|
||||
self.wdir = self.save_dir / "weights" # weights dir
|
||||
if RANK in {-1, 0}:
|
||||
self.wdir.mkdir(parents=True, exist_ok=True) # make dir
|
||||
self.args.save_dir = str(self.save_dir)
|
||||
YAML.save(self.save_dir / "args.yaml", vars(self.args)) # save run args
|
||||
self.last, self.best = self.wdir / "last.pt", self.wdir / "best.pt" # checkpoint paths
|
||||
self.save_period = self.args.save_period
|
||||
|
||||
self.batch_size = self.args.batch
|
||||
self.epochs = self.args.epochs or 100 # in case users accidentally pass epochs=None with timed training
|
||||
self.start_epoch = 0
|
||||
if RANK == -1:
|
||||
print_args(vars(self.args))
|
||||
|
||||
# Device
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU training as time dominated by inference, not dataloading
|
||||
|
||||
# Model and Dataset
|
||||
self.model = check_model_file_from_stem(self.args.model) # add suffix, i.e. yolo11n -> yolo11n.pt
|
||||
with torch_distributed_zero_first(LOCAL_RANK): # avoid auto-downloading dataset multiple times
|
||||
self.data = self.get_dataset()
|
||||
|
||||
self.ema = None
|
||||
|
||||
# Optimization utils init
|
||||
self.lf = None
|
||||
self.scheduler = None
|
||||
|
||||
# Epoch level metrics
|
||||
self.best_fitness = None
|
||||
self.fitness = None
|
||||
self.loss = None
|
||||
self.tloss = None
|
||||
self.loss_names = ["Loss"]
|
||||
self.csv = self.save_dir / "results.csv"
|
||||
self.plot_idx = [0, 1, 2]
|
||||
|
||||
# Callbacks
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
||||
if isinstance(self.args.device, str) and len(self.args.device): # i.e. device='0' or device='0,1,2,3'
|
||||
world_size = len(self.args.device.split(","))
|
||||
elif isinstance(self.args.device, (tuple, list)): # i.e. device=[0, 1, 2, 3] (multi-GPU from CLI is list)
|
||||
world_size = len(self.args.device)
|
||||
elif self.args.device in {"cpu", "mps"}: # i.e. device='cpu' or 'mps'
|
||||
world_size = 0
|
||||
elif torch.cuda.is_available(): # i.e. device=None or device='' or device=number
|
||||
world_size = 1 # default to device 0
|
||||
else: # i.e. device=None or device=''
|
||||
world_size = 0
|
||||
|
||||
self.ddp = world_size > 1 and "LOCAL_RANK" not in os.environ
|
||||
self.world_size = world_size
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if RANK in {-1, 0} and not self.ddp:
|
||||
callbacks.add_integration_callbacks(self)
|
||||
# Start console logging immediately at trainer initialization
|
||||
self.run_callbacks("on_pretrain_routine_start")
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""Append the given callback to the event's callback list."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def set_callback(self, event: str, callback):
|
||||
"""Override the existing callbacks with the given callback for the specified event."""
|
||||
self.callbacks[event] = [callback]
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all existing callbacks associated with a particular event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def train(self):
|
||||
"""Allow device='', device=None on Multi-GPU systems to default to device=0."""
|
||||
# Run subprocess if DDP training, else train normally
|
||||
if self.ddp:
|
||||
# Argument checks
|
||||
if self.args.rect:
|
||||
LOGGER.warning("'rect=True' is incompatible with Multi-GPU training, setting 'rect=False'")
|
||||
self.args.rect = False
|
||||
if self.args.batch < 1.0:
|
||||
raise ValueError(
|
||||
"AutoBatch with batch<1 not supported for Multi-GPU training, "
|
||||
f"please specify a valid batch size multiple of GPU count {self.world_size}, i.e. batch={self.world_size * 8}."
|
||||
)
|
||||
|
||||
# Command
|
||||
cmd, file = generate_ddp_command(self)
|
||||
try:
|
||||
LOGGER.info(f"{colorstr('DDP:')} debug command {' '.join(cmd)}")
|
||||
subprocess.run(cmd, check=True)
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
ddp_cleanup(self, str(file))
|
||||
|
||||
else:
|
||||
self._do_train()
|
||||
|
||||
def _setup_scheduler(self):
|
||||
"""Initialize training learning rate scheduler."""
|
||||
if self.args.cos_lr:
|
||||
self.lf = one_cycle(1, self.args.lrf, self.epochs) # cosine 1->hyp['lrf']
|
||||
else:
|
||||
self.lf = lambda x: max(1 - x / self.epochs, 0) * (1.0 - self.args.lrf) + self.args.lrf # linear
|
||||
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=self.lf)
|
||||
|
||||
def _setup_ddp(self):
|
||||
"""Initialize and set the DistributedDataParallel parameters for training."""
|
||||
torch.cuda.set_device(RANK)
|
||||
self.device = torch.device("cuda", RANK)
|
||||
os.environ["TORCH_NCCL_BLOCKING_WAIT"] = "1" # set to enforce timeout
|
||||
dist.init_process_group(
|
||||
backend="nccl" if dist.is_nccl_available() else "gloo",
|
||||
timeout=timedelta(seconds=10800), # 3 hours
|
||||
rank=RANK,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
|
||||
def _setup_train(self):
|
||||
"""Build dataloaders and optimizer on correct rank process."""
|
||||
ckpt = self.setup_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.set_model_attributes()
|
||||
|
||||
# Compile model
|
||||
self.model = attempt_compile(self.model, device=self.device, mode=self.args.compile)
|
||||
|
||||
# Freeze layers
|
||||
freeze_list = (
|
||||
self.args.freeze
|
||||
if isinstance(self.args.freeze, list)
|
||||
else range(self.args.freeze)
|
||||
if isinstance(self.args.freeze, int)
|
||||
else []
|
||||
)
|
||||
always_freeze_names = [".dfl"] # always freeze these layers
|
||||
freeze_layer_names = [f"model.{x}." for x in freeze_list] + always_freeze_names
|
||||
self.freeze_layer_names = freeze_layer_names
|
||||
for k, v in self.model.named_parameters():
|
||||
# v.register_hook(lambda x: torch.nan_to_num(x)) # NaN to 0 (commented for erratic training results)
|
||||
if any(x in k for x in freeze_layer_names):
|
||||
LOGGER.info(f"Freezing layer '{k}'")
|
||||
v.requires_grad = False
|
||||
elif not v.requires_grad and v.dtype.is_floating_point: # only floating point Tensor can require gradients
|
||||
LOGGER.warning(
|
||||
f"setting 'requires_grad=True' for frozen layer '{k}'. "
|
||||
"See ultralytics.engine.trainer for customization of frozen layers."
|
||||
)
|
||||
v.requires_grad = True
|
||||
|
||||
# Check AMP
|
||||
self.amp = torch.tensor(self.args.amp).to(self.device) # True or False
|
||||
if self.amp and RANK in {-1, 0}: # Single-GPU and DDP
|
||||
callbacks_backup = callbacks.default_callbacks.copy() # backup callbacks as check_amp() resets them
|
||||
self.amp = torch.tensor(check_amp(self.model), device=self.device)
|
||||
callbacks.default_callbacks = callbacks_backup # restore callbacks
|
||||
if RANK > -1 and self.world_size > 1: # DDP
|
||||
dist.broadcast(self.amp.int(), src=0) # broadcast from rank 0 to all other ranks; gloo errors with boolean
|
||||
self.amp = bool(self.amp) # as boolean
|
||||
self.scaler = (
|
||||
torch.amp.GradScaler("cuda", enabled=self.amp) if TORCH_2_4 else torch.cuda.amp.GradScaler(enabled=self.amp)
|
||||
)
|
||||
if self.world_size > 1:
|
||||
self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[RANK], find_unused_parameters=True)
|
||||
|
||||
# Check imgsz
|
||||
gs = max(int(self.model.stride.max() if hasattr(self.model, "stride") else 32), 32) # grid size (max stride)
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, stride=gs, floor=gs, max_dim=1)
|
||||
self.stride = gs # for multiscale training
|
||||
|
||||
# Batch size
|
||||
if self.batch_size < 1 and RANK == -1: # single-GPU only, estimate best batch size
|
||||
self.args.batch = self.batch_size = self.auto_batch()
|
||||
|
||||
# Dataloaders
|
||||
batch_size = self.batch_size // max(self.world_size, 1)
|
||||
self.train_loader = self.get_dataloader(
|
||||
self.data["train"], batch_size=batch_size, rank=LOCAL_RANK, mode="train"
|
||||
)
|
||||
if RANK in {-1, 0}:
|
||||
# Note: When training DOTA dataset, double batch size could get OOM on images with >2000 objects.
|
||||
self.test_loader = self.get_dataloader(
|
||||
self.data.get("val") or self.data.get("test"),
|
||||
batch_size=batch_size if self.args.task == "obb" else batch_size * 2,
|
||||
rank=-1,
|
||||
mode="val",
|
||||
)
|
||||
self.validator = self.get_validator()
|
||||
metric_keys = self.validator.metrics.keys + self.label_loss_items(prefix="val")
|
||||
self.metrics = dict(zip(metric_keys, [0] * len(metric_keys)))
|
||||
self.ema = ModelEMA(self.model)
|
||||
if self.args.plots:
|
||||
self.plot_training_labels()
|
||||
|
||||
# Optimizer
|
||||
self.accumulate = max(round(self.args.nbs / self.batch_size), 1) # accumulate loss before optimizing
|
||||
weight_decay = self.args.weight_decay * self.batch_size * self.accumulate / self.args.nbs # scale weight_decay
|
||||
iterations = math.ceil(len(self.train_loader.dataset) / max(self.batch_size, self.args.nbs)) * self.epochs
|
||||
self.optimizer = self.build_optimizer(
|
||||
model=self.model,
|
||||
name=self.args.optimizer,
|
||||
lr=self.args.lr0,
|
||||
momentum=self.args.momentum,
|
||||
decay=weight_decay,
|
||||
iterations=iterations,
|
||||
)
|
||||
# Scheduler
|
||||
self._setup_scheduler()
|
||||
self.stopper, self.stop = EarlyStopping(patience=self.args.patience), False
|
||||
self.resume_training(ckpt)
|
||||
self.scheduler.last_epoch = self.start_epoch - 1 # do not move
|
||||
self.run_callbacks("on_pretrain_routine_end")
|
||||
|
||||
def _do_train(self):
|
||||
"""Train the model with the specified world size."""
|
||||
if self.world_size > 1:
|
||||
self._setup_ddp()
|
||||
self._setup_train()
|
||||
|
||||
nb = len(self.train_loader) # number of batches
|
||||
nw = max(round(self.args.warmup_epochs * nb), 100) if self.args.warmup_epochs > 0 else -1 # warmup iterations
|
||||
last_opt_step = -1
|
||||
self.epoch_time = None
|
||||
self.epoch_time_start = time.time()
|
||||
self.train_time_start = time.time()
|
||||
self.run_callbacks("on_train_start")
|
||||
LOGGER.info(
|
||||
f"Image sizes {self.args.imgsz} train, {self.args.imgsz} val\n"
|
||||
f"Using {self.train_loader.num_workers * (self.world_size or 1)} dataloader workers\n"
|
||||
f"Logging results to {colorstr('bold', self.save_dir)}\n"
|
||||
f"Starting training for " + (f"{self.args.time} hours..." if self.args.time else f"{self.epochs} epochs...")
|
||||
)
|
||||
if self.args.close_mosaic:
|
||||
base_idx = (self.epochs - self.args.close_mosaic) * nb
|
||||
self.plot_idx.extend([base_idx, base_idx + 1, base_idx + 2])
|
||||
epoch = self.start_epoch
|
||||
self.optimizer.zero_grad() # zero any resumed gradients to ensure stability on train start
|
||||
while True:
|
||||
self.epoch = epoch
|
||||
self.run_callbacks("on_train_epoch_start")
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore") # suppress 'Detected lr_scheduler.step() before optimizer.step()'
|
||||
self.scheduler.step()
|
||||
|
||||
self._model_train()
|
||||
if RANK != -1:
|
||||
self.train_loader.sampler.set_epoch(epoch)
|
||||
pbar = enumerate(self.train_loader)
|
||||
# Update dataloader attributes (optional)
|
||||
if epoch == (self.epochs - self.args.close_mosaic):
|
||||
self._close_dataloader_mosaic()
|
||||
self.train_loader.reset()
|
||||
|
||||
if RANK in {-1, 0}:
|
||||
LOGGER.info(self.progress_string())
|
||||
pbar = TQDM(enumerate(self.train_loader), total=nb)
|
||||
self.tloss = None
|
||||
for i, batch in pbar:
|
||||
self.run_callbacks("on_train_batch_start")
|
||||
# Warmup
|
||||
ni = i + nb * epoch
|
||||
if ni <= nw:
|
||||
xi = [0, nw] # x interp
|
||||
self.accumulate = max(1, int(np.interp(ni, xi, [1, self.args.nbs / self.batch_size]).round()))
|
||||
for j, x in enumerate(self.optimizer.param_groups):
|
||||
# Bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
||||
x["lr"] = np.interp(
|
||||
ni, xi, [self.args.warmup_bias_lr if j == 0 else 0.0, x["initial_lr"] * self.lf(epoch)]
|
||||
)
|
||||
if "momentum" in x:
|
||||
x["momentum"] = np.interp(ni, xi, [self.args.warmup_momentum, self.args.momentum])
|
||||
|
||||
# Forward
|
||||
with autocast(self.amp):
|
||||
batch = self.preprocess_batch(batch)
|
||||
if self.args.compile:
|
||||
# Decouple inference and loss calculations for improved compile performance
|
||||
preds = self.model(batch["img"])
|
||||
loss, self.loss_items = unwrap_model(self.model).loss(batch, preds)
|
||||
else:
|
||||
loss, self.loss_items = self.model(batch)
|
||||
self.loss = loss.sum()
|
||||
if RANK != -1:
|
||||
self.loss *= self.world_size
|
||||
self.tloss = (
|
||||
(self.tloss * i + self.loss_items) / (i + 1) if self.tloss is not None else self.loss_items
|
||||
)
|
||||
|
||||
# Backward
|
||||
self.scaler.scale(self.loss).backward()
|
||||
|
||||
# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
|
||||
if ni - last_opt_step >= self.accumulate:
|
||||
self.optimizer_step()
|
||||
last_opt_step = ni
|
||||
|
||||
# Timed stopping
|
||||
if self.args.time:
|
||||
self.stop = (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop: # training time exceeded
|
||||
break
|
||||
|
||||
# Log
|
||||
if RANK in {-1, 0}:
|
||||
loss_length = self.tloss.shape[0] if len(self.tloss.shape) else 1
|
||||
pbar.set_description(
|
||||
("%11s" * 2 + "%11.4g" * (2 + loss_length))
|
||||
% (
|
||||
f"{epoch + 1}/{self.epochs}",
|
||||
f"{self._get_memory():.3g}G", # (GB) GPU memory util
|
||||
*(self.tloss if loss_length > 1 else torch.unsqueeze(self.tloss, 0)), # losses
|
||||
batch["cls"].shape[0], # batch size, i.e. 8
|
||||
batch["img"].shape[-1], # imgsz, i.e 640
|
||||
)
|
||||
)
|
||||
self.run_callbacks("on_batch_end")
|
||||
if self.args.plots and ni in self.plot_idx:
|
||||
self.plot_training_samples(batch, ni)
|
||||
|
||||
self.run_callbacks("on_train_batch_end")
|
||||
|
||||
self.lr = {f"lr/pg{ir}": x["lr"] for ir, x in enumerate(self.optimizer.param_groups)} # for loggers
|
||||
self.run_callbacks("on_train_epoch_end")
|
||||
if RANK in {-1, 0}:
|
||||
final_epoch = epoch + 1 >= self.epochs
|
||||
self.ema.update_attr(self.model, include=["yaml", "nc", "args", "names", "stride", "class_weights"])
|
||||
|
||||
# Validation
|
||||
if self.args.val or final_epoch or self.stopper.possible_stop or self.stop:
|
||||
self._clear_memory(threshold=0.5) # prevent VRAM spike
|
||||
self.metrics, self.fitness = self.validate()
|
||||
self.save_metrics(metrics={**self.label_loss_items(self.tloss), **self.metrics, **self.lr})
|
||||
self.stop |= self.stopper(epoch + 1, self.fitness) or final_epoch
|
||||
if self.args.time:
|
||||
self.stop |= (time.time() - self.train_time_start) > (self.args.time * 3600)
|
||||
|
||||
# Save model
|
||||
if self.args.save or final_epoch:
|
||||
self.save_model()
|
||||
self.run_callbacks("on_model_save")
|
||||
|
||||
# Scheduler
|
||||
t = time.time()
|
||||
self.epoch_time = t - self.epoch_time_start
|
||||
self.epoch_time_start = t
|
||||
if self.args.time:
|
||||
mean_epoch_time = (t - self.train_time_start) / (epoch - self.start_epoch + 1)
|
||||
self.epochs = self.args.epochs = math.ceil(self.args.time * 3600 / mean_epoch_time)
|
||||
self._setup_scheduler()
|
||||
self.scheduler.last_epoch = self.epoch # do not move
|
||||
self.stop |= epoch >= self.epochs # stop if exceeded epochs
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
self._clear_memory(0.5) # clear if memory utilization > 50%
|
||||
|
||||
# Early Stopping
|
||||
if RANK != -1: # if DDP training
|
||||
broadcast_list = [self.stop if RANK == 0 else None]
|
||||
dist.broadcast_object_list(broadcast_list, 0) # broadcast 'stop' to all ranks
|
||||
self.stop = broadcast_list[0]
|
||||
if self.stop:
|
||||
break # must break all DDP ranks
|
||||
epoch += 1
|
||||
|
||||
if RANK in {-1, 0}:
|
||||
# Do final val with best.pt
|
||||
seconds = time.time() - self.train_time_start
|
||||
LOGGER.info(f"\n{epoch - self.start_epoch + 1} epochs completed in {seconds / 3600:.3f} hours.")
|
||||
self.final_eval()
|
||||
if self.args.plots:
|
||||
self.plot_metrics()
|
||||
self.run_callbacks("on_train_end")
|
||||
self._clear_memory()
|
||||
unset_deterministic()
|
||||
self.run_callbacks("teardown")
|
||||
|
||||
def auto_batch(self, max_num_obj=0):
|
||||
"""Calculate optimal batch size based on model and device memory constraints."""
|
||||
return check_train_batch_size(
|
||||
model=self.model,
|
||||
imgsz=self.args.imgsz,
|
||||
amp=self.amp,
|
||||
batch=self.batch_size,
|
||||
max_num_obj=max_num_obj,
|
||||
) # returns batch size
|
||||
|
||||
def _get_memory(self, fraction=False):
|
||||
"""Get accelerator memory utilization in GB or as a fraction of total memory."""
|
||||
memory, total = 0, 0
|
||||
if self.device.type == "mps":
|
||||
memory = torch.mps.driver_allocated_memory()
|
||||
if fraction:
|
||||
return __import__("psutil").virtual_memory().percent / 100
|
||||
elif self.device.type != "cpu":
|
||||
memory = torch.cuda.memory_reserved()
|
||||
if fraction:
|
||||
total = torch.cuda.get_device_properties(self.device).total_memory
|
||||
return ((memory / total) if total > 0 else 0) if fraction else (memory / 2**30)
|
||||
|
||||
def _clear_memory(self, threshold: float = None):
|
||||
"""Clear accelerator memory by calling garbage collector and emptying cache."""
|
||||
if threshold:
|
||||
assert 0 <= threshold <= 1, "Threshold must be between 0 and 1."
|
||||
if self._get_memory(fraction=True) <= threshold:
|
||||
return
|
||||
gc.collect()
|
||||
if self.device.type == "mps":
|
||||
torch.mps.empty_cache()
|
||||
elif self.device.type == "cpu":
|
||||
return
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def read_results_csv(self):
|
||||
"""Read results.csv into a dictionary using polars."""
|
||||
import polars as pl # scope for faster 'import ultralytics'
|
||||
|
||||
return pl.read_csv(self.csv, infer_schema_length=None).to_dict(as_series=False)
|
||||
|
||||
def _model_train(self):
|
||||
"""Set model in training mode."""
|
||||
self.model.train()
|
||||
# Freeze BN stat
|
||||
for n, m in self.model.named_modules():
|
||||
if any(filter(lambda f: f in n, self.freeze_layer_names)) and isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def save_model(self):
|
||||
"""Save model training checkpoints with additional metadata."""
|
||||
import io
|
||||
|
||||
# Serialize ckpt to a byte buffer once (faster than repeated torch.save() calls)
|
||||
buffer = io.BytesIO()
|
||||
torch.save(
|
||||
{
|
||||
"epoch": self.epoch,
|
||||
"best_fitness": self.best_fitness,
|
||||
"model": None, # resume and final checkpoints derive from EMA
|
||||
"ema": deepcopy(unwrap_model(self.ema.ema)).half(),
|
||||
"updates": self.ema.updates,
|
||||
"optimizer": convert_optimizer_state_dict_to_fp16(deepcopy(self.optimizer.state_dict())),
|
||||
"scaler": self.scaler.state_dict(),
|
||||
"train_args": vars(self.args), # save as dict
|
||||
"train_metrics": {**self.metrics, **{"fitness": self.fitness}},
|
||||
"train_results": self.read_results_csv(),
|
||||
"date": datetime.now().isoformat(),
|
||||
"version": __version__,
|
||||
"git": {
|
||||
"root": str(GIT.root),
|
||||
"branch": GIT.branch,
|
||||
"commit": GIT.commit,
|
||||
"origin": GIT.origin,
|
||||
},
|
||||
"license": "AGPL-3.0 (https://ultralytics.com/license)",
|
||||
"docs": "https://docs.ultralytics.com",
|
||||
},
|
||||
buffer,
|
||||
)
|
||||
serialized_ckpt = buffer.getvalue() # get the serialized content to save
|
||||
|
||||
# Save checkpoints
|
||||
self.last.write_bytes(serialized_ckpt) # save last.pt
|
||||
if self.best_fitness == self.fitness:
|
||||
self.best.write_bytes(serialized_ckpt) # save best.pt
|
||||
if (self.save_period > 0) and (self.epoch % self.save_period == 0):
|
||||
(self.wdir / f"epoch{self.epoch}.pt").write_bytes(serialized_ckpt) # save epoch, i.e. 'epoch3.pt'
|
||||
|
||||
def get_dataset(self):
|
||||
"""
|
||||
Get train and validation datasets from data dictionary.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
"""
|
||||
try:
|
||||
if self.args.task == "classify":
|
||||
data = check_cls_dataset(self.args.data)
|
||||
elif self.args.data.rsplit(".", 1)[-1] == "ndjson":
|
||||
# Convert NDJSON to YOLO format
|
||||
import asyncio
|
||||
|
||||
from ultralytics.data.converter import convert_ndjson_to_yolo
|
||||
|
||||
yaml_path = asyncio.run(convert_ndjson_to_yolo(self.args.data))
|
||||
self.args.data = str(yaml_path)
|
||||
data = check_det_dataset(self.args.data)
|
||||
elif self.args.data.rsplit(".", 1)[-1] in {"yaml", "yml"} or self.args.task in {
|
||||
"detect",
|
||||
"segment",
|
||||
"pose",
|
||||
"obb",
|
||||
}:
|
||||
data = check_det_dataset(self.args.data)
|
||||
if "yaml_file" in data:
|
||||
self.args.data = data["yaml_file"] # for validating 'yolo train data=url.zip' usage
|
||||
except Exception as e:
|
||||
raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from e
|
||||
if self.args.single_cls:
|
||||
LOGGER.info("Overriding class names with single class.")
|
||||
data["names"] = {0: "item"}
|
||||
data["nc"] = 1
|
||||
return data
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create, or download model for any task.
|
||||
|
||||
Returns:
|
||||
(dict): Optional checkpoint to resume training from.
|
||||
"""
|
||||
if isinstance(self.model, torch.nn.Module): # if model is loaded beforehand. No setup needed
|
||||
return
|
||||
|
||||
cfg, weights = self.model, None
|
||||
ckpt = None
|
||||
if str(self.model).endswith(".pt"):
|
||||
weights, ckpt = load_checkpoint(self.model)
|
||||
cfg = weights.yaml
|
||||
elif isinstance(self.args.pretrained, (str, Path)):
|
||||
weights, _ = load_checkpoint(self.args.pretrained)
|
||||
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1) # calls Model(cfg, weights)
|
||||
return ckpt
|
||||
|
||||
def optimizer_step(self):
|
||||
"""Perform a single step of the training optimizer with gradient clipping and EMA update."""
|
||||
self.scaler.unscale_(self.optimizer) # unscale gradients
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0) # clip gradients
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
self.optimizer.zero_grad()
|
||||
if self.ema:
|
||||
self.ema.update(self.model)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Allow custom preprocessing model inputs and ground truths depending on task type."""
|
||||
return batch
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
Run validation on val set using self.validator.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Dictionary of validation metrics.
|
||||
fitness (float): Fitness score for the validation.
|
||||
"""
|
||||
metrics = self.validator(self)
|
||||
fitness = metrics.pop("fitness", -self.loss.detach().cpu().numpy()) # use loss as fitness measure if not found
|
||||
if not self.best_fitness or self.best_fitness < fitness:
|
||||
self.best_fitness = fitness
|
||||
return metrics, fitness
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""Get model and raise NotImplementedError for loading cfg files."""
|
||||
raise NotImplementedError("This task trainer doesn't support loading cfg files")
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a NotImplementedError when the get_validator function is called."""
|
||||
raise NotImplementedError("get_validator function not implemented in trainer")
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size=16, rank=0, mode="train"):
|
||||
"""Return dataloader derived from torch.data.Dataloader."""
|
||||
raise NotImplementedError("get_dataloader function not implemented in trainer")
|
||||
|
||||
def build_dataset(self, img_path, mode="train", batch=None):
|
||||
"""Build dataset."""
|
||||
raise NotImplementedError("build_dataset function not implemented in trainer")
|
||||
|
||||
def label_loss_items(self, loss_items=None, prefix="train"):
|
||||
"""
|
||||
Return a loss dict with labelled training loss items tensor.
|
||||
|
||||
Note:
|
||||
This is not needed for classification but necessary for segmentation & detection
|
||||
"""
|
||||
return {"loss": loss_items} if loss_items is not None else ["loss"]
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set or update model parameters before training."""
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def build_targets(self, preds, targets):
|
||||
"""Build target tensors for training YOLO model."""
|
||||
pass
|
||||
|
||||
def progress_string(self):
|
||||
"""Return a string describing training progress."""
|
||||
return ""
|
||||
|
||||
# TODO: may need to put these following functions into callback
|
||||
def plot_training_samples(self, batch, ni):
|
||||
"""Plot training samples during YOLO training."""
|
||||
pass
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Plot training labels for YOLO model."""
|
||||
pass
|
||||
|
||||
def save_metrics(self, metrics):
|
||||
"""Save training metrics to a CSV file."""
|
||||
keys, vals = list(metrics.keys()), list(metrics.values())
|
||||
n = len(metrics) + 2 # number of cols
|
||||
s = "" if self.csv.exists() else (("%s," * n % tuple(["epoch", "time"] + keys)).rstrip(",") + "\n") # header
|
||||
t = time.time() - self.train_time_start
|
||||
with open(self.csv, "a", encoding="utf-8") as f:
|
||||
f.write(s + ("%.6g," * n % tuple([self.epoch + 1, t] + vals)).rstrip(",") + "\n")
|
||||
|
||||
def plot_metrics(self):
|
||||
"""Plot metrics from a CSV file."""
|
||||
plot_results(file=self.csv, on_plot=self.on_plot) # save results.png
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Register plots (e.g. to be consumed in callbacks)."""
|
||||
path = Path(name)
|
||||
self.plots[path] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
def final_eval(self):
|
||||
"""Perform final evaluation and validation for object detection YOLO model."""
|
||||
ckpt = {}
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
if f is self.last:
|
||||
ckpt = strip_optimizer(f)
|
||||
elif f is self.best:
|
||||
k = "train_results" # update best.pt train_metrics from last.pt
|
||||
strip_optimizer(f, updates={k: ckpt[k]} if k in ckpt else None)
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.validator.args.compile = False # disable final val compile as too slow
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
|
||||
def check_resume(self, overrides):
|
||||
"""Check if resume checkpoint exists and update arguments accordingly."""
|
||||
resume = self.args.resume
|
||||
if resume:
|
||||
try:
|
||||
exists = isinstance(resume, (str, Path)) and Path(resume).exists()
|
||||
last = Path(check_file(resume) if exists else get_latest_run())
|
||||
|
||||
# Check that resume data YAML exists, otherwise strip to force re-download of dataset
|
||||
ckpt_args = load_checkpoint(last)[0].args
|
||||
if not isinstance(ckpt_args["data"], dict) and not Path(ckpt_args["data"]).exists():
|
||||
ckpt_args["data"] = self.args.data
|
||||
|
||||
resume = True
|
||||
self.args = get_cfg(ckpt_args)
|
||||
self.args.model = self.args.resume = str(last) # reinstate model
|
||||
for k in (
|
||||
"imgsz",
|
||||
"batch",
|
||||
"device",
|
||||
"close_mosaic",
|
||||
): # allow arg updates to reduce memory or update device on resume
|
||||
if k in overrides:
|
||||
setattr(self.args, k, overrides[k])
|
||||
|
||||
except Exception as e:
|
||||
raise FileNotFoundError(
|
||||
"Resume checkpoint not found. Please pass a valid checkpoint to resume from, "
|
||||
"i.e. 'yolo train resume model=path/to/last.pt'"
|
||||
) from e
|
||||
self.resume = resume
|
||||
|
||||
def resume_training(self, ckpt):
|
||||
"""Resume YOLO training from given epoch and best fitness."""
|
||||
if ckpt is None or not self.resume:
|
||||
return
|
||||
best_fitness = 0.0
|
||||
start_epoch = ckpt.get("epoch", -1) + 1
|
||||
if ckpt.get("optimizer") is not None:
|
||||
self.optimizer.load_state_dict(ckpt["optimizer"]) # optimizer
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
if ckpt.get("scaler") is not None:
|
||||
self.scaler.load_state_dict(ckpt["scaler"])
|
||||
if self.ema and ckpt.get("ema"):
|
||||
self.ema.ema.load_state_dict(ckpt["ema"].float().state_dict()) # EMA
|
||||
self.ema.updates = ckpt["updates"]
|
||||
assert start_epoch > 0, (
|
||||
f"{self.args.model} training to {self.epochs} epochs is finished, nothing to resume.\n"
|
||||
f"Start a new training without resuming, i.e. 'yolo train model={self.args.model}'"
|
||||
)
|
||||
LOGGER.info(f"Resuming training {self.args.model} from epoch {start_epoch + 1} to {self.epochs} total epochs")
|
||||
if self.epochs < start_epoch:
|
||||
LOGGER.info(
|
||||
f"{self.model} has been trained for {ckpt['epoch']} epochs. Fine-tuning for {self.epochs} more epochs."
|
||||
)
|
||||
self.epochs += ckpt["epoch"] # finetune additional epochs
|
||||
self.best_fitness = best_fitness
|
||||
self.start_epoch = start_epoch
|
||||
if start_epoch > (self.epochs - self.args.close_mosaic):
|
||||
self._close_dataloader_mosaic()
|
||||
|
||||
def _close_dataloader_mosaic(self):
|
||||
"""Update dataloaders to stop using mosaic augmentation."""
|
||||
if hasattr(self.train_loader.dataset, "mosaic"):
|
||||
self.train_loader.dataset.mosaic = False
|
||||
if hasattr(self.train_loader.dataset, "close_mosaic"):
|
||||
LOGGER.info("Closing dataloader mosaic")
|
||||
self.train_loader.dataset.close_mosaic(hyp=copy(self.args))
|
||||
|
||||
def build_optimizer(self, model, name="auto", lr=0.001, momentum=0.9, decay=1e-5, iterations=1e5):
|
||||
"""
|
||||
Construct an optimizer for the given model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for which to build an optimizer.
|
||||
name (str, optional): The name of the optimizer to use. If 'auto', the optimizer is selected
|
||||
based on the number of iterations.
|
||||
lr (float, optional): The learning rate for the optimizer.
|
||||
momentum (float, optional): The momentum factor for the optimizer.
|
||||
decay (float, optional): The weight decay for the optimizer.
|
||||
iterations (float, optional): The number of iterations, which determines the optimizer if
|
||||
name is 'auto'.
|
||||
|
||||
Returns:
|
||||
(torch.optim.Optimizer): The constructed optimizer.
|
||||
"""
|
||||
g = [], [], [] # optimizer parameter groups
|
||||
bn = tuple(v for k, v in nn.__dict__.items() if "Norm" in k) # normalization layers, i.e. BatchNorm2d()
|
||||
if name == "auto":
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} 'optimizer=auto' found, "
|
||||
f"ignoring 'lr0={self.args.lr0}' and 'momentum={self.args.momentum}' and "
|
||||
f"determining best 'optimizer', 'lr0' and 'momentum' automatically... "
|
||||
)
|
||||
nc = self.data.get("nc", 10) # number of classes
|
||||
lr_fit = round(0.002 * 5 / (4 + nc), 6) # lr0 fit equation to 6 decimal places
|
||||
name, lr, momentum = ("SGD", 0.01, 0.9) if iterations > 10000 else ("AdamW", lr_fit, 0.9)
|
||||
self.args.warmup_bias_lr = 0.0 # no higher than 0.01 for Adam
|
||||
|
||||
for module_name, module in model.named_modules():
|
||||
for param_name, param in module.named_parameters(recurse=False):
|
||||
fullname = f"{module_name}.{param_name}" if module_name else param_name
|
||||
if "bias" in fullname: # bias (no decay)
|
||||
g[2].append(param)
|
||||
elif isinstance(module, bn) or "logit_scale" in fullname: # weight (no decay)
|
||||
# ContrastiveHead and BNContrastiveHead included here with 'logit_scale'
|
||||
g[1].append(param)
|
||||
else: # weight (with decay)
|
||||
g[0].append(param)
|
||||
|
||||
optimizers = {"Adam", "Adamax", "AdamW", "NAdam", "RAdam", "RMSProp", "SGD", "auto"}
|
||||
name = {x.lower(): x for x in optimizers}.get(name.lower())
|
||||
if name in {"Adam", "Adamax", "AdamW", "NAdam", "RAdam"}:
|
||||
optimizer = getattr(optim, name, optim.Adam)(g[2], lr=lr, betas=(momentum, 0.999), weight_decay=0.0)
|
||||
elif name == "RMSProp":
|
||||
optimizer = optim.RMSprop(g[2], lr=lr, momentum=momentum)
|
||||
elif name == "SGD":
|
||||
optimizer = optim.SGD(g[2], lr=lr, momentum=momentum, nesterov=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Optimizer '{name}' not found in list of available optimizers {optimizers}. "
|
||||
"Request support for addition optimizers at https://github.com/ultralytics/ultralytics."
|
||||
)
|
||||
|
||||
optimizer.add_param_group({"params": g[0], "weight_decay": decay}) # add g0 with weight_decay
|
||||
optimizer.add_param_group({"params": g[1], "weight_decay": 0.0}) # add g1 (BatchNorm2d weights)
|
||||
LOGGER.info(
|
||||
f"{colorstr('optimizer:')} {type(optimizer).__name__}(lr={lr}, momentum={momentum}) with parameter groups "
|
||||
f"{len(g[1])} weight(decay=0.0), {len(g[0])} weight(decay={decay}), {len(g[2])} bias(decay=0.0)"
|
||||
)
|
||||
return optimizer
|
||||
459
ultralytics/engine/tuner.py
Normal file
459
ultralytics/engine/tuner.py
Normal file
@@ -0,0 +1,459 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Module provides functionalities for hyperparameter tuning of the Ultralytics YOLO models for object detection, instance
|
||||
segmentation, image classification, pose estimation, and multi-object tracking.
|
||||
|
||||
Hyperparameter tuning is the process of systematically searching for the optimal set of hyperparameters
|
||||
that yield the best model performance. This is particularly crucial in deep learning models like YOLO,
|
||||
where small changes in hyperparameters can lead to significant differences in model accuracy and efficiency.
|
||||
|
||||
Examples:
|
||||
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
>>> model.tune(data="coco8.yaml", epochs=10, iterations=300, optimizer="AdamW", plots=False, save=False, val=False)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.patches import torch_load
|
||||
from ultralytics.utils.plotting import plot_tune_results
|
||||
|
||||
|
||||
class Tuner:
|
||||
"""
|
||||
A class for hyperparameter tuning of YOLO models.
|
||||
|
||||
The class evolves YOLO model hyperparameters over a given number of iterations by mutating them according to the
|
||||
search space and retraining the model to evaluate their performance. Supports both local CSV storage and
|
||||
distributed MongoDB Atlas coordination for multi-machine hyperparameter optimization.
|
||||
|
||||
Attributes:
|
||||
space (dict[str, tuple]): Hyperparameter search space containing bounds and scaling factors for mutation.
|
||||
tune_dir (Path): Directory where evolution logs and results will be saved.
|
||||
tune_csv (Path): Path to the CSV file where evolution logs are saved.
|
||||
args (dict): Configuration arguments for the tuning process.
|
||||
callbacks (list): Callback functions to be executed during tuning.
|
||||
prefix (str): Prefix string for logging messages.
|
||||
mongodb (MongoClient): Optional MongoDB client for distributed tuning.
|
||||
collection (Collection): MongoDB collection for storing tuning results.
|
||||
|
||||
Methods:
|
||||
_mutate: Mutate hyperparameters based on bounds and scaling factors.
|
||||
__call__: Execute the hyperparameter evolution across multiple iterations.
|
||||
|
||||
Examples:
|
||||
Tune hyperparameters for YOLO11n on COCO8 at imgsz=640 and epochs=30 for 300 tuning iterations.
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
>>> model.tune(
|
||||
>>> data="coco8.yaml",
|
||||
>>> epochs=10,
|
||||
>>> iterations=300,
|
||||
>>> plots=False,
|
||||
>>> save=False,
|
||||
>>> val=False
|
||||
>>> )
|
||||
|
||||
Tune with distributed MongoDB Atlas coordination across multiple machines:
|
||||
>>> model.tune(
|
||||
>>> data="coco8.yaml",
|
||||
>>> epochs=10,
|
||||
>>> iterations=300,
|
||||
>>> mongodb_uri="mongodb+srv://user:pass@cluster.mongodb.net/",
|
||||
>>> mongodb_db="ultralytics",
|
||||
>>> mongodb_collection="tune_results"
|
||||
>>> )
|
||||
|
||||
Tune with custom search space:
|
||||
>>> model.tune(space={"lr0": (1e-5, 1e-1), "momentum": (0.6, 0.98)})
|
||||
"""
|
||||
|
||||
def __init__(self, args=DEFAULT_CFG, _callbacks: list | None = None):
|
||||
"""
|
||||
Initialize the Tuner with configurations.
|
||||
|
||||
Args:
|
||||
args (dict): Configuration for hyperparameter evolution.
|
||||
_callbacks (list | None, optional): Callback functions to be executed during tuning.
|
||||
"""
|
||||
self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
"lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
|
||||
"lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
|
||||
"weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
|
||||
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
||||
"box": (1.0, 20.0), # box loss gain
|
||||
"cls": (0.1, 4.0), # cls loss gain (scale with pixels)
|
||||
"dfl": (0.4, 6.0), # dfl loss gain
|
||||
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": (0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": (0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": (0.0, 0.95), # image scale (+/- gain)
|
||||
"shear": (0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": (0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": (0.0, 1.0), # image flip left-right (probability)
|
||||
"bgr": (0.0, 1.0), # image channel bgr (probability)
|
||||
"mosaic": (0.0, 1.0), # image mosaic (probability)
|
||||
"mixup": (0.0, 1.0), # image mixup (probability)
|
||||
"cutmix": (0.0, 1.0), # image cutmix (probability)
|
||||
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
||||
"close_mosaic": (0.0, 10.0), # close dataloader mosaic (epochs)
|
||||
}
|
||||
mongodb_uri = args.pop("mongodb_uri", None)
|
||||
mongodb_db = args.pop("mongodb_db", "ultralytics")
|
||||
mongodb_collection = args.pop("mongodb_collection", "tuner_results")
|
||||
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.args.exist_ok = self.args.resume # resume w/ same tune_dir
|
||||
self.tune_dir = get_save_dir(self.args, name=self.args.name or "tune")
|
||||
self.args.name, self.args.exist_ok, self.args.resume = (None, False, False) # reset to not affect training
|
||||
self.tune_csv = self.tune_dir / "tune_results.csv"
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
self.prefix = colorstr("Tuner: ")
|
||||
callbacks.add_integration_callbacks(self)
|
||||
|
||||
# MongoDB Atlas support (optional)
|
||||
self.mongodb = None
|
||||
if mongodb_uri:
|
||||
self._init_mongodb(mongodb_uri, mongodb_db, mongodb_collection)
|
||||
|
||||
LOGGER.info(
|
||||
f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
|
||||
f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
|
||||
)
|
||||
|
||||
def _connect(self, uri: str = "mongodb+srv://username:password@cluster.mongodb.net/", max_retries: int = 3):
|
||||
"""
|
||||
Create MongoDB client with exponential backoff retry on connection failures.
|
||||
|
||||
Args:
|
||||
uri (str): MongoDB connection string with credentials and cluster information.
|
||||
max_retries (int): Maximum number of connection attempts before giving up.
|
||||
|
||||
Returns:
|
||||
(MongoClient): Connected MongoDB client instance.
|
||||
"""
|
||||
check_requirements("pymongo")
|
||||
|
||||
from pymongo import MongoClient
|
||||
from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
client = MongoClient(
|
||||
uri,
|
||||
serverSelectionTimeoutMS=30000,
|
||||
connectTimeoutMS=20000,
|
||||
socketTimeoutMS=40000,
|
||||
retryWrites=True,
|
||||
retryReads=True,
|
||||
maxPoolSize=30,
|
||||
minPoolSize=3,
|
||||
maxIdleTimeMS=60000,
|
||||
)
|
||||
client.admin.command("ping") # Test connection
|
||||
LOGGER.info(f"{self.prefix}Connected to MongoDB Atlas (attempt {attempt + 1})")
|
||||
return client
|
||||
except (ConnectionFailure, ServerSelectionTimeoutError):
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
wait_time = 2**attempt
|
||||
LOGGER.warning(
|
||||
f"{self.prefix}MongoDB connection failed (attempt {attempt + 1}), retrying in {wait_time}s..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
def _init_mongodb(self, mongodb_uri="", mongodb_db="", mongodb_collection=""):
|
||||
"""
|
||||
Initialize MongoDB connection for distributed tuning.
|
||||
|
||||
Connects to MongoDB Atlas for distributed hyperparameter optimization across multiple machines.
|
||||
Each worker saves results to a shared collection and reads the latest best hyperparameters
|
||||
from all workers for evolution.
|
||||
|
||||
Args:
|
||||
mongodb_uri (str): MongoDB connection string, e.g. 'mongodb+srv://username:password@cluster.mongodb.net/'.
|
||||
mongodb_db (str, optional): Database name.
|
||||
mongodb_collection (str, optional): Collection name.
|
||||
|
||||
Notes:
|
||||
- Creates a fitness index for fast queries of top results
|
||||
- Falls back to CSV-only mode if connection fails
|
||||
- Uses connection pooling and retry logic for production reliability
|
||||
"""
|
||||
self.mongodb = self._connect(mongodb_uri)
|
||||
self.collection = self.mongodb[mongodb_db][mongodb_collection]
|
||||
self.collection.create_index([("fitness", -1)], background=True)
|
||||
LOGGER.info(f"{self.prefix}Using MongoDB Atlas for distributed tuning")
|
||||
|
||||
def _get_mongodb_results(self, n: int = 5) -> list:
|
||||
"""
|
||||
Get top N results from MongoDB sorted by fitness.
|
||||
|
||||
Args:
|
||||
n (int): Number of top results to retrieve.
|
||||
|
||||
Returns:
|
||||
(list[dict]): List of result documents with fitness scores and hyperparameters.
|
||||
"""
|
||||
try:
|
||||
return list(self.collection.find().sort("fitness", -1).limit(n))
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _save_to_mongodb(self, fitness: float, hyperparameters: dict[str, float], metrics: dict, iteration: int):
|
||||
"""
|
||||
Save results to MongoDB with proper type conversion.
|
||||
|
||||
Args:
|
||||
fitness (float): Fitness score achieved with these hyperparameters.
|
||||
hyperparameters (dict[str, float]): Dictionary of hyperparameter values.
|
||||
metrics (dict): Complete training metrics dictionary (mAP, precision, recall, losses, etc.).
|
||||
iteration (int): Current iteration number.
|
||||
"""
|
||||
try:
|
||||
self.collection.insert_one(
|
||||
{
|
||||
"fitness": float(fitness),
|
||||
"hyperparameters": {k: (v.item() if hasattr(v, "item") else v) for k, v in hyperparameters.items()},
|
||||
"metrics": metrics,
|
||||
"timestamp": datetime.now(),
|
||||
"iteration": iteration,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{self.prefix}MongoDB save failed: {e}")
|
||||
|
||||
def _sync_mongodb_to_csv(self):
|
||||
"""
|
||||
Sync MongoDB results to CSV for plotting compatibility.
|
||||
|
||||
Downloads all results from MongoDB and writes them to the local CSV file in chronological order. This enables
|
||||
the existing plotting functions to work seamlessly with distributed MongoDB data.
|
||||
"""
|
||||
try:
|
||||
# Get all results from MongoDB
|
||||
all_results = list(self.collection.find().sort("iteration", 1))
|
||||
if not all_results:
|
||||
return
|
||||
|
||||
# Write to CSV
|
||||
headers = ",".join(["fitness"] + list(self.space.keys())) + "\n"
|
||||
with open(self.tune_csv, "w", encoding="utf-8") as f:
|
||||
f.write(headers)
|
||||
for result in all_results:
|
||||
fitness = result["fitness"]
|
||||
hyp_values = [result["hyperparameters"][k] for k in self.space.keys()]
|
||||
log_row = [round(fitness, 5)] + hyp_values
|
||||
f.write(",".join(map(str, log_row)) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
||||
|
||||
def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
|
||||
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
|
||||
k = min(k, len(x))
|
||||
# fitness weights (shifted to >0); fallback to uniform if degenerate
|
||||
weights = x[:, 0] - x[:, 0].min() + 1e-6
|
||||
if not np.isfinite(weights).all() or weights.sum() == 0:
|
||||
weights = np.ones_like(weights)
|
||||
idxs = random.choices(range(len(x)), weights=weights, k=k)
|
||||
parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
|
||||
lo, hi = parents_mat.min(0), parents_mat.max(0)
|
||||
span = hi - lo
|
||||
return np.random.uniform(lo - alpha * span, hi + alpha * span)
|
||||
|
||||
def _mutate(
|
||||
self,
|
||||
n: int = 9,
|
||||
mutation: float = 0.5,
|
||||
sigma: float = 0.2,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||
|
||||
Args:
|
||||
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
||||
n (int): Number of top parents to consider.
|
||||
mutation (float): Probability of a parameter mutation in any given iteration.
|
||||
sigma (float): Standard deviation for Gaussian random number generator.
|
||||
|
||||
Returns:
|
||||
(dict[str, float]): A dictionary containing mutated hyperparameters.
|
||||
"""
|
||||
x = None
|
||||
|
||||
# Try MongoDB first if available
|
||||
if self.mongodb:
|
||||
results = self._get_mongodb_results(n)
|
||||
if results:
|
||||
# MongoDB already sorted by fitness DESC, so results[0] is best
|
||||
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
|
||||
elif self.collection.name in self.collection.database.list_collection_names(): # Tuner started elsewhere
|
||||
x = np.array([[0.0] + [getattr(self.args, k) for k in self.space.keys()]])
|
||||
|
||||
# Fall back to CSV if MongoDB unavailable or empty
|
||||
if x is None and self.tune_csv.exists():
|
||||
csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
if len(csv_data) > 0:
|
||||
fitness = csv_data[:, 0] # first column
|
||||
order = np.argsort(-fitness)
|
||||
x = csv_data[order][:n] # top-n sorted by fitness DESC
|
||||
|
||||
# Mutate if we have data, otherwise use defaults
|
||||
if x is not None:
|
||||
np.random.seed(int(time.time()))
|
||||
ng = len(self.space)
|
||||
|
||||
# Crossover
|
||||
genes = self._crossover(x)
|
||||
|
||||
# Mutation
|
||||
gains = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
||||
factors = np.ones(ng)
|
||||
while np.all(factors == 1): # mutate until a change occurs (prevent duplicates)
|
||||
mask = np.random.random(ng) < mutation
|
||||
step = np.random.randn(ng) * (sigma * gains)
|
||||
factors = np.where(mask, np.exp(step), 1.0).clip(0.25, 4.0)
|
||||
hyp = {k: float(genes[i] * factors[i]) for i, k in enumerate(self.space.keys())}
|
||||
else:
|
||||
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
||||
|
||||
# Constrain to limits
|
||||
for k, bounds in self.space.items():
|
||||
hyp[k] = round(min(max(hyp[k], bounds[0]), bounds[1]), 5)
|
||||
|
||||
# Update types
|
||||
if "close_mosaic" in hyp:
|
||||
hyp["close_mosaic"] = int(round(hyp["close_mosaic"]))
|
||||
|
||||
return hyp
|
||||
|
||||
def __call__(self, model=None, iterations: int = 10, cleanup: bool = True):
|
||||
"""
|
||||
Execute the hyperparameter evolution process when the Tuner instance is called.
|
||||
|
||||
This method iterates through the specified number of iterations, performing the following steps:
|
||||
1. Sync MongoDB results to CSV (if using distributed mode)
|
||||
2. Mutate hyperparameters using the best previous results or defaults
|
||||
3. Train a YOLO model with the mutated hyperparameters
|
||||
4. Log fitness scores and hyperparameters to MongoDB and/or CSV
|
||||
5. Track the best performing configuration across all iterations
|
||||
|
||||
Args:
|
||||
model (Model | None, optional): A pre-initialized YOLO model to be used for training.
|
||||
iterations (int): The number of generations to run the evolution for.
|
||||
cleanup (bool): Whether to delete iteration weights to reduce storage space during tuning.
|
||||
"""
|
||||
t0 = time.time()
|
||||
best_save_dir, best_metrics = None, None
|
||||
(self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Sync MongoDB to CSV at startup for proper resume logic
|
||||
if self.mongodb:
|
||||
self._sync_mongodb_to_csv()
|
||||
|
||||
start = 0
|
||||
if self.tune_csv.exists():
|
||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
start = x.shape[0]
|
||||
LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
|
||||
for i in range(start, iterations):
|
||||
# Linearly decay sigma from 0.2 → 0.1 over first 300 iterations
|
||||
frac = min(i / 300.0, 1.0)
|
||||
sigma_i = 0.2 - 0.1 * frac
|
||||
|
||||
# Mutate hyperparameters
|
||||
mutated_hyp = self._mutate(sigma=sigma_i)
|
||||
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
||||
|
||||
metrics = {}
|
||||
train_args = {**vars(self.args), **mutated_hyp}
|
||||
save_dir = get_save_dir(get_cfg(train_args))
|
||||
weights_dir = save_dir / "weights"
|
||||
try:
|
||||
# Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
|
||||
launch = [__import__("sys").executable, "-m", "ultralytics.cfg.__init__"] # workaround yolo not found
|
||||
cmd = [*launch, "train", *(f"{k}={v}" for k, v in train_args.items())]
|
||||
return_code = subprocess.run(cmd, check=True).returncode
|
||||
ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
|
||||
metrics = torch_load(ckpt_file)["train_metrics"]
|
||||
assert return_code == 0, "training failed"
|
||||
|
||||
# Cleanup
|
||||
time.sleep(1)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
||||
|
||||
# Save results - MongoDB takes precedence
|
||||
fitness = metrics.get("fitness", 0.0)
|
||||
if self.mongodb:
|
||||
self._save_to_mongodb(fitness, mutated_hyp, metrics, i + 1)
|
||||
self._sync_mongodb_to_csv()
|
||||
total_mongo_iterations = self.collection.count_documents({})
|
||||
if total_mongo_iterations >= iterations:
|
||||
LOGGER.info(
|
||||
f"{self.prefix}Target iterations ({iterations}) reached in MongoDB ({total_mongo_iterations}). Stopping."
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Save to CSV only if no MongoDB
|
||||
log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
|
||||
headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
|
||||
with open(self.tune_csv, "a", encoding="utf-8") as f:
|
||||
f.write(headers + ",".join(map(str, log_row)) + "\n")
|
||||
|
||||
# Get best results
|
||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
fitness = x[:, 0] # first column
|
||||
best_idx = fitness.argmax()
|
||||
best_is_current = best_idx == (i - start)
|
||||
if best_is_current:
|
||||
best_save_dir = str(save_dir)
|
||||
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
||||
for ckpt in weights_dir.glob("*.pt"):
|
||||
shutil.copy2(ckpt, self.tune_dir / "weights")
|
||||
elif cleanup and best_save_dir:
|
||||
shutil.rmtree(best_save_dir, ignore_errors=True) # remove iteration dirs to reduce storage space
|
||||
|
||||
# Plot tune results
|
||||
plot_tune_results(str(self.tune_csv))
|
||||
|
||||
# Save and print tune results
|
||||
header = (
|
||||
f"{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n"
|
||||
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
|
||||
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
|
||||
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
|
||||
f"{self.prefix}Best fitness model is {best_save_dir}"
|
||||
)
|
||||
LOGGER.info("\n" + header)
|
||||
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||
YAML.save(
|
||||
self.tune_dir / "best_hyperparameters.yaml",
|
||||
data=data,
|
||||
header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
|
||||
)
|
||||
YAML.print(self.tune_dir / "best_hyperparameters.yaml")
|
||||
370
ultralytics/engine/validator.py
Normal file
370
ultralytics/engine/validator.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Check a model's accuracy on a test or val split of a dataset.
|
||||
|
||||
Usage:
|
||||
$ yolo mode=val model=yolo11n.pt data=coco8.yaml imgsz=640
|
||||
|
||||
Usage - formats:
|
||||
$ yolo mode=val model=yolo11n.pt # PyTorch
|
||||
yolo11n.torchscript # TorchScript
|
||||
yolo11n.onnx # ONNX Runtime or OpenCV DNN with dnn=True
|
||||
yolo11n_openvino_model # OpenVINO
|
||||
yolo11n.engine # TensorRT
|
||||
yolo11n.mlpackage # CoreML (macOS-only)
|
||||
yolo11n_saved_model # TensorFlow SavedModel
|
||||
yolo11n.pb # TensorFlow GraphDef
|
||||
yolo11n.tflite # TensorFlow Lite
|
||||
yolo11n_edgetpu.tflite # TensorFlow Edge TPU
|
||||
yolo11n_paddle_model # PaddlePaddle
|
||||
yolo11n.mnn # MNN
|
||||
yolo11n_ncnn_model # NCNN
|
||||
yolo11n_imx_model # Sony IMX
|
||||
yolo11n_rknn_model # Rockchip RKNN
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
|
||||
from ultralytics.nn.autobackend import AutoBackend
|
||||
from ultralytics.utils import LOGGER, TQDM, callbacks, colorstr, emojis
|
||||
from ultralytics.utils.checks import check_imgsz
|
||||
from ultralytics.utils.ops import Profile
|
||||
from ultralytics.utils.torch_utils import attempt_compile, select_device, smart_inference_mode, unwrap_model
|
||||
|
||||
|
||||
class BaseValidator:
|
||||
"""
|
||||
A base class for creating validators.
|
||||
|
||||
This class provides the foundation for validation processes, including model evaluation, metric computation, and
|
||||
result visualization.
|
||||
|
||||
Attributes:
|
||||
args (SimpleNamespace): Configuration for the validator.
|
||||
dataloader (DataLoader): Dataloader to use for validation.
|
||||
model (nn.Module): Model to validate.
|
||||
data (dict): Data dictionary containing dataset information.
|
||||
device (torch.device): Device to use for validation.
|
||||
batch_i (int): Current batch index.
|
||||
training (bool): Whether the model is in training mode.
|
||||
names (dict): Class names mapping.
|
||||
seen (int): Number of images seen so far during validation.
|
||||
stats (dict): Statistics collected during validation.
|
||||
confusion_matrix: Confusion matrix for classification evaluation.
|
||||
nc (int): Number of classes.
|
||||
iouv (torch.Tensor): IoU thresholds from 0.50 to 0.95 in spaces of 0.05.
|
||||
jdict (list): List to store JSON validation results.
|
||||
speed (dict): Dictionary with keys 'preprocess', 'inference', 'loss', 'postprocess' and their respective
|
||||
batch processing times in milliseconds.
|
||||
save_dir (Path): Directory to save results.
|
||||
plots (dict): Dictionary to store plots for visualization.
|
||||
callbacks (dict): Dictionary to store various callback functions.
|
||||
stride (int): Model stride for padding calculations.
|
||||
loss (torch.Tensor): Accumulated loss during training validation.
|
||||
|
||||
Methods:
|
||||
__call__: Execute validation process, running inference on dataloader and computing performance metrics.
|
||||
match_predictions: Match predictions to ground truth objects using IoU.
|
||||
add_callback: Append the given callback to the specified event.
|
||||
run_callbacks: Run all callbacks associated with a specified event.
|
||||
get_dataloader: Get data loader from dataset path and batch size.
|
||||
build_dataset: Build dataset from image path.
|
||||
preprocess: Preprocess an input batch.
|
||||
postprocess: Postprocess the predictions.
|
||||
init_metrics: Initialize performance metrics for the YOLO model.
|
||||
update_metrics: Update metrics based on predictions and batch.
|
||||
finalize_metrics: Finalize and return all metrics.
|
||||
get_stats: Return statistics about the model's performance.
|
||||
print_results: Print the results of the model's predictions.
|
||||
get_desc: Get description of the YOLO model.
|
||||
on_plot: Register plots for visualization.
|
||||
plot_val_samples: Plot validation samples during training.
|
||||
plot_predictions: Plot YOLO model predictions on batch images.
|
||||
pred_to_json: Convert predictions to JSON format.
|
||||
eval_json: Evaluate and return JSON format of prediction statistics.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initialize a BaseValidator instance.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (SimpleNamespace, optional): Configuration for the validator.
|
||||
_callbacks (dict, optional): Dictionary to store various callback functions.
|
||||
"""
|
||||
import torchvision # noqa (import here so torchvision import time not recorded in postprocess time)
|
||||
|
||||
self.args = get_cfg(overrides=args)
|
||||
self.dataloader = dataloader
|
||||
self.stride = None
|
||||
self.data = None
|
||||
self.device = None
|
||||
self.batch_i = None
|
||||
self.training = True
|
||||
self.names = None
|
||||
self.seen = None
|
||||
self.stats = None
|
||||
self.confusion_matrix = None
|
||||
self.nc = None
|
||||
self.iouv = None
|
||||
self.jdict = None
|
||||
self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
|
||||
|
||||
self.save_dir = save_dir or get_save_dir(self.args)
|
||||
(self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
if self.args.conf is None:
|
||||
self.args.conf = 0.01 if self.args.task == "obb" else 0.001 # reduce OBB val memory usage
|
||||
self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
|
||||
|
||||
self.plots = {}
|
||||
self.callbacks = _callbacks or callbacks.get_default_callbacks()
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(self, trainer=None, model=None):
|
||||
"""
|
||||
Execute validation process, running inference on dataloader and computing performance metrics.
|
||||
|
||||
Args:
|
||||
trainer (object, optional): Trainer object that contains the model to validate.
|
||||
model (nn.Module, optional): Model to validate if not using a trainer.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing validation statistics.
|
||||
"""
|
||||
self.training = trainer is not None
|
||||
augment = self.args.augment and (not self.training)
|
||||
if self.training:
|
||||
self.device = trainer.device
|
||||
self.data = trainer.data
|
||||
# Force FP16 val during training
|
||||
self.args.half = self.device.type != "cpu" and trainer.amp
|
||||
model = trainer.ema.ema or trainer.model
|
||||
if trainer.args.compile and hasattr(model, "_orig_mod"):
|
||||
model = model._orig_mod # validate non-compiled original model to avoid issues
|
||||
model = model.half() if self.args.half else model.float()
|
||||
self.loss = torch.zeros_like(trainer.loss_items, device=trainer.device)
|
||||
self.args.plots &= trainer.stopper.possible_stop or (trainer.epoch == trainer.epochs - 1)
|
||||
model.eval()
|
||||
else:
|
||||
if str(self.args.model).endswith(".yaml") and model is None:
|
||||
LOGGER.warning("validating an untrained model YAML will result in 0 mAP.")
|
||||
callbacks.add_integration_callbacks(self)
|
||||
model = AutoBackend(
|
||||
model=model or self.args.model,
|
||||
device=select_device(self.args.device),
|
||||
dnn=self.args.dnn,
|
||||
data=self.args.data,
|
||||
fp16=self.args.half,
|
||||
)
|
||||
self.device = model.device # update device
|
||||
self.args.half = model.fp16 # update half
|
||||
stride, pt, jit = model.stride, model.pt, model.jit
|
||||
imgsz = check_imgsz(self.args.imgsz, stride=stride)
|
||||
if not (pt or jit or getattr(model, "dynamic", False)):
|
||||
self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1
|
||||
LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})")
|
||||
|
||||
if str(self.args.data).rsplit(".", 1)[-1] in {"yaml", "yml"}:
|
||||
self.data = check_det_dataset(self.args.data)
|
||||
elif self.args.task == "classify":
|
||||
self.data = check_cls_dataset(self.args.data, split=self.args.split)
|
||||
else:
|
||||
raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
|
||||
|
||||
if self.device.type in {"cpu", "mps"}:
|
||||
self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
|
||||
if not (pt or (getattr(model, "dynamic", False) and not model.imx)):
|
||||
self.args.rect = False
|
||||
self.stride = model.stride # used in get_dataloader() for padding
|
||||
self.dataloader = self.dataloader or self.get_dataloader(self.data.get(self.args.split), self.args.batch)
|
||||
|
||||
model.eval()
|
||||
if self.args.compile:
|
||||
model = attempt_compile(model, device=self.device)
|
||||
model.warmup(imgsz=(1 if pt else self.args.batch, self.data["channels"], imgsz, imgsz)) # warmup
|
||||
|
||||
self.run_callbacks("on_val_start")
|
||||
dt = (
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
Profile(device=self.device),
|
||||
)
|
||||
bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
|
||||
self.init_metrics(unwrap_model(model))
|
||||
self.jdict = [] # empty before each val
|
||||
for batch_i, batch in enumerate(bar):
|
||||
self.run_callbacks("on_val_batch_start")
|
||||
self.batch_i = batch_i
|
||||
# Preprocess
|
||||
with dt[0]:
|
||||
batch = self.preprocess(batch)
|
||||
|
||||
# Inference
|
||||
with dt[1]:
|
||||
preds = model(batch["img"], augment=augment)
|
||||
|
||||
# Loss
|
||||
with dt[2]:
|
||||
if self.training:
|
||||
self.loss += model.loss(batch, preds)[1]
|
||||
|
||||
# Postprocess
|
||||
with dt[3]:
|
||||
preds = self.postprocess(preds)
|
||||
|
||||
self.update_metrics(preds, batch)
|
||||
if self.args.plots and batch_i < 3:
|
||||
self.plot_val_samples(batch, batch_i)
|
||||
self.plot_predictions(batch, preds, batch_i)
|
||||
|
||||
self.run_callbacks("on_val_batch_end")
|
||||
stats = self.get_stats()
|
||||
self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
|
||||
self.finalize_metrics()
|
||||
self.print_results()
|
||||
self.run_callbacks("on_val_end")
|
||||
if self.training:
|
||||
model.float()
|
||||
results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
|
||||
return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
|
||||
else:
|
||||
LOGGER.info(
|
||||
"Speed: {:.1f}ms preprocess, {:.1f}ms inference, {:.1f}ms loss, {:.1f}ms postprocess per image".format(
|
||||
*tuple(self.speed.values())
|
||||
)
|
||||
)
|
||||
if self.args.save_json and self.jdict:
|
||||
with open(str(self.save_dir / "predictions.json"), "w", encoding="utf-8") as f:
|
||||
LOGGER.info(f"Saving {f.name}...")
|
||||
json.dump(self.jdict, f) # flatten and save
|
||||
stats = self.eval_json(stats) # update stats
|
||||
if self.args.plots or self.args.save_json:
|
||||
LOGGER.info(f"Results saved to {colorstr('bold', self.save_dir)}")
|
||||
return stats
|
||||
|
||||
def match_predictions(
|
||||
self, pred_classes: torch.Tensor, true_classes: torch.Tensor, iou: torch.Tensor, use_scipy: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Match predictions to ground truth objects using IoU.
|
||||
|
||||
Args:
|
||||
pred_classes (torch.Tensor): Predicted class indices of shape (N,).
|
||||
true_classes (torch.Tensor): Target class indices of shape (M,).
|
||||
iou (torch.Tensor): An NxM tensor containing the pairwise IoU values for predictions and ground truth.
|
||||
use_scipy (bool, optional): Whether to use scipy for matching (more precise).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Correct tensor of shape (N, 10) for 10 IoU thresholds.
|
||||
"""
|
||||
# Dx10 matrix, where D - detections, 10 - IoU thresholds
|
||||
correct = np.zeros((pred_classes.shape[0], self.iouv.shape[0])).astype(bool)
|
||||
# LxD matrix where L - labels (rows), D - detections (columns)
|
||||
correct_class = true_classes[:, None] == pred_classes
|
||||
iou = iou * correct_class # zero out the wrong classes
|
||||
iou = iou.cpu().numpy()
|
||||
for i, threshold in enumerate(self.iouv.cpu().tolist()):
|
||||
if use_scipy:
|
||||
# WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
|
||||
import scipy # scope import to avoid importing for all commands
|
||||
|
||||
cost_matrix = iou * (iou >= threshold)
|
||||
if cost_matrix.any():
|
||||
labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix)
|
||||
valid = cost_matrix[labels_idx, detections_idx] > 0
|
||||
if valid.any():
|
||||
correct[detections_idx[valid], i] = True
|
||||
else:
|
||||
matches = np.nonzero(iou >= threshold) # IoU > threshold and classes match
|
||||
matches = np.array(matches).T
|
||||
if matches.shape[0]:
|
||||
if matches.shape[0] > 1:
|
||||
matches = matches[iou[matches[:, 0], matches[:, 1]].argsort()[::-1]]
|
||||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=pred_classes.device)
|
||||
|
||||
def add_callback(self, event: str, callback):
|
||||
"""Append the given callback to the specified event."""
|
||||
self.callbacks[event].append(callback)
|
||||
|
||||
def run_callbacks(self, event: str):
|
||||
"""Run all callbacks associated with a specified event."""
|
||||
for callback in self.callbacks.get(event, []):
|
||||
callback(self)
|
||||
|
||||
def get_dataloader(self, dataset_path, batch_size):
|
||||
"""Get data loader from dataset path and batch size."""
|
||||
raise NotImplementedError("get_dataloader function not implemented for this validator")
|
||||
|
||||
def build_dataset(self, img_path):
|
||||
"""Build dataset from image path."""
|
||||
raise NotImplementedError("build_dataset function not implemented in validator")
|
||||
|
||||
def preprocess(self, batch):
|
||||
"""Preprocess an input batch."""
|
||||
return batch
|
||||
|
||||
def postprocess(self, preds):
|
||||
"""Postprocess the predictions."""
|
||||
return preds
|
||||
|
||||
def init_metrics(self, model):
|
||||
"""Initialize performance metrics for the YOLO model."""
|
||||
pass
|
||||
|
||||
def update_metrics(self, preds, batch):
|
||||
"""Update metrics based on predictions and batch."""
|
||||
pass
|
||||
|
||||
def finalize_metrics(self):
|
||||
"""Finalize and return all metrics."""
|
||||
pass
|
||||
|
||||
def get_stats(self):
|
||||
"""Return statistics about the model's performance."""
|
||||
return {}
|
||||
|
||||
def print_results(self):
|
||||
"""Print the results of the model's predictions."""
|
||||
pass
|
||||
|
||||
def get_desc(self):
|
||||
"""Get description of the YOLO model."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def metric_keys(self):
|
||||
"""Return the metric keys used in YOLO training/validation."""
|
||||
return []
|
||||
|
||||
def on_plot(self, name, data=None):
|
||||
"""Register plots for visualization."""
|
||||
self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
|
||||
|
||||
def plot_val_samples(self, batch, ni):
|
||||
"""Plot validation samples during training."""
|
||||
pass
|
||||
|
||||
def plot_predictions(self, batch, preds, ni):
|
||||
"""Plot YOLO model predictions on batch images."""
|
||||
pass
|
||||
|
||||
def pred_to_json(self, preds, batch):
|
||||
"""Convert predictions to JSON format."""
|
||||
pass
|
||||
|
||||
def eval_json(self, stats):
|
||||
"""Evaluate and return JSON format of prediction statistics."""
|
||||
pass
|
||||
Reference in New Issue
Block a user