init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,120 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
from __future__ import annotations
import os
from copy import deepcopy
import numpy as np
import torch
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
from ultralytics.utils.torch_utils import autocast, profile_ops
def check_train_batch_size(
model: torch.nn.Module,
imgsz: int = 640,
amp: bool = True,
batch: int | float = -1,
max_num_obj: int = 1,
) -> int:
"""
Compute optimal YOLO training batch size using the autobatch() function.
Args:
model (torch.nn.Module): YOLO model to check batch size for.
imgsz (int, optional): Image size used for training.
amp (bool, optional): Use automatic mixed precision if True.
batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
max_num_obj (int, optional): The maximum number of objects from dataset.
Returns:
(int): Optimal batch size computed using the autobatch() function.
Notes:
If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
Otherwise, a default fraction of 0.6 is used.
"""
with autocast(enabled=amp):
return autobatch(
deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
)
def autobatch(
model: torch.nn.Module,
imgsz: int = 640,
fraction: float = 0.60,
batch_size: int = DEFAULT_CFG.batch,
max_num_obj: int = 1,
) -> int:
"""
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
Args:
model (torch.nn.Module): YOLO model to compute batch size for.
imgsz (int, optional): The image size used as input for the YOLO model.
fraction (float, optional): The fraction of available CUDA memory to use.
batch_size (int, optional): The default batch size to use if an error is detected.
max_num_obj (int, optional): The maximum number of objects from dataset.
Returns:
(int): The optimal batch size.
"""
# Check device
prefix = colorstr("AutoBatch: ")
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
device = next(model.parameters()).device # get model device
if device.type in {"cpu", "mps"}:
LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}")
return batch_size
if torch.backends.cudnn.benchmark:
LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
return batch_size
# Inspect CUDA memory
gb = 1 << 30 # bytes to GiB (1024 ** 3)
d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0'
properties = torch.cuda.get_device_properties(device) # device properties
t = properties.total_memory / gb # GiB total
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
f = t - (r + a) # GiB free
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
# Profile batch sizes
batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
try:
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
# Fit a solution
xy = [
[x, y[2]]
for i, (x, y) in enumerate(zip(batch_sizes, results))
if y # valid result
and isinstance(y[2], (int, float)) # is numeric
and 0 < y[2] < t # between 0 and GPU limit
and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory
]
fit_x, fit_y = zip(*xy) if xy else ([], [])
p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space
b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)
if None in results: # some sizes failed
i = results.index(None) # first fail index
if b >= batch_sizes[i]: # y intercept above failure point
b = batch_sizes[max(i - 1, 0)] # select prior safe point
if b < 1 or b > 1024: # b outside of safe range
LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.")
b = batch_size
fraction = (np.polyval(p, b) + r + a) / t # predicted fraction
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
return b
except Exception as e:
LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.")
return batch_size
finally:
torch.cuda.empty_cache()

View File

@@ -0,0 +1,206 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from typing import Any
from ultralytics.utils import LOGGER
from ultralytics.utils.checks import check_requirements
class GPUInfo:
"""
Manages NVIDIA GPU information via pynvml with robust error handling.
Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle
GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml
library by logging warnings and disabling related features, preventing application crashes.
Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
selection. Manages NVML initialization and shutdown internally.
Attributes:
pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,
False otherwise.
gpu_stats (list[dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on
initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),
'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),
'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.
Methods:
refresh_stats: Refresh the internal gpu_stats list by querying NVML.
print_status: Print GPU status in a compact table format using current stats.
select_idle_gpu: Select the most idle GPUs based on utilization and free memory.
shutdown: Shut down NVML if it was initialized.
Examples:
Initialize GPUInfo and print status
>>> gpu_info = GPUInfo()
>>> gpu_info.print_status()
Select idle GPUs with minimum memory requirements
>>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)
>>> print(f"Selected GPU indices: {selected}")
"""
def __init__(self):
"""Initialize GPUInfo, attempting to import and initialize pynvml."""
self.pynvml: Any | None = None
self.nvml_available: bool = False
self.gpu_stats: list[dict[str, Any]] = []
try:
check_requirements("nvidia-ml-py>=12.0.0")
self.pynvml = __import__("pynvml")
self.pynvml.nvmlInit()
self.nvml_available = True
self.refresh_stats()
except Exception as e:
LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
def __del__(self):
"""Ensure NVML is shut down when the object is garbage collected."""
self.shutdown()
def shutdown(self):
"""Shut down NVML if it was initialized."""
if self.nvml_available and self.pynvml:
try:
self.pynvml.nvmlShutdown()
except Exception:
pass
self.nvml_available = False
def refresh_stats(self):
"""Refresh the internal gpu_stats list by querying NVML."""
self.gpu_stats = []
if not self.nvml_available or not self.pynvml:
return
try:
device_count = self.pynvml.nvmlDeviceGetCount()
self.gpu_stats.extend(self._get_device_stats(i) for i in range(device_count))
except Exception as e:
LOGGER.warning(f"Error during device query: {e}")
self.gpu_stats = []
def _get_device_stats(self, index: int) -> dict[str, Any]:
"""Get stats for a single GPU device."""
handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
def safe_get(func, *args, default=-1, divisor=1):
try:
val = func(*args)
return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val
except Exception:
return default
temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1)
return {
"index": index,
"name": self.pynvml.nvmlDeviceGetName(handle),
"utilization": util.gpu if util else -1,
"memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB
"memory_total": memory.total >> 20 if memory else -1,
"memory_free": memory.free >> 20 if memory else -1,
"temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
"power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W
"power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
}
def print_status(self):
"""Print GPU status in a compact table format using current stats."""
self.refresh_stats()
if not self.gpu_stats:
LOGGER.warning("No GPU stats available.")
return
stats = self.gpu_stats
name_len = max(len(gpu.get("name", "N/A")) for gpu in stats)
hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}"
LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}")
for gpu in stats:
u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A "
m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A "
t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A "
p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A "
LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}")
LOGGER.info(f"{'-' * len(hdr)}\n")
def select_idle_gpu(
self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0
) -> list[int]:
"""
Select the most idle GPUs based on utilization and free memory.
Args:
count (int): The number of idle GPUs to select.
min_memory_fraction (float): Minimum free memory required as a fraction of total memory.
min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.
Returns:
(list[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).
Notes:
Returns fewer than 'count' if not enough qualify or exist.
Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
"""
assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}"
assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}"
LOGGER.info(
f"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%..."
)
if count <= 0:
return []
self.refresh_stats()
if not self.gpu_stats:
LOGGER.warning("NVML stats unavailable.")
return []
# Filter and sort eligible GPUs
eligible_gpus = [
gpu
for gpu in self.gpu_stats
if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction
and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100
]
eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
# Select top 'count' indices
selected = [gpu["index"] for gpu in eligible_gpus[:count]]
if selected:
LOGGER.info(f"Selected idle CUDA devices {selected}")
else:
LOGGER.warning(
f"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%)."
)
return selected
if __name__ == "__main__":
required_free_mem_fraction = 0.2 # Require 20% free VRAM
required_free_util_fraction = 0.2 # Require 20% free utilization
num_gpus_to_select = 1
gpu_info = GPUInfo()
gpu_info.print_status()
if selected := gpu_info.select_idle_gpu(
count=num_gpus_to_select,
min_memory_fraction=required_free_mem_fraction,
min_util_fraction=required_free_util_fraction,
):
print(f"\n==> Using selected GPU indices: {selected}")
devices = [f"cuda:{idx}" for idx in selected]
print(f" Target devices: {devices}")

View File

@@ -0,0 +1,728 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Benchmark a YOLO model formats for speed and accuracy.
Usage:
from ultralytics.utils.benchmarks import ProfileModels, benchmark
ProfileModels(['yolo11n.yaml', 'yolov8s.yaml']).run()
benchmark(model='yolo11n.pt', imgsz=160)
Format | `format=argument` | Model
--- | --- | ---
PyTorch | - | yolo11n.pt
TorchScript | `torchscript` | yolo11n.torchscript
ONNX | `onnx` | yolo11n.onnx
OpenVINO | `openvino` | yolo11n_openvino_model/
TensorRT | `engine` | yolo11n.engine
CoreML | `coreml` | yolo11n.mlpackage
TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
TensorFlow GraphDef | `pb` | yolo11n.pb
TensorFlow Lite | `tflite` | yolo11n.tflite
TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
TensorFlow.js | `tfjs` | yolo11n_web_model/
PaddlePaddle | `paddle` | yolo11n_paddle_model/
MNN | `mnn` | yolo11n.mnn
NCNN | `ncnn` | yolo11n_ncnn_model/
IMX | `imx` | yolo11n_imx_model/
RKNN | `rknn` | yolo11n_rknn_model/
"""
from __future__ import annotations
import glob
import os
import platform
import re
import shutil
import time
from pathlib import Path
import numpy as np
import torch.cuda
from ultralytics import YOLO, YOLOWorld
from ultralytics.cfg import TASK2DATA, TASK2METRIC
from ultralytics.engine.exporter import export_formats
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML
from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip
from ultralytics.utils.downloads import safe_download
from ultralytics.utils.files import file_size
from ultralytics.utils.torch_utils import get_cpu_info, select_device
def benchmark(
model=WEIGHTS_DIR / "yolo11n.pt",
data=None,
imgsz=160,
half=False,
int8=False,
device="cpu",
verbose=False,
eps=1e-3,
format="",
**kwargs,
):
"""
Benchmark a YOLO model across different formats for speed and accuracy.
Args:
model (str | Path): Path to the model file or directory.
data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.
imgsz (int): Image size for the benchmark.
half (bool): Use half-precision for the model if True.
int8 (bool): Use int8-precision for the model if True.
device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.
verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
eps (float): Epsilon value for divide by zero prevention.
format (str): Export format for benchmarking. If not supplied all formats are benchmarked.
**kwargs (Any): Additional keyword arguments for exporter.
Returns:
(polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric,
and inference time.
Examples:
Benchmark a YOLO model with default settings:
>>> from ultralytics.utils.benchmarks import benchmark
>>> benchmark(model="yolo11n.pt", imgsz=640)
"""
imgsz = check_imgsz(imgsz)
assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
import polars as pl # scope for faster 'import ultralytics'
pl.Config.set_tbl_cols(-1) # Show all columns
pl.Config.set_tbl_rows(-1) # Show all rows
pl.Config.set_tbl_width_chars(-1) # No width limit
pl.Config.set_tbl_hide_column_data_types(True) # Hide data types
pl.Config.set_tbl_hide_dataframe_shape(True) # Hide shape info
pl.Config.set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED")
device = select_device(device, verbose=False)
if isinstance(model, (str, Path)):
model = YOLO(model)
is_end2end = getattr(model.model.model[-1], "end2end", False)
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
y = []
t0 = time.time()
format_arg = format.lower()
if format_arg:
formats = frozenset(export_formats()["Argument"])
assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
emoji, filename = "", None # export defaults
try:
if format_arg and format_arg != format:
continue
# Checks
if format == "pb":
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
elif format == "edgetpu":
assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
elif format in {"coreml", "tfjs"}:
assert MACOS or (LINUX and not ARM64), (
"CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
)
if format == "coreml":
assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
# assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
if format == "paddle":
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
if format == "mnn":
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
if format == "ncnn":
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
if format == "imx":
assert not is_end2end
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
assert model.task == "detect", "IMX only supported for detection task"
assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
if format == "rknn":
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
assert not is_end2end, "End-to-end models not supported by RKNN yet"
assert LINUX, "RKNN only supported on Linux"
assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices"
if "cpu" in device.type:
assert cpu, "inference not supported on CPU"
if "cuda" in device.type:
assert gpu, "inference not supported on GPU"
# Export
if format == "-":
filename = model.pt_path or model.ckpt_path or model.model_name
exported_model = model # PyTorch format
else:
filename = model.export(
imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs
)
exported_model = YOLO(filename, task=model.task)
assert suffix in str(filename), "export failed"
emoji = "" # indicates export succeeded
# Predict
assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
assert format not in {"edgetpu", "tfjs"}, "inference not supported"
assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
if format == "ncnn":
assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
# Validate
results = exported_model.val(
data=data,
batch=1,
imgsz=imgsz,
plots=False,
device=device,
half=half,
int8=int8,
verbose=False,
conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001
)
metric, speed = results.results_dict[key], results.speed["inference"]
fps = round(1000 / (speed + eps), 2) # frames per second
y.append([name, "", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])
except Exception as e:
if verbose:
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
LOGGER.error(f"Benchmark failure for {name}: {e}")
y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference
# Print results
check_yolo(device=device) # print system info
df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"], orient="row")
df = df.with_row_index(" ", offset=1) # add index info
df_display = df.with_columns(pl.all().cast(pl.String).fill_null("-"))
name = model.model_name
dt = time.time() - t0
legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df_display}\n"
LOGGER.info(s)
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
f.write(s)
if verbose and isinstance(verbose, float):
metrics = df[key].to_numpy() # values to compare to floor
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
return df_display
class RF100Benchmark:
"""
Benchmark YOLO model performance across various formats for speed and accuracy.
This class provides functionality to benchmark YOLO models on the RF100 dataset collection.
Attributes:
ds_names (list[str]): Names of datasets used for benchmarking.
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
rf (Roboflow): Roboflow instance for accessing datasets.
val_metrics (list[str]): Metrics used for validation.
Methods:
set_key: Set Roboflow API key for accessing datasets.
parse_dataset: Parse dataset links and download datasets.
fix_yaml: Fix train and validation paths in YAML files.
evaluate: Evaluate model performance on validation results.
"""
def __init__(self):
"""Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats."""
self.ds_names = []
self.ds_cfg_list = []
self.rf = None
self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
def set_key(self, api_key: str):
"""
Set Roboflow API key for processing.
Args:
api_key (str): The API key.
Examples:
Set the Roboflow API key for accessing datasets:
>>> benchmark = RF100Benchmark()
>>> benchmark.set_key("your_roboflow_api_key")
"""
check_requirements("roboflow")
from roboflow import Roboflow
self.rf = Roboflow(api_key=api_key)
def parse_dataset(self, ds_link_txt: str = "datasets_links.txt"):
"""
Parse dataset links and download datasets.
Args:
ds_link_txt (str): Path to the file containing dataset links.
Returns:
ds_names (list[str]): List of dataset names.
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
Examples:
>>> benchmark = RF100Benchmark()
>>> benchmark.set_key("api_key")
>>> benchmark.parse_dataset("datasets_links.txt")
"""
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
os.chdir("rf-100")
os.mkdir("ultralytics-benchmarks")
safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")
with open(ds_link_txt, encoding="utf-8") as file:
for line in file:
try:
_, url, workspace, project, version = re.split("/+", line.strip())
self.ds_names.append(project)
proj_version = f"{project}-{version}"
if not Path(proj_version).exists():
self.rf.workspace(workspace).project(project).version(version).download("yolov8")
else:
LOGGER.info("Dataset already downloaded.")
self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
except Exception:
continue
return self.ds_names, self.ds_cfg_list
@staticmethod
def fix_yaml(path: Path):
"""Fix the train and validation paths in a given YAML file."""
yaml_data = YAML.load(path)
yaml_data["train"] = "train/images"
yaml_data["val"] = "valid/images"
YAML.dump(yaml_data, path)
def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):
"""
Evaluate model performance on validation results.
Args:
yaml_path (str): Path to the YAML configuration file.
val_log_file (str): Path to the validation log file.
eval_log_file (str): Path to the evaluation log file.
list_ind (int): Index of the current dataset in the list.
Returns:
(float): The mean average precision (mAP) value for the evaluated model.
Examples:
Evaluate a model on a specific dataset
>>> benchmark = RF100Benchmark()
>>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0)
"""
skip_symbols = ["🚀", "⚠️", "💡", ""]
class_names = YAML.load(yaml_path)["names"]
with open(val_log_file, encoding="utf-8") as f:
lines = f.readlines()
eval_lines = []
for line in lines:
if any(symbol in line for symbol in skip_symbols):
continue
entries = line.split(" ")
entries = list(filter(lambda val: val != "", entries))
entries = [e.strip("\n") for e in entries]
eval_lines.extend(
{
"class": entries[0],
"images": entries[1],
"targets": entries[2],
"precision": entries[3],
"recall": entries[4],
"map50": entries[5],
"map95": entries[6],
}
for e in entries
if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
)
map_val = 0.0
if len(eval_lines) > 1:
LOGGER.info("Multiple dicts found")
for lst in eval_lines:
if lst["class"] == "all":
map_val = lst["map50"]
else:
LOGGER.info("Single dict found")
map_val = [res["map50"] for res in eval_lines][0]
with open(eval_log_file, "a", encoding="utf-8") as f:
f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
return float(map_val)
class ProfileModels:
"""
ProfileModels class for profiling different models on ONNX and TensorRT.
This class profiles the performance of different models, returning results such as model speed and FLOPs.
Attributes:
paths (list[str]): Paths of the models to profile.
num_timed_runs (int): Number of timed runs for the profiling.
num_warmup_runs (int): Number of warmup runs before profiling.
min_time (float): Minimum number of seconds to profile for.
imgsz (int): Image size used in the models.
half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
trt (bool): Flag to indicate whether to profile using TensorRT.
device (torch.device): Device used for profiling.
Methods:
run: Profile YOLO models for speed and accuracy across various formats.
get_files: Get all relevant model files.
get_onnx_model_info: Extract metadata from an ONNX model.
iterative_sigma_clipping: Apply sigma clipping to remove outliers.
profile_tensorrt_model: Profile a TensorRT model.
profile_onnx_model: Profile an ONNX model.
generate_table_row: Generate a table row with model metrics.
generate_results_dict: Generate a dictionary of profiling results.
print_table: Print a formatted table of results.
Examples:
Profile models and print results
>>> from ultralytics.utils.benchmarks import ProfileModels
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
>>> profiler.run()
"""
def __init__(
self,
paths: list[str],
num_timed_runs: int = 100,
num_warmup_runs: int = 10,
min_time: float = 60,
imgsz: int = 640,
half: bool = True,
trt: bool = True,
device: torch.device | str | None = None,
):
"""
Initialize the ProfileModels class for profiling models.
Args:
paths (list[str]): List of paths of the models to be profiled.
num_timed_runs (int): Number of timed runs for the profiling.
num_warmup_runs (int): Number of warmup runs before the actual profiling starts.
min_time (float): Minimum time in seconds for profiling a model.
imgsz (int): Size of the image used during profiling.
half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
trt (bool): Flag to indicate whether to profile using TensorRT.
device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.
Notes:
FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
Examples:
Initialize and profile models
>>> from ultralytics.utils.benchmarks import ProfileModels
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
>>> profiler.run()
"""
self.paths = paths
self.num_timed_runs = num_timed_runs
self.num_warmup_runs = num_warmup_runs
self.min_time = min_time
self.imgsz = imgsz
self.half = half
self.trt = trt # run TensorRT profiling
self.device = device if isinstance(device, torch.device) else select_device(device)
def run(self):
"""
Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
Returns:
(list[dict]): List of dictionaries containing profiling results for each model.
Examples:
Profile models and print results
>>> from ultralytics.utils.benchmarks import ProfileModels
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"])
>>> results = profiler.run()
"""
files = self.get_files()
if not files:
LOGGER.warning("No matching *.pt or *.onnx files found.")
return []
table_rows = []
output = []
for file in files:
engine_file = file.with_suffix(".engine")
if file.suffix in {".pt", ".yaml", ".yml"}:
model = YOLO(str(file))
model.fuse() # to report correct params and GFLOPs in model.info()
model_info = model.info()
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
engine_file = model.export(
format="engine",
half=self.half,
imgsz=self.imgsz,
device=self.device,
verbose=False,
)
onnx_file = model.export(
format="onnx",
imgsz=self.imgsz,
device=self.device,
verbose=False,
)
elif file.suffix == ".onnx":
model_info = self.get_onnx_model_info(file)
onnx_file = file
else:
continue
t_engine = self.profile_tensorrt_model(str(engine_file))
t_onnx = self.profile_onnx_model(str(onnx_file))
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
self.print_table(table_rows)
return output
def get_files(self):
"""
Return a list of paths for all relevant model files given by the user.
Returns:
(list[Path]): List of Path objects for the model files.
"""
files = []
for path in self.paths:
path = Path(path)
if path.is_dir():
extensions = ["*.pt", "*.onnx", "*.yaml"]
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
files.append(str(path))
else:
files.extend(glob.glob(str(path)))
LOGGER.info(f"Profiling: {sorted(files)}")
return [Path(file) for file in sorted(files)]
@staticmethod
def get_onnx_model_info(onnx_file: str):
"""Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
@staticmethod
def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):
"""
Apply iterative sigma clipping to data to remove outliers.
Args:
data (np.ndarray): Input data array.
sigma (float): Number of standard deviations to use for clipping.
max_iters (int): Maximum number of iterations for the clipping process.
Returns:
(np.ndarray): Clipped data array with outliers removed.
"""
data = np.array(data)
for _ in range(max_iters):
mean, std = np.mean(data), np.std(data)
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
if len(clipped_data) == len(data):
break
data = clipped_data
return data
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
"""
Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
Args:
engine_file (str): Path to the TensorRT engine file.
eps (float): Small epsilon value to prevent division by zero.
Returns:
mean_time (float): Mean inference time in milliseconds.
std_time (float): Standard deviation of inference time in milliseconds.
"""
if not self.trt or not Path(engine_file).is_file():
return 0.0, 0.0
# Model and input
model = YOLO(engine_file)
input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify
# Warmup runs
elapsed = 0.0
for _ in range(3):
start_time = time.time()
for _ in range(self.num_warmup_runs):
model(input_data, imgsz=self.imgsz, verbose=False)
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
# Timed runs
run_times = []
for _ in TQDM(range(num_runs), desc=engine_file):
results = model(input_data, imgsz=self.imgsz, verbose=False)
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
return np.mean(run_times), np.std(run_times)
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
"""
Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
Args:
onnx_file (str): Path to the ONNX model file.
eps (float): Small epsilon value to prevent division by zero.
Returns:
mean_time (float): Mean inference time in milliseconds.
std_time (float): Standard deviation of inference time in milliseconds.
"""
check_requirements("onnxruntime")
import onnxruntime as ort
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
sess_options.intra_op_num_threads = 8 # Limit the number of threads
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
input_tensor = sess.get_inputs()[0]
input_type = input_tensor.type
dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
# Mapping ONNX datatype to numpy datatype
if "float16" in input_type:
input_dtype = np.float16
elif "float" in input_type:
input_dtype = np.float32
elif "double" in input_type:
input_dtype = np.float64
elif "int64" in input_type:
input_dtype = np.int64
elif "int32" in input_type:
input_dtype = np.int32
else:
raise ValueError(f"Unsupported ONNX datatype {input_type}")
input_data = np.random.rand(*input_shape).astype(input_dtype)
input_name = input_tensor.name
output_name = sess.get_outputs()[0].name
# Warmup runs
elapsed = 0.0
for _ in range(3):
start_time = time.time()
for _ in range(self.num_warmup_runs):
sess.run([output_name], {input_name: input_data})
elapsed = time.time() - start_time
# Compute number of runs as higher of min_time or num_timed_runs
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
# Timed runs
run_times = []
for _ in TQDM(range(num_runs), desc=onnx_file):
start_time = time.time()
sess.run([output_name], {input_name: input_data})
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
return np.mean(run_times), np.std(run_times)
def generate_table_row(
self,
model_name: str,
t_onnx: tuple[float, float],
t_engine: tuple[float, float],
model_info: tuple[float, float, float, float],
):
"""
Generate a table row string with model performance metrics.
Args:
model_name (str): Name of the model.
t_onnx (tuple): ONNX model inference time statistics (mean, std).
t_engine (tuple): TensorRT engine inference time statistics (mean, std).
model_info (tuple): Model information (layers, params, gradients, flops).
Returns:
(str): Formatted table row string with model metrics.
"""
layers, params, gradients, flops = model_info
return (
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
)
@staticmethod
def generate_results_dict(
model_name: str,
t_onnx: tuple[float, float],
t_engine: tuple[float, float],
model_info: tuple[float, float, float, float],
):
"""
Generate a dictionary of profiling results.
Args:
model_name (str): Name of the model.
t_onnx (tuple): ONNX model inference time statistics (mean, std).
t_engine (tuple): TensorRT engine inference time statistics (mean, std).
model_info (tuple): Model information (layers, params, gradients, flops).
Returns:
(dict): Dictionary containing profiling results.
"""
layers, params, gradients, flops = model_info
return {
"model/name": model_name,
"model/parameters": params,
"model/GFLOPs": round(flops, 3),
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
}
@staticmethod
def print_table(table_rows: list[str]):
"""
Print a formatted table of model profiling results.
Args:
table_rows (list[str]): List of formatted table row strings.
"""
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
headers = [
"Model",
"size<br><sup>(pixels)",
"mAP<sup>val<br>50-95",
f"Speed<br><sup>CPU ({get_cpu_info()}) ONNX<br>(ms)",
f"Speed<br><sup>{gpu} TensorRT<br>(ms)",
"params<br><sup>(M)",
"FLOPs<br><sup>(B)",
]
header = "|" + "|".join(f" {h} " for h in headers) + "|"
separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|"
LOGGER.info(f"\n\n{header}")
LOGGER.info(separator)
for row in table_rows:
LOGGER.info(row)

View File

@@ -0,0 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"

View File

@@ -0,0 +1,235 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Base callbacks for Ultralytics training, validation, prediction, and export processes."""
from collections import defaultdict
from copy import deepcopy
# Trainer callbacks ----------------------------------------------------------------------------------------------------
def on_pretrain_routine_start(trainer):
"""Called before the pretraining routine starts."""
pass
def on_pretrain_routine_end(trainer):
"""Called after the pretraining routine ends."""
pass
def on_train_start(trainer):
"""Called when the training starts."""
pass
def on_train_epoch_start(trainer):
"""Called at the start of each training epoch."""
pass
def on_train_batch_start(trainer):
"""Called at the start of each training batch."""
pass
def optimizer_step(trainer):
"""Called when the optimizer takes a step."""
pass
def on_before_zero_grad(trainer):
"""Called before the gradients are set to zero."""
pass
def on_train_batch_end(trainer):
"""Called at the end of each training batch."""
pass
def on_train_epoch_end(trainer):
"""Called at the end of each training epoch."""
pass
def on_fit_epoch_end(trainer):
"""Called at the end of each fit epoch (train + val)."""
pass
def on_model_save(trainer):
"""Called when the model is saved."""
pass
def on_train_end(trainer):
"""Called when the training ends."""
pass
def on_params_update(trainer):
"""Called when the model parameters are updated."""
pass
def teardown(trainer):
"""Called during the teardown of the training process."""
pass
# Validator callbacks --------------------------------------------------------------------------------------------------
def on_val_start(validator):
"""Called when the validation starts."""
pass
def on_val_batch_start(validator):
"""Called at the start of each validation batch."""
pass
def on_val_batch_end(validator):
"""Called at the end of each validation batch."""
pass
def on_val_end(validator):
"""Called when the validation ends."""
pass
# Predictor callbacks --------------------------------------------------------------------------------------------------
def on_predict_start(predictor):
"""Called when the prediction starts."""
pass
def on_predict_batch_start(predictor):
"""Called at the start of each prediction batch."""
pass
def on_predict_batch_end(predictor):
"""Called at the end of each prediction batch."""
pass
def on_predict_postprocess_end(predictor):
"""Called after the post-processing of the prediction ends."""
pass
def on_predict_end(predictor):
"""Called when the prediction ends."""
pass
# Exporter callbacks ---------------------------------------------------------------------------------------------------
def on_export_start(exporter):
"""Called when the model export starts."""
pass
def on_export_end(exporter):
"""Called when the model export ends."""
pass
default_callbacks = {
# Run in trainer
"on_pretrain_routine_start": [on_pretrain_routine_start],
"on_pretrain_routine_end": [on_pretrain_routine_end],
"on_train_start": [on_train_start],
"on_train_epoch_start": [on_train_epoch_start],
"on_train_batch_start": [on_train_batch_start],
"optimizer_step": [optimizer_step],
"on_before_zero_grad": [on_before_zero_grad],
"on_train_batch_end": [on_train_batch_end],
"on_train_epoch_end": [on_train_epoch_end],
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
"on_model_save": [on_model_save],
"on_train_end": [on_train_end],
"on_params_update": [on_params_update],
"teardown": [teardown],
# Run in validator
"on_val_start": [on_val_start],
"on_val_batch_start": [on_val_batch_start],
"on_val_batch_end": [on_val_batch_end],
"on_val_end": [on_val_end],
# Run in predictor
"on_predict_start": [on_predict_start],
"on_predict_batch_start": [on_predict_batch_start],
"on_predict_postprocess_end": [on_predict_postprocess_end],
"on_predict_batch_end": [on_predict_batch_end],
"on_predict_end": [on_predict_end],
# Run in exporter
"on_export_start": [on_export_start],
"on_export_end": [on_export_end],
}
def get_default_callbacks():
"""
Get the default callbacks for Ultralytics training, validation, prediction, and export processes.
Returns:
(dict): Dictionary of default callbacks for various training events. Each key represents an event during the
training process, and the corresponding value is a list of callback functions executed when that event
occurs.
Examples:
>>> callbacks = get_default_callbacks()
>>> print(list(callbacks.keys())) # show all available callback events
['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]
"""
return defaultdict(list, deepcopy(default_callbacks))
def add_integration_callbacks(instance):
"""
Add integration callbacks to the instance's callbacks dictionary.
This function loads and adds various integration callbacks to the provided instance. The specific callbacks added
depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive
additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,
and Weights & Biases.
Args:
instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.
The type of instance determines which callbacks are loaded.
Examples:
>>> from ultralytics.engine.trainer import BaseTrainer
>>> trainer = BaseTrainer()
>>> add_integration_callbacks(trainer)
"""
from .hub import callbacks as hub_cb
from .platform import callbacks as platform_cb
# Load Ultralytics callbacks
callbacks_list = [hub_cb, platform_cb]
# Load training callbacks
if "Trainer" in instance.__class__.__name__:
from .clearml import callbacks as clear_cb
from .comet import callbacks as comet_cb
from .dvc import callbacks as dvc_cb
from .mlflow import callbacks as mlflow_cb
from .neptune import callbacks as neptune_cb
from .raytune import callbacks as tune_cb
from .tensorboard import callbacks as tb_cb
from .wb import callbacks as wb_cb
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
# Add the callbacks to the callbacks dictionary
for callbacks in callbacks_list:
for k, v in callbacks.items():
if v not in instance.callbacks[k]:
instance.callbacks[k].append(v)

View File

@@ -0,0 +1,154 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["clearml"] is True # verify integration is enabled
import clearml
from clearml import Task
assert hasattr(clearml, "__version__") # verify package is not directory
except (ImportError, AssertionError):
clearml = None
def _log_debug_samples(files, title: str = "Debug Samples") -> None:
"""
Log files (images) as debug samples in the ClearML task.
Args:
files (list[Path]): A list of file paths in PosixPath format.
title (str): A title that groups together images with the same values.
"""
import re
if task := Task.current_task():
for f in files:
if f.exists():
it = re.search(r"_batch(\d+)", f.name)
iteration = int(it.groups()[0]) if it else 0
task.get_logger().report_image(
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
)
def _log_plot(title: str, plot_path: str) -> None:
"""
Log an image as a plot in the plot section of ClearML.
Args:
title (str): The title of the plot.
plot_path (str): The path to the saved image file.
"""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
Task.current_task().get_logger().report_matplotlib_figure(
title=title, series="", figure=fig, report_interactive=False
)
def on_pretrain_routine_start(trainer) -> None:
"""Initialize and connect ClearML task at the start of pretraining routine."""
try:
if task := Task.current_task():
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
# We are logging these plots and model files manually in the integration
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
from clearml.binding.matplotlib_bind import PatchedMatplotlib
PatchPyTorchModelIO.update_current_task(None)
PatchedMatplotlib.update_current_task(None)
else:
task = Task.init(
project_name=trainer.args.project or "Ultralytics",
task_name=trainer.args.name,
tags=["Ultralytics"],
output_uri=True,
reuse_last_task_id=False,
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
)
LOGGER.warning(
"ClearML Initialized a new task. If you want to run remotely, "
"please add clearml-init and connect your arguments before initializing YOLO."
)
task.connect(vars(trainer.args), name="General")
except Exception as e:
LOGGER.warning(f"ClearML installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer) -> None:
"""Log debug samples for the first epoch and report current training progress."""
if task := Task.current_task():
# Log debug samples for first epoch only
if trainer.epoch == 1:
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
# Report the current training progress
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
for k, v in trainer.lr.items():
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
def on_fit_epoch_end(trainer) -> None:
"""Report model information and metrics to logger at the end of an epoch."""
if task := Task.current_task():
# Report epoch time and validation metrics
task.get_logger().report_scalar(
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
)
for k, v in trainer.metrics.items():
title = k.split("/")[0]
task.get_logger().report_scalar(title, k, v, iteration=trainer.epoch)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for k, v in model_info_for_loggers(trainer).items():
task.get_logger().report_single_value(k, v)
def on_val_end(validator) -> None:
"""Log validation results including labels and predictions."""
if Task.current_task():
# Log validation labels and predictions
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
def on_train_end(trainer) -> None:
"""Log final model and training results on training completion."""
if task := Task.current_task():
# Log final results, confusion matrix and PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Report final metrics
for k, v in trainer.validator.metrics.results_dict.items():
task.get_logger().report_single_value(k, v)
# Log the final model
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if clearml
else {}
)

View File

@@ -0,0 +1,639 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from collections.abc import Callable
from types import SimpleNamespace
from typing import Any
import cv2
import numpy as np
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["comet"] is True # verify integration is enabled
import comet_ml
assert hasattr(comet_ml, "__version__") # verify package is not directory
import os
from pathlib import Path
# Ensures certain logging functions only run for supported tasks
COMET_SUPPORTED_TASKS = ["detect", "segment"]
# Names of plots created by Ultralytics that are logged to Comet
CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
LABEL_PLOT_NAMES = ["labels"]
SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
DETECTION_METRICS_PLOT_PREFIX = ["Box"]
RESULTS_TABLE_NAME = "results.csv"
ARGS_YAML_NAME = "args.yaml"
_comet_image_prediction_count = 0
except (ImportError, AssertionError):
comet_ml = None
def _get_comet_mode() -> str:
"""Return the Comet mode from environment variables, defaulting to 'online'."""
comet_mode = os.getenv("COMET_MODE")
if comet_mode is not None:
LOGGER.warning(
"The COMET_MODE environment variable is deprecated. "
"Please use COMET_START_ONLINE to set the Comet experiment mode. "
"To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. "
"If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created."
)
return comet_mode
return "online"
def _get_comet_model_name() -> str:
"""Return the Comet model name from environment variable or default to 'Ultralytics'."""
return os.getenv("COMET_MODEL_NAME", "Ultralytics")
def _get_eval_batch_logging_interval() -> int:
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
def _get_max_image_predictions_to_log() -> int:
"""Get the maximum number of image predictions to log from environment variables."""
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
def _scale_confidence_score(score: float) -> float:
"""Scale the confidence score by a factor specified in environment variable."""
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
return score * scale
def _should_log_confusion_matrix() -> bool:
"""Determine if the confusion matrix should be logged based on environment variable settings."""
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
def _should_log_image_predictions() -> bool:
"""Determine whether to log image predictions based on environment variable."""
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
def _resume_or_create_experiment(args: SimpleNamespace) -> None:
"""
Resume CometML experiment or create a new experiment based on args.
Ensures that the experiment object is only created in a single process during distributed training.
Args:
args (SimpleNamespace): Training arguments containing project configuration and other parameters.
"""
if RANK not in {-1, 0}:
return
# Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood.
# IF COMET_START_ONLINE is set by the user it will override COMET_MODE value.
if os.getenv("COMET_START_ONLINE") is None:
comet_mode = _get_comet_mode()
os.environ["COMET_START_ONLINE"] = "1" if comet_mode != "offline" else "0"
try:
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
experiment = comet_ml.start(project_name=_project_name)
experiment.log_parameters(vars(args))
experiment.log_others(
{
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
"log_image_predictions": _should_log_image_predictions(),
"max_image_predictions": _get_max_image_predictions_to_log(),
}
)
experiment.log_other("Created from", "ultralytics")
except Exception as e:
LOGGER.warning(f"Comet installed but not initialized correctly, not logging this run. {e}")
def _fetch_trainer_metadata(trainer) -> dict:
"""
Return metadata for YOLO training including epoch and asset saving status.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.
Returns:
(dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.
"""
curr_epoch = trainer.epoch + 1
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
curr_step = curr_epoch * train_num_steps_per_epoch
final_epoch = curr_epoch == trainer.epochs
save = trainer.args.save
save_period = trainer.args.save_period
save_interval = curr_epoch % save_period == 0
save_assets = save and save_period > 0 and save_interval and not final_epoch
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
def _scale_bounding_box_to_original_image_shape(
box, resized_image_shape, original_image_shape, ratio_pad
) -> list[float]:
"""
Scale bounding box from resized image coordinates to original image coordinates.
YOLO resizes images during training and the label values are normalized based on this resized shape.
This function rescales the bounding box labels to the original image shape.
Args:
box (torch.Tensor): Bounding box in normalized xywh format.
resized_image_shape (tuple): Shape of the resized image (height, width).
original_image_shape (tuple): Shape of the original image (height, width).
ratio_pad (tuple): Ratio and padding information for scaling.
Returns:
(list[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.
"""
resized_image_height, resized_image_width = resized_image_shape
# Convert normalized xywh format predictions to xyxy in resized scale format
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
# Scale box predictions from resized image scale back to original image scale
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
# Convert bounding box format from xyxy to xywh for Comet logging
box = ops.xyxy2xywh(box)
# Adjust xy center to correspond top-left corner
box[:2] -= box[2:] / 2
box = box.tolist()
return box
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> dict | None:
"""
Format ground truth annotations for object detection.
This function processes ground truth annotations from a batch of images for object detection tasks. It extracts
bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for
visualization or evaluation.
Args:
img_idx (int): Index of the image in the batch to process.
image_path (str | Path): Path to the image file.
batch (dict): Batch dictionary containing detection data with keys:
- 'batch_idx': Tensor of batch indices
- 'bboxes': Tensor of bounding boxes in normalized xywh format
- 'cls': Tensor of class labels
- 'ori_shape': Original image shapes
- 'resized_shape': Resized image shapes
- 'ratio_pad': Ratio and padding information
class_name_map (dict, optional): Mapping from class indices to class names.
Returns:
(dict | None): Formatted ground truth annotations with the following structure:
- 'boxes': List of box coordinates [x, y, width, height]
- 'label': Label string with format "gt_{class_name}"
- 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)
Returns None if no bounding boxes are found for the image.
"""
indices = batch["batch_idx"] == img_idx
bboxes = batch["bboxes"][indices]
if len(bboxes) == 0:
LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes labels")
return None
cls_labels = batch["cls"][indices].squeeze(1).tolist()
if class_name_map:
cls_labels = [str(class_name_map[label]) for label in cls_labels]
original_image_shape = batch["ori_shape"][img_idx]
resized_image_shape = batch["resized_shape"][img_idx]
ratio_pad = batch["ratio_pad"][img_idx]
data = []
for box, label in zip(bboxes, cls_labels):
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
data.append(
{
"boxes": [box],
"label": f"gt_{label}",
"score": _scale_confidence_score(1.0),
}
)
return {"name": "ground_truth", "data": data}
def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> dict | None:
"""
Format YOLO predictions for object detection visualization.
Args:
image_path (Path): Path to the image file.
metadata (dict): Prediction metadata containing bounding boxes and class information.
class_label_map (dict, optional): Mapping from class indices to class names.
class_map (dict, optional): Additional class mapping for label conversion.
Returns:
(dict | None): Formatted prediction annotations or None if no predictions exist.
"""
stem = image_path.stem
image_id = int(stem) if stem.isnumeric() else stem
predictions = metadata.get(image_id)
if not predictions:
LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes predictions")
return None
# apply the mapping that was used to map the predicted classes when the JSON was created
if class_label_map and class_map:
class_label_map = {class_map[k]: v for k, v in class_label_map.items()}
try:
# import pycotools utilities to decompress annotations for various tasks, e.g. segmentation
from faster_coco_eval.core.mask import decode # noqa
except ImportError:
decode = None
data = []
for prediction in predictions:
boxes = prediction["bbox"]
score = _scale_confidence_score(prediction["score"])
cls_label = prediction["category_id"]
if class_label_map:
cls_label = str(class_label_map[cls_label])
annotation_data = {"boxes": [boxes], "label": cls_label, "score": score}
if decode is not None:
# do segmentation processing only if we are able to decode it
segments = prediction.get("segmentation", None)
if segments is not None:
segments = _extract_segmentation_annotation(segments, decode)
if segments is not None:
annotation_data["points"] = segments
data.append(annotation_data)
return {"name": "prediction", "data": data}
def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> list[list[Any]] | None:
"""
Extract segmentation annotation from compressed segmentations as list of polygons.
Args:
segmentation_raw (str): Raw segmentation data in compressed format.
decode (Callable): Function to decode the compressed segmentation data.
Returns:
(list[list[Any]] | None): List of polygon points or None if extraction fails.
"""
try:
mask = decode(segmentation_raw)
contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3]
return [annotation.ravel().tolist() for annotation in annotations]
except Exception as e:
LOGGER.warning(f"Comet Failed to extract segmentation annotation: {e}")
return None
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map) -> list | None:
"""
Join the ground truth and prediction annotations if they exist.
Args:
img_idx (int): Index of the image in the batch.
image_path (Path): Path to the image file.
batch (dict): Batch data containing ground truth annotations.
prediction_metadata_map (dict): Map of prediction metadata by image ID.
class_label_map (dict): Mapping from class indices to class names.
class_map (dict): Additional class mapping for label conversion.
Returns:
(list | None): List of annotation dictionaries or None if no annotations exist.
"""
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
img_idx, image_path, batch, class_label_map
)
prediction_annotations = _format_prediction_annotations(
image_path, prediction_metadata_map, class_label_map, class_map
)
annotations = [
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
]
return [annotations] if annotations else None
def _create_prediction_metadata_map(model_predictions) -> dict:
"""Create metadata map for model predictions by grouping them based on image ID."""
pred_metadata_map = {}
for prediction in model_predictions:
pred_metadata_map.setdefault(prediction["image_id"], [])
pred_metadata_map[prediction["image_id"]].append(prediction)
return pred_metadata_map
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
"""Log the confusion matrix to Comet experiment."""
conf_mat = trainer.validator.confusion_matrix.matrix
names = list(trainer.data["names"].values()) + ["background"]
experiment.log_confusion_matrix(
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
)
def _log_images(experiment, image_paths, curr_step: int | None, annotations=None) -> None:
"""
Log images to the experiment with optional annotations.
This function logs images to a Comet ML experiment, optionally including annotation data for visualization
such as bounding boxes or segmentation masks.
Args:
experiment (comet_ml.CometExperiment): The Comet ML experiment to log images to.
image_paths (list[Path]): List of paths to images that will be logged.
curr_step (int): Current training step/iteration for tracking in the experiment timeline.
annotations (list[list[dict]], optional): Nested list of annotation dictionaries for each image. Each
annotation contains visualization data like bounding boxes, labels, and confidence scores.
"""
if annotations:
for image_path, annotation in zip(image_paths, annotations):
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
else:
for image_path in image_paths:
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
def _log_image_predictions(experiment, validator, curr_step) -> None:
"""
Log predicted boxes for a single image during training.
This function logs image predictions to a Comet ML experiment during model validation. It processes
validation data and formats both ground truth and prediction annotations for visualization in the Comet
dashboard. The function respects configured limits on the number of images to log.
Args:
experiment (comet_ml.CometExperiment): The Comet ML experiment to log to.
validator (BaseValidator): The validator instance containing validation data and predictions.
curr_step (int): The current training step for logging timeline.
Notes:
This function uses global state to track the number of logged predictions across calls.
It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.
The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.
"""
global _comet_image_prediction_count
task = validator.args.task
if task not in COMET_SUPPORTED_TASKS:
return
jdict = validator.jdict
if not jdict:
return
predictions_metadata_map = _create_prediction_metadata_map(jdict)
dataloader = validator.dataloader
class_label_map = validator.names
class_map = getattr(validator, "class_map", None)
batch_logging_interval = _get_eval_batch_logging_interval()
max_image_predictions = _get_max_image_predictions_to_log()
for batch_idx, batch in enumerate(dataloader):
if (batch_idx + 1) % batch_logging_interval != 0:
continue
image_paths = batch["im_file"]
for img_idx, image_path in enumerate(image_paths):
if _comet_image_prediction_count >= max_image_predictions:
return
image_path = Path(image_path)
annotations = _fetch_annotations(
img_idx,
image_path,
batch,
predictions_metadata_map,
class_label_map,
class_map=class_map,
)
_log_images(
experiment,
[image_path],
curr_step,
annotations=annotations,
)
_comet_image_prediction_count += 1
def _log_plots(experiment, trainer) -> None:
"""
Log evaluation plots and label plots for the experiment.
This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles
different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots
for each type.
Args:
experiment (comet_ml.CometExperiment): The Comet ML experiment to log plots to.
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save
directory information.
Examples:
>>> from ultralytics.utils.callbacks.comet import _log_plots
>>> _log_plots(experiment, trainer)
"""
plot_filenames = None
if isinstance(trainer.validator.metrics, SegmentMetrics):
plot_filenames = [
trainer.save_dir / f"{prefix}{plots}.png"
for plots in EVALUATION_PLOT_NAMES
for prefix in SEGMENT_METRICS_PLOT_PREFIX
]
elif isinstance(trainer.validator.metrics, PoseMetrics):
plot_filenames = [
trainer.save_dir / f"{prefix}{plots}.png"
for plots in EVALUATION_PLOT_NAMES
for prefix in POSE_METRICS_PLOT_PREFIX
]
elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
plot_filenames = [
trainer.save_dir / f"{prefix}{plots}.png"
for plots in EVALUATION_PLOT_NAMES
for prefix in DETECTION_METRICS_PLOT_PREFIX
]
if plot_filenames is not None:
_log_images(experiment, plot_filenames, None)
confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
_log_images(experiment, confusion_matrix_filenames, None)
if not isinstance(trainer.validator.metrics, ClassifyMetrics):
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
_log_images(experiment, label_plot_filenames, None)
def _log_model(experiment, trainer) -> None:
"""Log the best-trained model to Comet.ml."""
model_name = _get_comet_model_name()
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
def _log_image_batches(experiment, trainer, curr_step: int) -> None:
"""Log samples of image batches for train, validation, and test."""
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
_log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
def _log_asset(experiment, asset_path) -> None:
"""
Logs a specific asset file to the given experiment.
This function facilitates logging an asset, such as a file, to the provided
experiment. It enables integration with experiment tracking platforms.
Args:
experiment (comet_ml.CometExperiment): The experiment instance to which the asset will be logged.
asset_path (Path): The file path of the asset to log.
"""
experiment.log_asset(asset_path)
def _log_table(experiment, table_path) -> None:
"""
Logs a table to the provided experiment.
This function is used to log a table file to the given experiment. The table
is identified by its file path.
Args:
experiment (comet_ml.CometExperiment): The experiment object where the table file will be logged.
table_path (Path): The file path of the table to be logged.
"""
experiment.log_table(str(table_path))
def on_pretrain_routine_start(trainer) -> None:
"""Create or resume a CometML experiment at the start of a YOLO pre-training routine."""
_resume_or_create_experiment(trainer.args)
def on_train_epoch_end(trainer) -> None:
"""Log metrics and save batch images at the end of training epochs."""
experiment = comet_ml.get_running_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
def on_fit_epoch_end(trainer) -> None:
"""
Log model assets at the end of each epoch during training.
This function is called at the end of each training epoch to log metrics, learning rates, and model information
to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on
configuration settings.
The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,
it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),
and image predictions (if enabled).
Args:
trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.
Examples:
>>> # Inside a training loop
>>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML
"""
experiment = comet_ml.get_running_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
save_assets = metadata["save_assets"]
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
if curr_epoch == 1:
from ultralytics.utils.torch_utils import model_info_for_loggers
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
if not save_assets:
return
_log_model(experiment, trainer)
if _should_log_confusion_matrix():
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
if _should_log_image_predictions():
_log_image_predictions(experiment, trainer.validator, curr_step)
def on_train_end(trainer) -> None:
"""Perform operations at the end of training."""
experiment = comet_ml.get_running_experiment()
if not experiment:
return
metadata = _fetch_trainer_metadata(trainer)
curr_epoch = metadata["curr_epoch"]
curr_step = metadata["curr_step"]
plots = trainer.args.plots
_log_model(experiment, trainer)
if plots:
_log_plots(experiment, trainer)
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
_log_image_predictions(experiment, trainer.validator, curr_step)
_log_image_batches(experiment, trainer, curr_step)
# log results table
table_path = trainer.save_dir / RESULTS_TABLE_NAME
if table_path.exists():
_log_table(experiment, table_path)
# log arguments YAML
args_path = trainer.save_dir / ARGS_YAML_NAME
if args_path.exists():
_log_asset(experiment, args_path)
experiment.end()
global _comet_image_prediction_count
_comet_image_prediction_count = 0
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if comet_ml
else {}
)

View File

@@ -0,0 +1,202 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from pathlib import Path
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["dvc"] is True # verify integration is enabled
import dvclive
assert checks.check_version("dvclive", "2.11.0", verbose=True)
import os
import re
# DVCLive logger instance
live = None
_processed_plots = {}
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
# distinguish final evaluation of the best model vs last epoch validation
_training_epoch = False
except (ImportError, AssertionError, TypeError):
dvclive = None
def _log_images(path: Path, prefix: str = "") -> None:
"""
Log images at specified path with an optional prefix using DVCLive.
This function logs images found at the given path to DVCLive, organizing them by batch to enable slider
functionality in the UI. It processes image filenames to extract batch information and restructures the path
accordingly.
Args:
path (Path): Path to the image file to be logged.
prefix (str, optional): Optional prefix to add to the image name when logging.
Examples:
>>> from pathlib import Path
>>> _log_images(Path("runs/train/exp/val_batch0_pred.jpg"), prefix="validation")
"""
if live:
name = path.name
# Group images by batch to enable sliders in UI
if m := re.search(r"_batch(\d+)", name):
ni = m[1]
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
name = (Path(new_stem) / ni).with_suffix(path.suffix)
live.log_image(os.path.join(prefix, name), path)
def _log_plots(plots: dict, prefix: str = "") -> None:
"""
Log plot images for training progress if they have not been previously processed.
Args:
plots (dict): Dictionary containing plot information with timestamps.
prefix (str, optional): Optional prefix to add to the logged image paths.
"""
for name, params in plots.items():
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
_log_images(name, prefix)
_processed_plots[name] = timestamp
def _log_confusion_matrix(validator) -> None:
"""
Log confusion matrix for a validator using DVCLive.
This function processes the confusion matrix from a validator object and logs it to DVCLive by converting
the matrix into lists of target and prediction labels.
Args:
validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have
attributes: confusion_matrix.matrix, confusion_matrix.task, and names.
"""
targets = []
preds = []
matrix = validator.confusion_matrix.matrix
names = list(validator.names.values())
if validator.confusion_matrix.task == "detect":
names += ["background"]
for ti, pred in enumerate(matrix.T.astype(int)):
for pi, num in enumerate(pred):
targets.extend([names[ti]] * num)
preds.extend([names[pi]] * num)
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
def on_pretrain_routine_start(trainer) -> None:
"""Initialize DVCLive logger for training metadata during pre-training routine."""
try:
global live
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
except Exception as e:
LOGGER.warning(f"DVCLive installed but not initialized correctly, not logging this run. {e}")
def on_pretrain_routine_end(trainer) -> None:
"""Log plots related to the training process at the end of the pretraining routine."""
_log_plots(trainer.plots, "train")
def on_train_start(trainer) -> None:
"""Log the training parameters if DVCLive logging is active."""
if live:
live.log_params(trainer.args)
def on_train_epoch_start(trainer) -> None:
"""Set the global variable _training_epoch value to True at the start of training each epoch."""
global _training_epoch
_training_epoch = True
def on_fit_epoch_end(trainer) -> None:
"""
Log training metrics, model info, and advance to next step at the end of each fit epoch.
This function is called at the end of each fit epoch during training. It logs various metrics including
training loss items, validation metrics, and learning rates. On the first epoch, it also logs model
information. Additionally, it logs training and validation plots and advances the DVCLive step counter.
Args:
trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.
Notes:
This function only performs logging operations when DVCLive logging is active and during a training epoch.
The global variable _training_epoch is used to track whether the current epoch is a training epoch.
"""
global _training_epoch
if live and _training_epoch:
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value)
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
for metric, value in model_info_for_loggers(trainer).items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "train")
_log_plots(trainer.validator.plots, "val")
live.next_step()
_training_epoch = False
def on_train_end(trainer) -> None:
"""
Log best metrics, plots, and confusion matrix at the end of training.
This function is called at the conclusion of the training process to log final metrics, visualizations, and
model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,
validation plots, and confusion matrix for later analysis.
Args:
trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.
Examples:
>>> # Inside a custom training loop
>>> from ultralytics.utils.callbacks.dvc import on_train_end
>>> on_train_end(trainer) # Log final metrics and artifacts
"""
if live:
# At the end log the best metrics. It runs validator on the best model internally.
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
for metric, value in all_metrics.items():
live.log_metric(metric, value, plot=False)
_log_plots(trainer.plots, "val")
_log_plots(trainer.validator.plots, "val")
_log_confusion_matrix(trainer.validator)
if trainer.best.exists():
live.log_artifact(trainer.best, copy=True, type="model")
live.end()
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_start": on_train_start,
"on_train_epoch_start": on_train_epoch_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if dvclive
else {}
)

View File

@@ -0,0 +1,110 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import json
from time import time
from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
from ultralytics.utils import LOGGER, RANK, SETTINGS
from ultralytics.utils.events import events
def on_pretrain_routine_start(trainer):
"""Create a remote Ultralytics HUB session to log local model training."""
if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None:
trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
def on_pretrain_routine_end(trainer):
"""Initialize timers for upload rate limiting before training begins."""
if session := getattr(trainer, "hub_session", None):
# Start timer for upload rate limit
session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting
def on_fit_epoch_end(trainer):
"""Upload training progress metrics to Ultralytics HUB at the end of each epoch."""
if session := getattr(trainer, "hub_session", None):
# Upload metrics after validation ends
all_plots = {
**trainer.label_loss_items(trainer.tloss, prefix="train"),
**trainer.metrics,
}
if trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
# If any metrics failed to upload previously, add them to the queue to attempt uploading again
if session.metrics_upload_failed_queue:
session.metrics_queue.update(session.metrics_upload_failed_queue)
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
session.upload_metrics()
session.timers["metrics"] = time() # reset timer
session.metrics_queue = {} # reset queue
def on_model_save(trainer):
"""Upload model checkpoints to Ultralytics HUB with rate limiting."""
if session := getattr(trainer, "hub_session", None):
# Upload checkpoints with rate limiting
is_best = trainer.best_fitness == trainer.fitness
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
session.upload_model(trainer.epoch, trainer.last, is_best)
session.timers["ckpt"] = time() # reset timer
def on_train_end(trainer):
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
if session := getattr(trainer, "hub_session", None):
# Upload final model and metrics with exponential standoff
LOGGER.info(f"{PREFIX}Syncing final model...")
session.upload_model(
trainer.epoch,
trainer.best,
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
final=True,
)
session.alive = False # stop heartbeats
LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀")
def on_train_start(trainer):
"""Run events on train start."""
events(trainer.args, trainer.device)
def on_val_start(validator):
"""Run events on validation start."""
if not validator.training:
events(validator.args, validator.device)
def on_predict_start(predictor):
"""Run events on predict start."""
events(predictor.args, predictor.device)
def on_export_start(exporter):
"""Run events on export start."""
events(exporter.args, exporter.device)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS["hub"] is True
else {}
)

View File

@@ -0,0 +1,135 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
MLflow Logging for Ultralytics YOLO.
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
Commands:
1. To set a project name:
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
2. To set a run name:
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
3. To start a local MLflow server:
mlflow server --backend-store-uri runs/mlflow
It will by default start a local server at http://127.0.0.1:5000.
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
4. To kill all running MLflow server instances:
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
"""
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
try:
import os
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
assert SETTINGS["mlflow"] is True # verify integration is enabled
import mlflow
assert hasattr(mlflow, "__version__") # verify package is not directory
from pathlib import Path
PREFIX = colorstr("MLflow: ")
except (ImportError, AssertionError):
mlflow = None
def sanitize_dict(x: dict) -> dict:
"""Sanitize dictionary keys by removing parentheses and converting values to floats."""
return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
def on_pretrain_routine_end(trainer):
"""
Log training parameters to MLflow at the end of the pretraining routine.
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
from the trainer.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
Environment Variables:
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
"""
global mlflow
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
mlflow.set_tracking_uri(uri)
# Set experiment and run names
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics"
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
mlflow.set_experiment(experiment_name)
mlflow.autolog()
try:
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
if Path(uri).is_dir():
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
mlflow.log_params(dict(trainer.args))
except Exception as e:
LOGGER.warning(f"{PREFIX}Failed to initialize: {e}")
LOGGER.warning(f"{PREFIX}Not tracking this run")
def on_train_epoch_end(trainer):
"""Log training metrics at the end of each train epoch to MLflow."""
if mlflow:
mlflow.log_metrics(
metrics={
**sanitize_dict(trainer.lr),
**sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
},
step=trainer.epoch,
)
def on_fit_epoch_end(trainer):
"""Log training metrics at the end of each fit epoch to MLflow."""
if mlflow:
mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
def on_train_end(trainer):
"""Log model artifacts at the end of training."""
if not mlflow:
return
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
mlflow.log_artifact(str(f))
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
if keep_run_active:
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
else:
mlflow.end_run()
LOGGER.debug(f"{PREFIX}mlflow run ended")
LOGGER.info(
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
)
callbacks = (
{
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if mlflow
else {}
)

View File

@@ -0,0 +1,134 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["neptune"] is True # verify integration is enabled
import neptune
from neptune.types import File
assert hasattr(neptune, "__version__")
run = None # NeptuneAI experiment logger instance
except (ImportError, AssertionError):
neptune = None
def _log_scalars(scalars: dict, step: int = 0) -> None:
"""
Log scalars to the NeptuneAI experiment logger.
Args:
scalars (dict): Dictionary of scalar values to log to NeptuneAI.
step (int, optional): The current step or iteration number for logging.
Examples:
>>> metrics = {"mAP": 0.85, "loss": 0.32}
>>> _log_scalars(metrics, step=100)
"""
if run:
for k, v in scalars.items():
run[k].append(value=v, step=step)
def _log_images(imgs_dict: dict, group: str = "") -> None:
"""
Log images to the NeptuneAI experiment logger.
This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized
under the specified group name.
Args:
imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.
group (str, optional): Group name to organize images under in the Neptune UI.
Examples:
>>> # Log validation images
>>> _log_images({"val_batch": img_tensor}, group="validation")
"""
if run:
for k, v in imgs_dict.items():
run[f"{group}/{k}"].upload(File(v))
def _log_plot(title: str, plot_path: str) -> None:
"""Log plots to the NeptuneAI experiment logger."""
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
img = mpimg.imread(plot_path)
fig = plt.figure()
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
ax.imshow(img)
run[f"Plots/{title}"].upload(fig)
def on_pretrain_routine_start(trainer) -> None:
"""Initialize NeptuneAI run and log hyperparameters before training starts."""
try:
global run
run = neptune.init_run(
project=trainer.args.project or "Ultralytics",
name=trainer.args.name,
tags=["Ultralytics"],
)
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
except Exception as e:
LOGGER.warning(f"NeptuneAI installed but not initialized correctly, not logging this run. {e}")
def on_train_epoch_end(trainer) -> None:
"""Log training metrics and learning rate at the end of each training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
if trainer.epoch == 1:
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
def on_fit_epoch_end(trainer) -> None:
"""Log model info and validation metrics at the end of each fit epoch."""
if run and trainer.epoch == 0:
from ultralytics.utils.torch_utils import model_info_for_loggers
run["Configuration/Model"] = model_info_for_loggers(trainer)
_log_scalars(trainer.metrics, trainer.epoch + 1)
def on_val_end(validator) -> None:
"""Log validation images at the end of validation."""
if run:
# Log val_labels and val_pred
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
def on_train_end(trainer) -> None:
"""Log final results, plots, and model weights at the end of training."""
if run:
# Log final results, CM matrix + PR plots
files = [
"results.png",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
]
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
for f in files:
_log_plot(title=f.stem, plot_path=f)
# Log the final model
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_val_end": on_val_end,
"on_train_end": on_train_end,
}
if neptune
else {}
)

View File

@@ -0,0 +1,73 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import RANK, SETTINGS
def on_pretrain_routine_start(trainer):
"""Initialize and start console logging immediately at the very beginning."""
if RANK in {-1, 0}:
from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
trainer.system_logger = SystemLogger()
trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
trainer.console_logger.start_capture()
def on_pretrain_routine_end(trainer):
"""Handle pre-training routine completion event."""
pass
def on_fit_epoch_end(trainer):
"""Handle end of training epoch event and collect system metrics."""
if RANK in {-1, 0} and hasattr(trainer, "system_logger"):
system_metrics = trainer.system_logger.get_metrics()
print(system_metrics) # for debug
def on_model_save(trainer):
"""Handle model checkpoint save event."""
pass
def on_train_end(trainer):
"""Stop console capture and finalize logs."""
if logger := getattr(trainer, "console_logger", None):
logger.stop_capture()
def on_train_start(trainer):
"""Handle training start event."""
pass
def on_val_start(validator):
"""Handle validation start event."""
pass
def on_predict_start(predictor):
"""Handle prediction start event."""
pass
def on_export_start(exporter):
"""Handle model export start event."""
pass
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_pretrain_routine_end": on_pretrain_routine_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_model_save": on_model_save,
"on_train_end": on_train_end,
"on_train_start": on_train_start,
"on_val_start": on_val_start,
"on_predict_start": on_predict_start,
"on_export_start": on_export_start,
}
if SETTINGS.get("platform", False) is True # disabled for debugging
else {}
)

View File

@@ -0,0 +1,43 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import SETTINGS
try:
assert SETTINGS["raytune"] is True # verify integration is enabled
import ray
from ray import tune
from ray.air import session
except (ImportError, AssertionError):
tune = None
def on_fit_epoch_end(trainer):
"""
Report training metrics to Ray Tune at epoch end when a Ray session is active.
Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.
Examples:
>>> # Called automatically by the Ultralytics training loop
>>> on_fit_epoch_end(trainer)
References:
Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html
"""
if ray.train._internal.session.get_session(): # check if Ray Tune session is active
metrics = trainer.metrics
session.report({**metrics, **{"epoch": trainer.epoch + 1}})
callbacks = (
{
"on_fit_epoch_end": on_fit_epoch_end,
}
if tune
else {}
)

View File

@@ -0,0 +1,131 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["tensorboard"] is True # verify integration is enabled
WRITER = None # TensorBoard SummaryWriter instance
PREFIX = colorstr("TensorBoard: ")
# Imports below only required if TensorBoard enabled
import warnings
from copy import deepcopy
import torch
from torch.utils.tensorboard import SummaryWriter
except (ImportError, AssertionError, TypeError, AttributeError):
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
SummaryWriter = None
def _log_scalars(scalars: dict, step: int = 0) -> None:
"""
Log scalar values to TensorBoard.
Args:
scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
corresponding scalar values.
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
Examples:
Log training metrics
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
>>> _log_scalars(metrics, step=100)
"""
if WRITER:
for k, v in scalars.items():
WRITER.add_scalar(k, v, step)
def _log_tensorboard_graph(trainer) -> None:
"""
Log model graph to TensorBoard.
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
approach for models like RTDETR that may require special handling.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.
Must have attributes model and args with imgsz.
Notes:
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
model architectures.
"""
# Input image
imgsz = trainer.args.imgsz
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
p = next(trainer.model.parameters()) # for device, type
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
# Try simple method first (YOLO)
try:
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
return
except Exception:
# Fallback to TorchScript export steps (RTDETR)
try:
model = deepcopy(torch_utils.unwrap_model(trainer.model))
model.eval()
model = model.fuse(verbose=False)
for m in model.modules():
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
m.export = True
m.format = "torchscript"
model(im) # dry run
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
except Exception as e:
LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure {e}")
def on_pretrain_routine_start(trainer) -> None:
"""Initialize TensorBoard logging with SummaryWriter."""
if SummaryWriter:
try:
global WRITER
WRITER = SummaryWriter(str(trainer.save_dir))
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
except Exception as e:
LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
def on_train_start(trainer) -> None:
"""Log TensorBoard graph."""
if WRITER:
_log_tensorboard_graph(trainer)
def on_train_epoch_end(trainer) -> None:
"""Log scalar statistics at the end of a training epoch."""
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
_log_scalars(trainer.lr, trainer.epoch + 1)
def on_fit_epoch_end(trainer) -> None:
"""Log epoch metrics at end of training epoch."""
_log_scalars(trainer.metrics, trainer.epoch + 1)
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_start": on_train_start,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_epoch_end": on_train_epoch_end,
}
if SummaryWriter
else {}
)

View File

@@ -0,0 +1,191 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import SETTINGS, TESTS_RUNNING
from ultralytics.utils.torch_utils import model_info_for_loggers
try:
assert not TESTS_RUNNING # do not log pytest
assert SETTINGS["wandb"] is True # verify integration is enabled
import wandb as wb
assert hasattr(wb, "__version__") # verify package is not directory
_processed_plots = {}
except (ImportError, AssertionError):
wb = None
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
"""
Create and log a custom metric visualization to wandb.plot.pr_curve.
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
different classes.
Args:
x (list): Values for the x-axis; expected to have length N.
y (list): Corresponding values for the y-axis; also expected to have length N.
classes (list): Labels identifying the class of each point; length N.
title (str, optional): Title for the plot.
x_title (str, optional): Label for the x-axis.
y_title (str, optional): Label for the y-axis.
Returns:
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
"""
import polars as pl # scope for faster 'import ultralytics'
import polars.selectors as cs
df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
data = df.select(["class", "y", "x"]).rows()
fields = {"x": "x", "y": "y", "class": "class"}
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
return wb.plot_table(
"wandb/area-under-curve/v0",
wb.Table(data=data, columns=["class", "y", "x"]),
fields=fields,
string_fields=string_fields,
)
def _plot_curve(
x,
y,
names=None,
id="precision-recall",
title="Precision Recall Curve",
x_title="Recall",
y_title="Precision",
num_x=100,
only_mean=False,
):
"""
Log a metric curve visualization.
This function generates a metric curve based on input data and logs the visualization to wandb.
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
Args:
x (np.ndarray): Data points for the x-axis with length N.
y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
names (list, optional): Names of the classes corresponding to the y-axis data; length C.
id (str, optional): Unique identifier for the logged data in wandb.
title (str, optional): Title for the visualization plot.
x_title (str, optional): Label for the x-axis.
y_title (str, optional): Label for the y-axis.
num_x (int, optional): Number of interpolated data points for visualization.
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
Notes:
The function leverages the '_custom_table' function to generate the actual visualization.
"""
import numpy as np
# Create new x
if names is None:
names = []
x_new = np.linspace(x[0], x[-1], num_x).round(5)
# Create arrays for logging
x_log = x_new.tolist()
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
if only_mean:
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
else:
classes = ["mean"] * len(x_log)
for i, yi in enumerate(y):
x_log.extend(x_new) # add new x
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
classes.extend([names[i]] * len(x_new)) # add class names
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
def _log_plots(plots, step):
"""
Log plots to WandB at a specific step if they haven't been logged already.
This function checks each plot in the input dictionary against previously processed plots and logs
new or updated plots to WandB at the specified step.
Args:
plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
containing plot metadata including timestamps.
step (int): The step/epoch at which to log the plots in the WandB run.
Notes:
The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
Plots are identified by their stem name (filename without extension).
Each plot is logged as a WandB Image object.
"""
for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
timestamp = params["timestamp"]
if _processed_plots.get(name) != timestamp:
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
_processed_plots[name] = timestamp
def on_pretrain_routine_start(trainer):
"""Initialize and start wandb project if module is present."""
if not wb.run:
wb.init(
project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
name=str(trainer.args.name).replace("/", "-"),
config=vars(trainer.args),
)
def on_fit_epoch_end(trainer):
"""Log training metrics and model information at the end of an epoch."""
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
if trainer.epoch == 0:
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
def on_train_epoch_end(trainer):
"""Log metrics and save images at the end of each training epoch."""
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
wb.run.log(trainer.lr, step=trainer.epoch + 1)
if trainer.epoch == 1:
_log_plots(trainer.plots, step=trainer.epoch + 1)
def on_train_end(trainer):
"""Save the best model as an artifact and log final plots at the end of training."""
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
_log_plots(trainer.plots, step=trainer.epoch + 1)
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
if trainer.best.exists():
art.add_file(trainer.best)
wb.run.log_artifact(art, aliases=["best"])
# Check if we actually have plots to save
if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
x, y, x_title, y_title = curve_values
_plot_curve(
x,
y,
names=list(trainer.validator.metrics.names.values()),
id=f"curves/{curve_name}",
title=curve_name,
x_title=x_title,
y_title=y_title,
)
wb.run.finish() # required or run continues on dashboard
callbacks = (
{
"on_pretrain_routine_start": on_pretrain_routine_start,
"on_train_epoch_end": on_train_epoch_end,
"on_fit_epoch_end": on_fit_epoch_end,
"on_train_end": on_train_end,
}
if wb
else {}
)

964
ultralytics/utils/checks.py Normal file
View File

@@ -0,0 +1,964 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import functools
import glob
import inspect
import math
import os
import platform
import re
import shutil
import subprocess
import time
from importlib import metadata
from pathlib import Path
from types import SimpleNamespace
import cv2
import numpy as np
import torch
from ultralytics.utils import (
ARM64,
ASSETS,
AUTOINSTALL,
GIT,
IS_COLAB,
IS_JETSON,
IS_KAGGLE,
IS_PIP_PACKAGE,
LINUX,
LOGGER,
MACOS,
ONLINE,
PYTHON_VERSION,
RKNN_CHIPS,
ROOT,
TORCH_VERSION,
TORCHVISION_VERSION,
USER_CONFIG_DIR,
WINDOWS,
Retry,
ThreadingLocked,
TryExcept,
clean_url,
colorstr,
downloads,
is_github_action_running,
url2file,
)
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
"""
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
Args:
file_path (Path): Path to the requirements.txt file.
package (str, optional): Python package to use instead of requirements.txt file.
Returns:
requirements (list[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and
`specifier` attributes.
Examples:
>>> from ultralytics.utils.checks import parse_requirements
>>> parse_requirements(package="ultralytics")
"""
if package:
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
else:
requires = Path(file_path).read_text().splitlines()
requirements = []
for line in requires:
line = line.strip()
if line and not line.startswith("#"):
line = line.partition("#")[0].strip() # ignore inline comments
if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line):
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
return requirements
@functools.lru_cache
def parse_version(version="0.0.0") -> tuple:
"""
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
Args:
version (str): Version string, i.e. '2.0.1+cpu'
Returns:
(tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
"""
try:
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
except Exception as e:
LOGGER.warning(f"failure for parse_version({version}), returning (0, 0, 0): {e}")
return 0, 0, 0
def is_ascii(s) -> bool:
"""
Check if a string is composed of only ASCII characters.
Args:
s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).
Returns:
(bool): True if the string is composed only of ASCII characters, False otherwise.
"""
return all(ord(c) < 128 for c in str(s))
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
"""
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
Args:
imgsz (int | list[int]): Image size.
stride (int): Stride value.
min_dim (int): Minimum number of dimensions.
max_dim (int): Maximum number of dimensions.
floor (int): Minimum allowed value for image size.
Returns:
(list[int] | int): Updated image size.
"""
# Convert stride to integer if it is a tensor
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
# Convert image size to list if it is an integer
if isinstance(imgsz, int):
imgsz = [imgsz]
elif isinstance(imgsz, (list, tuple)):
imgsz = list(imgsz)
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
else:
raise TypeError(
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
)
# Apply max_dim
if len(imgsz) > max_dim:
msg = (
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
)
if max_dim != 1:
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
LOGGER.warning(f"updating to 'imgsz={max(imgsz)}'. {msg}")
imgsz = [max(imgsz)]
# Make image size a multiple of the stride
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
# Print warning message if image size was updated
if sz != imgsz:
LOGGER.warning(f"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
# Add missing dimensions if necessary
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
return sz
@functools.lru_cache
def check_uv():
"""Check if uv package manager is installed and can run successfully."""
try:
return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0
except FileNotFoundError:
return False
@functools.lru_cache
def check_version(
current: str = "0.0.0",
required: str = "0.0.0",
name: str = "version",
hard: bool = False,
verbose: bool = False,
msg: str = "",
) -> bool:
"""
Check current version against the required version or range.
Args:
current (str): Current version or package name to get version from.
required (str): Required version or range (in pip-style format).
name (str): Name to be used in warning message.
hard (bool): If True, raise an AssertionError if the requirement is not met.
verbose (bool): If True, print warning message if requirement is not met.
msg (str): Extra message to display if verbose.
Returns:
(bool): True if requirement is met, False otherwise.
Examples:
Check if current version is exactly 22.04
>>> check_version(current="22.04", required="==22.04")
Check if current version is greater than or equal to 22.04
>>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
Check if current version is less than or equal to 22.04
>>> check_version(current="22.04", required="<=22.04")
Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
>>> check_version(current="21.10", required=">20.04,<22.04")
"""
if not current: # if current is '' or None
LOGGER.warning(f"invalid check_version({current}, {required}) requested, please check values.")
return True
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
try:
name = current # assigned package name to 'name' arg
current = metadata.version(current) # get version string from package name
except metadata.PackageNotFoundError as e:
if hard:
raise ModuleNotFoundError(f"{current} package is required but not installed") from e
else:
return False
if not required: # if required is '' or None
return True
if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"'
(WINDOWS and "win32" not in required)
or (LINUX and "linux" not in required)
or (MACOS and "macos" not in required and "darwin" not in required)
):
return True
op = ""
version = ""
result = True
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
for r in required.strip(",").split(","):
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
if not op:
op = ">=" # assume >= if no op passed
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
if op == "==" and c != v:
result = False
elif op == "!=" and c == v:
result = False
elif op == ">=" and not (c >= v):
result = False
elif op == "<=" and not (c <= v):
result = False
elif op == ">" and not (c > v):
result = False
elif op == "<" and not (c < v):
result = False
if not result:
warning = f"{name}{required} is required, but {name}=={current} is currently installed {msg}"
if hard:
raise ModuleNotFoundError(warning) # assert version requirements met
if verbose:
LOGGER.warning(warning)
return result
def check_latest_pypi_version(package_name="ultralytics"):
"""
Return the latest version of a PyPI package without downloading or installing it.
Args:
package_name (str): The name of the package to find the latest version for.
Returns:
(str): The latest version of the package.
"""
import requests # scoped as slow import
try:
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
if response.status_code == 200:
return response.json()["info"]["version"]
except Exception:
return None
def check_pip_update_available():
"""
Check if a new version of the ultralytics package is available on PyPI.
Returns:
(bool): True if an update is available, False otherwise.
"""
if ONLINE and IS_PIP_PACKAGE:
try:
from ultralytics import __version__
latest = check_latest_pypi_version()
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
LOGGER.info(
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
f"Update with 'pip install -U ultralytics'"
)
return True
except Exception:
pass
return False
@ThreadingLocked()
@functools.lru_cache
def check_font(font="Arial.ttf"):
"""
Find font locally or download to user's configuration directory if it does not already exist.
Args:
font (str): Path or name of font.
Returns:
(Path): Resolved font file path.
"""
from matplotlib import font_manager # scope for faster 'import ultralytics'
# Check USER_CONFIG_DIR
name = Path(font).name
file = USER_CONFIG_DIR / name
if file.exists():
return file
# Check system fonts
matches = [s for s in font_manager.findSystemFonts() if font in s]
if any(matches):
return matches[0]
# Download to USER_CONFIG_DIR if missing
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}"
if downloads.is_url(url, check=True):
downloads.safe_download(url=url, file=file)
return file
def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool:
"""
Check current python version against the required minimum version.
Args:
minimum (str): Required minimum version of python.
hard (bool): If True, raise an AssertionError if the requirement is not met.
verbose (bool): If True, print warning message if requirement is not met.
Returns:
(bool): Whether the installed Python version meets the minimum constraints.
"""
return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
@TryExcept()
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
"""
Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.
Args:
requirements (Path | str | list[str] | tuple[str]): Path to a requirements.txt file, a single package
requirement as a string, or a list of package requirements as strings.
exclude (tuple): Tuple of package names to exclude from checking.
install (bool): If True, attempt to auto-update packages that don't meet requirements.
cmds (str): Additional commands to pass to the pip install command when auto-updating.
Examples:
>>> from ultralytics.utils.checks import check_requirements
Check a requirements.txt file
>>> check_requirements("path/to/requirements.txt")
Check a single package
>>> check_requirements("ultralytics>=8.0.0")
Check multiple packages
>>> check_requirements(["numpy", "ultralytics>=8.0.0"])
"""
prefix = colorstr("red", "bold", "requirements:")
if isinstance(requirements, Path): # requirements.txt file
file = requirements.resolve()
assert file.exists(), f"{prefix} {file} not found, check failed."
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
elif isinstance(requirements, str):
requirements = [requirements]
pkgs = []
for r in requirements:
r_stripped = r.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
name, required = match[1], match[2].strip() if match[2] else ""
try:
assert check_version(metadata.version(name), required) # exception if requirements not met
except (AssertionError, metadata.PackageNotFoundError):
pkgs.append(r)
@Retry(times=2, delay=1)
def attempt_install(packages, commands, use_uv):
"""Attempt package installation with uv if available, falling back to pip."""
if use_uv:
base = (
f"uv pip install --no-cache-dir {packages} {commands} "
f"--index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
)
try:
return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE, text=True)
except subprocess.CalledProcessError as e:
if e.stderr and "No virtual environment found" in e.stderr:
return subprocess.check_output(
base.replace("uv pip install", "uv pip install --system"),
shell=True,
stderr=subprocess.PIPE,
text=True,
)
raise
return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True, text=True)
s = " ".join(f'"{x}"' for x in pkgs) # console string
if s:
if install and AUTOINSTALL: # check environment variable
# Note uv fails on arm64 macOS and Raspberry Pi runners
n = len(pkgs) # number of packages updates
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
try:
t = time.time()
assert ONLINE, "AutoUpdate skipped (offline)"
LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv()))
dt = time.time() - t
LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s")
LOGGER.warning(
f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
)
except Exception as e:
LOGGER.warning(f"{prefix}{e}")
return False
else:
return False
return True
def check_torchvision():
"""
Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
to the compatibility table based on: https://github.com/pytorch/vision#installation.
"""
compatibility_table = {
"2.9": ["0.24"],
"2.8": ["0.23"],
"2.7": ["0.22"],
"2.6": ["0.21"],
"2.5": ["0.20"],
"2.4": ["0.19"],
"2.3": ["0.18"],
"2.2": ["0.17"],
"2.1": ["0.16"],
"2.0": ["0.15"],
"1.13": ["0.14"],
"1.12": ["0.13"],
}
# Check major and minor versions
v_torch = ".".join(TORCH_VERSION.split("+", 1)[0].split(".")[:2])
if v_torch in compatibility_table:
compatible_versions = compatibility_table[v_torch]
v_torchvision = ".".join(TORCHVISION_VERSION.split("+", 1)[0].split(".")[:2])
if all(v_torchvision != v for v in compatible_versions):
LOGGER.warning(
f"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
"'pip install -U torch torchvision' to update both.\n"
"For a full compatibility table see https://github.com/pytorch/vision#installation"
)
def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
"""
Check file(s) for acceptable suffix.
Args:
file (str | list[str]): File or list of files to check.
suffix (str | tuple): Acceptable suffix or tuple of suffixes.
msg (str): Additional message to display in case of error.
"""
if file and suffix:
if isinstance(suffix, str):
suffix = {suffix}
for f in file if isinstance(file, (list, tuple)) else [file]:
if s := str(f).rpartition(".")[-1].lower().strip(): # file suffix
assert f".{s}" in suffix, f"{msg}{f} acceptable suffix is {suffix}, not .{s}"
def check_yolov5u_filename(file: str, verbose: bool = True):
"""
Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
Args:
file (str): Filename to check and potentially update.
verbose (bool): Whether to print information about the replacement.
Returns:
(str): Updated filename.
"""
if "yolov3" in file or "yolov5" in file:
if "u.yaml" in file:
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
elif ".pt" in file and "u" not in file:
original_file = file
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
if file != original_file and verbose:
LOGGER.info(
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
)
return file
def check_model_file_from_stem(model="yolo11n"):
"""
Return a model filename from a valid model stem.
Args:
model (str): Model stem to check.
Returns:
(str | Path): Model filename with appropriate suffix.
"""
path = Path(model)
if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS:
return path.with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt
return model
def check_file(file, suffix="", download=True, download_dir=".", hard=True):
"""
Search/download file (if necessary), check suffix (if provided), and return path.
Args:
file (str): File name or path.
suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.
download (bool): Whether to download the file if it doesn't exist locally.
download_dir (str): Directory to download the file to.
hard (bool): Whether to raise an error if the file is not found.
Returns:
(str): Path to the file.
"""
check_suffix(file, suffix) # optional
file = str(file).strip() # convert to string and strip spaces
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
if (
not file
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
or file.lower().startswith("grpc://")
): # file exists or gRPC Triton images
return file
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
url = file # warning: Pathlib turns :// -> :/
file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
if file.exists():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
downloads.safe_download(url=url, file=file, unzip=False)
return str(file)
else: # search
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
if not files and hard:
raise FileNotFoundError(f"'{file}' does not exist")
elif len(files) > 1 and hard:
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
return files[0] if len(files) else [] # return file
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
"""
Search/download YAML file (if necessary) and return path, checking suffix.
Args:
file (str | Path): File name or path.
suffix (tuple): Tuple of acceptable YAML file suffixes.
hard (bool): Whether to raise an error if the file is not found or multiple files are found.
Returns:
(str): Path to the YAML file.
"""
return check_file(file, suffix, hard=hard)
def check_is_path_safe(basedir, path):
"""
Check if the resolved path is under the intended directory to prevent path traversal.
Args:
basedir (Path | str): The intended directory.
path (Path | str): The path to check.
Returns:
(bool): True if the path is safe, False otherwise.
"""
base_dir_resolved = Path(basedir).resolve()
path_resolved = Path(path).resolve()
return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
@functools.lru_cache
def check_imshow(warn=False):
"""
Check if environment supports image displays.
Args:
warn (bool): Whether to warn if environment doesn't support image displays.
Returns:
(bool): True if environment supports image displays, False otherwise.
"""
try:
if LINUX:
assert not IS_COLAB and not IS_KAGGLE
assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
cv2.waitKey(1)
cv2.destroyAllWindows()
cv2.waitKey(1)
return True
except Exception as e:
if warn:
LOGGER.warning(f"Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
return False
def check_yolo(verbose=True, device=""):
"""
Return a human-readable YOLO software and hardware summary.
Args:
verbose (bool): Whether to print verbose information.
device (str | torch.device): Device to use for YOLO.
"""
import psutil # scoped as slow import
from ultralytics.utils.torch_utils import select_device
if IS_COLAB:
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
if verbose:
# System info
gib = 1 << 30 # bytes per GiB
ram = psutil.virtual_memory().total
total, used, free = shutil.disk_usage("/")
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
try:
from IPython import display
display.clear_output() # clear display if notebook
except ImportError:
pass
else:
s = ""
if GIT.is_repo:
check_multiple_install() # check conflicting installation if using local clone
select_device(device=device, newline=False)
LOGGER.info(f"Setup complete ✅ {s}")
def collect_system_info():
"""
Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
Returns:
(dict): Dictionary containing system information.
"""
import psutil # scoped as slow import
from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
gib = 1 << 30 # bytes per GiB
cuda = torch.cuda.is_available()
check_yolo()
total, used, free = shutil.disk_usage("/")
info_dict = {
"OS": platform.platform(),
"Environment": ENVIRONMENT,
"Python": PYTHON_VERSION,
"Install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
"Path": str(ROOT),
"RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
"Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
"CPU": get_cpu_info(),
"CPU count": os.cpu_count(),
"GPU": get_gpu_info(index=0) if cuda else None,
"GPU count": torch.cuda.device_count() if cuda else None,
"CUDA": torch.version.cuda if cuda else None,
}
LOGGER.info("\n" + "\n".join(f"{k:<23}{v}" for k, v in info_dict.items()) + "\n")
package_info = {}
for r in parse_requirements(package="ultralytics"):
try:
current = metadata.version(r.name)
is_met = "" if check_version(current, str(r.specifier), name=r.name, hard=True) else ""
except metadata.PackageNotFoundError:
current = "(not installed)"
is_met = ""
package_info[r.name] = f"{is_met}{current}{r.specifier}"
LOGGER.info(f"{r.name:<23}{package_info[r.name]}")
info_dict["Package Info"] = package_info
if is_github_action_running():
github_info = {
"RUNNER_OS": os.getenv("RUNNER_OS"),
"GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
"GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
"GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
"GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
"GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
}
LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
info_dict["GitHub Info"] = github_info
return info_dict
def check_amp(model):
"""
Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
results, so AMP will be disabled during training.
Args:
model (torch.nn.Module): A YOLO model instance.
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
Examples:
>>> from ultralytics import YOLO
>>> from ultralytics.utils.checks import check_amp
>>> model = YOLO("yolo11n.pt").model.cuda()
>>> check_amp(model)
"""
from ultralytics.utils.torch_utils import autocast
device = next(model.parameters()).device # get model device
prefix = colorstr("AMP: ")
if device.type in {"cpu", "mps"}:
return False # AMP only used on CUDA devices
else:
# GPUs that have issues with AMP
pattern = re.compile(
r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE
)
gpu = torch.cuda.get_device_name(device)
if bool(pattern.search(gpu)):
LOGGER.warning(
f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause "
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
)
return False
def amp_allclose(m, im):
"""All close FP32 vs AMP results."""
batch = [im] * 8
imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64
a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference
with autocast(enabled=True):
b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference
del m
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
im = ASSETS / "bus.jpg" # image to check
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO("yolo11n.pt"), im)
LOGGER.info(f"{prefix}checks passed ✅")
except ConnectionError:
LOGGER.warning(f"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}")
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f"{prefix}checks skipped. "
f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}"
)
except AssertionError:
LOGGER.error(
f"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to "
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
)
return False
return True
def check_multiple_install():
"""Check if there are multiple Ultralytics installations."""
import sys
try:
result = subprocess.run([sys.executable, "-m", "pip", "show", "ultralytics"], capture_output=True, text=True)
install_msg = (
f"Install your local copy in editable mode with 'pip install -e {ROOT.parent}' to avoid "
"issues. See https://docs.ultralytics.com/quickstart/"
)
if result.returncode != 0:
if "not found" in result.stderr.lower(): # Package not pip-installed but locally imported
LOGGER.warning(f"Ultralytics not found via pip but importing from: {ROOT}. {install_msg}")
return
yolo_path = (Path(re.findall(r"location:\s+(.+)", result.stdout, flags=re.I)[-1]) / "ultralytics").resolve()
if not yolo_path.samefile(ROOT.resolve()):
LOGGER.warning(
f"Multiple Ultralytics installations detected. The `yolo` command uses: {yolo_path}, "
f"but current session imports from: {ROOT}. This may cause version conflicts. {install_msg}"
)
except Exception:
return
def print_args(args: dict | None = None, show_file=True, show_func=False):
"""
Print function arguments (optional args dict).
Args:
args (dict, optional): Arguments to print.
show_file (bool): Whether to show the file name.
show_func (bool): Whether to show the function name.
"""
def strip_auth(v):
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in sorted(args.items())))
def cuda_device_count() -> int:
"""
Get the number of NVIDIA GPUs available in the environment.
Returns:
(int): The number of NVIDIA GPUs available.
"""
if IS_JETSON:
# NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead
return torch.cuda.device_count()
else:
try:
# Run the nvidia-smi command and capture its output
output = subprocess.check_output(
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
)
# Take the first line and strip any leading/trailing white space
first_line = output.strip().split("\n", 1)[0]
return int(first_line)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
return 0
def cuda_is_available() -> bool:
"""
Check if CUDA is available in the environment.
Returns:
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
"""
return cuda_device_count() > 0
def is_rockchip():
"""
Check if the current environment is running on a Rockchip SoC.
Returns:
(bool): True if running on a Rockchip SoC, False otherwise.
"""
if LINUX and ARM64:
try:
with open("/proc/device-tree/compatible") as f:
dev_str = f.read()
*_, soc = dev_str.split(",")
if soc.replace("\x00", "") in RKNN_CHIPS:
return True
except OSError:
return False
else:
return False
def is_intel():
"""
Check if the system has Intel hardware (CPU or GPU).
Returns:
(bool): True if Intel hardware is detected, False otherwise.
"""
from ultralytics.utils.torch_utils import get_cpu_info
# Check CPU
if "intel" in get_cpu_info().lower():
return True
# Check GPU via xpu-smi
try:
result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5)
return "intel" in result.stdout.lower()
except Exception: # broad clause to capture all Intel GPU exception types
return False
def is_sudo_available() -> bool:
"""
Check if the sudo command is available in the environment.
Returns:
(bool): True if the sudo command is available, False otherwise.
"""
if WINDOWS:
return False
cmd = "sudo --version"
return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
# Run checks and define constants
check_python("3.8", hard=False, verbose=True) # check python version
check_torchvision() # check torch-torchvision compatibility
# Define constants
IS_PYTHON_3_8 = PYTHON_VERSION.startswith("3.8")
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13")
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
IS_PYTHON_MINIMUM_3_12 = check_python("3.12", hard=False)

90
ultralytics/utils/cpu.py Normal file
View File

@@ -0,0 +1,90 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import platform
import re
import subprocess
import sys
from pathlib import Path
class CPUInfo:
"""
Provide cross-platform CPU brand and model information.
Query platform-specific sources to retrieve a human-readable CPU descriptor and normalize it for consistent
presentation across macOS, Linux, and Windows. If platform-specific probing fails, generic platform identifiers are
used to ensure a stable string is always returned.
Methods:
name: Return the normalized CPU name using platform-specific sources with robust fallbacks.
_clean: Normalize and prettify common vendor brand strings and frequency patterns.
__str__: Return the normalized CPU name for string contexts.
Examples:
>>> CPUInfo.name()
'Apple M4 Pro'
>>> str(CPUInfo())
'Intel Core i7-9750H 2.60GHz'
"""
@staticmethod
def name() -> str:
"""Return a normalized CPU model string from platform-specific sources."""
try:
if sys.platform == "darwin":
# Query macOS sysctl for the CPU brand string
s = subprocess.run(
["sysctl", "-n", "machdep.cpu.brand_string"], capture_output=True, text=True
).stdout.strip()
if s:
return CPUInfo._clean(s)
elif sys.platform.startswith("linux"):
# Parse /proc/cpuinfo for the first "model name" entry
p = Path("/proc/cpuinfo")
if p.exists():
for line in p.read_text(errors="ignore").splitlines():
if "model name" in line:
return CPUInfo._clean(line.split(":", 1)[1])
elif sys.platform.startswith("win"):
try:
import winreg as wr
with wr.OpenKey(wr.HKEY_LOCAL_MACHINE, r"HARDWARE\DESCRIPTION\System\CentralProcessor\0") as k:
val, _ = wr.QueryValueEx(k, "ProcessorNameString")
if val:
return CPUInfo._clean(val)
except Exception:
# Fall through to generic platform fallbacks on Windows registry access failure
pass
# Generic platform fallbacks
s = platform.processor() or getattr(platform.uname(), "processor", "") or platform.machine()
return CPUInfo._clean(s or "Unknown CPU")
except Exception:
# Ensure a string is always returned even on unexpected failures
s = platform.processor() or platform.machine() or ""
return CPUInfo._clean(s or "Unknown CPU")
@staticmethod
def _clean(s: str) -> str:
"""Normalize and prettify a raw CPU descriptor string."""
s = re.sub(r"\s+", " ", s.strip())
s = s.replace("(TM)", "").replace("(tm)", "").replace("(R)", "").replace("(r)", "").strip()
# Normalize common Intel pattern to 'Model Freq'
m = re.search(r"(Intel.*?i\d[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
if m:
return f"{m.group(1)} {m.group(2)}"
# Normalize common AMD Ryzen pattern to 'Model Freq'
m = re.search(r"(AMD.*?Ryzen.*?[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
if m:
return f"{m.group(1)} {m.group(2)}"
return s
def __str__(self) -> str:
"""Return the normalized CPU name."""
return self.name()
if __name__ == "__main__":
print(CPUInfo.name())

127
ultralytics/utils/dist.py Normal file
View File

@@ -0,0 +1,127 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import os
import shutil
import sys
import tempfile
from . import USER_CONFIG_DIR
from .torch_utils import TORCH_1_9
def find_free_network_port() -> int:
"""
Find a free port on localhost.
It is useful in single-node training when we don't want to connect to a real main node but have to set the
`MASTER_PORT` environment variable.
Returns:
(int): The available network port number.
"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1] # port
def generate_ddp_file(trainer):
"""
Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
This function creates a temporary Python file that enables distributed training across multiple GPUs.
The file contains the necessary configuration to initialize the trainer in a distributed environment.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.
Must have args attribute and be a class instance.
Returns:
(str): Path to the generated temporary DDP file.
Notes:
The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:
- Trainer class import
- Configuration overrides from the trainer arguments
- Model path configuration
- Training initialization code
"""
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
content = f"""
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
overrides = {vars(trainer.args)}
if __name__ == "__main__":
from {module} import {name}
from ultralytics.utils import DEFAULT_CFG_DICT
cfg = DEFAULT_CFG_DICT.copy()
cfg.update(save_dir='') # handle the extra key 'save_dir'
trainer = {name}(cfg=cfg, overrides=overrides)
trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
results = trainer.train()
"""
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
with tempfile.NamedTemporaryFile(
prefix="_temp_",
suffix=f"{id(trainer)}.py",
mode="w+",
encoding="utf-8",
dir=USER_CONFIG_DIR / "DDP",
delete=False,
) as file:
file.write(content)
return file.name
def generate_ddp_command(trainer):
"""
Generate command for distributed training.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.
Returns:
cmd (list[str]): The command to execute for distributed training.
file (str): Path to the temporary file created for DDP training.
"""
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
if not trainer.resume:
shutil.rmtree(trainer.save_dir) # remove the save_dir
file = generate_ddp_file(trainer)
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
port = find_free_network_port()
cmd = [
sys.executable,
"-m",
dist_cmd,
"--nproc_per_node",
f"{trainer.world_size}",
"--master_port",
f"{port}",
file,
]
return cmd, file
def ddp_cleanup(trainer, file):
"""
Delete temporary file if created during distributed data parallel (DDP) training.
This function checks if the provided file contains the trainer's ID in its name, indicating it was created
as a temporary file for DDP training, and deletes it if so.
Args:
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.
file (str): Path to the file that might need to be deleted.
Examples:
>>> trainer = YOLOTrainer()
>>> file = "/tmp/ddp_temp_123456789.py"
>>> ddp_cleanup(trainer, file)
"""
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
os.remove(file)

View File

@@ -0,0 +1,541 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import re
import shutil
import subprocess
from itertools import repeat
from multiprocessing.pool import ThreadPool
from pathlib import Path
from urllib import parse, request
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
GITHUB_ASSETS_REPO = "ultralytics/assets"
GITHUB_ASSETS_NAMES = frozenset(
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
+ [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
+ [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
+ [f"yolov8{k}-world.pt" for k in "smlx"]
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
+ [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
+ [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
+ [f"yolov9{k}.pt" for k in "tsmce"]
+ [f"yolov10{k}.pt" for k in "nsmblx"]
+ [f"yolo_nas_{k}.pt" for k in "sml"]
+ [f"sam_{k}.pt" for k in "bl"]
+ [f"sam2_{k}.pt" for k in "blst"]
+ [f"sam2.1_{k}.pt" for k in "blst"]
+ [f"FastSAM-{k}.pt" for k in "sx"]
+ [f"rtdetr-{k}.pt" for k in "lx"]
+ [
"mobile_sam.pt",
"mobileclip_blt.ts",
"yolo11n-grayscale.pt",
"calibration_image_sample_data_20x128x128x3_float32.npy.zip",
]
)
GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES)
def is_url(url: str | Path, check: bool = False) -> bool:
"""
Validate if the given string is a URL and optionally check if the URL exists online.
Args:
url (str): The string to be validated as a URL.
check (bool, optional): If True, performs an additional check to see if the URL exists online.
Returns:
(bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.
Examples:
>>> valid = is_url("https://www.example.com")
>>> valid_and_exists = is_url("https://www.example.com", check=True)
"""
try:
url = str(url)
result = parse.urlparse(url)
assert all([result.scheme, result.netloc]) # check if is url
if check:
with request.urlopen(url) as response:
return response.getcode() == 200 # check if exists online
return True
except Exception:
return False
def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_Store", "__MACOSX")) -> None:
"""
Delete all specified system files in a directory.
Args:
path (str | Path): The directory path where the files should be deleted.
files_to_delete (tuple): The files to be deleted.
Examples:
>>> from ultralytics.utils.downloads import delete_dsstore
>>> delete_dsstore("path/to/dir")
Notes:
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
are hidden system files and can cause issues when transferring files between different operating systems.
"""
for file in files_to_delete:
matches = list(Path(path).rglob(file))
LOGGER.info(f"Deleting {file} files: {matches}")
for f in matches:
f.unlink()
def zip_directory(
directory: str | Path,
compress: bool = True,
exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
progress: bool = True,
) -> Path:
"""
Zip the contents of a directory, excluding specified files.
The resulting zip file is named after the directory and placed alongside it.
Args:
directory (str | Path): The path to the directory to be zipped.
compress (bool): Whether to compress the files while zipping.
exclude (tuple, optional): A tuple of filename strings to be excluded.
progress (bool, optional): Whether to display a progress bar.
Returns:
(Path): The path to the resulting zip file.
Examples:
>>> from ultralytics.utils.downloads import zip_directory
>>> file = zip_directory("path/to/dir")
"""
from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
delete_dsstore(directory)
directory = Path(directory)
if not directory.is_dir():
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
# Zip with progress bar
files = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] # files to zip
zip_file = directory.with_suffix(".zip")
compression = ZIP_DEFLATED if compress else ZIP_STORED
with ZipFile(zip_file, "w", compression) as f:
for file in TQDM(files, desc=f"Zipping {directory} to {zip_file}...", unit="files", disable=not progress):
f.write(file, file.relative_to(directory))
return zip_file # return path to zip file
def unzip_file(
file: str | Path,
path: str | Path | None = None,
exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
exist_ok: bool = False,
progress: bool = True,
) -> Path:
"""
Unzip a *.zip file to the specified path, excluding specified files.
If the zipfile does not contain a single top-level directory, the function will create a new
directory with the same name as the zipfile (without the extension) to extract its contents.
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
Args:
file (str | Path): The path to the zipfile to be extracted.
path (str | Path, optional): The path to extract the zipfile to.
exclude (tuple, optional): A tuple of filename strings to be excluded.
exist_ok (bool, optional): Whether to overwrite existing contents if they exist.
progress (bool, optional): Whether to display a progress bar.
Returns:
(Path): The path to the directory where the zipfile was extracted.
Raises:
BadZipFile: If the provided file does not exist or is not a valid zipfile.
Examples:
>>> from ultralytics.utils.downloads import unzip_file
>>> directory = unzip_file("path/to/file.zip")
"""
from zipfile import BadZipFile, ZipFile, is_zipfile
if not (Path(file).exists() and is_zipfile(file)):
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
if path is None:
path = Path(file).parent # default path
# Unzip the file contents
with ZipFile(file) as zipObj:
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
top_level_dirs = {Path(f).parts[0] for f in files}
# Decide to unzip directly or unzip into a directory
unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/"))
if unzip_as_dir:
# Zip has 1 top-level directory
extract_path = path # i.e. ../datasets
path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/
else:
# Zip has multiple files at top level
path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/
# Check if destination directory already exists and contains files
if path.exists() and any(path.iterdir()) and not exist_ok:
# If it exists and is not empty, return the path without unzipping
LOGGER.warning(f"Skipping {file} unzip as destination directory {path} is not empty.")
return path
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="files", disable=not progress):
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
if ".." in Path(f).parts:
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
continue
zipObj.extract(f, extract_path)
return path # return unzip dir
def check_disk_space(
file_bytes: int,
path: str | Path = Path.cwd(),
sf: float = 1.5,
hard: bool = True,
) -> bool:
"""
Check if there is sufficient disk space to download and store a file.
Args:
file_bytes (int): The file size in bytes.
path (str | Path, optional): The path or drive to check the available free space on.
sf (float, optional): Safety factor, the multiplier for the required free space.
hard (bool, optional): Whether to throw an error or not on insufficient disk space.
Returns:
(bool): True if there is sufficient disk space, False otherwise.
"""
total, used, free = shutil.disk_usage(path) # bytes
if file_bytes * sf < free:
return True # sufficient space
# Insufficient space
text = (
f"Insufficient free disk space {free >> 30:.3f} GB < {int(file_bytes * sf) >> 30:.3f} GB required, "
f"Please free {int(file_bytes * sf - free) >> 30:.3f} GB additional disk space and try again."
)
if hard:
raise MemoryError(text)
LOGGER.warning(text)
return False
def get_google_drive_file_info(link: str) -> tuple[str, str | None]:
"""
Retrieve the direct download link and filename for a shareable Google Drive file link.
Args:
link (str): The shareable link of the Google Drive file.
Returns:
url (str): Direct download URL for the Google Drive file.
filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.
Examples:
>>> from ultralytics.utils.downloads import get_google_drive_file_info
>>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
>>> url, filename = get_google_drive_file_info(link)
"""
import requests # scoped as slow import
file_id = link.split("/d/")[1].split("/view", 1)[0]
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
filename = None
# Start session
with requests.Session() as session:
response = session.get(drive_url, stream=True)
if "quota exceeded" in str(response.content.lower()):
raise ConnectionError(
emojis(
f"❌ Google Drive file download quota exceeded. "
f"Please try again later or download this file manually at {link}."
)
)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
drive_url += f"&confirm={v}" # v is token
if cd := response.headers.get("content-disposition"):
filename = re.findall('filename="(.+)"', cd)[0]
return drive_url, filename
def safe_download(
url: str | Path,
file: str | Path | None = None,
dir: str | Path | None = None,
unzip: bool = True,
delete: bool = False,
curl: bool = False,
retry: int = 3,
min_bytes: float = 1e0,
exist_ok: bool = False,
progress: bool = True,
) -> Path | str:
"""
Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
robust partial download detection using Content-Length validation.
Args:
url (str): The URL of the file to be downloaded.
file (str, optional): The filename of the downloaded file.
If not provided, the file will be saved with the same name as the URL.
dir (str | Path, optional): The directory to save the downloaded file.
If not provided, the file will be saved in the current working directory.
unzip (bool, optional): Whether to unzip the downloaded file.
delete (bool, optional): Whether to delete the downloaded file after unzipping.
curl (bool, optional): Whether to use curl command line tool for downloading.
retry (int, optional): The number of times to retry the download in case of failure.
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
a successful download.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
progress (bool, optional): Whether to display a progress bar during the download.
Returns:
(Path | str): The path to the downloaded file or extracted directory.
Examples:
>>> from ultralytics.utils.downloads import safe_download
>>> link = "https://ultralytics.com/assets/bus.jpg"
>>> path = safe_download(link)
"""
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
if gdrive:
url, file = get_google_drive_file_info(url)
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
f = Path(url) # filename
elif not f.is_file(): # URL and file do not exist
uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url
"https://github.com/ultralytics/assets/releases/download/v0.0.0/",
"https://ultralytics.com/assets/", # assets alias
)
desc = f"Downloading {uri} to '{f}'"
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
curl_installed = shutil.which("curl")
for i in range(retry + 1):
try:
if (curl or i > 0) and curl_installed: # curl download with retry, continue
s = "sS" * (not progress) # silent
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
assert r == 0, f"Curl return value {r}"
expected_size = None # Can't get size with curl
else: # urllib download
with request.urlopen(url) as response:
expected_size = int(response.getheader("Content-Length", 0))
if i == 0 and expected_size > 1048576:
check_disk_space(expected_size, path=f.parent)
buffer_size = max(8192, min(1048576, expected_size // 1000)) if expected_size else 8192
with TQDM(
total=expected_size,
desc=desc,
disable=not progress,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
with open(f, "wb") as f_opened:
while True:
data = response.read(buffer_size)
if not data:
break
f_opened.write(data)
pbar.update(len(data))
if f.exists():
file_size = f.stat().st_size
if file_size > min_bytes:
# Check if download is complete (only if we have expected_size)
if expected_size and file_size != expected_size:
LOGGER.warning(
f"Partial download: {file_size}/{expected_size} bytes ({file_size / expected_size * 100:.1f}%)"
)
else:
break # success
f.unlink() # remove partial downloads
except MemoryError:
raise # Re-raise immediately - no point retrying if insufficient disk space
except Exception as e:
if i == 0 and not is_online():
raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e
elif i >= retry:
raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e
LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}...")
if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
from zipfile import is_zipfile
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
if is_zipfile(f):
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
elif f.suffix in {".tar", ".gz"}:
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
if delete:
f.unlink() # remove zip
return unzip_dir
return f
def get_github_assets(
repo: str = "ultralytics/assets",
version: str = "latest",
retry: bool = False,
) -> tuple[str, list[str]]:
"""
Retrieve the specified version's tag and assets from a GitHub repository.
If the version is not specified, the function fetches the latest release assets.
Args:
repo (str, optional): The GitHub repository in the format 'owner/repo'.
version (str, optional): The release version to fetch assets from.
retry (bool, optional): Flag to retry the request in case of a failure.
Returns:
tag (str): The release tag.
assets (list[str]): A list of asset names.
Examples:
>>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
"""
import requests # scoped as slow import
if version != "latest":
version = f"tags/{version}" # i.e. tags/v6.2
url = f"https://api.github.com/repos/{repo}/releases/{version}"
r = requests.get(url) # github api
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
r = requests.get(url) # try again
if r.status_code != 200:
LOGGER.warning(f"GitHub assets check failure for {url}: {r.status_code} {r.reason}")
return "", []
data = r.json()
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]
def attempt_download_asset(
file: str | Path,
repo: str = "ultralytics/assets",
release: str = "v8.3.0",
**kwargs,
) -> str:
"""
Attempt to download a file from GitHub release assets if it is not found locally.
Args:
file (str | Path): The filename or file path to be downloaded.
repo (str, optional): The GitHub repository in the format 'owner/repo'.
release (str, optional): The specific release version to be downloaded.
**kwargs (Any): Additional keyword arguments for the download process.
Returns:
(str): The path to the downloaded file.
Examples:
>>> file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest")
"""
from ultralytics.utils import SETTINGS # scoped for circular import
# YOLOv3/5u updates
file = str(file)
file = checks.check_yolov5u_filename(file)
file = Path(file.strip().replace("'", ""))
if file.exists():
return str(file)
elif (SETTINGS["weights_dir"] / file).exists():
return str(SETTINGS["weights_dir"] / file)
else:
# URL specified
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
download_url = f"https://github.com/{repo}/releases/download"
if str(file).startswith(("http:/", "https:/")): # download
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
if Path(file).is_file():
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
else:
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
else:
tag, assets = get_github_assets(repo, release)
if not assets:
tag, assets = get_github_assets(repo) # latest release
if name in assets:
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
return str(file)
def download(
url: str | list[str] | Path,
dir: Path = Path.cwd(),
unzip: bool = True,
delete: bool = False,
curl: bool = False,
threads: int = 1,
retry: int = 3,
exist_ok: bool = False,
) -> None:
"""
Download files from specified URLs to a given directory.
Supports concurrent downloads if multiple threads are specified.
Args:
url (str | list[str]): The URL or list of URLs of the files to be downloaded.
dir (Path, optional): The directory where the files will be saved.
unzip (bool, optional): Flag to unzip the files after downloading.
delete (bool, optional): Flag to delete the zip files after extraction.
curl (bool, optional): Flag to use curl for downloading.
threads (int, optional): Number of threads to use for concurrent downloads.
retry (int, optional): Number of retries in case of download failure.
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
Examples:
>>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
"""
dir = Path(dir)
dir.mkdir(parents=True, exist_ok=True) # make directory
urls = [url] if isinstance(url, (str, Path)) else url
if threads > 1:
LOGGER.info(f"Downloading {len(urls)} file(s) with {threads} threads to {dir}...")
with ThreadPool(threads) as pool:
pool.map(
lambda x: safe_download(
url=x[0],
dir=x[1],
unzip=unzip,
delete=delete,
curl=curl,
retry=retry,
exist_ok=exist_ok,
progress=True,
),
zip(urls, repeat(dir)),
)
pool.close()
pool.join()
else:
for u in urls:
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)

View File

@@ -0,0 +1,43 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.utils import emojis
class HUBModelError(Exception):
"""
Exception raised when a model cannot be found or retrieved from ultralytics HUB.
This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.
The error message is processed to include emojis for better user experience.
Attributes:
message (str): The error message displayed when the exception is raised.
Methods:
__init__: Initialize the HUBModelError with a custom message.
Examples:
>>> try:
... # Code that might fail to find a model
... raise HUBModelError("Custom model not found message")
... except HUBModelError as e:
... print(e) # Displays the emoji-enhanced error message
"""
def __init__(self, message: str = "Model not found. Please check model URL and try again."):
"""
Initialize a HUBModelError exception.
This exception is raised when a requested model is not found or cannot be retrieved from ultralytics HUB.
The message is processed to include emojis for better user experience.
Args:
message (str, optional): The error message to display when the exception is raised.
Examples:
>>> try:
... raise HUBModelError("Custom model error message")
... except HUBModelError as e:
... print(e)
"""
super().__init__(emojis(message))

115
ultralytics/utils/events.py Normal file
View File

@@ -0,0 +1,115 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import json
import random
import time
from pathlib import Path
from threading import Thread
from urllib.request import Request, urlopen
from ultralytics import SETTINGS, __version__
from ultralytics.utils import ARGV, ENVIRONMENT, GIT, IS_PIP_PACKAGE, ONLINE, PYTHON_VERSION, RANK, TESTS_RUNNING
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
from ultralytics.utils.torch_utils import get_cpu_info
def _post(url: str, data: dict, timeout: float = 5.0) -> None:
"""Send a one-shot JSON POST request."""
try:
body = json.dumps(data, separators=(",", ":")).encode() # compact JSON
req = Request(url, data=body, headers={"Content-Type": "application/json"})
urlopen(req, timeout=timeout).close()
except Exception:
pass
class Events:
"""
Collect and send anonymous usage analytics with rate-limiting.
Event collection and transmission are enabled when sync is enabled in settings, the current process is rank -1 or 0,
tests are not running, the environment is online, and the installation source is either pip or the official
Ultralytics GitHub repository.
Attributes:
url (str): Measurement Protocol endpoint for receiving anonymous events.
events (list[dict]): In-memory queue of event payloads awaiting transmission.
rate_limit (float): Minimum time in seconds between POST requests.
t (float): Timestamp of the last transmission in seconds since the epoch.
metadata (dict): Static metadata describing runtime, installation source, and environment.
enabled (bool): Flag indicating whether analytics collection is active.
Methods:
__init__: Initialize the event queue, rate limiter, and runtime metadata.
__call__: Queue an event and trigger a non-blocking send when the rate limit elapses.
"""
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
def __init__(self) -> None:
"""Initialize the Events instance with queue, rate limiter, and environment metadata."""
self.events = [] # pending events
self.rate_limit = 30.0 # rate limit (seconds)
self.t = 0.0 # last send timestamp (seconds)
self.metadata = {
"cli": Path(ARGV[0]).name == "yolo",
"install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
"python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
"CPU": get_cpu_info(),
# "GPU": get_gpu_info(index=0) if cuda else None,
"version": __version__,
"env": ENVIRONMENT,
"session_id": round(random.random() * 1e15),
"engagement_time_msec": 1000,
}
self.enabled = (
SETTINGS["sync"]
and RANK in {-1, 0}
and not TESTS_RUNNING
and ONLINE
and (IS_PIP_PACKAGE or GIT.origin == "https://github.com/ultralytics/ultralytics.git")
)
def __call__(self, cfg, device=None) -> None:
"""
Queue an event and flush the queue asynchronously when the rate limit elapses.
Args:
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
"""
if not self.enabled:
# Events disabled, do nothing
return
# Attempt to enqueue a new event
if len(self.events) < 25: # Queue limited to 25 events to bound memory and traffic
params = {
**self.metadata,
"task": cfg.task,
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
"device": str(device),
}
if cfg.mode == "export":
params["format"] = cfg.format
self.events.append({"name": cfg.mode, "params": params})
# Check rate limit and return early if under limit
t = time.time()
if (t - self.t) < self.rate_limit:
return
# Overrate limit: send a snapshot of queued events in a background thread
payload_events = list(self.events) # snapshot to avoid race with queue reset
Thread(
target=_post,
args=(self.url, {"client_id": SETTINGS["uuid"], "events": payload_events}), # SHA-256 anonymized
daemon=True,
).start()
# Reset queue and rate limit timer
self.events = []
self.t = t
events = Events()

View File

@@ -0,0 +1,239 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import json
from pathlib import Path
import torch
from ultralytics.utils import IS_JETSON, LOGGER
from .imx import torch2imx # noqa
def torch2onnx(
torch_model: torch.nn.Module,
im: torch.Tensor,
onnx_file: str,
opset: int = 14,
input_names: list[str] = ["images"],
output_names: list[str] = ["output0"],
dynamic: bool | dict = False,
) -> None:
"""
Export a PyTorch model to ONNX format.
Args:
torch_model (torch.nn.Module): The PyTorch model to export.
im (torch.Tensor): Example input tensor for the model.
onnx_file (str): Path to save the exported ONNX file.
opset (int): ONNX opset version to use for export.
input_names (list[str]): List of input tensor names.
output_names (list[str]): List of output tensor names.
dynamic (bool | dict, optional): Whether to enable dynamic axes.
Notes:
Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
"""
torch.onnx.export(
torch_model,
im,
onnx_file,
verbose=False,
opset_version=opset,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic or None,
)
def onnx2engine(
onnx_file: str,
engine_file: str | None = None,
workspace: int | None = None,
half: bool = False,
int8: bool = False,
dynamic: bool = False,
shape: tuple[int, int, int, int] = (1, 3, 640, 640),
dla: int | None = None,
dataset=None,
metadata: dict | None = None,
verbose: bool = False,
prefix: str = "",
) -> None:
"""
Export a YOLO model to TensorRT engine format.
Args:
onnx_file (str): Path to the ONNX file to be converted.
engine_file (str, optional): Path to save the generated TensorRT engine file.
workspace (int, optional): Workspace size in GB for TensorRT.
half (bool, optional): Enable FP16 precision.
int8 (bool, optional): Enable INT8 precision.
dynamic (bool, optional): Enable dynamic input shapes.
shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
dla (int, optional): DLA core to use (Jetson devices only).
dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
metadata (dict, optional): Metadata to include in the engine file.
verbose (bool, optional): Enable verbose logging.
prefix (str, optional): Prefix for log messages.
Raises:
ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
RuntimeError: If the ONNX file cannot be parsed.
Notes:
TensorRT version compatibility is handled for workspace size and engine building.
INT8 calibration requires a dataset and generates a calibration cache.
Metadata is serialized and written to the engine file if provided.
"""
import tensorrt as trt # noqa
engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
if verbose:
logger.min_severity = trt.Logger.Severity.VERBOSE
# Engine builder
builder = trt.Builder(logger)
config = builder.create_builder_config()
workspace_bytes = int((workspace or 0) * (1 << 30))
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
if is_trt10 and workspace_bytes > 0:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
elif workspace_bytes > 0: # TensorRT versions 7, 8
config.max_workspace_size = workspace_bytes
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
half = builder.platform_has_fast_fp16 and half
int8 = builder.platform_has_fast_int8 and int8
# Optionally switch to DLA if enabled
if dla is not None:
if not IS_JETSON:
raise ValueError("DLA is only available on NVIDIA Jetson devices")
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
if not half and not int8:
raise ValueError(
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
)
config.default_device_type = trt.DeviceType.DLA
config.DLA_core = int(dla)
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
# Read ONNX file
parser = trt.OnnxParser(network, logger)
if not parser.parse_from_file(onnx_file):
raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
# Network inputs
inputs = [network.get_input(i) for i in range(network.num_inputs)]
outputs = [network.get_output(i) for i in range(network.num_outputs)]
for inp in inputs:
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
for out in outputs:
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
if dynamic:
profile = builder.create_optimization_profile()
min_shape = (1, shape[1], 32, 32) # minimum input shape
max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
for inp in inputs:
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
config.add_optimization_profile(profile)
if int8:
config.set_calibration_profile(profile)
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
if int8:
config.set_flag(trt.BuilderFlag.INT8)
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
class EngineCalibrator(trt.IInt8Calibrator):
"""
Custom INT8 calibrator for TensorRT engine optimization.
This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
using a dataset. It handles batch generation, caching, and calibration algorithm selection.
Attributes:
dataset: Dataset for calibration.
data_iter: Iterator over the calibration dataset.
algo (trt.CalibrationAlgoType): Calibration algorithm type.
batch (int): Batch size for calibration.
cache (Path): Path to save the calibration cache.
Methods:
get_algorithm: Get the calibration algorithm to use.
get_batch_size: Get the batch size to use for calibration.
get_batch: Get the next batch to use for calibration.
read_calibration_cache: Use existing cache instead of calibrating again.
write_calibration_cache: Write calibration cache to disk.
"""
def __init__(
self,
dataset, # ultralytics.data.build.InfiniteDataLoader
cache: str = "",
) -> None:
"""Initialize the INT8 calibrator with dataset and cache path."""
trt.IInt8Calibrator.__init__(self)
self.dataset = dataset
self.data_iter = iter(dataset)
self.algo = (
trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
if dla is not None
else trt.CalibrationAlgoType.MINMAX_CALIBRATION
)
self.batch = dataset.batch_size
self.cache = Path(cache)
def get_algorithm(self) -> trt.CalibrationAlgoType:
"""Get the calibration algorithm to use."""
return self.algo
def get_batch_size(self) -> int:
"""Get the batch size to use for calibration."""
return self.batch or 1
def get_batch(self, names) -> list[int] | None:
"""Get the next batch to use for calibration, as a list of device memory pointers."""
try:
im0s = next(self.data_iter)["img"] / 255.0
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
return [int(im0s.data_ptr())]
except StopIteration:
# Return None to signal to TensorRT there is no calibration data remaining
return None
def read_calibration_cache(self) -> bytes | None:
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
if self.cache.exists() and self.cache.suffix == ".cache":
return self.cache.read_bytes()
def write_calibration_cache(self, cache: bytes) -> None:
"""Write calibration cache to disk."""
_ = self.cache.write_bytes(cache)
# Load dataset w/ builder (for batching) and calibrate
config.int8_calibrator = EngineCalibrator(
dataset=dataset,
cache=str(Path(onnx_file).with_suffix(".cache")),
)
elif half:
config.set_flag(trt.BuilderFlag.FP16)
# Write file
build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(engine_file, "wb") as t:
# Metadata
if metadata is not None:
meta = json.dumps(metadata)
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
t.write(meta.encode())
# Model
t.write(engine if is_trt10 else engine.serialize())

View File

@@ -0,0 +1,289 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import subprocess
import types
from pathlib import Path
import torch
from ultralytics.nn.modules import Detect, Pose
from ultralytics.utils import LOGGER
from ultralytics.utils.tal import make_anchors
from ultralytics.utils.torch_utils import copy_attr
class FXModel(torch.nn.Module):
"""
A custom model class for torch.fx compatibility.
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
copying.
Attributes:
model (nn.Module): The original model's layers.
"""
def __init__(self, model, imgsz=(640, 640)):
"""
Initialize the FXModel.
Args:
model (nn.Module): The original model to wrap for torch.fx compatibility.
imgsz (tuple[int, int]): The input image size (height, width). Default is (640, 640).
"""
super().__init__()
copy_attr(self, model)
# Explicitly set `model` since `copy_attr` somehow does not copy it.
self.model = model.model
self.imgsz = imgsz
def forward(self, x):
"""
Forward pass through the model.
This method performs the forward pass through the model, handling the dependencies between layers and saving
intermediate outputs.
Args:
x (torch.Tensor): The input tensor to the model.
Returns:
(torch.Tensor): The output tensor from the model.
"""
y = [] # outputs
for m in self.model:
if m.f != -1: # if not from previous layer
# from earlier layers
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
if isinstance(m, Detect):
m._inference = types.MethodType(_inference, m) # bind method to Detect
m.anchors, m.strides = (
x.transpose(0, 1)
for x in make_anchors(
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
)
)
if type(m) is Pose:
m.forward = types.MethodType(pose_forward, m) # bind method to Detect
x = m(x) # run
y.append(x) # save output
return x
def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
"""Decode boxes and cls scores for imx object detection."""
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward pass for imx pose estimation, including keypoint decoding."""
bs = x[0].shape[0] # batch size
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
x = Detect.forward(self, x)
pred_kpt = self.kpts_decode(bs, kpt)
return (*x, pred_kpt.permute(0, 2, 1))
class NMSWrapper(torch.nn.Module):
"""Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
def __init__(
self,
model: torch.nn.Module,
score_threshold: float = 0.001,
iou_threshold: float = 0.7,
max_detections: int = 300,
task: str = "detect",
):
"""
Initialize NMSWrapper with PyTorch Module and NMS parameters.
Args:
model (torch.nn.Module): Model instance.
score_threshold (float): Score threshold for non-maximum suppression.
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
max_detections (int): The number of detections to return.
task (str): Task type, either 'detect' or 'pose'.
"""
super().__init__()
self.model = model
self.score_threshold = score_threshold
self.iou_threshold = iou_threshold
self.max_detections = max_detections
self.task = task
def forward(self, images):
"""Forward pass with model inference and NMS post-processing."""
from sony_custom_layers.pytorch import multiclass_nms_with_indices
# model inference
outputs = self.model(images)
boxes, scores = outputs[0], outputs[1]
nms_outputs = multiclass_nms_with_indices(
boxes=boxes,
scores=scores,
score_threshold=self.score_threshold,
iou_threshold=self.iou_threshold,
max_detections=self.max_detections,
)
if self.task == "pose":
kpts = outputs[2] # (bs, max_detections, kpts 17*3)
out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
def torch2imx(
model: torch.nn.Module,
file: Path | str,
conf: float,
iou: float,
max_det: int,
metadata: dict | None = None,
gptq: bool = False,
dataset=None,
prefix: str = "",
):
"""
Export YOLO model to IMX format for deployment on Sony IMX500 devices.
This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
to IMX format compatible with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n
models for detection and pose estimation tasks.
Args:
model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
file (Path | str): Output file path for the exported model.
conf (float): Confidence threshold for NMS post-processing.
iou (float): IoU threshold for NMS post-processing.
max_det (int): Maximum number of detections to return.
metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
If False, uses standard Post Training Quantization. Defaults to False.
dataset (optional): Representative dataset for quantization calibration. Defaults to None.
prefix (str, optional): Logging prefix string. Defaults to "".
Returns:
f (Path): Path to the exported IMX model directory
Raises:
ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
Example:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt")
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
Note:
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and sony_custom_layers packages
- Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
- Output includes quantized ONNX model, IMX binary, and labels.txt file
"""
import model_compression_toolkit as mct
import onnx
from edgemdt_tpc import get_target_platform_capabilities
LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
def representative_dataset_gen(dataloader=dataset):
for batch in dataloader:
img = batch["img"]
img = img / 255.0
yield [img]
tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
bit_cfg = mct.core.BitWidthConfig()
if "C2PSA" in model.__str__(): # YOLO11
if model.task == "detect":
layer_names = ["sub", "mul_2", "add_14", "cat_21"]
weights_memory = 2585350.2439
n_layers = 238 # 238 layers for fused YOLO11n
elif model.task == "pose":
layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
weights_memory = 2437771.67
n_layers = 257 # 257 layers for fused YOLO11n-pose
else: # YOLOv8
if model.task == "detect":
layer_names = ["sub", "mul", "add_6", "cat_17"]
weights_memory = 2550540.8
n_layers = 168 # 168 layers for fused YOLOv8n
elif model.task == "pose":
layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
weights_memory = 2482451.85
n_layers = 187 # 187 layers for fused YOLO11n-pose
# Check if the model has the expected number of layers
if len(list(model.modules())) != n_layers:
raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
for layer_name in layer_names:
bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
config = mct.core.CoreConfig(
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
bit_width_config=bit_cfg,
)
resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
quant_model = (
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
model=model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
gptq_config=mct.gptq.get_pytorch_gptq_config(
n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
),
core_config=config,
target_platform_capabilities=tpc,
)[0]
if gptq
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
in_module=model,
representative_data_gen=representative_dataset_gen,
target_resource_utilization=resource_utilization,
core_config=config,
target_platform_capabilities=tpc,
)[0]
)
quant_model = NMSWrapper(
model=quant_model,
score_threshold=conf or 0.001,
iou_threshold=iou,
max_detections=max_det,
task=model.task,
)
f = Path(str(file).replace(file.suffix, "_imx_model"))
f.mkdir(exist_ok=True)
onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
mct.exporter.pytorch_export_model(
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
)
model_onnx = onnx.load(onnx_model) # load onnx model
for k, v in metadata.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)
onnx.save(model_onnx, onnx_model)
subprocess.run(
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
check=True,
)
# Needed for imx models.
with open(f / "labels.txt", "w", encoding="utf-8") as file:
file.writelines([f"{name}\n" for _, name in model.names.items()])
return f

223
ultralytics/utils/files.py Normal file
View File

@@ -0,0 +1,223 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import contextlib
import glob
import os
import shutil
import tempfile
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
class WorkingDirectory(contextlib.ContextDecorator):
"""
A context manager and decorator for temporarily changing the working directory.
This class allows for the temporary change of the working directory using a context manager or decorator.
It ensures that the original working directory is restored after the context or decorated function completes.
Attributes:
dir (Path | str): The new directory to switch to.
cwd (Path): The original current working directory before the switch.
Methods:
__enter__: Changes the current directory to the specified directory.
__exit__: Restores the original working directory on context exit.
Examples:
Using as a context manager:
>>> with WorkingDirectory('/path/to/new/dir'):
>>> # Perform operations in the new directory
>>> pass
Using as a decorator:
>>> @WorkingDirectory('/path/to/new/dir')
>>> def some_function():
>>> # Perform operations in the new directory
>>> pass
"""
def __init__(self, new_dir: str | Path):
"""Initialize the WorkingDirectory context manager with the target directory."""
self.dir = new_dir # new dir
self.cwd = Path.cwd().resolve() # current dir
def __enter__(self):
"""Change the current working directory to the specified directory upon entering the context."""
os.chdir(self.dir)
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
"""Restore the original working directory when exiting the context."""
os.chdir(self.cwd)
@contextmanager
def spaces_in_path(path: str | Path):
"""
Context manager to handle paths with spaces in their names.
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes
the context code block, then copies the file/directory back to its original location.
Args:
path (str | Path): The original path that may contain spaces.
Yields:
(Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the
original path.
Examples:
>>> with spaces_in_path('/path/with spaces') as new_path:
>>> # Your code here
>>> pass
"""
# If path has spaces, replace them with underscores
if " " in str(path):
string = isinstance(path, str) # input type
path = Path(path)
# Create a temporary directory and construct the new path
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
# Copy file/directory
if path.is_dir():
shutil.copytree(path, tmp_path)
elif path.is_file():
tmp_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(path, tmp_path)
try:
# Yield the temporary path
yield str(tmp_path) if string else tmp_path
finally:
# Copy file/directory back
if tmp_path.is_dir():
shutil.copytree(tmp_path, path, dirs_exist_ok=True)
elif tmp_path.is_file():
shutil.copy2(tmp_path, path) # Copy back the file
else:
# If there are no spaces, just yield the original path
yield path
def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
"""
Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
number will be appended directly to the end of the path.
Args:
path (str | Path): Path to increment.
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.
sep (str, optional): Separator to use between the path and the incrementation number.
mkdir (bool, optional): Create a directory if it does not exist.
Returns:
(Path): Incremented path.
Examples:
Increment a directory path:
>>> from pathlib import Path
>>> path = Path("runs/exp")
>>> new_path = increment_path(path)
>>> print(new_path)
runs/exp2
Increment a file path:
>>> path = Path("runs/exp/results.txt")
>>> new_path = increment_path(path)
>>> print(new_path)
runs/exp/results2.txt
"""
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
# Method 1
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def file_age(path: str | Path = __file__) -> int:
"""Return days since the last modification of the specified file."""
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
return dt.days # + dt.seconds / 86400 # fractional days
def file_date(path: str | Path = __file__) -> str:
"""Return the file modification date in 'YYYY-M-D' format."""
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
return f"{t.year}-{t.month}-{t.day}"
def file_size(path: str | Path) -> float:
"""Return the size of a file or directory in megabytes (MB)."""
if isinstance(path, (str, Path)):
mb = 1 << 20 # bytes to MiB (1024 ** 2)
path = Path(path)
if path.is_file():
return path.stat().st_size / mb
elif path.is_dir():
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
return 0.0
def get_latest_run(search_dir: str = ".") -> str:
"""Return the path to the most recent 'last.pt' file in the specified directory for resuming training."""
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
return max(last_list, key=os.path.getctime) if last_list else ""
def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
"""
Update and re-save specified YOLO models in an 'updated_models' subdirectory.
Args:
model_names (tuple, optional): Model filenames to update.
source_dir (Path, optional): Directory containing models and target subdirectory.
update_names (bool, optional): Update model names from a data YAML.
Examples:
Update specified YOLO models and save them in 'updated_models' subdirectory:
>>> from ultralytics.utils.files import update_models
>>> model_names = ("yolo11n.pt", "yolov8s.pt")
>>> update_models(model_names, source_dir=Path("/models"), update_names=True)
"""
from ultralytics import YOLO
from ultralytics.nn.autobackend import default_class_names
target_dir = source_dir / "updated_models"
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
for model_name in model_names:
model_path = source_dir / model_name
print(f"Loading model from {model_path}")
# Load model
model = YOLO(model_path)
model.half()
if update_names: # update model names from a dataset YAML
model.model.names = default_class_names("coco8.yaml")
# Define new save path
save_path = target_dir / model_name
# Save model using model.save()
print(f"Re-saving {model_name} model to {save_path}")
model.save(save_path)

139
ultralytics/utils/git.py Normal file
View File

@@ -0,0 +1,139 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from functools import cached_property
from pathlib import Path
class GitRepo:
"""
Represent a local Git repository and expose branch, commit, and remote metadata.
This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
invoke the git binary and therefore works in restricted environments. All metadata properties are resolved
lazily and cached; construct a new instance to refresh state.
Attributes:
root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
gitdir (Path | None): Resolved .git directory path; handles worktrees; None if unresolved.
head (str | None): Raw contents of HEAD; a SHA for detached HEAD or "ref: <refname>" for branch heads.
is_repo (bool): Whether the provided path resides inside a Git repository.
branch (str | None): Current branch name when HEAD points to a branch; None for detached HEAD or non-repo.
commit (str | None): Current commit SHA for HEAD; None if not determinable.
origin (str | None): URL of the "origin" remote as read from gitdir/config; None if unset or unavailable.
Examples:
Initialize from the current working directory and read metadata
>>> from pathlib import Path
>>> repo = GitRepo(Path.cwd())
>>> repo.is_repo
True
>>> repo.branch, repo.commit[:7], repo.origin
('main', '1a2b3c4', 'https://example.com/owner/repo.git')
Notes:
- Resolves metadata by reading files: HEAD, packed-refs, and config; no subprocess calls are used.
- Caches properties on first access using cached_property; recreate the object to reflect repository changes.
"""
def __init__(self, path: Path = Path(__file__).resolve()):
"""
Initialize a Git repository context by discovering the repository root from a starting path.
Args:
path (Path, optional): File or directory path used as the starting point to locate the repository root.
"""
self.root = self._find_root(path)
self.gitdir = self._gitdir(self.root) if self.root else None
@staticmethod
def _find_root(p: Path) -> Path | None:
"""Return repo root or None."""
return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
@staticmethod
def _gitdir(root: Path) -> Path | None:
"""Resolve actual .git directory (handles worktrees)."""
g = root / ".git"
if g.is_dir():
return g
if g.is_file():
t = g.read_text(errors="ignore").strip()
if t.startswith("gitdir:"):
return (root / t.split(":", 1)[1].strip()).resolve()
return None
def _read(self, p: Path | None) -> str | None:
"""Read and strip file if exists."""
return p.read_text(errors="ignore").strip() if p and p.exists() else None
@cached_property
def head(self) -> str | None:
"""HEAD file contents."""
return self._read(self.gitdir / "HEAD" if self.gitdir else None)
def _ref_commit(self, ref: str) -> str | None:
"""Commit for ref (handles packed-refs)."""
rf = self.gitdir / ref
s = self._read(rf)
if s:
return s
pf = self.gitdir / "packed-refs"
b = pf.read_bytes().splitlines() if pf.exists() else []
tgt = ref.encode()
for line in b:
if line[:1] in (b"#", b"^") or b" " not in line:
continue
sha, name = line.split(b" ", 1)
if name.strip() == tgt:
return sha.decode()
return None
@property
def is_repo(self) -> bool:
"""True if inside a git repo."""
return self.gitdir is not None
@cached_property
def branch(self) -> str | None:
"""Current branch or None."""
if not self.is_repo or not self.head or not self.head.startswith("ref: "):
return None
ref = self.head[5:].strip()
return ref[len("refs/heads/") :] if ref.startswith("refs/heads/") else ref
@cached_property
def commit(self) -> str | None:
"""Current commit SHA or None."""
if not self.is_repo or not self.head:
return None
return self._ref_commit(self.head[5:].strip()) if self.head.startswith("ref: ") else self.head
@cached_property
def origin(self) -> str | None:
"""Origin URL or None."""
if not self.is_repo:
return None
cfg = self.gitdir / "config"
remote, url = None, None
for s in (self._read(cfg) or "").splitlines():
t = s.strip()
if t.startswith("[") and t.endswith("]"):
remote = t.lower()
elif t.lower().startswith("url =") and remote == '[remote "origin"]':
url = t.split("=", 1)[1].strip()
break
return url
if __name__ == "__main__":
import time
g = GitRepo()
if g.is_repo:
t0 = time.perf_counter()
print(f"repo={g.root}\nbranch={g.branch}\ncommit={g.commit}\norigin={g.origin}")
dt = (time.perf_counter() - t0) * 1000
print(f"\n⏱️ Profiling: total {dt:.3f} ms")

View File

@@ -0,0 +1,505 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from collections import abc
from itertools import repeat
from numbers import Number
import numpy as np
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
def _ntuple(n):
"""Create a function that converts input to n-tuple by repeating singleton values."""
def parse(x):
"""Parse input to return n-tuple by repeating singleton values n times."""
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
to_4tuple = _ntuple(4)
# `xyxy` means left top and right bottom
# `xywh` means center x, center y and width, height(YOLO format)
# `ltwh` means left top and width, height(COCO format)
_formats = ["xyxy", "xywh", "ltwh"]
__all__ = ("Bboxes", "Instances") # tuple or list
class Bboxes:
"""
A class for handling bounding boxes in multiple formats.
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
Attributes:
bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
Methods:
convert: Convert bounding box format from one type to another.
areas: Calculate the area of bounding boxes.
mul: Multiply bounding box coordinates by scale factor(s).
add: Add offset to bounding box coordinates.
concatenate: Concatenate multiple Bboxes objects.
Examples:
Create bounding boxes in YOLO format
>>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh")
>>> bboxes.convert("xyxy")
>>> print(bboxes.areas())
Notes:
This class does not handle normalization or denormalization of bounding boxes.
"""
def __init__(self, bboxes: np.ndarray, format: str = "xyxy") -> None:
"""
Initialize the Bboxes class with bounding box data in a specified format.
Args:
bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).
format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'.
"""
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
assert bboxes.ndim == 2
assert bboxes.shape[1] == 4
self.bboxes = bboxes
self.format = format
def convert(self, format: str) -> None:
"""
Convert bounding box format from one type to another.
Args:
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
"""
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
if self.format == format:
return
elif self.format == "xyxy":
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
elif self.format == "xywh":
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
else:
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
self.bboxes = func(self.bboxes)
self.format = format
def areas(self) -> np.ndarray:
"""Calculate the area of bounding boxes."""
return (
(self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy
if self.format == "xyxy"
else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
)
def mul(self, scale: int | tuple | list) -> None:
"""
Multiply bounding box coordinates by scale factor(s).
Args:
scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
all coordinates.
"""
if isinstance(scale, Number):
scale = to_4tuple(scale)
assert isinstance(scale, (tuple, list))
assert len(scale) == 4
self.bboxes[:, 0] *= scale[0]
self.bboxes[:, 1] *= scale[1]
self.bboxes[:, 2] *= scale[2]
self.bboxes[:, 3] *= scale[3]
def add(self, offset: int | tuple | list) -> None:
"""
Add offset to bounding box coordinates.
Args:
offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
all coordinates.
"""
if isinstance(offset, Number):
offset = to_4tuple(offset)
assert isinstance(offset, (tuple, list))
assert len(offset) == 4
self.bboxes[:, 0] += offset[0]
self.bboxes[:, 1] += offset[1]
self.bboxes[:, 2] += offset[2]
self.bboxes[:, 3] += offset[3]
def __len__(self) -> int:
"""Return the number of bounding boxes."""
return len(self.bboxes)
@classmethod
def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
"""
Concatenate a list of Bboxes objects into a single Bboxes object.
Args:
boxes_list (list[Bboxes]): A list of Bboxes objects to concatenate.
axis (int, optional): The axis along which to concatenate the bounding boxes.
Returns:
(Bboxes): A new Bboxes object containing the concatenated bounding boxes.
Notes:
The input should be a list or tuple of Bboxes objects.
"""
assert isinstance(boxes_list, (list, tuple))
if not boxes_list:
return cls(np.empty(0))
assert all(isinstance(box, Bboxes) for box in boxes_list)
if len(boxes_list) == 1:
return boxes_list[0]
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
"""
Retrieve a specific bounding box or a set of bounding boxes using indexing.
Args:
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
Returns:
(Bboxes): A new Bboxes object containing the selected bounding boxes.
Notes:
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
bounding boxes.
"""
if isinstance(index, int):
return Bboxes(self.bboxes[index].reshape(1, -1))
b = self.bboxes[index]
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
return Bboxes(b)
class Instances:
"""
Container for bounding boxes, segments, and keypoints of detected objects in an image.
This class provides a unified interface for handling different types of object annotations including bounding
boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,
and format conversion.
Attributes:
_bboxes (Bboxes): Internal object for handling bounding box operations.
keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.
Methods:
convert_bbox: Convert bounding box format.
scale: Scale coordinates by given factors.
denormalize: Convert normalized coordinates to absolute coordinates.
normalize: Convert absolute coordinates to normalized coordinates.
add_padding: Add padding to coordinates.
flipud: Flip coordinates vertically.
fliplr: Flip coordinates horizontally.
clip: Clip coordinates to stay within image boundaries.
remove_zero_area_boxes: Remove boxes with zero area.
update: Update instance variables.
concatenate: Concatenate multiple Instances objects.
Examples:
Create instances with bounding boxes and segments
>>> instances = Instances(
... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),
... )
"""
def __init__(
self,
bboxes: np.ndarray,
segments: np.ndarray = None,
keypoints: np.ndarray = None,
bbox_format: str = "xywh",
normalized: bool = True,
) -> None:
"""
Initialize the Instances object with bounding boxes, segments, and keypoints.
Args:
bboxes (np.ndarray): Bounding boxes with shape (N, 4).
segments (np.ndarray, optional): Segmentation masks.
keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).
bbox_format (str): Format of bboxes.
normalized (bool): Whether the coordinates are normalized.
"""
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
self.keypoints = keypoints
self.normalized = normalized
self.segments = segments
def convert_bbox(self, format: str) -> None:
"""
Convert bounding box format.
Args:
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
"""
self._bboxes.convert(format=format)
@property
def bbox_areas(self) -> np.ndarray:
"""Calculate the area of bounding boxes."""
return self._bboxes.areas()
def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):
"""
Scale coordinates by given factors.
Args:
scale_w (float): Scale factor for width.
scale_h (float): Scale factor for height.
bbox_only (bool, optional): Whether to scale only bounding boxes.
"""
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
if bbox_only:
return
self.segments[..., 0] *= scale_w
self.segments[..., 1] *= scale_h
if self.keypoints is not None:
self.keypoints[..., 0] *= scale_w
self.keypoints[..., 1] *= scale_h
def denormalize(self, w: int, h: int) -> None:
"""
Convert normalized coordinates to absolute coordinates.
Args:
w (int): Image width.
h (int): Image height.
"""
if not self.normalized:
return
self._bboxes.mul(scale=(w, h, w, h))
self.segments[..., 0] *= w
self.segments[..., 1] *= h
if self.keypoints is not None:
self.keypoints[..., 0] *= w
self.keypoints[..., 1] *= h
self.normalized = False
def normalize(self, w: int, h: int) -> None:
"""
Convert absolute coordinates to normalized coordinates.
Args:
w (int): Image width.
h (int): Image height.
"""
if self.normalized:
return
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
self.segments[..., 0] /= w
self.segments[..., 1] /= h
if self.keypoints is not None:
self.keypoints[..., 0] /= w
self.keypoints[..., 1] /= h
self.normalized = True
def add_padding(self, padw: int, padh: int) -> None:
"""
Add padding to coordinates.
Args:
padw (int): Padding width.
padh (int): Padding height.
"""
assert not self.normalized, "you should add padding with absolute coordinates."
self._bboxes.add(offset=(padw, padh, padw, padh))
self.segments[..., 0] += padw
self.segments[..., 1] += padh
if self.keypoints is not None:
self.keypoints[..., 0] += padw
self.keypoints[..., 1] += padh
def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
"""
Retrieve a specific instance or a set of instances using indexing.
Args:
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.
Returns:
(Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.
Notes:
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
instances.
"""
segments = self.segments[index] if len(self.segments) else self.segments
keypoints = self.keypoints[index] if self.keypoints is not None else None
bboxes = self.bboxes[index]
bbox_format = self._bboxes.format
return Instances(
bboxes=bboxes,
segments=segments,
keypoints=keypoints,
bbox_format=bbox_format,
normalized=self.normalized,
)
def flipud(self, h: int) -> None:
"""
Flip coordinates vertically.
Args:
h (int): Image height.
"""
if self._bboxes.format == "xyxy":
y1 = self.bboxes[:, 1].copy()
y2 = self.bboxes[:, 3].copy()
self.bboxes[:, 1] = h - y2
self.bboxes[:, 3] = h - y1
else:
self.bboxes[:, 1] = h - self.bboxes[:, 1]
self.segments[..., 1] = h - self.segments[..., 1]
if self.keypoints is not None:
self.keypoints[..., 1] = h - self.keypoints[..., 1]
def fliplr(self, w: int) -> None:
"""
Flip coordinates horizontally.
Args:
w (int): Image width.
"""
if self._bboxes.format == "xyxy":
x1 = self.bboxes[:, 0].copy()
x2 = self.bboxes[:, 2].copy()
self.bboxes[:, 0] = w - x2
self.bboxes[:, 2] = w - x1
else:
self.bboxes[:, 0] = w - self.bboxes[:, 0]
self.segments[..., 0] = w - self.segments[..., 0]
if self.keypoints is not None:
self.keypoints[..., 0] = w - self.keypoints[..., 0]
def clip(self, w: int, h: int) -> None:
"""
Clip coordinates to stay within image boundaries.
Args:
w (int): Image width.
h (int): Image height.
"""
ori_format = self._bboxes.format
self.convert_bbox(format="xyxy")
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
if ori_format != "xyxy":
self.convert_bbox(format=ori_format)
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
if self.keypoints is not None:
# Set out of bounds visibility to zero
self.keypoints[..., 2][
(self.keypoints[..., 0] < 0)
| (self.keypoints[..., 0] > w)
| (self.keypoints[..., 1] < 0)
| (self.keypoints[..., 1] > h)
] = 0.0
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
def remove_zero_area_boxes(self) -> np.ndarray:
"""
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
Returns:
(np.ndarray): Boolean array indicating which boxes were kept.
"""
good = self.bbox_areas > 0
if not all(good):
self._bboxes = self._bboxes[good]
if len(self.segments):
self.segments = self.segments[good]
if self.keypoints is not None:
self.keypoints = self.keypoints[good]
return good
def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):
"""
Update instance variables.
Args:
bboxes (np.ndarray): New bounding boxes.
segments (np.ndarray, optional): New segments.
keypoints (np.ndarray, optional): New keypoints.
"""
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
if segments is not None:
self.segments = segments
if keypoints is not None:
self.keypoints = keypoints
def __len__(self) -> int:
"""Return the number of instances."""
return len(self.bboxes)
@classmethod
def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
"""
Concatenate a list of Instances objects into a single Instances object.
Args:
instances_list (list[Instances]): A list of Instances objects to concatenate.
axis (int, optional): The axis along which the arrays will be concatenated.
Returns:
(Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
if present.
Notes:
The `Instances` objects in the list should have the same properties, such as the format of the bounding
boxes, whether keypoints are present, and if the coordinates are normalized.
"""
assert isinstance(instances_list, (list, tuple))
if not instances_list:
return cls(np.empty(0))
assert all(isinstance(instance, Instances) for instance in instances_list)
if len(instances_list) == 1:
return instances_list[0]
use_keypoint = instances_list[0].keypoints is not None
bbox_format = instances_list[0]._bboxes.format
normalized = instances_list[0].normalized
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
seg_len = [b.segments.shape[1] for b in instances_list]
if len(frozenset(seg_len)) > 1: # resample segments if there's different length
max_len = max(seg_len)
cat_segments = np.concatenate(
[
resample_segments(list(b.segments), max_len)
if len(b.segments)
else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments
for b in instances_list
],
axis=axis,
)
else:
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
@property
def bboxes(self) -> np.ndarray:
"""Return bounding boxes."""
return self._bboxes.bboxes

408
ultralytics/utils/logger.py Normal file
View File

@@ -0,0 +1,408 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import logging
import queue
import shutil
import sys
import threading
import time
from datetime import datetime
from pathlib import Path
from ultralytics.utils import MACOS, RANK
from ultralytics.utils.checks import check_requirements
# Initialize default log file
DEFAULT_LOG_PATH = Path("train.log")
if RANK in {-1, 0} and DEFAULT_LOG_PATH.exists():
DEFAULT_LOG_PATH.unlink(missing_ok=True)
class ConsoleLogger:
"""
Console output capture with API/file streaming and deduplication.
Captures stdout/stderr output and streams it to either an API endpoint or local file, with intelligent
deduplication to reduce noise from repetitive console output.
Attributes:
destination (str | Path): Target destination for streaming (URL or Path object).
is_api (bool): Whether destination is an API endpoint (True) or local file (False).
original_stdout: Reference to original sys.stdout for restoration.
original_stderr: Reference to original sys.stderr for restoration.
log_queue (queue.Queue): Thread-safe queue for buffering log messages.
active (bool): Whether console capture is currently active.
worker_thread (threading.Thread): Background thread for processing log queue.
last_line (str): Last processed line for deduplication.
last_time (float): Timestamp of last processed line.
last_progress_line (str): Last progress bar line for progress deduplication.
last_was_progress (bool): Whether the last line was a progress bar.
Examples:
Basic file logging:
>>> logger = ConsoleLogger("training.log")
>>> logger.start_capture()
>>> print("This will be logged")
>>> logger.stop_capture()
API streaming:
>>> logger = ConsoleLogger("https://api.example.com/logs")
>>> logger.start_capture()
>>> # All output streams to API
>>> logger.stop_capture()
"""
def __init__(self, destination):
"""
Initialize with API endpoint or local file path.
Args:
destination (str | Path): API endpoint URL (http/https) or local file path for streaming output.
"""
self.destination = destination
self.is_api = isinstance(destination, str) and destination.startswith(("http://", "https://"))
if not self.is_api:
self.destination = Path(destination)
# Console capture
self.original_stdout = sys.stdout
self.original_stderr = sys.stderr
self.log_queue = queue.Queue(maxsize=1000)
self.active = False
self.worker_thread = None
# State tracking
self.last_line = ""
self.last_time = 0.0
self.last_progress_line = "" # Track last progress line for deduplication
self.last_was_progress = False # Track if last line was a progress bar
def start_capture(self):
"""Start capturing console output and redirect stdout/stderr to custom capture objects."""
if self.active:
return
self.active = True
sys.stdout = self._ConsoleCapture(self.original_stdout, self._queue_log)
sys.stderr = self._ConsoleCapture(self.original_stderr, self._queue_log)
# Hook Ultralytics logger
try:
handler = self._LogHandler(self._queue_log)
logging.getLogger("ultralytics").addHandler(handler)
except Exception:
pass
self.worker_thread = threading.Thread(target=self._stream_worker, daemon=True)
self.worker_thread.start()
def stop_capture(self):
"""Stop capturing console output and restore original stdout/stderr."""
if not self.active:
return
self.active = False
sys.stdout = self.original_stdout
sys.stderr = self.original_stderr
self.log_queue.put(None)
def _queue_log(self, text):
"""Queue console text with deduplication and timestamp processing."""
if not self.active:
return
current_time = time.time()
# Handle carriage returns and process lines
if "\r" in text:
text = text.split("\r")[-1]
lines = text.split("\n")
if lines and lines[-1] == "":
lines.pop()
for line in lines:
line = line.rstrip()
# Skip lines with only thin progress bars (partial progress)
if "" in line: # Has thin lines but no thick lines
continue
# Deduplicate completed progress bars only if they match the previous progress line
if " ━━" in line:
progress_core = line.split(" ━━")[0].strip()
if progress_core == self.last_progress_line and self.last_was_progress:
continue
self.last_progress_line = progress_core
self.last_was_progress = True
else:
# Skip empty line after progress bar
if not line and self.last_was_progress:
self.last_was_progress = False
continue
self.last_was_progress = False
# General deduplication
if line == self.last_line and current_time - self.last_time < 0.1:
continue
self.last_line = line
self.last_time = current_time
# Add timestamp if needed
if not line.startswith("[20"):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
line = f"[{timestamp}] {line}"
# Queue with overflow protection
if not self._safe_put(f"{line}\n"):
continue # Skip if queue handling fails
def _safe_put(self, item):
"""Safely put item in queue with overflow handling."""
try:
self.log_queue.put_nowait(item)
return True
except queue.Full:
try:
self.log_queue.get_nowait() # Drop oldest
self.log_queue.put_nowait(item)
return True
except queue.Empty:
return False
def _stream_worker(self):
"""Background worker for streaming logs to destination."""
while self.active:
try:
log_text = self.log_queue.get(timeout=1)
if log_text is None:
break
self._write_log(log_text)
except queue.Empty:
continue
def _write_log(self, text):
"""Write log to API endpoint or local file destination."""
try:
if self.is_api:
import requests # scoped as slow import
payload = {"timestamp": datetime.now().isoformat(), "message": text.strip()}
requests.post(str(self.destination), json=payload, timeout=5)
else:
self.destination.parent.mkdir(parents=True, exist_ok=True)
with self.destination.open("a", encoding="utf-8") as f:
f.write(text)
except Exception as e:
print(f"Platform logging error: {e}", file=self.original_stderr)
class _ConsoleCapture:
"""Lightweight stdout/stderr capture."""
__slots__ = ("original", "callback")
def __init__(self, original, callback):
self.original = original
self.callback = callback
def write(self, text):
self.original.write(text)
self.callback(text)
def flush(self):
self.original.flush()
class _LogHandler(logging.Handler):
"""Lightweight logging handler."""
__slots__ = ("callback",)
def __init__(self, callback):
super().__init__()
self.callback = callback
def emit(self, record):
self.callback(self.format(record) + "\n")
class SystemLogger:
"""
Log dynamic system metrics for training monitoring.
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for
training performance monitoring and analysis.
Attributes:
pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
nvidia_initialized (bool): Whether NVIDIA GPU monitoring is available and initialized.
net_start: Initial network I/O counters for calculating cumulative usage.
disk_start: Initial disk I/O counters for calculating cumulative usage.
Examples:
Basic usage:
>>> logger = SystemLogger()
>>> metrics = logger.get_metrics()
>>> print(f"CPU: {metrics['cpu']}%, RAM: {metrics['ram']}%")
>>> if metrics["gpus"]:
... gpu0 = metrics["gpus"]["0"]
... print(f"GPU0: {gpu0['usage']}% usage, {gpu0['temp']}°C")
Training loop integration:
>>> system_logger = SystemLogger()
>>> for epoch in range(epochs):
... # Training code here
... metrics = system_logger.get_metrics()
... # Log to database/file
"""
def __init__(self):
"""Initialize the system logger."""
import psutil # scoped as slow import
self.pynvml = None
self.nvidia_initialized = self._init_nvidia()
self.net_start = psutil.net_io_counters()
self.disk_start = psutil.disk_io_counters()
def _init_nvidia(self):
"""Initialize NVIDIA GPU monitoring with pynvml."""
try:
assert not MACOS
check_requirements("nvidia-ml-py>=12.0.0")
self.pynvml = __import__("pynvml")
self.pynvml.nvmlInit()
return True
except Exception:
return False
def get_metrics(self):
"""
Get current system metrics.
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics,
network I/O statistics, and GPU metrics (if available). Example output:
```python
metrics = {
"cpu": 45.2,
"ram": 78.9,
"disk": {"read_mb": 156.7, "write_mb": 89.3, "used_gb": 256.8},
"network": {"recv_mb": 157.2, "sent_mb": 89.1},
"gpus": {
0: {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
1: {"usage": 94.1, "memory": 82.7, "temp": 70, "power": 278},
},
}
```
- cpu (float): CPU usage percentage (0-100%)
- ram (float): RAM usage percentage (0-100%)
- disk (dict):
- read_mb (float): Cumulative disk read in MB since initialization
- write_mb (float): Cumulative disk write in MB since initialization
- used_gb (float): Total disk space used in GB
- network (dict):
- recv_mb (float): Cumulative network received in MB since initialization
- sent_mb (float): Cumulative network sent in MB since initialization
- gpus (dict): GPU metrics by device index (e.g., 0, 1) containing:
- usage (int): GPU utilization percentage (0-100%)
- memory (float): CUDA memory usage percentage (0-100%)
- temp (int): GPU temperature in degrees Celsius
- power (int): GPU power consumption in watts
Returns:
metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with respective usage data.
"""
import psutil # scoped as slow import
net = psutil.net_io_counters()
disk = psutil.disk_io_counters()
memory = psutil.virtual_memory()
disk_usage = shutil.disk_usage("/")
metrics = {
"cpu": round(psutil.cpu_percent(), 3),
"ram": round(memory.percent, 3),
"disk": {
"read_mb": round((disk.read_bytes - self.disk_start.read_bytes) / (1 << 20), 3),
"write_mb": round((disk.write_bytes - self.disk_start.write_bytes) / (1 << 20), 3),
"used_gb": round(disk_usage.used / (1 << 30), 3),
},
"network": {
"recv_mb": round((net.bytes_recv - self.net_start.bytes_recv) / (1 << 20), 3),
"sent_mb": round((net.bytes_sent - self.net_start.bytes_sent) / (1 << 20), 3),
},
"gpus": {},
}
# Add GPU metrics (NVIDIA only)
if self.nvidia_initialized:
metrics["gpus"].update(self._get_nvidia_metrics())
return metrics
def _get_nvidia_metrics(self):
"""Get NVIDIA GPU metrics including utilization, memory, temperature, and power."""
gpus = {}
if not self.nvidia_initialized or not self.pynvml:
return gpus
try:
device_count = self.pynvml.nvmlDeviceGetCount()
for i in range(device_count):
handle = self.pynvml.nvmlDeviceGetHandleByIndex(i)
util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
temp = self.pynvml.nvmlDeviceGetTemperature(handle, self.pynvml.NVML_TEMPERATURE_GPU)
power = self.pynvml.nvmlDeviceGetPowerUsage(handle) // 1000
gpus[str(i)] = {
"usage": round(util.gpu, 3),
"memory": round((memory.used / memory.total) * 100, 3),
"temp": temp,
"power": power,
}
except Exception:
pass
return gpus
if __name__ == "__main__":
print("SystemLogger Real-time Metrics Monitor")
print("Press Ctrl+C to stop\n")
logger = SystemLogger()
try:
while True:
metrics = logger.get_metrics()
# Clear screen (works on most terminals)
print("\033[H\033[J", end="")
# Display system metrics
print(f"CPU: {metrics['cpu']:5.1f}%")
print(f"RAM: {metrics['ram']:5.1f}%")
print(f"Disk Read: {metrics['disk']['read_mb']:8.1f} MB")
print(f"Disk Write: {metrics['disk']['write_mb']:7.1f} MB")
print(f"Disk Used: {metrics['disk']['used_gb']:8.1f} GB")
print(f"Net Recv: {metrics['network']['recv_mb']:9.1f} MB")
print(f"Net Sent: {metrics['network']['sent_mb']:9.1f} MB")
# Display GPU metrics if available
if metrics["gpus"]:
print("\nGPU Metrics:")
for gpu_id, gpu_data in metrics["gpus"].items():
print(
f" GPU {gpu_id}: {gpu_data['usage']:3}% | "
f"Mem: {gpu_data['memory']:5.1f}% | "
f"Temp: {gpu_data['temp']:2}°C | "
f"Power: {gpu_data['power']:3}W"
)
else:
print("\nGPU: No NVIDIA GPUs detected")
time.sleep(1)
except KeyboardInterrupt:
print("\n\nStopped monitoring.")

857
ultralytics/utils/loss.py Normal file
View File

@@ -0,0 +1,857 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.metrics import OKS_SIGMA
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
from ultralytics.utils.torch_utils import autocast
from .metrics import bbox_iou, probiou
from .tal import bbox2dist
class VarifocalLoss(nn.Module):
"""
Varifocal loss by Zhang et al.
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
hard-to-classify examples and balancing positive/negative samples.
Attributes:
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
alpha (float): The balancing factor used to address class imbalance.
References:
https://arxiv.org/abs/2008.13367
"""
def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
"""Initialize the VarifocalLoss class with focusing and balancing parameters."""
super().__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""Compute varifocal loss between predictions and ground truth."""
weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
with autocast(enabled=False):
loss = (
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
.mean(1)
.sum()
)
return loss
class FocalLoss(nn.Module):
"""
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
on hard negatives during training.
Attributes:
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
alpha (torch.Tensor): The balancing factor used to address class imbalance.
"""
def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
"""Initialize FocalLoss class with focusing and balancing parameters."""
super().__init__()
self.gamma = gamma
self.alpha = torch.tensor(alpha)
def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
"""Calculate focal loss with modulating factors for class imbalance."""
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
# p_t = torch.exp(-loss)
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
pred_prob = pred.sigmoid() # prob from logits
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
modulating_factor = (1.0 - p_t) ** self.gamma
loss *= modulating_factor
if (self.alpha > 0).any():
self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
loss *= alpha_factor
return loss.mean(1).sum()
class DFLoss(nn.Module):
"""Criterion class for computing Distribution Focal Loss (DFL)."""
def __init__(self, reg_max: int = 16) -> None:
"""Initialize the DFL module with regularization maximum."""
super().__init__()
self.reg_max = reg_max
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
target = target.clamp_(0, self.reg_max - 1 - 0.01)
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
class BboxLoss(nn.Module):
"""Criterion class for computing training losses for bounding boxes."""
def __init__(self, reg_max: int = 16):
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
super().__init__()
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
def forward(
self,
pred_dist: torch.Tensor,
pred_bboxes: torch.Tensor,
anchor_points: torch.Tensor,
target_bboxes: torch.Tensor,
target_scores: torch.Tensor,
target_scores_sum: torch.Tensor,
fg_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute IoU and DFL losses for bounding boxes."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class RotatedBboxLoss(BboxLoss):
"""Criterion class for computing training losses for rotated bounding boxes."""
def __init__(self, reg_max: int):
"""Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
super().__init__(reg_max)
def forward(
self,
pred_dist: torch.Tensor,
pred_bboxes: torch.Tensor,
anchor_points: torch.Tensor,
target_bboxes: torch.Tensor,
target_scores: torch.Tensor,
target_scores_sum: torch.Tensor,
fg_mask: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute IoU and DFL losses for rotated bounding boxes."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
class KeypointLoss(nn.Module):
"""Criterion class for computing keypoint losses."""
def __init__(self, sigmas: torch.Tensor) -> None:
"""Initialize the KeypointLoss class with keypoint sigmas."""
super().__init__()
self.sigmas = sigmas
def forward(
self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
class v8DetectionLoss:
"""Criterion class for computing training losses for YOLOv8 object detection."""
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
device = next(model.parameters()).device # get model device
h = model.args # hyperparameters
m = model.model[-1] # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.nc + m.reg_max * 4
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = BboxLoss(m.reg_max).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
"""Preprocess targets by converting to tensor format and scaling coordinates."""
nl, ne = targets.shape
if nl == 0:
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
for j in range(batch_size):
matches = i == j
if n := matches.sum():
out[j, :n] = targets[matches, 1:]
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
return out
def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
return dist2bbox(pred_dist, anchor_points, xywh=False)
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats = preds[1] if isinstance(preds, tuple) else preds
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
batch_size = pred_scores.shape[0]
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
loss[0], loss[2] = self.bbox_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
class v8SegmentationLoss(v8DetectionLoss):
"""Criterion class for computing training losses for YOLOv8 segmentation."""
def __init__(self, model): # model must be de-paralleled
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
super().__init__(model)
self.overlap = model.args.overlap_mask
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate and return the combined loss for detection and segmentation."""
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
except RuntimeError as e:
raise TypeError(
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
if fg_mask.sum():
# Bbox loss
loss[0], loss[3] = self.bbox_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
@staticmethod
def single_mask_loss(
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
"""
Compute the instance segmentation loss for a single image.
Args:
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
proto (torch.Tensor): Prototype masks of shape (32, H, W).
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
Returns:
(torch.Tensor): The calculated mask loss for a single image.
Notes:
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
predicted masks from the prototype masks and predicted mask coefficients.
"""
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
def calculate_segmentation_loss(
self,
fg_mask: torch.Tensor,
masks: torch.Tensor,
target_gt_idx: torch.Tensor,
target_bboxes: torch.Tensor,
batch_idx: torch.Tensor,
proto: torch.Tensor,
pred_masks: torch.Tensor,
imgsz: torch.Tensor,
overlap: bool,
) -> torch.Tensor:
"""
Calculate the loss for instance segmentation.
Args:
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
overlap (bool): Whether the masks in `masks` tensor overlap.
Returns:
(torch.Tensor): The calculated loss for instance segmentation.
Notes:
The batch loss can be computed for improved speed at higher memory usage.
For example, pred_mask can be computed as follows:
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
"""
_, _, mask_h, mask_w = proto.shape
loss = 0
# Normalize to 0-1
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
# Areas of target bboxes
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
# Normalize to mask size
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
if fg_mask_i.any():
mask_idx = target_gt_idx_i[fg_mask_i]
if overlap:
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
gt_mask = gt_mask.float()
else:
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
loss += self.single_mask_loss(
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
)
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
return loss / fg_mask.sum()
class v8PoseLoss(v8DetectionLoss):
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
def __init__(self, model): # model must be de-paralleled
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
super().__init__(model)
self.kpt_shape = model.model[-1].kpt_shape
self.bce_pose = nn.BCEWithLogitsLoss()
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0] # number of keypoints
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the total loss and detach it for pose estimation."""
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# B, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
batch_size = pred_scores.shape[0]
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes /= stride_tensor
loss[0], loss[4] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
keypoints = batch["keypoints"].to(self.device).float().clone()
keypoints[..., 0] *= imgsz[1]
keypoints[..., 1] *= imgsz[0]
loss[1], loss[2] = self.calculate_keypoints_loss(
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
)
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.pose # pose gain
loss[2] *= self.hyp.kobj # kobj gain
loss[3] *= self.hyp.cls # cls gain
loss[4] *= self.hyp.dfl # dfl gain
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
@staticmethod
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
"""Decode predicted keypoints to image coordinates."""
y = pred_kpts.clone()
y[..., :2] *= 2.0
y[..., 0] += anchor_points[:, [0]] - 0.5
y[..., 1] += anchor_points[:, [1]] - 0.5
return y
def calculate_keypoints_loss(
self,
masks: torch.Tensor,
target_gt_idx: torch.Tensor,
keypoints: torch.Tensor,
batch_idx: torch.Tensor,
stride_tensor: torch.Tensor,
target_bboxes: torch.Tensor,
pred_kpts: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Calculate the keypoints loss for the model.
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
a binary classification loss that classifies whether a keypoint is present or not.
Args:
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
Returns:
kpts_loss (torch.Tensor): The keypoints loss.
kpts_obj_loss (torch.Tensor): The keypoints object loss.
"""
batch_idx = batch_idx.flatten()
batch_size = len(masks)
# Find the maximum number of keypoints in a single image
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
# Create a tensor to hold batched keypoints
batched_keypoints = torch.zeros(
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
)
# TODO: any idea how to vectorize this?
# Fill batched_keypoints with keypoints based on batch_idx
for i in range(batch_size):
keypoints_i = keypoints[batch_idx == i]
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
selected_keypoints = batched_keypoints.gather(
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
)
# Divide coordinates by stride
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
kpts_loss = 0
kpts_obj_loss = 0
if masks.any():
gt_kpt = selected_keypoints[masks]
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
pred_kpt = pred_kpts[masks]
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
if pred_kpt.shape[-1] == 3:
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
return kpts_loss, kpts_obj_loss
class v8ClassificationLoss:
"""Criterion class for computing training losses for classification."""
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute the classification loss between predictions and true labels."""
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
return loss, loss.detach()
class v8OBBLoss(v8DetectionLoss):
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
def __init__(self, model):
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
super().__init__(model)
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
"""Preprocess targets for oriented bounding box detection."""
if targets.shape[0] == 0:
out = torch.zeros(batch_size, 0, 6, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
for j in range(batch_size):
matches = i == j
if n := matches.sum():
bboxes = targets[matches, 2:]
bboxes[..., :4].mul_(scale_tensor)
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
return out
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate and return the loss for oriented bounding box detection."""
loss = torch.zeros(3, device=self.device) # box, cls, dfl
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)
# b, grids, ..
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
dtype = pred_scores.dtype
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
except RuntimeError as e:
raise TypeError(
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
"i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
) from e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
bboxes_for_assigner = pred_bboxes.clone().detach()
# Only the first four elements need to be scaled
bboxes_for_assigner[..., :4] *= stride_tensor
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
pred_scores.detach().sigmoid(),
bboxes_for_assigner.type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
# Bbox loss
if fg_mask.sum():
target_bboxes[..., :4] /= stride_tensor
loss[0], loss[2] = self.bbox_loss(
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
else:
loss[0] += (pred_angle * 0).sum()
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.cls # cls gain
loss[2] *= self.hyp.dfl # dfl gain
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
def bbox_decode(
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
) -> torch.Tensor:
"""
Decode predicted object bounding box coordinates from anchor points and distribution.
Args:
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
Returns:
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
"""
if self.use_dfl:
b, a, c = pred_dist.shape # batch, anchors, channels
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
class E2EDetectLoss:
"""Criterion class for computing training losses for end-to-end detection."""
def __init__(self, model):
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
self.one2many = v8DetectionLoss(model, tal_topk=10)
self.one2one = v8DetectionLoss(model, tal_topk=1)
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
preds = preds[1] if isinstance(preds, tuple) else preds
one2many = preds["one2many"]
loss_one2many = self.one2many(one2many, batch)
one2one = preds["one2one"]
loss_one2one = self.one2one(one2one, batch)
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
class TVPDetectLoss:
"""Criterion class for computing training losses for text-visual prompt detection."""
def __init__(self, model):
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
self.vp_criterion = v8DetectionLoss(model)
# NOTE: store following info as it's changeable in __call__
self.ori_nc = self.vp_criterion.nc
self.ori_no = self.vp_criterion.no
self.ori_reg_max = self.vp_criterion.reg_max
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the loss for text-visual prompt detection."""
feats = preds[1] if isinstance(preds, tuple) else preds
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
return loss, loss.detach()
vp_feats = self._get_vp_features(feats)
vp_loss = self.vp_criterion(vp_feats, batch)
box_loss = vp_loss[0][1]
return box_loss, vp_loss[1]
def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
"""Extract visual-prompt features from the model output."""
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
self.vp_criterion.nc = vnc
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
self.vp_criterion.assigner.num_classes = vnc
return [
torch.cat((box, cls_vp), dim=1)
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
]
class TVPSegmentLoss(TVPDetectLoss):
"""Criterion class for computing training losses for text-visual prompt segmentation."""
def __init__(self, model):
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
super().__init__(model)
self.vp_criterion = v8SegmentationLoss(model)
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the loss for text-visual prompt segmentation."""
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
return loss, loss.detach()
vp_feats = self._get_vp_features(feats)
vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
cls_loss = vp_loss[0][2]
return cls_loss, vp_loss[1]

1592
ultralytics/utils/metrics.py Normal file

File diff suppressed because it is too large Load Diff

340
ultralytics/utils/nms.py Normal file
View File

@@ -0,0 +1,340 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import sys
import time
import torch
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import batch_probiou, box_iou
from ultralytics.utils.ops import xywh2xyxy
def non_max_suppression(
prediction,
conf_thres: float = 0.25,
iou_thres: float = 0.45,
classes=None,
agnostic: bool = False,
multi_label: bool = False,
labels=(),
max_det: int = 300,
nc: int = 0, # number of classes (optional)
max_time_img: float = 0.05,
max_nms: int = 30000,
max_wh: int = 7680,
rotated: bool = False,
end2end: bool = False,
return_idxs: bool = False,
):
"""
Perform non-maximum suppression (NMS) on prediction results.
Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
detection formats including standard boxes, rotated boxes, and masks.
Args:
prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
containing boxes, classes, and optional masks.
conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
classes (list[int], optional): List of class indices to consider. If None, all classes are considered.
agnostic (bool): Whether to perform class-agnostic NMS.
multi_label (bool): Whether each box can have multiple labels.
labels (list[list[Union[int, float, torch.Tensor]]]): A priori labels for each image.
max_det (int): Maximum number of detections to keep per image.
nc (int): Number of classes. Indices after this are considered masks.
max_time_img (float): Maximum time in seconds for processing one image.
max_nms (int): Maximum number of boxes for NMS.
max_wh (int): Maximum box width and height in pixels.
rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
end2end (bool): Whether the model is end-to-end and doesn't require NMS.
return_idxs (bool): Whether to return the indices of kept detections.
Returns:
output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
keepi (list[torch.Tensor]): Indices of kept detections if return_idxs=True.
"""
# Checks
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
prediction = prediction[0] # select only inference output
if classes is not None:
classes = torch.tensor(classes, device=prediction.device)
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
if classes is not None:
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
return output
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
nc = nc or (prediction.shape[1] - 4) # number of classes
extra = prediction.shape[1] - nc - 4 # number of extra info
mi = 4 + nc # mask start index
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
xinds = torch.arange(prediction.shape[-1], device=prediction.device).expand(bs, -1)[..., None] # to track idxs
# Settings
# min_wh = 2 # (pixels) minimum box width and height
time_limit = 2.0 + max_time_img * bs # seconds to quit after
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
if not rotated:
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
t = time.time()
output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
# Apply constraints
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
filt = xc[xi] # confidence
x = x[filt]
if return_idxs:
xk = xk[filt]
# Cat apriori labels if autolabelling
if labels and len(labels[xi]) and not rotated:
lb = labels[xi]
v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Detections matrix nx6 (xyxy, conf, cls)
box, cls, mask = x.split((4, nc, extra), 1)
if multi_label:
i, j = torch.where(cls > conf_thres)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
if return_idxs:
xk = xk[i]
else: # best class only
conf, j = cls.max(1, keepdim=True)
filt = conf.view(-1) > conf_thres
x = torch.cat((box, conf, j.float(), mask), 1)[filt]
if return_idxs:
xk = xk[filt]
# Filter by class
if classes is not None:
filt = (x[:, 5:6] == classes).any(1)
x = x[filt]
if return_idxs:
xk = xk[filt]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
if n > max_nms: # excess boxes
filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
x = x[filt]
if return_idxs:
xk = xk[filt]
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
scores = x[:, 4] # scores
if rotated:
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
i = TorchNMS.fast_nms(boxes, scores, iou_thres, iou_func=batch_probiou)
else:
boxes = x[:, :4] + c # boxes (offset by class)
# Speed strategy: torchvision for val or already loaded (faster), TorchNMS for predict (lower latency)
if "torchvision" in sys.modules:
import torchvision # scope as slow import
i = torchvision.ops.nms(boxes, scores, iou_thres)
else:
i = TorchNMS.nms(boxes, scores, iou_thres)
i = i[:max_det] # limit detections
output[xi] = x[i]
if return_idxs:
keepi[xi] = xk[i].view(-1)
if (time.time() - t) > time_limit:
LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
break # time limit exceeded
return (output, keepi) if return_idxs else output
class TorchNMS:
"""
Ultralytics custom NMS implementation optimized for YOLO.
This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
including both standard NMS and batched NMS for multi-class scenarios.
Methods:
nms: Optimized NMS with early termination that matches torchvision behavior exactly.
batched_nms: Batched NMS for class-aware suppression.
Examples:
Perform standard NMS on boxes and scores
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
>>> scores = torch.tensor([0.9, 0.8])
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
"""
@staticmethod
def fast_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
iou_threshold: float,
use_triu: bool = True,
iou_func=box_iou,
exit_early: bool = True,
) -> torch.Tensor:
"""
Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
scores (torch.Tensor): Confidence scores with shape (N,).
iou_threshold (float): IoU threshold for suppression.
use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
iou_func (callable): Function to compute IoU between boxes.
exit_early (bool): Whether to exit early if there are no boxes.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
Examples:
Apply NMS to a set of boxes
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
>>> scores = torch.tensor([0.9, 0.8])
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
"""
if boxes.numel() == 0 and exit_early:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
sorted_idx = torch.argsort(scores, descending=True)
boxes = boxes[sorted_idx]
ious = iou_func(boxes, boxes)
if use_triu:
ious = ious.triu_(diagonal=1)
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
else:
n = boxes.shape[0]
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
upper_mask = row_idx < col_idx
ious = ious * upper_mask
# Zeroing these scores ensures the additional indices would not affect the final results
scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
# NOTE: return indices with fixed length to avoid TFLite reshape error
pick = torch.topk(scores, scores.shape[0]).indices
return sorted_idx[pick]
@staticmethod
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
"""
Optimized NMS with early termination that matches torchvision behavior exactly.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
scores (torch.Tensor): Confidence scores with shape (N,).
iou_threshold (float): IoU threshold for suppression.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
Examples:
Apply NMS to a set of boxes
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
>>> scores = torch.tensor([0.9, 0.8])
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# Pre-allocate and extract coordinates once
x1, y1, x2, y2 = boxes.unbind(1)
areas = (x2 - x1) * (y2 - y1)
# Sort by scores descending
order = scores.argsort(0, descending=True)
# Pre-allocate keep list with maximum possible size
keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
keep_idx = 0
while order.numel() > 0:
i = order[0]
keep[keep_idx] = i
keep_idx += 1
if order.numel() == 1:
break
# Vectorized IoU calculation for remaining boxes
rest = order[1:]
xx1 = torch.maximum(x1[i], x1[rest])
yy1 = torch.maximum(y1[i], y1[rest])
xx2 = torch.minimum(x2[i], x2[rest])
yy2 = torch.minimum(y2[i], y2[rest])
# Fast intersection and IoU
w = (xx2 - xx1).clamp_(min=0)
h = (yy2 - yy1).clamp_(min=0)
inter = w * h
# Early exit: skip IoU calculation if no intersection
if inter.sum() == 0:
# No overlaps with current box, keep all remaining boxes
order = rest
continue
iou = inter / (areas[i] + areas[rest] - inter)
# Keep boxes with IoU <= threshold
order = rest[iou <= iou_threshold]
return keep[:keep_idx]
@staticmethod
def batched_nms(
boxes: torch.Tensor,
scores: torch.Tensor,
idxs: torch.Tensor,
iou_threshold: float,
use_fast_nms: bool = False,
) -> torch.Tensor:
"""
Batched NMS for class-aware suppression.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
scores (torch.Tensor): Confidence scores with shape (N,).
idxs (torch.Tensor): Class indices with shape (N,).
iou_threshold (float): IoU threshold for suppression.
use_fast_nms (bool): Whether to use the Fast-NMS implementation.
Returns:
(torch.Tensor): Indices of boxes to keep after NMS.
Examples:
Apply batched NMS across multiple classes
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
>>> scores = torch.tensor([0.9, 0.8])
>>> idxs = torch.tensor([0, 1])
>>> keep = TorchNMS.batched_nms(boxes, scores, idxs, 0.5)
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# Strategy: offset boxes by class index to prevent cross-class suppression
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
return (
TorchNMS.fast_nms(boxes_for_nms, scores, iou_threshold)
if use_fast_nms
else TorchNMS.nms(boxes_for_nms, scores, iou_threshold)
)

722
ultralytics/utils/ops.py Normal file
View File

@@ -0,0 +1,722 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import contextlib
import math
import re
import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.utils import NOT_MACOS14
class Profile(contextlib.ContextDecorator):
"""
Ultralytics Profile class for timing code execution.
Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
measurements with CUDA synchronization support for GPU operations.
Attributes:
t (float): Accumulated time in seconds.
device (torch.device): Device used for model inference.
cuda (bool): Whether CUDA is being used for timing synchronization.
Examples:
Use as a context manager to time code execution
>>> with Profile(device=device) as dt:
... pass # slow operation here
>>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
Use as a decorator to time function execution
>>> @Profile()
... def slow_function():
... time.sleep(0.1)
"""
def __init__(self, t: float = 0.0, device: torch.device | None = None):
"""
Initialize the Profile class.
Args:
t (float): Initial accumulated time in seconds.
device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
"""
self.t = t
self.device = device
self.cuda = bool(device and str(device).startswith("cuda"))
def __enter__(self):
"""Start timing."""
self.start = self.time()
return self
def __exit__(self, type, value, traceback): # noqa
"""Stop timing."""
self.dt = self.time() - self.start # delta-time
self.t += self.dt # accumulate dt
def __str__(self):
"""Return a human-readable string representing the accumulated elapsed time."""
return f"Elapsed time is {self.t} s"
def time(self):
"""Get current time with CUDA synchronization if applicable."""
if self.cuda:
torch.cuda.synchronize(self.device)
return time.perf_counter()
def segment2box(segment, width: int = 640, height: int = 640):
"""
Convert segment coordinates to bounding box coordinates.
Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
Applies inside-image constraint and clips coordinates when necessary.
Args:
segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
width (int): Width of the image in pixels.
height (int): Height of the image in pixels.
Returns:
(np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
"""
x, y = segment.T # segment xy
# Clip coordinates if 3 out of 4 sides are outside the image
if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
x = x.clip(0, width)
y = y.clip(0, height)
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
x = x[inside]
y = y[inside]
return (
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
if any(x)
else np.zeros(4, dtype=segment.dtype)
) # xyxy
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
"""
Rescale bounding boxes from one image shape to another.
Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
Supports both xyxy and xywh box formats.
Args:
img1_shape (tuple): Shape of the source image (height, width).
boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
img0_shape (tuple): Shape of the target image (height, width).
ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
xywh (bool): Whether box format is xywh (True) or xyxy (False).
Returns:
(torch.Tensor): Rescaled bounding boxes in the same format as input.
"""
if ratio_pad is None: # calculate from img0_shape
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
else:
gain = ratio_pad[0][0]
pad_x, pad_y = ratio_pad[1]
if padding:
boxes[..., 0] -= pad_x # x padding
boxes[..., 1] -= pad_y # y padding
if not xywh:
boxes[..., 2] -= pad_x # x padding
boxes[..., 3] -= pad_y # y padding
boxes[..., :4] /= gain
return boxes if xywh else clip_boxes(boxes, img0_shape)
def make_divisible(x: int, divisor):
"""
Return the nearest number that is divisible by the given divisor.
Args:
x (int): The number to make divisible.
divisor (int | torch.Tensor): The divisor.
Returns:
(int): The nearest number divisible by the divisor.
"""
if isinstance(divisor, torch.Tensor):
divisor = int(divisor.max()) # to int
return math.ceil(x / divisor) * divisor
def clip_boxes(boxes, shape):
"""
Clip bounding boxes to image boundaries.
Args:
boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
shape (tuple): Image shape as HWC or HW (supports both).
Returns:
(torch.Tensor | np.ndarray): Clipped bounding boxes.
"""
h, w = shape[:2] # supports both HWC or HW shapes
if isinstance(boxes, torch.Tensor): # faster individually
if NOT_MACOS14:
boxes[..., 0].clamp_(0, w) # x1
boxes[..., 1].clamp_(0, h) # y1
boxes[..., 2].clamp_(0, w) # x2
boxes[..., 3].clamp_(0, h) # y2
else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
boxes[..., 0] = boxes[..., 0].clamp(0, w)
boxes[..., 1] = boxes[..., 1].clamp(0, h)
boxes[..., 2] = boxes[..., 2].clamp(0, w)
boxes[..., 3] = boxes[..., 3].clamp(0, h)
else: # np.array (faster grouped)
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, w) # x1, x2
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, h) # y1, y2
return boxes
def clip_coords(coords, shape):
"""
Clip line coordinates to image boundaries.
Args:
coords (torch.Tensor | np.ndarray): Line coordinates to clip.
shape (tuple): Image shape as HWC or HW (supports both).
Returns:
(torch.Tensor | np.ndarray): Clipped coordinates.
"""
h, w = shape[:2] # supports both HWC or HW shapes
if isinstance(coords, torch.Tensor):
if NOT_MACOS14:
coords[..., 0].clamp_(0, w) # x
coords[..., 1].clamp_(0, h) # y
else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
coords[..., 0] = coords[..., 0].clamp(0, w)
coords[..., 1] = coords[..., 1].clamp(0, h)
else: # np.array
coords[..., 0] = coords[..., 0].clip(0, w) # x
coords[..., 1] = coords[..., 1].clip(0, h) # y
return coords
def scale_image(masks, im0_shape, ratio_pad=None):
"""
Rescale masks to original image size.
Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
that was applied during preprocessing.
Args:
masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
im0_shape (tuple): Original image shape as HWC or HW (supports both).
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
Returns:
(np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
"""
# Rescale coordinates (xyxy) from im1_shape to im0_shape
im0_h, im0_w = im0_shape[:2] # supports both HWC or HW shapes
im1_h, im1_w, _ = masks.shape
if im1_h == im0_h and im1_w == im0_w:
return masks
if ratio_pad is None: # calculate from im0_shape
gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2 # wh padding
else:
pad = ratio_pad[1]
pad_w, pad_h = pad
top = int(round(pad_h - 0.1))
left = int(round(pad_w - 0.1))
bottom = im1_h - int(round(pad_h + 0.1))
right = im1_w - int(round(pad_w + 0.1))
if len(masks.shape) < 2:
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
masks = masks[top:bottom, left:right]
# handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
if len(masks.shape) == 2:
masks = masks[:, :, None]
return masks
def xyxy2xywh(x):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
y[..., 0] = (x1 + x2) / 2 # x center
y[..., 1] = (y1 + y2) / 2 # y center
y[..., 2] = x2 - x1 # width
y[..., 3] = y2 - y1 # height
return y
def xywh2xyxy(x):
"""
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
xy = x[..., :2] # centers
wh = x[..., 2:] / 2 # half width-height
y[..., :2] = xy - wh # top left xy
y[..., 2:] = xy + wh # bottom right xy
return y
def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
"""
Convert normalized bounding box coordinates to pixel coordinates.
Args:
x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
w (int): Image width in pixels.
h (int): Image height in pixels.
padw (int): Padding width in pixels.
padh (int): Padding height in pixels.
Returns:
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
"""
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
xc, yc, xw, xh = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
half_w, half_h = xw / 2, xh / 2
y[..., 0] = w * (xc - half_w) + padw # top left x
y[..., 1] = h * (yc - half_h) + padh # top left y
y[..., 2] = w * (xc + half_w) + padw # bottom right x
y[..., 3] = h * (yc + half_h) + padh # bottom right y
return y
def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
"""
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
width and height are normalized to image dimensions.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
w (int): Image width in pixels.
h (int): Image height in pixels.
clip (bool): Whether to clip boxes to image boundaries.
eps (float): Minimum value for box width and height.
Returns:
(np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
"""
if clip:
x = clip_boxes(x, (h - eps, w - eps))
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
y = empty_like(x) # faster than clone/copy
x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
y[..., 0] = ((x1 + x2) / 2) / w # x center
y[..., 1] = ((y1 + y2) / 2) / h # y center
y[..., 2] = (x2 - x1) / w # width
y[..., 3] = (y2 - y1) / h # height
return y
def xywh2ltwh(x):
"""
Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
return y
def xyxy2ltwh(x):
"""
Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] - x[..., 0] # width
y[..., 3] = x[..., 3] - x[..., 1] # height
return y
def ltwh2xywh(x):
"""
Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
Args:
x (torch.Tensor): Input bounding box coordinates.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
return y
def xyxyxyxy2xywhr(x):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
Args:
x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
Returns:
(np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
Rotation values are in radians from 0 to pi/2.
"""
is_torch = isinstance(x, torch.Tensor)
points = x.cpu().numpy() if is_torch else x
points = points.reshape(len(x), -1, 2)
rboxes = []
for pts in points:
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
# especially some objects are cut off by augmentations in dataloader.
(cx, cy), (w, h), angle = cv2.minAreaRect(pts)
rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
def xywhr2xyxyxyxy(x):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
Args:
x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
Rotation values should be in radians from 0 to pi/2.
Returns:
(np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
"""
cos, sin, cat, stack = (
(torch.cos, torch.sin, torch.cat, torch.stack)
if isinstance(x, torch.Tensor)
else (np.cos, np.sin, np.concatenate, np.stack)
)
ctr = x[..., :2]
w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
vec1 = cat(vec1, -1)
vec2 = cat(vec2, -1)
pt1 = ctr + vec1 + vec2
pt2 = ctr + vec1 - vec2
pt3 = ctr - vec1 - vec2
pt4 = ctr - vec1 + vec2
return stack([pt1, pt2, pt3, pt4], -2)
def ltwh2xyxy(x):
"""
Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
Args:
x (np.ndarray | torch.Tensor): Input bounding box coordinates.
Returns:
(np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
"""
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
y[..., 2] = x[..., 2] + x[..., 0] # width
y[..., 3] = x[..., 3] + x[..., 1] # height
return y
def segments2boxes(segments):
"""
Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
Args:
segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
Returns:
(np.ndarray): Bounding box coordinates in xywh format.
"""
boxes = []
for s in segments:
x, y = s.T # segment xy
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
return xyxy2xywh(np.array(boxes)) # cls, xywh
def resample_segments(segments, n: int = 1000):
"""
Resample segments to n points each using linear interpolation.
Args:
segments (list): List of (N, 2) arrays where N is the number of points in each segment.
n (int): Number of points to resample each segment to.
Returns:
(list): Resampled segments with n points each.
"""
for i, s in enumerate(segments):
if len(s) == n:
continue
s = np.concatenate((s, s[0:1, :]), axis=0)
x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
xp = np.arange(len(s))
x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
segments[i] = (
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
) # segment xy
return segments
def crop_mask(masks, boxes):
"""
Crop masks to bounding box regions.
Args:
masks (torch.Tensor): Masks with shape (N, H, W).
boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
Returns:
(torch.Tensor): Cropped masks.
"""
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
"""
Apply masks to bounding boxes using mask head output.
Args:
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
shape (tuple): Input image size as (height, width).
upsample (bool): Whether to upsample masks to original image size.
Returns:
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
are the height and width of the input image. The mask is applied to the bounding boxes.
"""
c, mh, mw = protos.shape # CHW
ih, iw = shape
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
width_ratio = mw / iw
height_ratio = mh / ih
downsampled_bboxes = bboxes.clone()
downsampled_bboxes[:, 0] *= width_ratio
downsampled_bboxes[:, 2] *= width_ratio
downsampled_bboxes[:, 3] *= height_ratio
downsampled_bboxes[:, 1] *= height_ratio
masks = crop_mask(masks, downsampled_bboxes) # CHW
if upsample:
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
return masks.gt_(0.0)
def process_mask_native(protos, masks_in, bboxes, shape):
"""
Apply masks to bounding boxes using mask head output with native upsampling.
Args:
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
shape (tuple): Input image size as (height, width).
Returns:
(torch.Tensor): Binary mask tensor with shape (H, W, N).
"""
c, mh, mw = protos.shape # CHW
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
masks = scale_masks(masks[None], shape)[0] # CHW
masks = crop_mask(masks, bboxes) # CHW
return masks.gt_(0.0)
def scale_masks(masks, shape, padding: bool = True):
"""
Rescale segment masks to target shape.
Args:
masks (torch.Tensor): Masks with shape (N, C, H, W).
shape (tuple): Target height and width as (height, width).
padding (bool): Whether masks are based on YOLO-style augmented images with padding.
Returns:
(torch.Tensor): Rescaled masks.
"""
mh, mw = masks.shape[2:]
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
pad_w = mw - shape[1] * gain
pad_h = mh - shape[0] * gain
if padding:
pad_w /= 2
pad_h /= 2
top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
bottom = mh - int(round(pad_h + 0.1))
right = mw - int(round(pad_w + 0.1))
return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear", align_corners=False) # NCHW masks
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
"""
Rescale segment coordinates from img1_shape to img0_shape.
Args:
img1_shape (tuple): Source image shape as HWC or HW (supports both).
coords (torch.Tensor): Coordinates to scale with shape (N, 2).
img0_shape (tuple): Image 0 shape as HWC or HW (supports both).
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
normalize (bool): Whether to normalize coordinates to range [0, 1].
padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
Returns:
(torch.Tensor): Scaled coordinates.
"""
img0_h, img0_w = img0_shape[:2] # supports both HWC or HW shapes
if ratio_pad is None: # calculate from img0_shape
img1_h, img1_w = img1_shape[:2] # supports both HWC or HW shapes
gain = min(img1_h / img0_h, img1_w / img0_w) # gain = old / new
pad = (img1_w - img0_w * gain) / 2, (img1_h - img0_h * gain) / 2 # wh padding
else:
gain = ratio_pad[0][0]
pad = ratio_pad[1]
if padding:
coords[..., 0] -= pad[0] # x padding
coords[..., 1] -= pad[1] # y padding
coords[..., 0] /= gain
coords[..., 1] /= gain
coords = clip_coords(coords, img0_shape)
if normalize:
coords[..., 0] /= img0_w # width
coords[..., 1] /= img0_h # height
return coords
def regularize_rboxes(rboxes):
"""
Regularize rotated bounding boxes to range [0, pi/2].
Args:
rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
Returns:
(torch.Tensor): Regularized rotated boxes.
"""
x, y, w, h, t = rboxes.unbind(dim=-1)
# Swap edge if t >= pi/2 while not being symmetrically opposite
swap = t % math.pi >= math.pi / 2
w_ = torch.where(swap, h, w)
h_ = torch.where(swap, w, h)
t = t % (math.pi / 2)
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
def masks2segments(masks, strategy: str = "all"):
"""
Convert masks to segments using contour detection.
Args:
masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
strategy (str): Segmentation strategy, either 'all' or 'largest'.
Returns:
(list): List of segment masks as float32 arrays.
"""
from ultralytics.data.converter import merge_multi_segment
segments = []
for x in masks.int().cpu().numpy().astype("uint8"):
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
if c:
if strategy == "all": # merge and concatenate all segments
c = (
np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
if len(c) > 1
else c[0].reshape(-1, 2)
)
elif strategy == "largest": # select largest segment
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
else:
c = np.zeros((0, 2)) # no segments found
segments.append(c.astype("float32"))
return segments
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
"""
Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
Args:
batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
Returns:
(np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
"""
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
def clean_str(s):
"""
Clean a string by replacing special characters with '_' character.
Args:
s (str): A string needing special characters replaced.
Returns:
(str): A string with special characters replaced by an underscore _.
"""
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
def empty_like(x):
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
return (
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
)

View File

@@ -0,0 +1,189 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""Monkey patches to update/extend functionality of existing functions."""
from __future__ import annotations
import time
from contextlib import contextmanager
from copy import copy
from pathlib import Path
from typing import Any
import cv2
import numpy as np
import torch
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
_imshow = cv2.imshow # copy to avoid recursion errors
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
"""
Read an image from a file with multilanguage filename support.
Args:
filename (str): Path to the file to read.
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
Returns:
(np.ndarray | None): The read image array, or None if reading fails.
Examples:
>>> img = imread("path/to/image.jpg")
>>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE)
"""
file_bytes = np.fromfile(filename, np.uint8)
if filename.endswith((".tiff", ".tif")):
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
if success:
# Handle RGB images in tif/tiff format
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
return None
else:
im = cv2.imdecode(file_bytes, flags)
return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
"""
Write an image to a file with multilanguage filename support.
Args:
filename (str): Path to the file to write.
img (np.ndarray): Image to write.
params (list[int], optional): Additional parameters for image encoding.
Returns:
(bool): True if the file was written successfully, False otherwise.
Examples:
>>> import numpy as np
>>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
>>> success = imwrite("output.jpg", img) # Write image to file
>>> print(success)
True
"""
try:
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
return True
except Exception:
return False
def imshow(winname: str, mat: np.ndarray) -> None:
"""
Display an image in the specified window with multilanguage window name support.
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
multilanguage window names by encoding them properly for OpenCV compatibility.
Args:
winname (str): Name of the window where the image will be displayed. If a window with this name already
exists, the image will be displayed in that window.
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
Examples:
>>> import numpy as np
>>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image
>>> img[:100, :100] = [255, 0, 0] # Add a blue square
>>> imshow("Example Window", img) # Display the image
"""
_imshow(winname.encode("unicode_escape").decode(), mat)
# PyTorch functions ----------------------------------------------------------------------------------------------------
_torch_save = torch.save
def torch_load(*args, **kwargs):
"""
Load a PyTorch model with updated arguments to avoid warnings.
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
Args:
*args (Any): Variable length argument list to pass to torch.load.
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
Returns:
(Any): The loaded PyTorch object.
Notes:
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
if the argument is not provided, to avoid deprecation warnings.
"""
from ultralytics.utils.torch_utils import TORCH_1_13
if TORCH_1_13 and "weights_only" not in kwargs:
kwargs["weights_only"] = False
return torch.load(*args, **kwargs)
def torch_save(*args, **kwargs):
"""
Save PyTorch objects with retry mechanism for robustness.
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
due to device flushing delays or antivirus scanning.
Args:
*args (Any): Positional arguments to pass to torch.save.
**kwargs (Any): Keyword arguments to pass to torch.save.
Examples:
>>> model = torch.nn.Linear(10, 1)
>>> torch_save(model.state_dict(), "model.pt")
"""
for i in range(4): # 3 retries
try:
return _torch_save(*args, **kwargs)
except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
if i == 3:
raise e
time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
@contextmanager
def arange_patch(args):
"""
Workaround for ONNX torch.arange incompatibility with FP16.
https://github.com/pytorch/pytorch/issues/148041.
"""
if args.dynamic and args.half and args.format == "onnx":
func = torch.arange
def arange(*args, dtype=None, **kwargs):
"""Return a 1-D tensor of size with values from the interval and common difference."""
return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
torch.arange = arange # patch
yield
torch.arange = func # unpatch
else:
yield
@contextmanager
def override_configs(args, overrides: dict[str, Any] | None = None):
"""
Context manager to temporarily override configurations in args.
Args:
args (IterableSimpleNamespace): Original configuration arguments.
overrides (dict[str, Any]): Dictionary of overrides to apply.
Yields:
(IterableSimpleNamespace): Configuration arguments with overrides applied.
"""
if overrides:
original_args = copy(args)
for key, value in overrides.items():
setattr(args, key, value)
try:
yield args
finally:
args.__dict__.update(original_args.__dict__)
else:
yield args

File diff suppressed because it is too large Load Diff

417
ultralytics/utils/tal.py Normal file
View File

@@ -0,0 +1,417 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
import torch.nn as nn
from . import LOGGER
from .metrics import bbox_iou, probiou
from .ops import xywhr2xyxyxyxy
from .torch_utils import TORCH_1_11
class TaskAlignedAssigner(nn.Module):
"""
A task-aligned assigner for object detection.
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
classification and localization information.
Attributes:
topk (int): The number of top candidates to consider.
num_classes (int): The number of object classes.
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
beta (float): The beta parameter for the localization component of the task-aligned metric.
eps (float): A small value to prevent division by zero.
"""
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
"""
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
Args:
topk (int, optional): The number of top candidates to consider.
num_classes (int, optional): The number of object classes.
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
eps (float, optional): A small value to prevent division by zero.
"""
super().__init__()
self.topk = topk
self.num_classes = num_classes
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment.
Args:
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
Returns:
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
References:
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
"""
self.bs = pd_scores.shape[0]
self.n_max_boxes = gt_bboxes.shape[1]
device = gt_bboxes.device
if self.n_max_boxes == 0:
return (
torch.full_like(pd_scores[..., 0], self.num_classes),
torch.zeros_like(pd_bboxes),
torch.zeros_like(pd_scores),
torch.zeros_like(pd_scores[..., 0]),
torch.zeros_like(pd_scores[..., 0]),
)
try:
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
except torch.cuda.OutOfMemoryError:
# Move tensors to CPU, compute, then move back to original device
LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
result = self._forward(*cpu_tensors)
return tuple(t.to(device) for t in result)
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
"""
Compute the task-aligned assignment.
Args:
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
Returns:
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
"""
mask_pos, align_metric, overlaps = self.get_pos_mask(
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
)
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
# Assigned target
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
"""
Get positive mask for each ground truth box.
Args:
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
Returns:
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
"""
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
# Get anchor_align metric, (b, max_num_obj, h*w)
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
# Get topk_metric mask, (b, max_num_obj, h*w)
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
# Merge all mask to a final mask, (b, max_num_obj, h*w)
mask_pos = mask_topk * mask_in_gts * mask_gt
return mask_pos, align_metric, overlaps
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
"""
Compute alignment metric given predicted and ground truth bounding boxes.
Args:
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
Returns:
align_metric (torch.Tensor): Alignment metric combining classification and localization.
overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
"""
na = pd_bboxes.shape[-2]
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
# Get the scores of each grid for each gt cls
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
return align_metric, overlaps
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""
Calculate IoU for horizontal bounding boxes.
Args:
gt_bboxes (torch.Tensor): Ground truth boxes.
pd_bboxes (torch.Tensor): Predicted boxes.
Returns:
(torch.Tensor): IoU values between each pair of boxes.
"""
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
def select_topk_candidates(self, metrics, topk_mask=None):
"""
Select the top-k candidates based on the given metrics.
Args:
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
the maximum number of objects, and h*w represents the total number of anchor points.
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
topk is the number of top candidates to consider. If not provided, the top-k values are automatically
computed based on the given metrics.
Returns:
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
"""
# (b, max_num_obj, topk)
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
if topk_mask is None:
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
# (b, max_num_obj, topk)
topk_idxs.masked_fill_(~topk_mask, 0)
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
for k in range(self.topk):
# Expand topk_idxs for each value of k and add 1 at the specified positions
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
# Filter invalid bboxes
count_tensor.masked_fill_(count_tensor > 1, 0)
return count_tensor.to(metrics.dtype)
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
"""
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
Args:
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
batch size and max_num_obj is the maximum number of objects.
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
anchor points, with shape (b, h*w), where h*w is the total
number of anchor points.
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
(foreground) anchor points.
Returns:
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
"""
# Assigned target labels, (b, 1)
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
# Assigned target scores
target_labels.clamp_(0)
# 10x faster than F.one_hot()
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
dtype=torch.int64,
device=target_labels.device,
) # (b, h*w, 80)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
return target_labels, target_bboxes, target_scores
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
"""
Select positive anchor centers within ground truth bounding boxes.
Args:
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
eps (float, optional): Small value for numerical stability.
Returns:
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
Note:
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
Bounding box format: [x_min, y_min, x_max, y_max].
"""
n_anchors = xy_centers.shape[0]
bs, n_boxes, _ = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
return bbox_deltas.amin(3).gt_(eps)
@staticmethod
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
"""
Select anchor boxes with highest IoU when assigned to multiple ground truths.
Args:
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
n_max_boxes (int): Maximum number of ground truth boxes.
Returns:
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
"""
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
fg_mask = mask_pos.sum(-2)
# Find each grid serve which gt(index)
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
return target_gt_idx, fg_mask, mask_pos
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
def iou_calculation(self, gt_bboxes, pd_bboxes):
"""Calculate IoU for rotated bounding boxes."""
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
@staticmethod
def select_candidates_in_gts(xy_centers, gt_bboxes):
"""
Select the positive anchor center in gt for rotated bounding boxes.
Args:
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
Returns:
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
"""
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
corners = xywhr2xyxyxyxy(gt_bboxes)
# (b, n_boxes, 1, 2)
a, b, _, d = corners.split(1, dim=-2)
ab = b - a
ad = d - a
# (b, n_boxes, h*w, 2)
ap = xy_centers - a
norm_ab = (ab * ab).sum(dim=-1)
norm_ad = (ad * ad).sum(dim=-1)
ap_dot_ab = (ap * ab).sum(dim=-1)
ap_dot_ad = (ap * ad).sum(dim=-1)
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
"""Transform distance(ltrb) to box(xywh or xyxy)."""
lt, rb = distance.chunk(2, dim)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
if xywh:
c_xy = (x1y1 + x2y2) / 2
wh = x2y2 - x1y1
return torch.cat([c_xy, wh], dim) # xywh bbox
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
"""
Decode predicted rotated bounding box coordinates from anchor points and distribution.
Args:
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
dim (int, optional): Dimension along which to split.
Returns:
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
"""
lt, rb = pred_dist.split(2, dim=dim)
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
# (bs, h*w, 1)
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
x, y = xf * cos - yf * sin, xf * sin + yf * cos
xy = torch.cat([x, y], dim=dim) + anchor_points
return torch.cat([xy, lt + rb], dim=dim)

File diff suppressed because it is too large Load Diff

440
ultralytics/utils/tqdm.py Normal file
View File

@@ -0,0 +1,440 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import os
import sys
import time
from functools import lru_cache
from typing import IO, Any
@lru_cache(maxsize=1)
def is_noninteractive_console() -> bool:
"""Check for known non-interactive console environments."""
return "GITHUB_ACTIONS" in os.environ or "RUNPOD_POD_ID" in os.environ
class TQDM:
"""
Lightweight zero-dependency progress bar for Ultralytics.
Provides clean, rich-style progress bars suitable for various environments including Weights & Biases,
console outputs, and other logging systems. Features zero external dependencies, clean single-line output,
rich-style progress bars with Unicode block characters, context manager support, iterator protocol support,
and dynamic description updates.
Attributes:
iterable (object): Iterable to wrap with progress bar.
desc (str): Prefix description for the progress bar.
total (int): Expected number of iterations.
disable (bool): Whether to disable the progress bar.
unit (str): String for units of iteration.
unit_scale (bool): Auto-scale units flag.
unit_divisor (int): Divisor for unit scaling.
leave (bool): Whether to leave the progress bar after completion.
mininterval (float): Minimum time interval between updates.
initial (int): Initial counter value.
n (int): Current iteration count.
closed (bool): Whether the progress bar is closed.
bar_format (str): Custom bar format string.
file (object): Output file stream.
Methods:
update: Update progress by n steps.
set_description: Set or update the description.
set_postfix: Set postfix for the progress bar.
close: Close the progress bar and clean up.
refresh: Refresh the progress bar display.
clear: Clear the progress bar from display.
write: Write a message without breaking the progress bar.
Examples:
Basic usage with iterator:
>>> for i in TQDM(range(100)):
... time.sleep(0.01)
With custom description:
>>> pbar = TQDM(range(100), desc="Processing")
>>> for i in pbar:
... pbar.set_description(f"Processing item {i}")
Context manager usage:
>>> with TQDM(total=100, unit="B", unit_scale=True) as pbar:
... for i in range(100):
... pbar.update(1)
Manual updates:
>>> pbar = TQDM(total=100, desc="Training")
>>> for epoch in range(100):
... # Do work
... pbar.update(1)
>>> pbar.close()
"""
# Constants
MIN_RATE_CALC_INTERVAL = 0.01 # Minimum time interval for rate calculation
RATE_SMOOTHING_FACTOR = 0.3 # Factor for exponential smoothing of rates
MAX_SMOOTHED_RATE = 1000000 # Maximum rate to apply smoothing to
NONINTERACTIVE_MIN_INTERVAL = 60.0 # Minimum interval for non-interactive environments
def __init__(
self,
iterable: Any = None,
desc: str | None = None,
total: int | None = None,
leave: bool = True,
file: IO[str] | None = None,
mininterval: float = 0.1,
disable: bool | None = None,
unit: str = "it",
unit_scale: bool = True,
unit_divisor: int = 1000,
bar_format: str | None = None, # kept for API compatibility; not used for formatting
initial: int = 0,
**kwargs,
) -> None:
"""
Initialize the TQDM progress bar with specified configuration options.
Args:
iterable (object, optional): Iterable to wrap with progress bar.
desc (str, optional): Prefix description for the progress bar.
total (int, optional): Expected number of iterations.
leave (bool, optional): Whether to leave the progress bar after completion.
file (object, optional): Output file stream for progress display.
mininterval (float, optional): Minimum time interval between updates (default 0.1s, 60s in GitHub Actions).
disable (bool, optional): Whether to disable the progress bar. Auto-detected if None.
unit (str, optional): String for units of iteration (default "it" for items).
unit_scale (bool, optional): Auto-scale units for bytes/data units.
unit_divisor (int, optional): Divisor for unit scaling (default 1000).
bar_format (str, optional): Custom bar format string.
initial (int, optional): Initial counter value.
**kwargs (Any): Additional keyword arguments for compatibility (ignored).
Examples:
>>> pbar = TQDM(range(100), desc="Processing")
>>> with TQDM(total=1000, unit="B", unit_scale=True) as pbar:
... pbar.update(1024) # Updates by 1KB
"""
# Disable if not verbose
if disable is None:
try:
from ultralytics.utils import LOGGER, VERBOSE
disable = not VERBOSE or LOGGER.getEffectiveLevel() > 20
except ImportError:
disable = False
self.iterable = iterable
self.desc = desc or ""
self.total = total or (len(iterable) if hasattr(iterable, "__len__") else None) or None # prevent total=0
self.disable = disable
self.unit = unit
self.unit_scale = unit_scale
self.unit_divisor = unit_divisor
self.leave = leave
self.noninteractive = is_noninteractive_console()
self.mininterval = max(mininterval, self.NONINTERACTIVE_MIN_INTERVAL) if self.noninteractive else mininterval
self.initial = initial
# Kept for API compatibility (unused for f-string formatting)
self.bar_format = bar_format
self.file = file or sys.stdout
# Internal state
self.n = self.initial
self.last_print_n = self.initial
self.last_print_t = time.time()
self.start_t = time.time()
self.last_rate = 0.0
self.closed = False
self.is_bytes = unit_scale and unit in ("B", "bytes")
self.scales = (
[(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]
if self.is_bytes
else [(1e9, f"G{self.unit}/s"), (1e6, f"M{self.unit}/s"), (1e3, f"K{self.unit}/s")]
)
if not self.disable and self.total and not self.noninteractive:
self._display()
def _format_rate(self, rate: float) -> str:
"""Format rate with units."""
if rate <= 0:
return ""
fallback = f"{rate:.1f}B/s" if self.is_bytes else f"{rate:.1f}{self.unit}/s"
return next((f"{rate / t:.1f}{u}" for t, u in self.scales if rate >= t), fallback)
def _format_num(self, num: int | float) -> str:
"""Format number with optional unit scaling."""
if not self.unit_scale or not self.is_bytes:
return str(num)
for unit in ("", "K", "M", "G", "T"):
if abs(num) < self.unit_divisor:
return f"{num:3.1f}{unit}B" if unit else f"{num:.0f}B"
num /= self.unit_divisor
return f"{num:.1f}PB"
def _format_time(self, seconds: float) -> str:
"""Format time duration."""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
return f"{int(seconds // 60)}:{seconds % 60:02.0f}"
else:
h, m = int(seconds // 3600), int((seconds % 3600) // 60)
return f"{h}:{m:02d}:{seconds % 60:02.0f}"
def _generate_bar(self, width: int = 12) -> str:
"""Generate progress bar."""
if self.total is None:
return "" * width if self.closed else "" * width
frac = min(1.0, self.n / self.total)
filled = int(frac * width)
bar = "" * filled + "" * (width - filled)
if filled < width and frac * width - filled > 0.5:
bar = f"{bar[:filled]}{bar[filled + 1 :]}"
return bar
def _should_update(self, dt: float, dn: int) -> bool:
"""Check if display should update."""
if self.noninteractive:
return False
return (self.total is not None and self.n >= self.total) or (dt >= self.mininterval)
def _display(self, final: bool = False) -> None:
"""Display progress bar."""
if self.disable or (self.closed and not final):
return
current_time = time.time()
dt = current_time - self.last_print_t
dn = self.n - self.last_print_n
if not final and not self._should_update(dt, dn):
return
# Calculate rate (avoid crazy numbers)
if dt > self.MIN_RATE_CALC_INTERVAL:
rate = dn / dt if dt else 0.0
# Smooth rate for reasonable values, use raw rate for very high values
if rate < self.MAX_SMOOTHED_RATE:
self.last_rate = self.RATE_SMOOTHING_FACTOR * rate + (1 - self.RATE_SMOOTHING_FACTOR) * self.last_rate
rate = self.last_rate
else:
rate = self.last_rate
# At completion, use overall rate
if self.total and self.n >= self.total:
overall_elapsed = current_time - self.start_t
if overall_elapsed > 0:
rate = self.n / overall_elapsed
# Update counters
self.last_print_n = self.n
self.last_print_t = current_time
elapsed = current_time - self.start_t
# Remaining time
remaining_str = ""
if self.total and 0 < self.n < self.total and elapsed > 0:
est_rate = rate or (self.n / elapsed)
remaining_str = f"<{self._format_time((self.total - self.n) / est_rate)}"
# Numbers and percent
if self.total:
percent = (self.n / self.total) * 100
n_str = self._format_num(self.n)
t_str = self._format_num(self.total)
if self.is_bytes:
# Collapse suffix only when identical (e.g. "5.4/5.4MB")
if n_str[-2] == t_str[-2]:
n_str = n_str.rstrip("KMGTPB") # Remove unit suffix from current if different than total
else:
percent = 0.0
n_str, t_str = self._format_num(self.n), "?"
elapsed_str = self._format_time(elapsed)
rate_str = self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else "")
bar = self._generate_bar()
# Compose progress line via f-strings (two shapes: with/without total)
if self.total:
if self.is_bytes and self.n >= self.total:
# Completed bytes: show only final size
progress_str = f"{self.desc}: {percent:.0f}% {bar} {t_str} {rate_str} {elapsed_str}"
else:
progress_str = (
f"{self.desc}: {percent:.0f}% {bar} {n_str}/{t_str} {rate_str} {elapsed_str}{remaining_str}"
)
else:
progress_str = f"{self.desc}: {bar} {n_str} {rate_str} {elapsed_str}"
# Write to output
try:
if self.noninteractive:
# In non-interactive environments, avoid carriage return which creates empty lines
self.file.write(progress_str)
else:
# In interactive terminals, use carriage return and clear line for updating display
self.file.write(f"\r\033[K{progress_str}")
self.file.flush()
except Exception:
pass
def update(self, n: int = 1) -> None:
"""Update progress by n steps."""
if not self.disable and not self.closed:
self.n += n
self._display()
def set_description(self, desc: str | None) -> None:
"""Set description."""
self.desc = desc or ""
if not self.disable:
self._display()
def set_postfix(self, **kwargs: Any) -> None:
"""Set postfix (appends to description)."""
if kwargs:
postfix = ", ".join(f"{k}={v}" for k, v in kwargs.items())
base_desc = self.desc.split(" | ")[0] if " | " in self.desc else self.desc
self.set_description(f"{base_desc} | {postfix}")
def close(self) -> None:
"""Close progress bar."""
if self.closed:
return
self.closed = True
if not self.disable:
# Final display
if self.total and self.n >= self.total:
self.n = self.total
self._display(final=True)
# Cleanup
if self.leave:
self.file.write("\n")
else:
self.file.write("\r\033[K")
try:
self.file.flush()
except Exception:
pass
def __enter__(self) -> TQDM:
"""Enter context manager."""
return self
def __exit__(self, *args: Any) -> None:
"""Exit context manager and close progress bar."""
self.close()
def __iter__(self) -> Any:
"""Iterate over the wrapped iterable with progress updates."""
if self.iterable is None:
raise TypeError("'NoneType' object is not iterable")
try:
for item in self.iterable:
yield item
self.update(1)
finally:
self.close()
def __del__(self) -> None:
"""Destructor to ensure cleanup."""
try:
self.close()
except Exception:
pass
def refresh(self) -> None:
"""Refresh display."""
if not self.disable:
self._display()
def clear(self) -> None:
"""Clear progress bar."""
if not self.disable:
try:
self.file.write("\r\033[K")
self.file.flush()
except Exception:
pass
@staticmethod
def write(s: str, file: IO[str] | None = None, end: str = "\n") -> None:
"""Static method to write without breaking progress bar."""
file = file or sys.stdout
try:
file.write(s + end)
file.flush()
except Exception:
pass
if __name__ == "__main__":
import time
print("1. Basic progress bar with known total:")
for i in TQDM(range(3), desc="Known total"):
time.sleep(0.05)
print("\n2. Manual updates with known total:")
pbar = TQDM(total=300, desc="Manual updates", unit="files")
for i in range(300):
time.sleep(0.03)
pbar.update(1)
if i % 10 == 9:
pbar.set_description(f"Processing batch {i // 10 + 1}")
pbar.close()
print("\n3. Progress bar with unknown total:")
pbar = TQDM(desc="Unknown total", unit="items")
for i in range(25):
time.sleep(0.08)
pbar.update(1)
if i % 5 == 4:
pbar.set_postfix(processed=i + 1, status="OK")
pbar.close()
print("\n4. Context manager with unknown total:")
with TQDM(desc="Processing stream", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
for i in range(30):
time.sleep(0.1)
pbar.update(1024 * 1024 * i) # Simulate processing MB of data
print("\n5. Iterator with unknown length:")
def data_stream():
"""Simulate a data stream of unknown length."""
import random
for i in range(random.randint(10, 20)):
yield f"data_chunk_{i}"
for chunk in TQDM(data_stream(), desc="Stream processing", unit="chunks"):
time.sleep(0.1)
print("\n6. File processing simulation (unknown size):")
def process_files():
"""Simulate processing files of unknown count."""
return [f"file_{i}.txt" for i in range(18)]
pbar = TQDM(desc="Scanning files", unit="files")
files = process_files()
for i, filename in enumerate(files):
time.sleep(0.06)
pbar.update(1)
pbar.set_description(f"Processing {filename}")
pbar.close()

118
ultralytics/utils/triton.py Normal file
View File

@@ -0,0 +1,118 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from urllib.parse import urlsplit
import numpy as np
class TritonRemoteModel:
"""
Client for interacting with a remote Triton Inference Server model.
This class provides a convenient interface for sending inference requests to a Triton Inference Server
and processing the responses. Supports both HTTP and gRPC communication protocols.
Attributes:
endpoint (str): The name of the model on the Triton server.
url (str): The URL of the Triton server.
triton_client: The Triton client (either HTTP or gRPC).
InferInput: The input class for the Triton client.
InferRequestedOutput: The output request class for the Triton client.
input_formats (list[str]): The data types of the model inputs.
np_input_formats (list[type]): The numpy data types of the model inputs.
input_names (list[str]): The names of the model inputs.
output_names (list[str]): The names of the model outputs.
metadata: The metadata associated with the model.
Methods:
__call__: Call the model with the given inputs and return the outputs.
Examples:
Initialize a Triton client with HTTP
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
Make inference with numpy arrays
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
"""
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
"""
Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
Arguments may be provided individually or parsed from a collective 'url' argument of the form
<scheme>://<netloc>/<endpoint>/<task_name>
Args:
url (str): The URL of the Triton server.
endpoint (str, optional): The name of the model on the Triton server.
scheme (str, optional): The communication scheme ('http' or 'grpc').
Examples:
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
>>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
"""
if not endpoint and not scheme: # Parse all args from URL string
splits = urlsplit(url)
endpoint = splits.path.strip("/").split("/", 1)[0]
scheme = splits.scheme
url = splits.netloc
self.endpoint = endpoint
self.url = url
# Choose the Triton client based on the communication scheme
if scheme == "http":
import tritonclient.http as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint)
else:
import tritonclient.grpc as client # noqa
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
# Define model attributes
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
self.InferRequestedOutput = client.InferRequestedOutput
self.InferInput = client.InferInput
self.input_formats = [x["data_type"] for x in config["input"]]
self.np_input_formats = [type_map[x] for x in self.input_formats]
self.input_names = [x["name"] for x in config["input"]]
self.output_names = [x["name"] for x in config["output"]]
self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]:
"""
Call the model with the given inputs and return inference results.
Args:
*inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
for the corresponding model input.
Returns:
(list[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
corresponds to one of the model's output tensors.
Examples:
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
"""
infer_inputs = []
input_format = inputs[0].dtype
for i, x in enumerate(inputs):
if x.dtype != self.np_input_formats[i]:
x = x.astype(self.np_input_formats[i])
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
infer_input.set_data_from_numpy(x)
infer_inputs.append(infer_input)
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]

159
ultralytics/utils/tuner.py Normal file
View File

@@ -0,0 +1,159 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
def run_ray_tune(
model,
space: dict = None,
grace_period: int = 10,
gpu_per_trial: int = None,
max_samples: int = 10,
**train_args,
):
"""
Run hyperparameter tuning using Ray Tune.
Args:
model (YOLO): Model to run the tuner on.
space (dict, optional): The hyperparameter search space. If not provided, uses default space.
grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
max_samples (int, optional): The maximum number of trials to run.
**train_args (Any): Additional arguments to pass to the `train()` method.
Returns:
(ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.
Examples:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt") # Load a YOLO11n model
Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
>>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
"""
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
if train_args is None:
train_args = {}
try:
checks.check_requirements("ray[tune]")
import ray
from ray import tune
from ray.air import RunConfig
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.tune.schedulers import ASHAScheduler
except ImportError:
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
try:
import wandb
assert hasattr(wandb, "__version__")
except (ImportError, AssertionError):
wandb = False
checks.check_version(ray.__version__, ">=2.0.0", "ray")
default_space = {
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
"lr0": tune.uniform(1e-5, 1e-1),
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
"box": tune.uniform(0.02, 0.2), # box loss gain
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
"mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability)
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
"cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability)
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
}
# Put the model in ray store
task = model.task
model_in_store = ray.put(model)
def _tune(config):
"""Train the YOLO model with the specified hyperparameters and return results."""
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
model_to_train.reset_callbacks()
config.update(train_args)
results = model_to_train.train(**config)
return results.results_dict
# Get search space
if not space and not train_args.get("resume"):
space = default_space
LOGGER.warning("Search space not provided, using default search space.")
# Get dataset
data = train_args.get("data", TASK2DATA[task])
space["data"] = data
if "data" not in train_args:
LOGGER.warning(f'Data not provided, using default "data={data}".')
# Define the trainable function with allocated resources
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
# Define the ASHA scheduler for hyperparameter search
asha_scheduler = ASHAScheduler(
time_attr="epoch",
metric=TASK2METRIC[task],
mode="max",
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
grace_period=grace_period,
reduction_factor=3,
)
# Define the callbacks for the hyperparameter search
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
# Create the Ray Tune hyperparameter search tuner
tune_dir = get_save_dir(
get_cfg(
DEFAULT_CFG,
{**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
),
name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
) # must be absolute dir
tune_dir.mkdir(parents=True, exist_ok=True)
if tune.Tuner.can_restore(tune_dir):
LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
else:
tuner = tune.Tuner(
trainable_with_resources,
param_space=space,
tune_config=tune.TuneConfig(
scheduler=asha_scheduler,
num_samples=max_samples,
trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
),
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
)
# Run the hyperparameter search
tuner.fit()
# Get the results of the hyperparameter search
results = tuner.get_results()
# Shut down Ray to clean up workers
ray.shutdown()
return results