init commit
This commit is contained in:
1450
ultralytics/utils/__init__.py
Normal file
1450
ultralytics/utils/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
BIN
ultralytics/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/autobatch.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/autobatch.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/checks.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/checks.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/cpu.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/cpu.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/dist.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/dist.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/downloads.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/downloads.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/errors.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/errors.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/events.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/events.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/files.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/files.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/git.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/git.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/instance.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/instance.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/loss.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/loss.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/metrics.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/metrics.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/nms.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/nms.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/ops.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/ops.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/patches.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/patches.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/plotting.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/plotting.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/tal.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/tal.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/torch_utils.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/torch_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/__pycache__/tqdm.cpython-310.pyc
Normal file
BIN
ultralytics/utils/__pycache__/tqdm.cpython-310.pyc
Normal file
Binary file not shown.
120
ultralytics/utils/autobatch.py
Normal file
120
ultralytics/utils/autobatch.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
|
||||
from ultralytics.utils.torch_utils import autocast, profile_ops
|
||||
|
||||
|
||||
def check_train_batch_size(
|
||||
model: torch.nn.Module,
|
||||
imgsz: int = 640,
|
||||
amp: bool = True,
|
||||
batch: int | float = -1,
|
||||
max_num_obj: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
Compute optimal YOLO training batch size using the autobatch() function.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to check batch size for.
|
||||
imgsz (int, optional): Image size used for training.
|
||||
amp (bool, optional): Use automatic mixed precision if True.
|
||||
batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
|
||||
max_num_obj (int, optional): The maximum number of objects from dataset.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size computed using the autobatch() function.
|
||||
|
||||
Notes:
|
||||
If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
|
||||
Otherwise, a default fraction of 0.6 is used.
|
||||
"""
|
||||
with autocast(enabled=amp):
|
||||
return autobatch(
|
||||
deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
|
||||
)
|
||||
|
||||
|
||||
def autobatch(
|
||||
model: torch.nn.Module,
|
||||
imgsz: int = 640,
|
||||
fraction: float = 0.60,
|
||||
batch_size: int = DEFAULT_CFG.batch,
|
||||
max_num_obj: int = 1,
|
||||
) -> int:
|
||||
"""
|
||||
Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): YOLO model to compute batch size for.
|
||||
imgsz (int, optional): The image size used as input for the YOLO model.
|
||||
fraction (float, optional): The fraction of available CUDA memory to use.
|
||||
batch_size (int, optional): The default batch size to use if an error is detected.
|
||||
max_num_obj (int, optional): The maximum number of objects from dataset.
|
||||
|
||||
Returns:
|
||||
(int): The optimal batch size.
|
||||
"""
|
||||
# Check device
|
||||
prefix = colorstr("AutoBatch: ")
|
||||
LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
|
||||
device = next(model.parameters()).device # get model device
|
||||
if device.type in {"cpu", "mps"}:
|
||||
LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
if torch.backends.cudnn.benchmark:
|
||||
LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
|
||||
return batch_size
|
||||
|
||||
# Inspect CUDA memory
|
||||
gb = 1 << 30 # bytes to GiB (1024 ** 3)
|
||||
d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0'
|
||||
properties = torch.cuda.get_device_properties(device) # device properties
|
||||
t = properties.total_memory / gb # GiB total
|
||||
r = torch.cuda.memory_reserved(device) / gb # GiB reserved
|
||||
a = torch.cuda.memory_allocated(device) / gb # GiB allocated
|
||||
f = t - (r + a) # GiB free
|
||||
LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
|
||||
|
||||
# Profile batch sizes
|
||||
batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
|
||||
try:
|
||||
img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
|
||||
results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
|
||||
|
||||
# Fit a solution
|
||||
xy = [
|
||||
[x, y[2]]
|
||||
for i, (x, y) in enumerate(zip(batch_sizes, results))
|
||||
if y # valid result
|
||||
and isinstance(y[2], (int, float)) # is numeric
|
||||
and 0 < y[2] < t # between 0 and GPU limit
|
||||
and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory
|
||||
]
|
||||
fit_x, fit_y = zip(*xy) if xy else ([], [])
|
||||
p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space
|
||||
b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)
|
||||
if None in results: # some sizes failed
|
||||
i = results.index(None) # first fail index
|
||||
if b >= batch_sizes[i]: # y intercept above failure point
|
||||
b = batch_sizes[max(i - 1, 0)] # select prior safe point
|
||||
if b < 1 or b > 1024: # b outside of safe range
|
||||
LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.")
|
||||
b = batch_size
|
||||
|
||||
fraction = (np.polyval(p, b) + r + a) / t # predicted fraction
|
||||
LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
|
||||
return b
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.")
|
||||
return batch_size
|
||||
finally:
|
||||
torch.cuda.empty_cache()
|
||||
206
ultralytics/utils/autodevice.py
Normal file
206
ultralytics/utils/autodevice.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
|
||||
class GPUInfo:
|
||||
"""
|
||||
Manages NVIDIA GPU information via pynvml with robust error handling.
|
||||
|
||||
Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle
|
||||
GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml
|
||||
library by logging warnings and disabling related features, preventing application crashes.
|
||||
|
||||
Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
|
||||
selection. Manages NVML initialization and shutdown internally.
|
||||
|
||||
Attributes:
|
||||
pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
|
||||
nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,
|
||||
False otherwise.
|
||||
gpu_stats (list[dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on
|
||||
initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),
|
||||
'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),
|
||||
'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.
|
||||
|
||||
Methods:
|
||||
refresh_stats: Refresh the internal gpu_stats list by querying NVML.
|
||||
print_status: Print GPU status in a compact table format using current stats.
|
||||
select_idle_gpu: Select the most idle GPUs based on utilization and free memory.
|
||||
shutdown: Shut down NVML if it was initialized.
|
||||
|
||||
Examples:
|
||||
Initialize GPUInfo and print status
|
||||
>>> gpu_info = GPUInfo()
|
||||
>>> gpu_info.print_status()
|
||||
|
||||
Select idle GPUs with minimum memory requirements
|
||||
>>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)
|
||||
>>> print(f"Selected GPU indices: {selected}")
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize GPUInfo, attempting to import and initialize pynvml."""
|
||||
self.pynvml: Any | None = None
|
||||
self.nvml_available: bool = False
|
||||
self.gpu_stats: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
check_requirements("nvidia-ml-py>=12.0.0")
|
||||
self.pynvml = __import__("pynvml")
|
||||
self.pynvml.nvmlInit()
|
||||
self.nvml_available = True
|
||||
self.refresh_stats()
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
|
||||
|
||||
def __del__(self):
|
||||
"""Ensure NVML is shut down when the object is garbage collected."""
|
||||
self.shutdown()
|
||||
|
||||
def shutdown(self):
|
||||
"""Shut down NVML if it was initialized."""
|
||||
if self.nvml_available and self.pynvml:
|
||||
try:
|
||||
self.pynvml.nvmlShutdown()
|
||||
except Exception:
|
||||
pass
|
||||
self.nvml_available = False
|
||||
|
||||
def refresh_stats(self):
|
||||
"""Refresh the internal gpu_stats list by querying NVML."""
|
||||
self.gpu_stats = []
|
||||
if not self.nvml_available or not self.pynvml:
|
||||
return
|
||||
|
||||
try:
|
||||
device_count = self.pynvml.nvmlDeviceGetCount()
|
||||
self.gpu_stats.extend(self._get_device_stats(i) for i in range(device_count))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Error during device query: {e}")
|
||||
self.gpu_stats = []
|
||||
|
||||
def _get_device_stats(self, index: int) -> dict[str, Any]:
|
||||
"""Get stats for a single GPU device."""
|
||||
handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
|
||||
memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||
|
||||
def safe_get(func, *args, default=-1, divisor=1):
|
||||
try:
|
||||
val = func(*args)
|
||||
return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1)
|
||||
|
||||
return {
|
||||
"index": index,
|
||||
"name": self.pynvml.nvmlDeviceGetName(handle),
|
||||
"utilization": util.gpu if util else -1,
|
||||
"memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB
|
||||
"memory_total": memory.total >> 20 if memory else -1,
|
||||
"memory_free": memory.free >> 20 if memory else -1,
|
||||
"temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
|
||||
"power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W
|
||||
"power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
|
||||
}
|
||||
|
||||
def print_status(self):
|
||||
"""Print GPU status in a compact table format using current stats."""
|
||||
self.refresh_stats()
|
||||
if not self.gpu_stats:
|
||||
LOGGER.warning("No GPU stats available.")
|
||||
return
|
||||
|
||||
stats = self.gpu_stats
|
||||
name_len = max(len(gpu.get("name", "N/A")) for gpu in stats)
|
||||
hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}"
|
||||
LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}")
|
||||
|
||||
for gpu in stats:
|
||||
u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A "
|
||||
m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A "
|
||||
t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A "
|
||||
p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A "
|
||||
|
||||
LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}")
|
||||
|
||||
LOGGER.info(f"{'-' * len(hdr)}\n")
|
||||
|
||||
def select_idle_gpu(
|
||||
self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0
|
||||
) -> list[int]:
|
||||
"""
|
||||
Select the most idle GPUs based on utilization and free memory.
|
||||
|
||||
Args:
|
||||
count (int): The number of idle GPUs to select.
|
||||
min_memory_fraction (float): Minimum free memory required as a fraction of total memory.
|
||||
min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.
|
||||
|
||||
Returns:
|
||||
(list[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).
|
||||
|
||||
Notes:
|
||||
Returns fewer than 'count' if not enough qualify or exist.
|
||||
Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
|
||||
"""
|
||||
assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}"
|
||||
assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}"
|
||||
LOGGER.info(
|
||||
f"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%..."
|
||||
)
|
||||
|
||||
if count <= 0:
|
||||
return []
|
||||
|
||||
self.refresh_stats()
|
||||
if not self.gpu_stats:
|
||||
LOGGER.warning("NVML stats unavailable.")
|
||||
return []
|
||||
|
||||
# Filter and sort eligible GPUs
|
||||
eligible_gpus = [
|
||||
gpu
|
||||
for gpu in self.gpu_stats
|
||||
if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction
|
||||
and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100
|
||||
]
|
||||
eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
|
||||
|
||||
# Select top 'count' indices
|
||||
selected = [gpu["index"] for gpu in eligible_gpus[:count]]
|
||||
|
||||
if selected:
|
||||
LOGGER.info(f"Selected idle CUDA devices {selected}")
|
||||
else:
|
||||
LOGGER.warning(
|
||||
f"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%)."
|
||||
)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
required_free_mem_fraction = 0.2 # Require 20% free VRAM
|
||||
required_free_util_fraction = 0.2 # Require 20% free utilization
|
||||
num_gpus_to_select = 1
|
||||
|
||||
gpu_info = GPUInfo()
|
||||
gpu_info.print_status()
|
||||
|
||||
if selected := gpu_info.select_idle_gpu(
|
||||
count=num_gpus_to_select,
|
||||
min_memory_fraction=required_free_mem_fraction,
|
||||
min_util_fraction=required_free_util_fraction,
|
||||
):
|
||||
print(f"\n==> Using selected GPU indices: {selected}")
|
||||
devices = [f"cuda:{idx}" for idx in selected]
|
||||
print(f" Target devices: {devices}")
|
||||
728
ultralytics/utils/benchmarks.py
Normal file
728
ultralytics/utils/benchmarks.py
Normal file
@@ -0,0 +1,728 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Benchmark a YOLO model formats for speed and accuracy.
|
||||
|
||||
Usage:
|
||||
from ultralytics.utils.benchmarks import ProfileModels, benchmark
|
||||
ProfileModels(['yolo11n.yaml', 'yolov8s.yaml']).run()
|
||||
benchmark(model='yolo11n.pt', imgsz=160)
|
||||
|
||||
Format | `format=argument` | Model
|
||||
--- | --- | ---
|
||||
PyTorch | - | yolo11n.pt
|
||||
TorchScript | `torchscript` | yolo11n.torchscript
|
||||
ONNX | `onnx` | yolo11n.onnx
|
||||
OpenVINO | `openvino` | yolo11n_openvino_model/
|
||||
TensorRT | `engine` | yolo11n.engine
|
||||
CoreML | `coreml` | yolo11n.mlpackage
|
||||
TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
|
||||
TensorFlow GraphDef | `pb` | yolo11n.pb
|
||||
TensorFlow Lite | `tflite` | yolo11n.tflite
|
||||
TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
|
||||
TensorFlow.js | `tfjs` | yolo11n_web_model/
|
||||
PaddlePaddle | `paddle` | yolo11n_paddle_model/
|
||||
MNN | `mnn` | yolo11n.mnn
|
||||
NCNN | `ncnn` | yolo11n_ncnn_model/
|
||||
IMX | `imx` | yolo11n_imx_model/
|
||||
RKNN | `rknn` | yolo11n_rknn_model/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
|
||||
from ultralytics import YOLO, YOLOWorld
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC
|
||||
from ultralytics.engine.exporter import export_formats
|
||||
from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML
|
||||
from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip
|
||||
from ultralytics.utils.downloads import safe_download
|
||||
from ultralytics.utils.files import file_size
|
||||
from ultralytics.utils.torch_utils import get_cpu_info, select_device
|
||||
|
||||
|
||||
def benchmark(
|
||||
model=WEIGHTS_DIR / "yolo11n.pt",
|
||||
data=None,
|
||||
imgsz=160,
|
||||
half=False,
|
||||
int8=False,
|
||||
device="cpu",
|
||||
verbose=False,
|
||||
eps=1e-3,
|
||||
format="",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Benchmark a YOLO model across different formats for speed and accuracy.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the model file or directory.
|
||||
data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.
|
||||
imgsz (int): Image size for the benchmark.
|
||||
half (bool): Use half-precision for the model if True.
|
||||
int8 (bool): Use int8-precision for the model if True.
|
||||
device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.
|
||||
verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
|
||||
eps (float): Epsilon value for divide by zero prevention.
|
||||
format (str): Export format for benchmarking. If not supplied all formats are benchmarked.
|
||||
**kwargs (Any): Additional keyword arguments for exporter.
|
||||
|
||||
Returns:
|
||||
(polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric,
|
||||
and inference time.
|
||||
|
||||
Examples:
|
||||
Benchmark a YOLO model with default settings:
|
||||
>>> from ultralytics.utils.benchmarks import benchmark
|
||||
>>> benchmark(model="yolo11n.pt", imgsz=640)
|
||||
"""
|
||||
imgsz = check_imgsz(imgsz)
|
||||
assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
|
||||
|
||||
import polars as pl # scope for faster 'import ultralytics'
|
||||
|
||||
pl.Config.set_tbl_cols(-1) # Show all columns
|
||||
pl.Config.set_tbl_rows(-1) # Show all rows
|
||||
pl.Config.set_tbl_width_chars(-1) # No width limit
|
||||
pl.Config.set_tbl_hide_column_data_types(True) # Hide data types
|
||||
pl.Config.set_tbl_hide_dataframe_shape(True) # Hide shape info
|
||||
pl.Config.set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED")
|
||||
|
||||
device = select_device(device, verbose=False)
|
||||
if isinstance(model, (str, Path)):
|
||||
model = YOLO(model)
|
||||
is_end2end = getattr(model.model.model[-1], "end2end", False)
|
||||
data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
|
||||
key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
|
||||
|
||||
y = []
|
||||
t0 = time.time()
|
||||
|
||||
format_arg = format.lower()
|
||||
if format_arg:
|
||||
formats = frozenset(export_formats()["Argument"])
|
||||
assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
|
||||
for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
|
||||
emoji, filename = "❌", None # export defaults
|
||||
try:
|
||||
if format_arg and format_arg != format:
|
||||
continue
|
||||
|
||||
# Checks
|
||||
if format == "pb":
|
||||
assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
|
||||
elif format == "edgetpu":
|
||||
assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
|
||||
elif format in {"coreml", "tfjs"}:
|
||||
assert MACOS or (LINUX and not ARM64), (
|
||||
"CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
|
||||
)
|
||||
if format == "coreml":
|
||||
assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
|
||||
if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
|
||||
# assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
|
||||
if format == "paddle":
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
|
||||
assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
|
||||
assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
|
||||
assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
|
||||
if format == "mnn":
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
|
||||
if format == "ncnn":
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
|
||||
if format == "imx":
|
||||
assert not is_end2end
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
|
||||
assert model.task == "detect", "IMX only supported for detection task"
|
||||
assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
|
||||
if format == "rknn":
|
||||
assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
|
||||
assert not is_end2end, "End-to-end models not supported by RKNN yet"
|
||||
assert LINUX, "RKNN only supported on Linux"
|
||||
assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices"
|
||||
if "cpu" in device.type:
|
||||
assert cpu, "inference not supported on CPU"
|
||||
if "cuda" in device.type:
|
||||
assert gpu, "inference not supported on GPU"
|
||||
|
||||
# Export
|
||||
if format == "-":
|
||||
filename = model.pt_path or model.ckpt_path or model.model_name
|
||||
exported_model = model # PyTorch format
|
||||
else:
|
||||
filename = model.export(
|
||||
imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs
|
||||
)
|
||||
exported_model = YOLO(filename, task=model.task)
|
||||
assert suffix in str(filename), "export failed"
|
||||
emoji = "❎" # indicates export succeeded
|
||||
|
||||
# Predict
|
||||
assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
|
||||
assert format not in {"edgetpu", "tfjs"}, "inference not supported"
|
||||
assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
|
||||
if format == "ncnn":
|
||||
assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
|
||||
exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
|
||||
|
||||
# Validate
|
||||
results = exported_model.val(
|
||||
data=data,
|
||||
batch=1,
|
||||
imgsz=imgsz,
|
||||
plots=False,
|
||||
device=device,
|
||||
half=half,
|
||||
int8=int8,
|
||||
verbose=False,
|
||||
conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001
|
||||
)
|
||||
metric, speed = results.results_dict[key], results.speed["inference"]
|
||||
fps = round(1000 / (speed + eps), 2) # frames per second
|
||||
y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
|
||||
LOGGER.error(f"Benchmark failure for {name}: {e}")
|
||||
y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference
|
||||
|
||||
# Print results
|
||||
check_yolo(device=device) # print system info
|
||||
df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"], orient="row")
|
||||
df = df.with_row_index(" ", offset=1) # add index info
|
||||
df_display = df.with_columns(pl.all().cast(pl.String).fill_null("-"))
|
||||
|
||||
name = model.model_name
|
||||
dt = time.time() - t0
|
||||
legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
|
||||
s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df_display}\n"
|
||||
LOGGER.info(s)
|
||||
with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
|
||||
f.write(s)
|
||||
|
||||
if verbose and isinstance(verbose, float):
|
||||
metrics = df[key].to_numpy() # values to compare to floor
|
||||
floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
|
||||
assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
|
||||
|
||||
return df_display
|
||||
|
||||
|
||||
class RF100Benchmark:
|
||||
"""
|
||||
Benchmark YOLO model performance across various formats for speed and accuracy.
|
||||
|
||||
This class provides functionality to benchmark YOLO models on the RF100 dataset collection.
|
||||
|
||||
Attributes:
|
||||
ds_names (list[str]): Names of datasets used for benchmarking.
|
||||
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
|
||||
rf (Roboflow): Roboflow instance for accessing datasets.
|
||||
val_metrics (list[str]): Metrics used for validation.
|
||||
|
||||
Methods:
|
||||
set_key: Set Roboflow API key for accessing datasets.
|
||||
parse_dataset: Parse dataset links and download datasets.
|
||||
fix_yaml: Fix train and validation paths in YAML files.
|
||||
evaluate: Evaluate model performance on validation results.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats."""
|
||||
self.ds_names = []
|
||||
self.ds_cfg_list = []
|
||||
self.rf = None
|
||||
self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
|
||||
|
||||
def set_key(self, api_key: str):
|
||||
"""
|
||||
Set Roboflow API key for processing.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key.
|
||||
|
||||
Examples:
|
||||
Set the Roboflow API key for accessing datasets:
|
||||
>>> benchmark = RF100Benchmark()
|
||||
>>> benchmark.set_key("your_roboflow_api_key")
|
||||
"""
|
||||
check_requirements("roboflow")
|
||||
from roboflow import Roboflow
|
||||
|
||||
self.rf = Roboflow(api_key=api_key)
|
||||
|
||||
def parse_dataset(self, ds_link_txt: str = "datasets_links.txt"):
|
||||
"""
|
||||
Parse dataset links and download datasets.
|
||||
|
||||
Args:
|
||||
ds_link_txt (str): Path to the file containing dataset links.
|
||||
|
||||
Returns:
|
||||
ds_names (list[str]): List of dataset names.
|
||||
ds_cfg_list (list[Path]): List of paths to dataset configuration files.
|
||||
|
||||
Examples:
|
||||
>>> benchmark = RF100Benchmark()
|
||||
>>> benchmark.set_key("api_key")
|
||||
>>> benchmark.parse_dataset("datasets_links.txt")
|
||||
"""
|
||||
(shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
|
||||
os.chdir("rf-100")
|
||||
os.mkdir("ultralytics-benchmarks")
|
||||
safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")
|
||||
|
||||
with open(ds_link_txt, encoding="utf-8") as file:
|
||||
for line in file:
|
||||
try:
|
||||
_, url, workspace, project, version = re.split("/+", line.strip())
|
||||
self.ds_names.append(project)
|
||||
proj_version = f"{project}-{version}"
|
||||
if not Path(proj_version).exists():
|
||||
self.rf.workspace(workspace).project(project).version(version).download("yolov8")
|
||||
else:
|
||||
LOGGER.info("Dataset already downloaded.")
|
||||
self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return self.ds_names, self.ds_cfg_list
|
||||
|
||||
@staticmethod
|
||||
def fix_yaml(path: Path):
|
||||
"""Fix the train and validation paths in a given YAML file."""
|
||||
yaml_data = YAML.load(path)
|
||||
yaml_data["train"] = "train/images"
|
||||
yaml_data["val"] = "valid/images"
|
||||
YAML.dump(yaml_data, path)
|
||||
|
||||
def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):
|
||||
"""
|
||||
Evaluate model performance on validation results.
|
||||
|
||||
Args:
|
||||
yaml_path (str): Path to the YAML configuration file.
|
||||
val_log_file (str): Path to the validation log file.
|
||||
eval_log_file (str): Path to the evaluation log file.
|
||||
list_ind (int): Index of the current dataset in the list.
|
||||
|
||||
Returns:
|
||||
(float): The mean average precision (mAP) value for the evaluated model.
|
||||
|
||||
Examples:
|
||||
Evaluate a model on a specific dataset
|
||||
>>> benchmark = RF100Benchmark()
|
||||
>>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0)
|
||||
"""
|
||||
skip_symbols = ["🚀", "⚠️", "💡", "❌"]
|
||||
class_names = YAML.load(yaml_path)["names"]
|
||||
with open(val_log_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
eval_lines = []
|
||||
for line in lines:
|
||||
if any(symbol in line for symbol in skip_symbols):
|
||||
continue
|
||||
entries = line.split(" ")
|
||||
entries = list(filter(lambda val: val != "", entries))
|
||||
entries = [e.strip("\n") for e in entries]
|
||||
eval_lines.extend(
|
||||
{
|
||||
"class": entries[0],
|
||||
"images": entries[1],
|
||||
"targets": entries[2],
|
||||
"precision": entries[3],
|
||||
"recall": entries[4],
|
||||
"map50": entries[5],
|
||||
"map95": entries[6],
|
||||
}
|
||||
for e in entries
|
||||
if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
|
||||
)
|
||||
map_val = 0.0
|
||||
if len(eval_lines) > 1:
|
||||
LOGGER.info("Multiple dicts found")
|
||||
for lst in eval_lines:
|
||||
if lst["class"] == "all":
|
||||
map_val = lst["map50"]
|
||||
else:
|
||||
LOGGER.info("Single dict found")
|
||||
map_val = [res["map50"] for res in eval_lines][0]
|
||||
|
||||
with open(eval_log_file, "a", encoding="utf-8") as f:
|
||||
f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
|
||||
|
||||
return float(map_val)
|
||||
|
||||
|
||||
class ProfileModels:
|
||||
"""
|
||||
ProfileModels class for profiling different models on ONNX and TensorRT.
|
||||
|
||||
This class profiles the performance of different models, returning results such as model speed and FLOPs.
|
||||
|
||||
Attributes:
|
||||
paths (list[str]): Paths of the models to profile.
|
||||
num_timed_runs (int): Number of timed runs for the profiling.
|
||||
num_warmup_runs (int): Number of warmup runs before profiling.
|
||||
min_time (float): Minimum number of seconds to profile for.
|
||||
imgsz (int): Image size used in the models.
|
||||
half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
|
||||
trt (bool): Flag to indicate whether to profile using TensorRT.
|
||||
device (torch.device): Device used for profiling.
|
||||
|
||||
Methods:
|
||||
run: Profile YOLO models for speed and accuracy across various formats.
|
||||
get_files: Get all relevant model files.
|
||||
get_onnx_model_info: Extract metadata from an ONNX model.
|
||||
iterative_sigma_clipping: Apply sigma clipping to remove outliers.
|
||||
profile_tensorrt_model: Profile a TensorRT model.
|
||||
profile_onnx_model: Profile an ONNX model.
|
||||
generate_table_row: Generate a table row with model metrics.
|
||||
generate_results_dict: Generate a dictionary of profiling results.
|
||||
print_table: Print a formatted table of results.
|
||||
|
||||
Examples:
|
||||
Profile models and print results
|
||||
>>> from ultralytics.utils.benchmarks import ProfileModels
|
||||
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
|
||||
>>> profiler.run()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
paths: list[str],
|
||||
num_timed_runs: int = 100,
|
||||
num_warmup_runs: int = 10,
|
||||
min_time: float = 60,
|
||||
imgsz: int = 640,
|
||||
half: bool = True,
|
||||
trt: bool = True,
|
||||
device: torch.device | str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the ProfileModels class for profiling models.
|
||||
|
||||
Args:
|
||||
paths (list[str]): List of paths of the models to be profiled.
|
||||
num_timed_runs (int): Number of timed runs for the profiling.
|
||||
num_warmup_runs (int): Number of warmup runs before the actual profiling starts.
|
||||
min_time (float): Minimum time in seconds for profiling a model.
|
||||
imgsz (int): Size of the image used during profiling.
|
||||
half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
|
||||
trt (bool): Flag to indicate whether to profile using TensorRT.
|
||||
device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.
|
||||
|
||||
Notes:
|
||||
FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
|
||||
|
||||
Examples:
|
||||
Initialize and profile models
|
||||
>>> from ultralytics.utils.benchmarks import ProfileModels
|
||||
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
|
||||
>>> profiler.run()
|
||||
"""
|
||||
self.paths = paths
|
||||
self.num_timed_runs = num_timed_runs
|
||||
self.num_warmup_runs = num_warmup_runs
|
||||
self.min_time = min_time
|
||||
self.imgsz = imgsz
|
||||
self.half = half
|
||||
self.trt = trt # run TensorRT profiling
|
||||
self.device = device if isinstance(device, torch.device) else select_device(device)
|
||||
|
||||
def run(self):
|
||||
"""
|
||||
Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
|
||||
|
||||
Returns:
|
||||
(list[dict]): List of dictionaries containing profiling results for each model.
|
||||
|
||||
Examples:
|
||||
Profile models and print results
|
||||
>>> from ultralytics.utils.benchmarks import ProfileModels
|
||||
>>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"])
|
||||
>>> results = profiler.run()
|
||||
"""
|
||||
files = self.get_files()
|
||||
|
||||
if not files:
|
||||
LOGGER.warning("No matching *.pt or *.onnx files found.")
|
||||
return []
|
||||
|
||||
table_rows = []
|
||||
output = []
|
||||
for file in files:
|
||||
engine_file = file.with_suffix(".engine")
|
||||
if file.suffix in {".pt", ".yaml", ".yml"}:
|
||||
model = YOLO(str(file))
|
||||
model.fuse() # to report correct params and GFLOPs in model.info()
|
||||
model_info = model.info()
|
||||
if self.trt and self.device.type != "cpu" and not engine_file.is_file():
|
||||
engine_file = model.export(
|
||||
format="engine",
|
||||
half=self.half,
|
||||
imgsz=self.imgsz,
|
||||
device=self.device,
|
||||
verbose=False,
|
||||
)
|
||||
onnx_file = model.export(
|
||||
format="onnx",
|
||||
imgsz=self.imgsz,
|
||||
device=self.device,
|
||||
verbose=False,
|
||||
)
|
||||
elif file.suffix == ".onnx":
|
||||
model_info = self.get_onnx_model_info(file)
|
||||
onnx_file = file
|
||||
else:
|
||||
continue
|
||||
|
||||
t_engine = self.profile_tensorrt_model(str(engine_file))
|
||||
t_onnx = self.profile_onnx_model(str(onnx_file))
|
||||
table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
|
||||
output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
|
||||
|
||||
self.print_table(table_rows)
|
||||
return output
|
||||
|
||||
def get_files(self):
|
||||
"""
|
||||
Return a list of paths for all relevant model files given by the user.
|
||||
|
||||
Returns:
|
||||
(list[Path]): List of Path objects for the model files.
|
||||
"""
|
||||
files = []
|
||||
for path in self.paths:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
extensions = ["*.pt", "*.onnx", "*.yaml"]
|
||||
files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
|
||||
elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
|
||||
files.append(str(path))
|
||||
else:
|
||||
files.extend(glob.glob(str(path)))
|
||||
|
||||
LOGGER.info(f"Profiling: {sorted(files)}")
|
||||
return [Path(file) for file in sorted(files)]
|
||||
|
||||
@staticmethod
|
||||
def get_onnx_model_info(onnx_file: str):
|
||||
"""Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
|
||||
return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
|
||||
|
||||
@staticmethod
|
||||
def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):
|
||||
"""
|
||||
Apply iterative sigma clipping to data to remove outliers.
|
||||
|
||||
Args:
|
||||
data (np.ndarray): Input data array.
|
||||
sigma (float): Number of standard deviations to use for clipping.
|
||||
max_iters (int): Maximum number of iterations for the clipping process.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Clipped data array with outliers removed.
|
||||
"""
|
||||
data = np.array(data)
|
||||
for _ in range(max_iters):
|
||||
mean, std = np.mean(data), np.std(data)
|
||||
clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
|
||||
if len(clipped_data) == len(data):
|
||||
break
|
||||
data = clipped_data
|
||||
return data
|
||||
|
||||
def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
|
||||
"""
|
||||
Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
|
||||
|
||||
Args:
|
||||
engine_file (str): Path to the TensorRT engine file.
|
||||
eps (float): Small epsilon value to prevent division by zero.
|
||||
|
||||
Returns:
|
||||
mean_time (float): Mean inference time in milliseconds.
|
||||
std_time (float): Standard deviation of inference time in milliseconds.
|
||||
"""
|
||||
if not self.trt or not Path(engine_file).is_file():
|
||||
return 0.0, 0.0
|
||||
|
||||
# Model and input
|
||||
model = YOLO(engine_file)
|
||||
input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=engine_file):
|
||||
results = model(input_data, imgsz=self.imgsz, verbose=False)
|
||||
run_times.append(results[0].speed["inference"]) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
|
||||
"""
|
||||
Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
|
||||
|
||||
Args:
|
||||
onnx_file (str): Path to the ONNX model file.
|
||||
eps (float): Small epsilon value to prevent division by zero.
|
||||
|
||||
Returns:
|
||||
mean_time (float): Mean inference time in milliseconds.
|
||||
std_time (float): Standard deviation of inference time in milliseconds.
|
||||
"""
|
||||
check_requirements("onnxruntime")
|
||||
import onnxruntime as ort
|
||||
|
||||
# Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
sess_options.intra_op_num_threads = 8 # Limit the number of threads
|
||||
sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
|
||||
|
||||
input_tensor = sess.get_inputs()[0]
|
||||
input_type = input_tensor.type
|
||||
dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
|
||||
input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
|
||||
|
||||
# Mapping ONNX datatype to numpy datatype
|
||||
if "float16" in input_type:
|
||||
input_dtype = np.float16
|
||||
elif "float" in input_type:
|
||||
input_dtype = np.float32
|
||||
elif "double" in input_type:
|
||||
input_dtype = np.float64
|
||||
elif "int64" in input_type:
|
||||
input_dtype = np.int64
|
||||
elif "int32" in input_type:
|
||||
input_dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f"Unsupported ONNX datatype {input_type}")
|
||||
|
||||
input_data = np.random.rand(*input_shape).astype(input_dtype)
|
||||
input_name = input_tensor.name
|
||||
output_name = sess.get_outputs()[0].name
|
||||
|
||||
# Warmup runs
|
||||
elapsed = 0.0
|
||||
for _ in range(3):
|
||||
start_time = time.time()
|
||||
for _ in range(self.num_warmup_runs):
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Compute number of runs as higher of min_time or num_timed_runs
|
||||
num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
|
||||
|
||||
# Timed runs
|
||||
run_times = []
|
||||
for _ in TQDM(range(num_runs), desc=onnx_file):
|
||||
start_time = time.time()
|
||||
sess.run([output_name], {input_name: input_data})
|
||||
run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
|
||||
|
||||
run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
|
||||
return np.mean(run_times), np.std(run_times)
|
||||
|
||||
def generate_table_row(
|
||||
self,
|
||||
model_name: str,
|
||||
t_onnx: tuple[float, float],
|
||||
t_engine: tuple[float, float],
|
||||
model_info: tuple[float, float, float, float],
|
||||
):
|
||||
"""
|
||||
Generate a table row string with model performance metrics.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model.
|
||||
t_onnx (tuple): ONNX model inference time statistics (mean, std).
|
||||
t_engine (tuple): TensorRT engine inference time statistics (mean, std).
|
||||
model_info (tuple): Model information (layers, params, gradients, flops).
|
||||
|
||||
Returns:
|
||||
(str): Formatted table row string with model metrics.
|
||||
"""
|
||||
layers, params, gradients, flops = model_info
|
||||
return (
|
||||
f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
|
||||
f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def generate_results_dict(
|
||||
model_name: str,
|
||||
t_onnx: tuple[float, float],
|
||||
t_engine: tuple[float, float],
|
||||
model_info: tuple[float, float, float, float],
|
||||
):
|
||||
"""
|
||||
Generate a dictionary of profiling results.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model.
|
||||
t_onnx (tuple): ONNX model inference time statistics (mean, std).
|
||||
t_engine (tuple): TensorRT engine inference time statistics (mean, std).
|
||||
model_info (tuple): Model information (layers, params, gradients, flops).
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing profiling results.
|
||||
"""
|
||||
layers, params, gradients, flops = model_info
|
||||
return {
|
||||
"model/name": model_name,
|
||||
"model/parameters": params,
|
||||
"model/GFLOPs": round(flops, 3),
|
||||
"model/speed_ONNX(ms)": round(t_onnx[0], 3),
|
||||
"model/speed_TensorRT(ms)": round(t_engine[0], 3),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def print_table(table_rows: list[str]):
|
||||
"""
|
||||
Print a formatted table of model profiling results.
|
||||
|
||||
Args:
|
||||
table_rows (list[str]): List of formatted table row strings.
|
||||
"""
|
||||
gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
|
||||
headers = [
|
||||
"Model",
|
||||
"size<br><sup>(pixels)",
|
||||
"mAP<sup>val<br>50-95",
|
||||
f"Speed<br><sup>CPU ({get_cpu_info()}) ONNX<br>(ms)",
|
||||
f"Speed<br><sup>{gpu} TensorRT<br>(ms)",
|
||||
"params<br><sup>(M)",
|
||||
"FLOPs<br><sup>(B)",
|
||||
]
|
||||
header = "|" + "|".join(f" {h} " for h in headers) + "|"
|
||||
separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|"
|
||||
|
||||
LOGGER.info(f"\n\n{header}")
|
||||
LOGGER.info(separator)
|
||||
for row in table_rows:
|
||||
LOGGER.info(row)
|
||||
5
ultralytics/utils/callbacks/__init__.py
Normal file
5
ultralytics/utils/callbacks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
|
||||
|
||||
__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
|
||||
BIN
ultralytics/utils/callbacks/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/utils/callbacks/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/callbacks/__pycache__/base.cpython-310.pyc
Normal file
BIN
ultralytics/utils/callbacks/__pycache__/base.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/callbacks/__pycache__/hub.cpython-310.pyc
Normal file
BIN
ultralytics/utils/callbacks/__pycache__/hub.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/callbacks/__pycache__/platform.cpython-310.pyc
Normal file
BIN
ultralytics/utils/callbacks/__pycache__/platform.cpython-310.pyc
Normal file
Binary file not shown.
235
ultralytics/utils/callbacks/base.py
Normal file
235
ultralytics/utils/callbacks/base.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Base callbacks for Ultralytics training, validation, prediction, and export processes."""
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
|
||||
# Trainer callbacks ----------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Called before the pretraining routine starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Called after the pretraining routine ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Called when the training starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer):
|
||||
"""Called at the start of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_start(trainer):
|
||||
"""Called at the start of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def optimizer_step(trainer):
|
||||
"""Called when the optimizer takes a step."""
|
||||
pass
|
||||
|
||||
|
||||
def on_before_zero_grad(trainer):
|
||||
"""Called before the gradients are set to zero."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_batch_end(trainer):
|
||||
"""Called at the end of each training batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Called at the end of each training epoch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Called at the end of each fit epoch (train + val)."""
|
||||
pass
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Called when the model is saved."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Called when the training ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_params_update(trainer):
|
||||
"""Called when the model parameters are updated."""
|
||||
pass
|
||||
|
||||
|
||||
def teardown(trainer):
|
||||
"""Called during the teardown of the training process."""
|
||||
pass
|
||||
|
||||
|
||||
# Validator callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Called when the validation starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_start(validator):
|
||||
"""Called at the start of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_batch_end(validator):
|
||||
"""Called at the end of each validation batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_end(validator):
|
||||
"""Called when the validation ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Predictor callbacks --------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Called when the prediction starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_start(predictor):
|
||||
"""Called at the start of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_batch_end(predictor):
|
||||
"""Called at the end of each prediction batch."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_postprocess_end(predictor):
|
||||
"""Called after the post-processing of the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_end(predictor):
|
||||
"""Called when the prediction ends."""
|
||||
pass
|
||||
|
||||
|
||||
# Exporter callbacks ---------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Called when the model export starts."""
|
||||
pass
|
||||
|
||||
|
||||
def on_export_end(exporter):
|
||||
"""Called when the model export ends."""
|
||||
pass
|
||||
|
||||
|
||||
default_callbacks = {
|
||||
# Run in trainer
|
||||
"on_pretrain_routine_start": [on_pretrain_routine_start],
|
||||
"on_pretrain_routine_end": [on_pretrain_routine_end],
|
||||
"on_train_start": [on_train_start],
|
||||
"on_train_epoch_start": [on_train_epoch_start],
|
||||
"on_train_batch_start": [on_train_batch_start],
|
||||
"optimizer_step": [optimizer_step],
|
||||
"on_before_zero_grad": [on_before_zero_grad],
|
||||
"on_train_batch_end": [on_train_batch_end],
|
||||
"on_train_epoch_end": [on_train_epoch_end],
|
||||
"on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
|
||||
"on_model_save": [on_model_save],
|
||||
"on_train_end": [on_train_end],
|
||||
"on_params_update": [on_params_update],
|
||||
"teardown": [teardown],
|
||||
# Run in validator
|
||||
"on_val_start": [on_val_start],
|
||||
"on_val_batch_start": [on_val_batch_start],
|
||||
"on_val_batch_end": [on_val_batch_end],
|
||||
"on_val_end": [on_val_end],
|
||||
# Run in predictor
|
||||
"on_predict_start": [on_predict_start],
|
||||
"on_predict_batch_start": [on_predict_batch_start],
|
||||
"on_predict_postprocess_end": [on_predict_postprocess_end],
|
||||
"on_predict_batch_end": [on_predict_batch_end],
|
||||
"on_predict_end": [on_predict_end],
|
||||
# Run in exporter
|
||||
"on_export_start": [on_export_start],
|
||||
"on_export_end": [on_export_end],
|
||||
}
|
||||
|
||||
|
||||
def get_default_callbacks():
|
||||
"""
|
||||
Get the default callbacks for Ultralytics training, validation, prediction, and export processes.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary of default callbacks for various training events. Each key represents an event during the
|
||||
training process, and the corresponding value is a list of callback functions executed when that event
|
||||
occurs.
|
||||
|
||||
Examples:
|
||||
>>> callbacks = get_default_callbacks()
|
||||
>>> print(list(callbacks.keys())) # show all available callback events
|
||||
['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]
|
||||
"""
|
||||
return defaultdict(list, deepcopy(default_callbacks))
|
||||
|
||||
|
||||
def add_integration_callbacks(instance):
|
||||
"""
|
||||
Add integration callbacks to the instance's callbacks dictionary.
|
||||
|
||||
This function loads and adds various integration callbacks to the provided instance. The specific callbacks added
|
||||
depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive
|
||||
additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,
|
||||
and Weights & Biases.
|
||||
|
||||
Args:
|
||||
instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.
|
||||
The type of instance determines which callbacks are loaded.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.engine.trainer import BaseTrainer
|
||||
>>> trainer = BaseTrainer()
|
||||
>>> add_integration_callbacks(trainer)
|
||||
"""
|
||||
from .hub import callbacks as hub_cb
|
||||
from .platform import callbacks as platform_cb
|
||||
|
||||
# Load Ultralytics callbacks
|
||||
callbacks_list = [hub_cb, platform_cb]
|
||||
|
||||
# Load training callbacks
|
||||
if "Trainer" in instance.__class__.__name__:
|
||||
from .clearml import callbacks as clear_cb
|
||||
from .comet import callbacks as comet_cb
|
||||
from .dvc import callbacks as dvc_cb
|
||||
from .mlflow import callbacks as mlflow_cb
|
||||
from .neptune import callbacks as neptune_cb
|
||||
from .raytune import callbacks as tune_cb
|
||||
from .tensorboard import callbacks as tb_cb
|
||||
from .wb import callbacks as wb_cb
|
||||
|
||||
callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
|
||||
|
||||
# Add the callbacks to the callbacks dictionary
|
||||
for callbacks in callbacks_list:
|
||||
for k, v in callbacks.items():
|
||||
if v not in instance.callbacks[k]:
|
||||
instance.callbacks[k].append(v)
|
||||
154
ultralytics/utils/callbacks/clearml.py
Normal file
154
ultralytics/utils/callbacks/clearml.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["clearml"] is True # verify integration is enabled
|
||||
import clearml
|
||||
from clearml import Task
|
||||
|
||||
assert hasattr(clearml, "__version__") # verify package is not directory
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
clearml = None
|
||||
|
||||
|
||||
def _log_debug_samples(files, title: str = "Debug Samples") -> None:
|
||||
"""
|
||||
Log files (images) as debug samples in the ClearML task.
|
||||
|
||||
Args:
|
||||
files (list[Path]): A list of file paths in PosixPath format.
|
||||
title (str): A title that groups together images with the same values.
|
||||
"""
|
||||
import re
|
||||
|
||||
if task := Task.current_task():
|
||||
for f in files:
|
||||
if f.exists():
|
||||
it = re.search(r"_batch(\d+)", f.name)
|
||||
iteration = int(it.groups()[0]) if it else 0
|
||||
task.get_logger().report_image(
|
||||
title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
|
||||
)
|
||||
|
||||
|
||||
def _log_plot(title: str, plot_path: str) -> None:
|
||||
"""
|
||||
Log an image as a plot in the plot section of ClearML.
|
||||
|
||||
Args:
|
||||
title (str): The title of the plot.
|
||||
plot_path (str): The path to the saved image file.
|
||||
"""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
|
||||
Task.current_task().get_logger().report_matplotlib_figure(
|
||||
title=title, series="", figure=fig, report_interactive=False
|
||||
)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer) -> None:
|
||||
"""Initialize and connect ClearML task at the start of pretraining routine."""
|
||||
try:
|
||||
if task := Task.current_task():
|
||||
# WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
|
||||
# We are logging these plots and model files manually in the integration
|
||||
from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from clearml.binding.matplotlib_bind import PatchedMatplotlib
|
||||
|
||||
PatchPyTorchModelIO.update_current_task(None)
|
||||
PatchedMatplotlib.update_current_task(None)
|
||||
else:
|
||||
task = Task.init(
|
||||
project_name=trainer.args.project or "Ultralytics",
|
||||
task_name=trainer.args.name,
|
||||
tags=["Ultralytics"],
|
||||
output_uri=True,
|
||||
reuse_last_task_id=False,
|
||||
auto_connect_frameworks={"pytorch": False, "matplotlib": False},
|
||||
)
|
||||
LOGGER.warning(
|
||||
"ClearML Initialized a new task. If you want to run remotely, "
|
||||
"please add clearml-init and connect your arguments before initializing YOLO."
|
||||
)
|
||||
task.connect(vars(trainer.args), name="General")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"ClearML installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer) -> None:
|
||||
"""Log debug samples for the first epoch and report current training progress."""
|
||||
if task := Task.current_task():
|
||||
# Log debug samples for first epoch only
|
||||
if trainer.epoch == 1:
|
||||
_log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
|
||||
# Report the current training progress
|
||||
for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
|
||||
task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
|
||||
for k, v in trainer.lr.items():
|
||||
task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer) -> None:
|
||||
"""Report model information and metrics to logger at the end of an epoch."""
|
||||
if task := Task.current_task():
|
||||
# Report epoch time and validation metrics
|
||||
task.get_logger().report_scalar(
|
||||
title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
|
||||
)
|
||||
for k, v in trainer.metrics.items():
|
||||
title = k.split("/")[0]
|
||||
task.get_logger().report_scalar(title, k, v, iteration=trainer.epoch)
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for k, v in model_info_for_loggers(trainer).items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
|
||||
|
||||
def on_val_end(validator) -> None:
|
||||
"""Log validation results including labels and predictions."""
|
||||
if Task.current_task():
|
||||
# Log validation labels and predictions
|
||||
_log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer) -> None:
|
||||
"""Log final model and training results on training completion."""
|
||||
if task := Task.current_task():
|
||||
# Log final results, confusion matrix and PR plots
|
||||
files = [
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Report final metrics
|
||||
for k, v in trainer.validator.metrics.results_dict.items():
|
||||
task.get_logger().report_single_value(k, v)
|
||||
# Log the final model
|
||||
task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if clearml
|
||||
else {}
|
||||
)
|
||||
639
ultralytics/utils/callbacks/comet.py
Normal file
639
ultralytics/utils/callbacks/comet.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["comet"] is True # verify integration is enabled
|
||||
import comet_ml
|
||||
|
||||
assert hasattr(comet_ml, "__version__") # verify package is not directory
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensures certain logging functions only run for supported tasks
|
||||
COMET_SUPPORTED_TASKS = ["detect", "segment"]
|
||||
|
||||
# Names of plots created by Ultralytics that are logged to Comet
|
||||
CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
|
||||
EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
|
||||
LABEL_PLOT_NAMES = ["labels"]
|
||||
SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
|
||||
POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
|
||||
DETECTION_METRICS_PLOT_PREFIX = ["Box"]
|
||||
RESULTS_TABLE_NAME = "results.csv"
|
||||
ARGS_YAML_NAME = "args.yaml"
|
||||
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
comet_ml = None
|
||||
|
||||
|
||||
def _get_comet_mode() -> str:
|
||||
"""Return the Comet mode from environment variables, defaulting to 'online'."""
|
||||
comet_mode = os.getenv("COMET_MODE")
|
||||
if comet_mode is not None:
|
||||
LOGGER.warning(
|
||||
"The COMET_MODE environment variable is deprecated. "
|
||||
"Please use COMET_START_ONLINE to set the Comet experiment mode. "
|
||||
"To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. "
|
||||
"If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created."
|
||||
)
|
||||
return comet_mode
|
||||
|
||||
return "online"
|
||||
|
||||
|
||||
def _get_comet_model_name() -> str:
|
||||
"""Return the Comet model name from environment variable or default to 'Ultralytics'."""
|
||||
return os.getenv("COMET_MODEL_NAME", "Ultralytics")
|
||||
|
||||
|
||||
def _get_eval_batch_logging_interval() -> int:
|
||||
"""Get the evaluation batch logging interval from environment variable or use default value 1."""
|
||||
return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
|
||||
|
||||
|
||||
def _get_max_image_predictions_to_log() -> int:
|
||||
"""Get the maximum number of image predictions to log from environment variables."""
|
||||
return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
|
||||
|
||||
|
||||
def _scale_confidence_score(score: float) -> float:
|
||||
"""Scale the confidence score by a factor specified in environment variable."""
|
||||
scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
|
||||
return score * scale
|
||||
|
||||
|
||||
def _should_log_confusion_matrix() -> bool:
|
||||
"""Determine if the confusion matrix should be logged based on environment variable settings."""
|
||||
return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
|
||||
|
||||
|
||||
def _should_log_image_predictions() -> bool:
|
||||
"""Determine whether to log image predictions based on environment variable."""
|
||||
return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
|
||||
|
||||
|
||||
def _resume_or_create_experiment(args: SimpleNamespace) -> None:
|
||||
"""
|
||||
Resume CometML experiment or create a new experiment based on args.
|
||||
|
||||
Ensures that the experiment object is only created in a single process during distributed training.
|
||||
|
||||
Args:
|
||||
args (SimpleNamespace): Training arguments containing project configuration and other parameters.
|
||||
"""
|
||||
if RANK not in {-1, 0}:
|
||||
return
|
||||
|
||||
# Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood.
|
||||
# IF COMET_START_ONLINE is set by the user it will override COMET_MODE value.
|
||||
if os.getenv("COMET_START_ONLINE") is None:
|
||||
comet_mode = _get_comet_mode()
|
||||
os.environ["COMET_START_ONLINE"] = "1" if comet_mode != "offline" else "0"
|
||||
|
||||
try:
|
||||
_project_name = os.getenv("COMET_PROJECT_NAME", args.project)
|
||||
experiment = comet_ml.start(project_name=_project_name)
|
||||
experiment.log_parameters(vars(args))
|
||||
experiment.log_others(
|
||||
{
|
||||
"eval_batch_logging_interval": _get_eval_batch_logging_interval(),
|
||||
"log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
|
||||
"log_image_predictions": _should_log_image_predictions(),
|
||||
"max_image_predictions": _get_max_image_predictions_to_log(),
|
||||
}
|
||||
)
|
||||
experiment.log_other("Created from", "ultralytics")
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Comet installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def _fetch_trainer_metadata(trainer) -> dict:
|
||||
"""
|
||||
Return metadata for YOLO training including epoch and asset saving status.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.
|
||||
"""
|
||||
curr_epoch = trainer.epoch + 1
|
||||
|
||||
train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
|
||||
curr_step = curr_epoch * train_num_steps_per_epoch
|
||||
final_epoch = curr_epoch == trainer.epochs
|
||||
|
||||
save = trainer.args.save
|
||||
save_period = trainer.args.save_period
|
||||
save_interval = curr_epoch % save_period == 0
|
||||
save_assets = save and save_period > 0 and save_interval and not final_epoch
|
||||
|
||||
return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
|
||||
|
||||
|
||||
def _scale_bounding_box_to_original_image_shape(
|
||||
box, resized_image_shape, original_image_shape, ratio_pad
|
||||
) -> list[float]:
|
||||
"""
|
||||
Scale bounding box from resized image coordinates to original image coordinates.
|
||||
|
||||
YOLO resizes images during training and the label values are normalized based on this resized shape.
|
||||
This function rescales the bounding box labels to the original image shape.
|
||||
|
||||
Args:
|
||||
box (torch.Tensor): Bounding box in normalized xywh format.
|
||||
resized_image_shape (tuple): Shape of the resized image (height, width).
|
||||
original_image_shape (tuple): Shape of the original image (height, width).
|
||||
ratio_pad (tuple): Ratio and padding information for scaling.
|
||||
|
||||
Returns:
|
||||
(list[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.
|
||||
"""
|
||||
resized_image_height, resized_image_width = resized_image_shape
|
||||
|
||||
# Convert normalized xywh format predictions to xyxy in resized scale format
|
||||
box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
|
||||
# Scale box predictions from resized image scale back to original image scale
|
||||
box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
|
||||
# Convert bounding box format from xyxy to xywh for Comet logging
|
||||
box = ops.xyxy2xywh(box)
|
||||
# Adjust xy center to correspond top-left corner
|
||||
box[:2] -= box[2:] / 2
|
||||
box = box.tolist()
|
||||
|
||||
return box
|
||||
|
||||
|
||||
def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> dict | None:
|
||||
"""
|
||||
Format ground truth annotations for object detection.
|
||||
|
||||
This function processes ground truth annotations from a batch of images for object detection tasks. It extracts
|
||||
bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for
|
||||
visualization or evaluation.
|
||||
|
||||
Args:
|
||||
img_idx (int): Index of the image in the batch to process.
|
||||
image_path (str | Path): Path to the image file.
|
||||
batch (dict): Batch dictionary containing detection data with keys:
|
||||
- 'batch_idx': Tensor of batch indices
|
||||
- 'bboxes': Tensor of bounding boxes in normalized xywh format
|
||||
- 'cls': Tensor of class labels
|
||||
- 'ori_shape': Original image shapes
|
||||
- 'resized_shape': Resized image shapes
|
||||
- 'ratio_pad': Ratio and padding information
|
||||
class_name_map (dict, optional): Mapping from class indices to class names.
|
||||
|
||||
Returns:
|
||||
(dict | None): Formatted ground truth annotations with the following structure:
|
||||
- 'boxes': List of box coordinates [x, y, width, height]
|
||||
- 'label': Label string with format "gt_{class_name}"
|
||||
- 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)
|
||||
Returns None if no bounding boxes are found for the image.
|
||||
"""
|
||||
indices = batch["batch_idx"] == img_idx
|
||||
bboxes = batch["bboxes"][indices]
|
||||
if len(bboxes) == 0:
|
||||
LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes labels")
|
||||
return None
|
||||
|
||||
cls_labels = batch["cls"][indices].squeeze(1).tolist()
|
||||
if class_name_map:
|
||||
cls_labels = [str(class_name_map[label]) for label in cls_labels]
|
||||
|
||||
original_image_shape = batch["ori_shape"][img_idx]
|
||||
resized_image_shape = batch["resized_shape"][img_idx]
|
||||
ratio_pad = batch["ratio_pad"][img_idx]
|
||||
|
||||
data = []
|
||||
for box, label in zip(bboxes, cls_labels):
|
||||
box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
|
||||
data.append(
|
||||
{
|
||||
"boxes": [box],
|
||||
"label": f"gt_{label}",
|
||||
"score": _scale_confidence_score(1.0),
|
||||
}
|
||||
)
|
||||
|
||||
return {"name": "ground_truth", "data": data}
|
||||
|
||||
|
||||
def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> dict | None:
|
||||
"""
|
||||
Format YOLO predictions for object detection visualization.
|
||||
|
||||
Args:
|
||||
image_path (Path): Path to the image file.
|
||||
metadata (dict): Prediction metadata containing bounding boxes and class information.
|
||||
class_label_map (dict, optional): Mapping from class indices to class names.
|
||||
class_map (dict, optional): Additional class mapping for label conversion.
|
||||
|
||||
Returns:
|
||||
(dict | None): Formatted prediction annotations or None if no predictions exist.
|
||||
"""
|
||||
stem = image_path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
|
||||
predictions = metadata.get(image_id)
|
||||
if not predictions:
|
||||
LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes predictions")
|
||||
return None
|
||||
|
||||
# apply the mapping that was used to map the predicted classes when the JSON was created
|
||||
if class_label_map and class_map:
|
||||
class_label_map = {class_map[k]: v for k, v in class_label_map.items()}
|
||||
try:
|
||||
# import pycotools utilities to decompress annotations for various tasks, e.g. segmentation
|
||||
from faster_coco_eval.core.mask import decode # noqa
|
||||
except ImportError:
|
||||
decode = None
|
||||
|
||||
data = []
|
||||
for prediction in predictions:
|
||||
boxes = prediction["bbox"]
|
||||
score = _scale_confidence_score(prediction["score"])
|
||||
cls_label = prediction["category_id"]
|
||||
if class_label_map:
|
||||
cls_label = str(class_label_map[cls_label])
|
||||
|
||||
annotation_data = {"boxes": [boxes], "label": cls_label, "score": score}
|
||||
|
||||
if decode is not None:
|
||||
# do segmentation processing only if we are able to decode it
|
||||
segments = prediction.get("segmentation", None)
|
||||
if segments is not None:
|
||||
segments = _extract_segmentation_annotation(segments, decode)
|
||||
if segments is not None:
|
||||
annotation_data["points"] = segments
|
||||
|
||||
data.append(annotation_data)
|
||||
|
||||
return {"name": "prediction", "data": data}
|
||||
|
||||
|
||||
def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> list[list[Any]] | None:
|
||||
"""
|
||||
Extract segmentation annotation from compressed segmentations as list of polygons.
|
||||
|
||||
Args:
|
||||
segmentation_raw (str): Raw segmentation data in compressed format.
|
||||
decode (Callable): Function to decode the compressed segmentation data.
|
||||
|
||||
Returns:
|
||||
(list[list[Any]] | None): List of polygon points or None if extraction fails.
|
||||
"""
|
||||
try:
|
||||
mask = decode(segmentation_raw)
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3]
|
||||
return [annotation.ravel().tolist() for annotation in annotations]
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Comet Failed to extract segmentation annotation: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map) -> list | None:
|
||||
"""
|
||||
Join the ground truth and prediction annotations if they exist.
|
||||
|
||||
Args:
|
||||
img_idx (int): Index of the image in the batch.
|
||||
image_path (Path): Path to the image file.
|
||||
batch (dict): Batch data containing ground truth annotations.
|
||||
prediction_metadata_map (dict): Map of prediction metadata by image ID.
|
||||
class_label_map (dict): Mapping from class indices to class names.
|
||||
class_map (dict): Additional class mapping for label conversion.
|
||||
|
||||
Returns:
|
||||
(list | None): List of annotation dictionaries or None if no annotations exist.
|
||||
"""
|
||||
ground_truth_annotations = _format_ground_truth_annotations_for_detection(
|
||||
img_idx, image_path, batch, class_label_map
|
||||
)
|
||||
prediction_annotations = _format_prediction_annotations(
|
||||
image_path, prediction_metadata_map, class_label_map, class_map
|
||||
)
|
||||
|
||||
annotations = [
|
||||
annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
|
||||
]
|
||||
return [annotations] if annotations else None
|
||||
|
||||
|
||||
def _create_prediction_metadata_map(model_predictions) -> dict:
|
||||
"""Create metadata map for model predictions by grouping them based on image ID."""
|
||||
pred_metadata_map = {}
|
||||
for prediction in model_predictions:
|
||||
pred_metadata_map.setdefault(prediction["image_id"], [])
|
||||
pred_metadata_map[prediction["image_id"]].append(prediction)
|
||||
|
||||
return pred_metadata_map
|
||||
|
||||
|
||||
def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
|
||||
"""Log the confusion matrix to Comet experiment."""
|
||||
conf_mat = trainer.validator.confusion_matrix.matrix
|
||||
names = list(trainer.data["names"].values()) + ["background"]
|
||||
experiment.log_confusion_matrix(
|
||||
matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
|
||||
)
|
||||
|
||||
|
||||
def _log_images(experiment, image_paths, curr_step: int | None, annotations=None) -> None:
|
||||
"""
|
||||
Log images to the experiment with optional annotations.
|
||||
|
||||
This function logs images to a Comet ML experiment, optionally including annotation data for visualization
|
||||
such as bounding boxes or segmentation masks.
|
||||
|
||||
Args:
|
||||
experiment (comet_ml.CometExperiment): The Comet ML experiment to log images to.
|
||||
image_paths (list[Path]): List of paths to images that will be logged.
|
||||
curr_step (int): Current training step/iteration for tracking in the experiment timeline.
|
||||
annotations (list[list[dict]], optional): Nested list of annotation dictionaries for each image. Each
|
||||
annotation contains visualization data like bounding boxes, labels, and confidence scores.
|
||||
"""
|
||||
if annotations:
|
||||
for image_path, annotation in zip(image_paths, annotations):
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
|
||||
|
||||
else:
|
||||
for image_path in image_paths:
|
||||
experiment.log_image(image_path, name=image_path.stem, step=curr_step)
|
||||
|
||||
|
||||
def _log_image_predictions(experiment, validator, curr_step) -> None:
|
||||
"""
|
||||
Log predicted boxes for a single image during training.
|
||||
|
||||
This function logs image predictions to a Comet ML experiment during model validation. It processes
|
||||
validation data and formats both ground truth and prediction annotations for visualization in the Comet
|
||||
dashboard. The function respects configured limits on the number of images to log.
|
||||
|
||||
Args:
|
||||
experiment (comet_ml.CometExperiment): The Comet ML experiment to log to.
|
||||
validator (BaseValidator): The validator instance containing validation data and predictions.
|
||||
curr_step (int): The current training step for logging timeline.
|
||||
|
||||
Notes:
|
||||
This function uses global state to track the number of logged predictions across calls.
|
||||
It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.
|
||||
The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.
|
||||
"""
|
||||
global _comet_image_prediction_count
|
||||
|
||||
task = validator.args.task
|
||||
if task not in COMET_SUPPORTED_TASKS:
|
||||
return
|
||||
|
||||
jdict = validator.jdict
|
||||
if not jdict:
|
||||
return
|
||||
|
||||
predictions_metadata_map = _create_prediction_metadata_map(jdict)
|
||||
dataloader = validator.dataloader
|
||||
class_label_map = validator.names
|
||||
class_map = getattr(validator, "class_map", None)
|
||||
|
||||
batch_logging_interval = _get_eval_batch_logging_interval()
|
||||
max_image_predictions = _get_max_image_predictions_to_log()
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if (batch_idx + 1) % batch_logging_interval != 0:
|
||||
continue
|
||||
|
||||
image_paths = batch["im_file"]
|
||||
for img_idx, image_path in enumerate(image_paths):
|
||||
if _comet_image_prediction_count >= max_image_predictions:
|
||||
return
|
||||
|
||||
image_path = Path(image_path)
|
||||
annotations = _fetch_annotations(
|
||||
img_idx,
|
||||
image_path,
|
||||
batch,
|
||||
predictions_metadata_map,
|
||||
class_label_map,
|
||||
class_map=class_map,
|
||||
)
|
||||
_log_images(
|
||||
experiment,
|
||||
[image_path],
|
||||
curr_step,
|
||||
annotations=annotations,
|
||||
)
|
||||
_comet_image_prediction_count += 1
|
||||
|
||||
|
||||
def _log_plots(experiment, trainer) -> None:
|
||||
"""
|
||||
Log evaluation plots and label plots for the experiment.
|
||||
|
||||
This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles
|
||||
different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots
|
||||
for each type.
|
||||
|
||||
Args:
|
||||
experiment (comet_ml.CometExperiment): The Comet ML experiment to log plots to.
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save
|
||||
directory information.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.callbacks.comet import _log_plots
|
||||
>>> _log_plots(experiment, trainer)
|
||||
"""
|
||||
plot_filenames = None
|
||||
if isinstance(trainer.validator.metrics, SegmentMetrics):
|
||||
plot_filenames = [
|
||||
trainer.save_dir / f"{prefix}{plots}.png"
|
||||
for plots in EVALUATION_PLOT_NAMES
|
||||
for prefix in SEGMENT_METRICS_PLOT_PREFIX
|
||||
]
|
||||
elif isinstance(trainer.validator.metrics, PoseMetrics):
|
||||
plot_filenames = [
|
||||
trainer.save_dir / f"{prefix}{plots}.png"
|
||||
for plots in EVALUATION_PLOT_NAMES
|
||||
for prefix in POSE_METRICS_PLOT_PREFIX
|
||||
]
|
||||
elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
|
||||
plot_filenames = [
|
||||
trainer.save_dir / f"{prefix}{plots}.png"
|
||||
for plots in EVALUATION_PLOT_NAMES
|
||||
for prefix in DETECTION_METRICS_PLOT_PREFIX
|
||||
]
|
||||
|
||||
if plot_filenames is not None:
|
||||
_log_images(experiment, plot_filenames, None)
|
||||
|
||||
confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
|
||||
_log_images(experiment, confusion_matrix_filenames, None)
|
||||
|
||||
if not isinstance(trainer.validator.metrics, ClassifyMetrics):
|
||||
label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
|
||||
_log_images(experiment, label_plot_filenames, None)
|
||||
|
||||
|
||||
def _log_model(experiment, trainer) -> None:
|
||||
"""Log the best-trained model to Comet.ml."""
|
||||
model_name = _get_comet_model_name()
|
||||
experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
|
||||
|
||||
|
||||
def _log_image_batches(experiment, trainer, curr_step: int) -> None:
|
||||
"""Log samples of image batches for train, validation, and test."""
|
||||
_log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
|
||||
_log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
|
||||
|
||||
|
||||
def _log_asset(experiment, asset_path) -> None:
|
||||
"""
|
||||
Logs a specific asset file to the given experiment.
|
||||
|
||||
This function facilitates logging an asset, such as a file, to the provided
|
||||
experiment. It enables integration with experiment tracking platforms.
|
||||
|
||||
Args:
|
||||
experiment (comet_ml.CometExperiment): The experiment instance to which the asset will be logged.
|
||||
asset_path (Path): The file path of the asset to log.
|
||||
"""
|
||||
experiment.log_asset(asset_path)
|
||||
|
||||
|
||||
def _log_table(experiment, table_path) -> None:
|
||||
"""
|
||||
Logs a table to the provided experiment.
|
||||
|
||||
This function is used to log a table file to the given experiment. The table
|
||||
is identified by its file path.
|
||||
|
||||
Args:
|
||||
experiment (comet_ml.CometExperiment): The experiment object where the table file will be logged.
|
||||
table_path (Path): The file path of the table to be logged.
|
||||
"""
|
||||
experiment.log_table(str(table_path))
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer) -> None:
|
||||
"""Create or resume a CometML experiment at the start of a YOLO pre-training routine."""
|
||||
_resume_or_create_experiment(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer) -> None:
|
||||
"""Log metrics and save batch images at the end of training epochs."""
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
|
||||
experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer) -> None:
|
||||
"""
|
||||
Log model assets at the end of each epoch during training.
|
||||
|
||||
This function is called at the end of each training epoch to log metrics, learning rates, and model information
|
||||
to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on
|
||||
configuration settings.
|
||||
|
||||
The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,
|
||||
it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),
|
||||
and image predictions (if enabled).
|
||||
|
||||
Args:
|
||||
trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.
|
||||
|
||||
Examples:
|
||||
>>> # Inside a training loop
|
||||
>>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML
|
||||
"""
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
save_assets = metadata["save_assets"]
|
||||
|
||||
experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
|
||||
experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
|
||||
if curr_epoch == 1:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
|
||||
|
||||
if not save_assets:
|
||||
return
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if _should_log_confusion_matrix():
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
if _should_log_image_predictions():
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
|
||||
|
||||
def on_train_end(trainer) -> None:
|
||||
"""Perform operations at the end of training."""
|
||||
experiment = comet_ml.get_running_experiment()
|
||||
if not experiment:
|
||||
return
|
||||
|
||||
metadata = _fetch_trainer_metadata(trainer)
|
||||
curr_epoch = metadata["curr_epoch"]
|
||||
curr_step = metadata["curr_step"]
|
||||
plots = trainer.args.plots
|
||||
|
||||
_log_model(experiment, trainer)
|
||||
if plots:
|
||||
_log_plots(experiment, trainer)
|
||||
|
||||
_log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
|
||||
_log_image_predictions(experiment, trainer.validator, curr_step)
|
||||
_log_image_batches(experiment, trainer, curr_step)
|
||||
# log results table
|
||||
table_path = trainer.save_dir / RESULTS_TABLE_NAME
|
||||
if table_path.exists():
|
||||
_log_table(experiment, table_path)
|
||||
|
||||
# log arguments YAML
|
||||
args_path = trainer.save_dir / ARGS_YAML_NAME
|
||||
if args_path.exists():
|
||||
_log_asset(experiment, args_path)
|
||||
|
||||
experiment.end()
|
||||
|
||||
global _comet_image_prediction_count
|
||||
_comet_image_prediction_count = 0
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if comet_ml
|
||||
else {}
|
||||
)
|
||||
202
ultralytics/utils/callbacks/dvc.py
Normal file
202
ultralytics/utils/callbacks/dvc.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["dvc"] is True # verify integration is enabled
|
||||
import dvclive
|
||||
|
||||
assert checks.check_version("dvclive", "2.11.0", verbose=True)
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
# DVCLive logger instance
|
||||
live = None
|
||||
_processed_plots = {}
|
||||
|
||||
# `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
|
||||
# distinguish final evaluation of the best model vs last epoch validation
|
||||
_training_epoch = False
|
||||
|
||||
except (ImportError, AssertionError, TypeError):
|
||||
dvclive = None
|
||||
|
||||
|
||||
def _log_images(path: Path, prefix: str = "") -> None:
|
||||
"""
|
||||
Log images at specified path with an optional prefix using DVCLive.
|
||||
|
||||
This function logs images found at the given path to DVCLive, organizing them by batch to enable slider
|
||||
functionality in the UI. It processes image filenames to extract batch information and restructures the path
|
||||
accordingly.
|
||||
|
||||
Args:
|
||||
path (Path): Path to the image file to be logged.
|
||||
prefix (str, optional): Optional prefix to add to the image name when logging.
|
||||
|
||||
Examples:
|
||||
>>> from pathlib import Path
|
||||
>>> _log_images(Path("runs/train/exp/val_batch0_pred.jpg"), prefix="validation")
|
||||
"""
|
||||
if live:
|
||||
name = path.name
|
||||
|
||||
# Group images by batch to enable sliders in UI
|
||||
if m := re.search(r"_batch(\d+)", name):
|
||||
ni = m[1]
|
||||
new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
|
||||
name = (Path(new_stem) / ni).with_suffix(path.suffix)
|
||||
|
||||
live.log_image(os.path.join(prefix, name), path)
|
||||
|
||||
|
||||
def _log_plots(plots: dict, prefix: str = "") -> None:
|
||||
"""
|
||||
Log plot images for training progress if they have not been previously processed.
|
||||
|
||||
Args:
|
||||
plots (dict): Dictionary containing plot information with timestamps.
|
||||
prefix (str, optional): Optional prefix to add to the logged image paths.
|
||||
"""
|
||||
for name, params in plots.items():
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
_log_images(name, prefix)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def _log_confusion_matrix(validator) -> None:
|
||||
"""
|
||||
Log confusion matrix for a validator using DVCLive.
|
||||
|
||||
This function processes the confusion matrix from a validator object and logs it to DVCLive by converting
|
||||
the matrix into lists of target and prediction labels.
|
||||
|
||||
Args:
|
||||
validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have
|
||||
attributes: confusion_matrix.matrix, confusion_matrix.task, and names.
|
||||
"""
|
||||
targets = []
|
||||
preds = []
|
||||
matrix = validator.confusion_matrix.matrix
|
||||
names = list(validator.names.values())
|
||||
if validator.confusion_matrix.task == "detect":
|
||||
names += ["background"]
|
||||
|
||||
for ti, pred in enumerate(matrix.T.astype(int)):
|
||||
for pi, num in enumerate(pred):
|
||||
targets.extend([names[ti]] * num)
|
||||
preds.extend([names[pi]] * num)
|
||||
|
||||
live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer) -> None:
|
||||
"""Initialize DVCLive logger for training metadata during pre-training routine."""
|
||||
try:
|
||||
global live
|
||||
live = dvclive.Live(save_dvc_exp=True, cache_images=True)
|
||||
LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"DVCLive installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer) -> None:
|
||||
"""Log plots related to the training process at the end of the pretraining routine."""
|
||||
_log_plots(trainer.plots, "train")
|
||||
|
||||
|
||||
def on_train_start(trainer) -> None:
|
||||
"""Log the training parameters if DVCLive logging is active."""
|
||||
if live:
|
||||
live.log_params(trainer.args)
|
||||
|
||||
|
||||
def on_train_epoch_start(trainer) -> None:
|
||||
"""Set the global variable _training_epoch value to True at the start of training each epoch."""
|
||||
global _training_epoch
|
||||
_training_epoch = True
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer) -> None:
|
||||
"""
|
||||
Log training metrics, model info, and advance to next step at the end of each fit epoch.
|
||||
|
||||
This function is called at the end of each fit epoch during training. It logs various metrics including
|
||||
training loss items, validation metrics, and learning rates. On the first epoch, it also logs model
|
||||
information. Additionally, it logs training and validation plots and advances the DVCLive step counter.
|
||||
|
||||
Args:
|
||||
trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.
|
||||
|
||||
Notes:
|
||||
This function only performs logging operations when DVCLive logging is active and during a training epoch.
|
||||
The global variable _training_epoch is used to track whether the current epoch is a training epoch.
|
||||
"""
|
||||
global _training_epoch
|
||||
if live and _training_epoch:
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value)
|
||||
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
for metric, value in model_info_for_loggers(trainer).items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, "train")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
|
||||
live.next_step()
|
||||
_training_epoch = False
|
||||
|
||||
|
||||
def on_train_end(trainer) -> None:
|
||||
"""
|
||||
Log best metrics, plots, and confusion matrix at the end of training.
|
||||
|
||||
This function is called at the conclusion of the training process to log final metrics, visualizations, and
|
||||
model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,
|
||||
validation plots, and confusion matrix for later analysis.
|
||||
|
||||
Args:
|
||||
trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.
|
||||
|
||||
Examples:
|
||||
>>> # Inside a custom training loop
|
||||
>>> from ultralytics.utils.callbacks.dvc import on_train_end
|
||||
>>> on_train_end(trainer) # Log final metrics and artifacts
|
||||
"""
|
||||
if live:
|
||||
# At the end log the best metrics. It runs validator on the best model internally.
|
||||
all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
|
||||
for metric, value in all_metrics.items():
|
||||
live.log_metric(metric, value, plot=False)
|
||||
|
||||
_log_plots(trainer.plots, "val")
|
||||
_log_plots(trainer.validator.plots, "val")
|
||||
_log_confusion_matrix(trainer.validator)
|
||||
|
||||
if trainer.best.exists():
|
||||
live.log_artifact(trainer.best, copy=True, type="model")
|
||||
|
||||
live.end()
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_train_epoch_start": on_train_epoch_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if dvclive
|
||||
else {}
|
||||
)
|
||||
110
ultralytics/utils/callbacks/hub.py
Normal file
110
ultralytics/utils/callbacks/hub.py
Normal file
@@ -0,0 +1,110 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import json
|
||||
from time import time
|
||||
|
||||
from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
|
||||
from ultralytics.utils import LOGGER, RANK, SETTINGS
|
||||
from ultralytics.utils.events import events
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Create a remote Ultralytics HUB session to log local model training."""
|
||||
if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None:
|
||||
trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Initialize timers for upload rate limiting before training begins."""
|
||||
if session := getattr(trainer, "hub_session", None):
|
||||
# Start timer for upload rate limit
|
||||
session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Upload training progress metrics to Ultralytics HUB at the end of each epoch."""
|
||||
if session := getattr(trainer, "hub_session", None):
|
||||
# Upload metrics after validation ends
|
||||
all_plots = {
|
||||
**trainer.label_loss_items(trainer.tloss, prefix="train"),
|
||||
**trainer.metrics,
|
||||
}
|
||||
if trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
all_plots = {**all_plots, **model_info_for_loggers(trainer)}
|
||||
|
||||
session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
|
||||
|
||||
# If any metrics failed to upload previously, add them to the queue to attempt uploading again
|
||||
if session.metrics_upload_failed_queue:
|
||||
session.metrics_queue.update(session.metrics_upload_failed_queue)
|
||||
|
||||
if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
|
||||
session.upload_metrics()
|
||||
session.timers["metrics"] = time() # reset timer
|
||||
session.metrics_queue = {} # reset queue
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Upload model checkpoints to Ultralytics HUB with rate limiting."""
|
||||
if session := getattr(trainer, "hub_session", None):
|
||||
# Upload checkpoints with rate limiting
|
||||
is_best = trainer.best_fitness == trainer.fitness
|
||||
if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
|
||||
LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
|
||||
session.upload_model(trainer.epoch, trainer.last, is_best)
|
||||
session.timers["ckpt"] = time() # reset timer
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Upload final model and metrics to Ultralytics HUB at the end of training."""
|
||||
if session := getattr(trainer, "hub_session", None):
|
||||
# Upload final model and metrics with exponential standoff
|
||||
LOGGER.info(f"{PREFIX}Syncing final model...")
|
||||
session.upload_model(
|
||||
trainer.epoch,
|
||||
trainer.best,
|
||||
map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
|
||||
final=True,
|
||||
)
|
||||
session.alive = False # stop heartbeats
|
||||
LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀")
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Run events on train start."""
|
||||
events(trainer.args, trainer.device)
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Run events on validation start."""
|
||||
if not validator.training:
|
||||
events(validator.args, validator.device)
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Run events on predict start."""
|
||||
events(predictor.args, predictor.device)
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Run events on export start."""
|
||||
events(exporter.args, exporter.device)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_model_save": on_model_save,
|
||||
"on_train_end": on_train_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_predict_start": on_predict_start,
|
||||
"on_export_start": on_export_start,
|
||||
}
|
||||
if SETTINGS["hub"] is True
|
||||
else {}
|
||||
)
|
||||
135
ultralytics/utils/callbacks/mlflow.py
Normal file
135
ultralytics/utils/callbacks/mlflow.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
MLflow Logging for Ultralytics YOLO.
|
||||
|
||||
This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
|
||||
For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
|
||||
|
||||
Commands:
|
||||
1. To set a project name:
|
||||
`export MLFLOW_EXPERIMENT_NAME=<your_experiment_name>` or use the project=<project> argument
|
||||
|
||||
2. To set a run name:
|
||||
`export MLFLOW_RUN=<your_run_name>` or use the name=<name> argument
|
||||
|
||||
3. To start a local MLflow server:
|
||||
mlflow server --backend-store-uri runs/mlflow
|
||||
It will by default start a local server at http://127.0.0.1:5000.
|
||||
To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
|
||||
|
||||
4. To kill all running MLflow server instances:
|
||||
ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
|
||||
"""
|
||||
|
||||
from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
|
||||
|
||||
try:
|
||||
import os
|
||||
|
||||
assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
|
||||
assert SETTINGS["mlflow"] is True # verify integration is enabled
|
||||
import mlflow
|
||||
|
||||
assert hasattr(mlflow, "__version__") # verify package is not directory
|
||||
from pathlib import Path
|
||||
|
||||
PREFIX = colorstr("MLflow: ")
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
mlflow = None
|
||||
|
||||
|
||||
def sanitize_dict(x: dict) -> dict:
|
||||
"""Sanitize dictionary keys by removing parentheses and converting values to floats."""
|
||||
return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""
|
||||
Log training parameters to MLflow at the end of the pretraining routine.
|
||||
|
||||
This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
|
||||
experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
|
||||
from the trainer.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
|
||||
|
||||
Environment Variables:
|
||||
MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
|
||||
MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
|
||||
MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
|
||||
MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
|
||||
"""
|
||||
global mlflow
|
||||
|
||||
uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
|
||||
LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
|
||||
mlflow.set_tracking_uri(uri)
|
||||
|
||||
# Set experiment and run names
|
||||
experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics"
|
||||
run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
|
||||
mlflow.set_experiment(experiment_name)
|
||||
|
||||
mlflow.autolog()
|
||||
try:
|
||||
active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
|
||||
LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
|
||||
if Path(uri).is_dir():
|
||||
LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
|
||||
LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
|
||||
mlflow.log_params(dict(trainer.args))
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}Failed to initialize: {e}")
|
||||
LOGGER.warning(f"{PREFIX}Not tracking this run")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each train epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(
|
||||
metrics={
|
||||
**sanitize_dict(trainer.lr),
|
||||
**sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
|
||||
},
|
||||
step=trainer.epoch,
|
||||
)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Log training metrics at the end of each fit epoch to MLflow."""
|
||||
if mlflow:
|
||||
mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Log model artifacts at the end of training."""
|
||||
if not mlflow:
|
||||
return
|
||||
mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
|
||||
for f in trainer.save_dir.glob("*"): # log all other files in save_dir
|
||||
if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
|
||||
mlflow.log_artifact(str(f))
|
||||
keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
|
||||
if keep_run_active:
|
||||
LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
|
||||
else:
|
||||
mlflow.end_run()
|
||||
LOGGER.debug(f"{PREFIX}mlflow run ended")
|
||||
|
||||
LOGGER.info(
|
||||
f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
|
||||
)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if mlflow
|
||||
else {}
|
||||
)
|
||||
134
ultralytics/utils/callbacks/neptune.py
Normal file
134
ultralytics/utils/callbacks/neptune.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["neptune"] is True # verify integration is enabled
|
||||
|
||||
import neptune
|
||||
from neptune.types import File
|
||||
|
||||
assert hasattr(neptune, "__version__")
|
||||
|
||||
run = None # NeptuneAI experiment logger instance
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
neptune = None
|
||||
|
||||
|
||||
def _log_scalars(scalars: dict, step: int = 0) -> None:
|
||||
"""
|
||||
Log scalars to the NeptuneAI experiment logger.
|
||||
|
||||
Args:
|
||||
scalars (dict): Dictionary of scalar values to log to NeptuneAI.
|
||||
step (int, optional): The current step or iteration number for logging.
|
||||
|
||||
Examples:
|
||||
>>> metrics = {"mAP": 0.85, "loss": 0.32}
|
||||
>>> _log_scalars(metrics, step=100)
|
||||
"""
|
||||
if run:
|
||||
for k, v in scalars.items():
|
||||
run[k].append(value=v, step=step)
|
||||
|
||||
|
||||
def _log_images(imgs_dict: dict, group: str = "") -> None:
|
||||
"""
|
||||
Log images to the NeptuneAI experiment logger.
|
||||
|
||||
This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized
|
||||
under the specified group name.
|
||||
|
||||
Args:
|
||||
imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.
|
||||
group (str, optional): Group name to organize images under in the Neptune UI.
|
||||
|
||||
Examples:
|
||||
>>> # Log validation images
|
||||
>>> _log_images({"val_batch": img_tensor}, group="validation")
|
||||
"""
|
||||
if run:
|
||||
for k, v in imgs_dict.items():
|
||||
run[f"{group}/{k}"].upload(File(v))
|
||||
|
||||
|
||||
def _log_plot(title: str, plot_path: str) -> None:
|
||||
"""Log plots to the NeptuneAI experiment logger."""
|
||||
import matplotlib.image as mpimg
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
img = mpimg.imread(plot_path)
|
||||
fig = plt.figure()
|
||||
ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
|
||||
ax.imshow(img)
|
||||
run[f"Plots/{title}"].upload(fig)
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer) -> None:
|
||||
"""Initialize NeptuneAI run and log hyperparameters before training starts."""
|
||||
try:
|
||||
global run
|
||||
run = neptune.init_run(
|
||||
project=trainer.args.project or "Ultralytics",
|
||||
name=trainer.args.name,
|
||||
tags=["Ultralytics"],
|
||||
)
|
||||
run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"NeptuneAI installed but not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer) -> None:
|
||||
"""Log training metrics and learning rate at the end of each training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer) -> None:
|
||||
"""Log model info and validation metrics at the end of each fit epoch."""
|
||||
if run and trainer.epoch == 0:
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
run["Configuration/Model"] = model_info_for_loggers(trainer)
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_val_end(validator) -> None:
|
||||
"""Log validation images at the end of validation."""
|
||||
if run:
|
||||
# Log val_labels and val_pred
|
||||
_log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
|
||||
|
||||
|
||||
def on_train_end(trainer) -> None:
|
||||
"""Log final results, plots, and model weights at the end of training."""
|
||||
if run:
|
||||
# Log final results, CM matrix + PR plots
|
||||
files = [
|
||||
"results.png",
|
||||
"confusion_matrix.png",
|
||||
"confusion_matrix_normalized.png",
|
||||
*(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
|
||||
]
|
||||
files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
|
||||
for f in files:
|
||||
_log_plot(title=f.stem, plot_path=f)
|
||||
# Log the final model
|
||||
run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_val_end": on_val_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if neptune
|
||||
else {}
|
||||
)
|
||||
73
ultralytics/utils/callbacks/platform.py
Normal file
73
ultralytics/utils/callbacks/platform.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import RANK, SETTINGS
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initialize and start console logging immediately at the very beginning."""
|
||||
if RANK in {-1, 0}:
|
||||
from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
|
||||
|
||||
trainer.system_logger = SystemLogger()
|
||||
trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
|
||||
trainer.console_logger.start_capture()
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer):
|
||||
"""Handle pre-training routine completion event."""
|
||||
pass
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Handle end of training epoch event and collect system metrics."""
|
||||
if RANK in {-1, 0} and hasattr(trainer, "system_logger"):
|
||||
system_metrics = trainer.system_logger.get_metrics()
|
||||
print(system_metrics) # for debug
|
||||
|
||||
|
||||
def on_model_save(trainer):
|
||||
"""Handle model checkpoint save event."""
|
||||
pass
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Stop console capture and finalize logs."""
|
||||
if logger := getattr(trainer, "console_logger", None):
|
||||
logger.stop_capture()
|
||||
|
||||
|
||||
def on_train_start(trainer):
|
||||
"""Handle training start event."""
|
||||
pass
|
||||
|
||||
|
||||
def on_val_start(validator):
|
||||
"""Handle validation start event."""
|
||||
pass
|
||||
|
||||
|
||||
def on_predict_start(predictor):
|
||||
"""Handle prediction start event."""
|
||||
pass
|
||||
|
||||
|
||||
def on_export_start(exporter):
|
||||
"""Handle model export start event."""
|
||||
pass
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_pretrain_routine_end": on_pretrain_routine_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_model_save": on_model_save,
|
||||
"on_train_end": on_train_end,
|
||||
"on_train_start": on_train_start,
|
||||
"on_val_start": on_val_start,
|
||||
"on_predict_start": on_predict_start,
|
||||
"on_export_start": on_export_start,
|
||||
}
|
||||
if SETTINGS.get("platform", False) is True # disabled for debugging
|
||||
else {}
|
||||
)
|
||||
43
ultralytics/utils/callbacks/raytune.py
Normal file
43
ultralytics/utils/callbacks/raytune.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import SETTINGS
|
||||
|
||||
try:
|
||||
assert SETTINGS["raytune"] is True # verify integration is enabled
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import session
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
tune = None
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""
|
||||
Report training metrics to Ray Tune at epoch end when a Ray session is active.
|
||||
|
||||
Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
|
||||
enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.
|
||||
|
||||
Examples:
|
||||
>>> # Called automatically by the Ultralytics training loop
|
||||
>>> on_fit_epoch_end(trainer)
|
||||
|
||||
References:
|
||||
Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html
|
||||
"""
|
||||
if ray.train._internal.session.get_session(): # check if Ray Tune session is active
|
||||
metrics = trainer.metrics
|
||||
session.report({**metrics, **{"epoch": trainer.epoch + 1}})
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
}
|
||||
if tune
|
||||
else {}
|
||||
)
|
||||
131
ultralytics/utils/callbacks/tensorboard.py
Normal file
131
ultralytics/utils/callbacks/tensorboard.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["tensorboard"] is True # verify integration is enabled
|
||||
WRITER = None # TensorBoard SummaryWriter instance
|
||||
PREFIX = colorstr("TensorBoard: ")
|
||||
|
||||
# Imports below only required if TensorBoard enabled
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
except (ImportError, AssertionError, TypeError, AttributeError):
|
||||
# TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
|
||||
# AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
|
||||
SummaryWriter = None
|
||||
|
||||
|
||||
def _log_scalars(scalars: dict, step: int = 0) -> None:
|
||||
"""
|
||||
Log scalar values to TensorBoard.
|
||||
|
||||
Args:
|
||||
scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
|
||||
corresponding scalar values.
|
||||
step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
|
||||
|
||||
Examples:
|
||||
Log training metrics
|
||||
>>> metrics = {"loss": 0.5, "accuracy": 0.95}
|
||||
>>> _log_scalars(metrics, step=100)
|
||||
"""
|
||||
if WRITER:
|
||||
for k, v in scalars.items():
|
||||
WRITER.add_scalar(k, v, step)
|
||||
|
||||
|
||||
def _log_tensorboard_graph(trainer) -> None:
|
||||
"""
|
||||
Log model graph to TensorBoard.
|
||||
|
||||
This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
|
||||
tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
|
||||
approach for models like RTDETR that may require special handling.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.
|
||||
Must have attributes model and args with imgsz.
|
||||
|
||||
Notes:
|
||||
This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
|
||||
It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
|
||||
model architectures.
|
||||
"""
|
||||
# Input image
|
||||
imgsz = trainer.args.imgsz
|
||||
imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
|
||||
p = next(trainer.model.parameters()) # for device, type
|
||||
im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
|
||||
warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
|
||||
|
||||
# Try simple method first (YOLO)
|
||||
try:
|
||||
trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
|
||||
WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
return
|
||||
|
||||
except Exception:
|
||||
# Fallback to TorchScript export steps (RTDETR)
|
||||
try:
|
||||
model = deepcopy(torch_utils.unwrap_model(trainer.model))
|
||||
model.eval()
|
||||
model = model.fuse(verbose=False)
|
||||
for m in model.modules():
|
||||
if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
|
||||
m.export = True
|
||||
m.format = "torchscript"
|
||||
model(im) # dry run
|
||||
WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
|
||||
LOGGER.info(f"{PREFIX}model graph visualization added ✅")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure {e}")
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer) -> None:
|
||||
"""Initialize TensorBoard logging with SummaryWriter."""
|
||||
if SummaryWriter:
|
||||
try:
|
||||
global WRITER
|
||||
WRITER = SummaryWriter(str(trainer.save_dir))
|
||||
LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
|
||||
|
||||
|
||||
def on_train_start(trainer) -> None:
|
||||
"""Log TensorBoard graph."""
|
||||
if WRITER:
|
||||
_log_tensorboard_graph(trainer)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer) -> None:
|
||||
"""Log scalar statistics at the end of a training epoch."""
|
||||
_log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
|
||||
_log_scalars(trainer.lr, trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer) -> None:
|
||||
"""Log epoch metrics at end of training epoch."""
|
||||
_log_scalars(trainer.metrics, trainer.epoch + 1)
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_start": on_train_start,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
}
|
||||
if SummaryWriter
|
||||
else {}
|
||||
)
|
||||
191
ultralytics/utils/callbacks/wb.py
Normal file
191
ultralytics/utils/callbacks/wb.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import SETTINGS, TESTS_RUNNING
|
||||
from ultralytics.utils.torch_utils import model_info_for_loggers
|
||||
|
||||
try:
|
||||
assert not TESTS_RUNNING # do not log pytest
|
||||
assert SETTINGS["wandb"] is True # verify integration is enabled
|
||||
import wandb as wb
|
||||
|
||||
assert hasattr(wb, "__version__") # verify package is not directory
|
||||
_processed_plots = {}
|
||||
|
||||
except (ImportError, AssertionError):
|
||||
wb = None
|
||||
|
||||
|
||||
def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
|
||||
"""
|
||||
Create and log a custom metric visualization to wandb.plot.pr_curve.
|
||||
|
||||
This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
|
||||
curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
|
||||
different classes.
|
||||
|
||||
Args:
|
||||
x (list): Values for the x-axis; expected to have length N.
|
||||
y (list): Corresponding values for the y-axis; also expected to have length N.
|
||||
classes (list): Labels identifying the class of each point; length N.
|
||||
title (str, optional): Title for the plot.
|
||||
x_title (str, optional): Label for the x-axis.
|
||||
y_title (str, optional): Label for the y-axis.
|
||||
|
||||
Returns:
|
||||
(wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
|
||||
"""
|
||||
import polars as pl # scope for faster 'import ultralytics'
|
||||
import polars.selectors as cs
|
||||
|
||||
df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
|
||||
data = df.select(["class", "y", "x"]).rows()
|
||||
|
||||
fields = {"x": "x", "y": "y", "class": "class"}
|
||||
string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
|
||||
return wb.plot_table(
|
||||
"wandb/area-under-curve/v0",
|
||||
wb.Table(data=data, columns=["class", "y", "x"]),
|
||||
fields=fields,
|
||||
string_fields=string_fields,
|
||||
)
|
||||
|
||||
|
||||
def _plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=None,
|
||||
id="precision-recall",
|
||||
title="Precision Recall Curve",
|
||||
x_title="Recall",
|
||||
y_title="Precision",
|
||||
num_x=100,
|
||||
only_mean=False,
|
||||
):
|
||||
"""
|
||||
Log a metric curve visualization.
|
||||
|
||||
This function generates a metric curve based on input data and logs the visualization to wandb.
|
||||
The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Data points for the x-axis with length N.
|
||||
y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
|
||||
names (list, optional): Names of the classes corresponding to the y-axis data; length C.
|
||||
id (str, optional): Unique identifier for the logged data in wandb.
|
||||
title (str, optional): Title for the visualization plot.
|
||||
x_title (str, optional): Label for the x-axis.
|
||||
y_title (str, optional): Label for the y-axis.
|
||||
num_x (int, optional): Number of interpolated data points for visualization.
|
||||
only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
|
||||
|
||||
Notes:
|
||||
The function leverages the '_custom_table' function to generate the actual visualization.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Create new x
|
||||
if names is None:
|
||||
names = []
|
||||
x_new = np.linspace(x[0], x[-1], num_x).round(5)
|
||||
|
||||
# Create arrays for logging
|
||||
x_log = x_new.tolist()
|
||||
y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
|
||||
|
||||
if only_mean:
|
||||
table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
|
||||
wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
|
||||
else:
|
||||
classes = ["mean"] * len(x_log)
|
||||
for i, yi in enumerate(y):
|
||||
x_log.extend(x_new) # add new x
|
||||
y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
|
||||
classes.extend([names[i]] * len(x_new)) # add class names
|
||||
wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
|
||||
|
||||
|
||||
def _log_plots(plots, step):
|
||||
"""
|
||||
Log plots to WandB at a specific step if they haven't been logged already.
|
||||
|
||||
This function checks each plot in the input dictionary against previously processed plots and logs
|
||||
new or updated plots to WandB at the specified step.
|
||||
|
||||
Args:
|
||||
plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
|
||||
containing plot metadata including timestamps.
|
||||
step (int): The step/epoch at which to log the plots in the WandB run.
|
||||
|
||||
Notes:
|
||||
The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
|
||||
Plots are identified by their stem name (filename without extension).
|
||||
Each plot is logged as a WandB Image object.
|
||||
"""
|
||||
for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
|
||||
timestamp = params["timestamp"]
|
||||
if _processed_plots.get(name) != timestamp:
|
||||
wb.run.log({name.stem: wb.Image(str(name))}, step=step)
|
||||
_processed_plots[name] = timestamp
|
||||
|
||||
|
||||
def on_pretrain_routine_start(trainer):
|
||||
"""Initialize and start wandb project if module is present."""
|
||||
if not wb.run:
|
||||
wb.init(
|
||||
project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
|
||||
name=str(trainer.args.name).replace("/", "-"),
|
||||
config=vars(trainer.args),
|
||||
)
|
||||
|
||||
|
||||
def on_fit_epoch_end(trainer):
|
||||
"""Log training metrics and model information at the end of an epoch."""
|
||||
wb.run.log(trainer.metrics, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 0:
|
||||
wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_epoch_end(trainer):
|
||||
"""Log metrics and save images at the end of each training epoch."""
|
||||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
|
||||
wb.run.log(trainer.lr, step=trainer.epoch + 1)
|
||||
if trainer.epoch == 1:
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
|
||||
|
||||
def on_train_end(trainer):
|
||||
"""Save the best model as an artifact and log final plots at the end of training."""
|
||||
_log_plots(trainer.validator.plots, step=trainer.epoch + 1)
|
||||
_log_plots(trainer.plots, step=trainer.epoch + 1)
|
||||
art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
|
||||
if trainer.best.exists():
|
||||
art.add_file(trainer.best)
|
||||
wb.run.log_artifact(art, aliases=["best"])
|
||||
# Check if we actually have plots to save
|
||||
if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
|
||||
for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
|
||||
x, y, x_title, y_title = curve_values
|
||||
_plot_curve(
|
||||
x,
|
||||
y,
|
||||
names=list(trainer.validator.metrics.names.values()),
|
||||
id=f"curves/{curve_name}",
|
||||
title=curve_name,
|
||||
x_title=x_title,
|
||||
y_title=y_title,
|
||||
)
|
||||
wb.run.finish() # required or run continues on dashboard
|
||||
|
||||
|
||||
callbacks = (
|
||||
{
|
||||
"on_pretrain_routine_start": on_pretrain_routine_start,
|
||||
"on_train_epoch_end": on_train_epoch_end,
|
||||
"on_fit_epoch_end": on_fit_epoch_end,
|
||||
"on_train_end": on_train_end,
|
||||
}
|
||||
if wb
|
||||
else {}
|
||||
)
|
||||
964
ultralytics/utils/checks.py
Normal file
964
ultralytics/utils/checks.py
Normal file
@@ -0,0 +1,964 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import glob
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
from importlib import metadata
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import (
|
||||
ARM64,
|
||||
ASSETS,
|
||||
AUTOINSTALL,
|
||||
GIT,
|
||||
IS_COLAB,
|
||||
IS_JETSON,
|
||||
IS_KAGGLE,
|
||||
IS_PIP_PACKAGE,
|
||||
LINUX,
|
||||
LOGGER,
|
||||
MACOS,
|
||||
ONLINE,
|
||||
PYTHON_VERSION,
|
||||
RKNN_CHIPS,
|
||||
ROOT,
|
||||
TORCH_VERSION,
|
||||
TORCHVISION_VERSION,
|
||||
USER_CONFIG_DIR,
|
||||
WINDOWS,
|
||||
Retry,
|
||||
ThreadingLocked,
|
||||
TryExcept,
|
||||
clean_url,
|
||||
colorstr,
|
||||
downloads,
|
||||
is_github_action_running,
|
||||
url2file,
|
||||
)
|
||||
|
||||
|
||||
def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
|
||||
"""
|
||||
Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the requirements.txt file.
|
||||
package (str, optional): Python package to use instead of requirements.txt file.
|
||||
|
||||
Returns:
|
||||
requirements (list[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and
|
||||
`specifier` attributes.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.checks import parse_requirements
|
||||
>>> parse_requirements(package="ultralytics")
|
||||
"""
|
||||
if package:
|
||||
requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
|
||||
else:
|
||||
requires = Path(file_path).read_text().splitlines()
|
||||
|
||||
requirements = []
|
||||
for line in requires:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
line = line.partition("#")[0].strip() # ignore inline comments
|
||||
if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line):
|
||||
requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
|
||||
|
||||
return requirements
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def parse_version(version="0.0.0") -> tuple:
|
||||
"""
|
||||
Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
|
||||
|
||||
Args:
|
||||
version (str): Version string, i.e. '2.0.1+cpu'
|
||||
|
||||
Returns:
|
||||
(tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
|
||||
"""
|
||||
try:
|
||||
return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"failure for parse_version({version}), returning (0, 0, 0): {e}")
|
||||
return 0, 0, 0
|
||||
|
||||
|
||||
def is_ascii(s) -> bool:
|
||||
"""
|
||||
Check if a string is composed of only ASCII characters.
|
||||
|
||||
Args:
|
||||
s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).
|
||||
|
||||
Returns:
|
||||
(bool): True if the string is composed only of ASCII characters, False otherwise.
|
||||
"""
|
||||
return all(ord(c) < 128 for c in str(s))
|
||||
|
||||
|
||||
def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
|
||||
"""
|
||||
Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
|
||||
stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
|
||||
|
||||
Args:
|
||||
imgsz (int | list[int]): Image size.
|
||||
stride (int): Stride value.
|
||||
min_dim (int): Minimum number of dimensions.
|
||||
max_dim (int): Maximum number of dimensions.
|
||||
floor (int): Minimum allowed value for image size.
|
||||
|
||||
Returns:
|
||||
(list[int] | int): Updated image size.
|
||||
"""
|
||||
# Convert stride to integer if it is a tensor
|
||||
stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
|
||||
|
||||
# Convert image size to list if it is an integer
|
||||
if isinstance(imgsz, int):
|
||||
imgsz = [imgsz]
|
||||
elif isinstance(imgsz, (list, tuple)):
|
||||
imgsz = list(imgsz)
|
||||
elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
|
||||
imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
|
||||
f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
|
||||
)
|
||||
|
||||
# Apply max_dim
|
||||
if len(imgsz) > max_dim:
|
||||
msg = (
|
||||
"'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
|
||||
"or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
|
||||
)
|
||||
if max_dim != 1:
|
||||
raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
|
||||
LOGGER.warning(f"updating to 'imgsz={max(imgsz)}'. {msg}")
|
||||
imgsz = [max(imgsz)]
|
||||
# Make image size a multiple of the stride
|
||||
sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
|
||||
|
||||
# Print warning message if image size was updated
|
||||
if sz != imgsz:
|
||||
LOGGER.warning(f"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
|
||||
|
||||
# Add missing dimensions if necessary
|
||||
sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
|
||||
|
||||
return sz
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def check_uv():
|
||||
"""Check if uv package manager is installed and can run successfully."""
|
||||
try:
|
||||
return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def check_version(
|
||||
current: str = "0.0.0",
|
||||
required: str = "0.0.0",
|
||||
name: str = "version",
|
||||
hard: bool = False,
|
||||
verbose: bool = False,
|
||||
msg: str = "",
|
||||
) -> bool:
|
||||
"""
|
||||
Check current version against the required version or range.
|
||||
|
||||
Args:
|
||||
current (str): Current version or package name to get version from.
|
||||
required (str): Required version or range (in pip-style format).
|
||||
name (str): Name to be used in warning message.
|
||||
hard (bool): If True, raise an AssertionError if the requirement is not met.
|
||||
verbose (bool): If True, print warning message if requirement is not met.
|
||||
msg (str): Extra message to display if verbose.
|
||||
|
||||
Returns:
|
||||
(bool): True if requirement is met, False otherwise.
|
||||
|
||||
Examples:
|
||||
Check if current version is exactly 22.04
|
||||
>>> check_version(current="22.04", required="==22.04")
|
||||
|
||||
Check if current version is greater than or equal to 22.04
|
||||
>>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
|
||||
|
||||
Check if current version is less than or equal to 22.04
|
||||
>>> check_version(current="22.04", required="<=22.04")
|
||||
|
||||
Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
|
||||
>>> check_version(current="21.10", required=">20.04,<22.04")
|
||||
"""
|
||||
if not current: # if current is '' or None
|
||||
LOGGER.warning(f"invalid check_version({current}, {required}) requested, please check values.")
|
||||
return True
|
||||
elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
|
||||
try:
|
||||
name = current # assigned package name to 'name' arg
|
||||
current = metadata.version(current) # get version string from package name
|
||||
except metadata.PackageNotFoundError as e:
|
||||
if hard:
|
||||
raise ModuleNotFoundError(f"{current} package is required but not installed") from e
|
||||
else:
|
||||
return False
|
||||
|
||||
if not required: # if required is '' or None
|
||||
return True
|
||||
|
||||
if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"'
|
||||
(WINDOWS and "win32" not in required)
|
||||
or (LINUX and "linux" not in required)
|
||||
or (MACOS and "macos" not in required and "darwin" not in required)
|
||||
):
|
||||
return True
|
||||
|
||||
op = ""
|
||||
version = ""
|
||||
result = True
|
||||
c = parse_version(current) # '1.2.3' -> (1, 2, 3)
|
||||
for r in required.strip(",").split(","):
|
||||
op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
|
||||
if not op:
|
||||
op = ">=" # assume >= if no op passed
|
||||
v = parse_version(version) # '1.2.3' -> (1, 2, 3)
|
||||
if op == "==" and c != v:
|
||||
result = False
|
||||
elif op == "!=" and c == v:
|
||||
result = False
|
||||
elif op == ">=" and not (c >= v):
|
||||
result = False
|
||||
elif op == "<=" and not (c <= v):
|
||||
result = False
|
||||
elif op == ">" and not (c > v):
|
||||
result = False
|
||||
elif op == "<" and not (c < v):
|
||||
result = False
|
||||
if not result:
|
||||
warning = f"{name}{required} is required, but {name}=={current} is currently installed {msg}"
|
||||
if hard:
|
||||
raise ModuleNotFoundError(warning) # assert version requirements met
|
||||
if verbose:
|
||||
LOGGER.warning(warning)
|
||||
return result
|
||||
|
||||
|
||||
def check_latest_pypi_version(package_name="ultralytics"):
|
||||
"""
|
||||
Return the latest version of a PyPI package without downloading or installing it.
|
||||
|
||||
Args:
|
||||
package_name (str): The name of the package to find the latest version for.
|
||||
|
||||
Returns:
|
||||
(str): The latest version of the package.
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
try:
|
||||
requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
|
||||
response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
|
||||
if response.status_code == 200:
|
||||
return response.json()["info"]["version"]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def check_pip_update_available():
|
||||
"""
|
||||
Check if a new version of the ultralytics package is available on PyPI.
|
||||
|
||||
Returns:
|
||||
(bool): True if an update is available, False otherwise.
|
||||
"""
|
||||
if ONLINE and IS_PIP_PACKAGE:
|
||||
try:
|
||||
from ultralytics import __version__
|
||||
|
||||
latest = check_latest_pypi_version()
|
||||
if check_version(__version__, f"<{latest}"): # check if current version is < latest version
|
||||
LOGGER.info(
|
||||
f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
|
||||
f"Update with 'pip install -U ultralytics'"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
@ThreadingLocked()
|
||||
@functools.lru_cache
|
||||
def check_font(font="Arial.ttf"):
|
||||
"""
|
||||
Find font locally or download to user's configuration directory if it does not already exist.
|
||||
|
||||
Args:
|
||||
font (str): Path or name of font.
|
||||
|
||||
Returns:
|
||||
(Path): Resolved font file path.
|
||||
"""
|
||||
from matplotlib import font_manager # scope for faster 'import ultralytics'
|
||||
|
||||
# Check USER_CONFIG_DIR
|
||||
name = Path(font).name
|
||||
file = USER_CONFIG_DIR / name
|
||||
if file.exists():
|
||||
return file
|
||||
|
||||
# Check system fonts
|
||||
matches = [s for s in font_manager.findSystemFonts() if font in s]
|
||||
if any(matches):
|
||||
return matches[0]
|
||||
|
||||
# Download to USER_CONFIG_DIR if missing
|
||||
url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}"
|
||||
if downloads.is_url(url, check=True):
|
||||
downloads.safe_download(url=url, file=file)
|
||||
return file
|
||||
|
||||
|
||||
def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool:
|
||||
"""
|
||||
Check current python version against the required minimum version.
|
||||
|
||||
Args:
|
||||
minimum (str): Required minimum version of python.
|
||||
hard (bool): If True, raise an AssertionError if the requirement is not met.
|
||||
verbose (bool): If True, print warning message if requirement is not met.
|
||||
|
||||
Returns:
|
||||
(bool): Whether the installed Python version meets the minimum constraints.
|
||||
"""
|
||||
return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
|
||||
|
||||
|
||||
@TryExcept()
|
||||
def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
|
||||
"""
|
||||
Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.
|
||||
|
||||
Args:
|
||||
requirements (Path | str | list[str] | tuple[str]): Path to a requirements.txt file, a single package
|
||||
requirement as a string, or a list of package requirements as strings.
|
||||
exclude (tuple): Tuple of package names to exclude from checking.
|
||||
install (bool): If True, attempt to auto-update packages that don't meet requirements.
|
||||
cmds (str): Additional commands to pass to the pip install command when auto-updating.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.checks import check_requirements
|
||||
|
||||
Check a requirements.txt file
|
||||
>>> check_requirements("path/to/requirements.txt")
|
||||
|
||||
Check a single package
|
||||
>>> check_requirements("ultralytics>=8.0.0")
|
||||
|
||||
Check multiple packages
|
||||
>>> check_requirements(["numpy", "ultralytics>=8.0.0"])
|
||||
"""
|
||||
prefix = colorstr("red", "bold", "requirements:")
|
||||
if isinstance(requirements, Path): # requirements.txt file
|
||||
file = requirements.resolve()
|
||||
assert file.exists(), f"{prefix} {file} not found, check failed."
|
||||
requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
|
||||
elif isinstance(requirements, str):
|
||||
requirements = [requirements]
|
||||
|
||||
pkgs = []
|
||||
for r in requirements:
|
||||
r_stripped = r.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
|
||||
match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
|
||||
name, required = match[1], match[2].strip() if match[2] else ""
|
||||
try:
|
||||
assert check_version(metadata.version(name), required) # exception if requirements not met
|
||||
except (AssertionError, metadata.PackageNotFoundError):
|
||||
pkgs.append(r)
|
||||
|
||||
@Retry(times=2, delay=1)
|
||||
def attempt_install(packages, commands, use_uv):
|
||||
"""Attempt package installation with uv if available, falling back to pip."""
|
||||
if use_uv:
|
||||
base = (
|
||||
f"uv pip install --no-cache-dir {packages} {commands} "
|
||||
f"--index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
|
||||
)
|
||||
try:
|
||||
return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE, text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if e.stderr and "No virtual environment found" in e.stderr:
|
||||
return subprocess.check_output(
|
||||
base.replace("uv pip install", "uv pip install --system"),
|
||||
shell=True,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
raise
|
||||
return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True, text=True)
|
||||
|
||||
s = " ".join(f'"{x}"' for x in pkgs) # console string
|
||||
if s:
|
||||
if install and AUTOINSTALL: # check environment variable
|
||||
# Note uv fails on arm64 macOS and Raspberry Pi runners
|
||||
n = len(pkgs) # number of packages updates
|
||||
LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
|
||||
try:
|
||||
t = time.time()
|
||||
assert ONLINE, "AutoUpdate skipped (offline)"
|
||||
LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv()))
|
||||
dt = time.time() - t
|
||||
LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s")
|
||||
LOGGER.warning(
|
||||
f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
|
||||
)
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{prefix} ❌ {e}")
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_torchvision():
|
||||
"""
|
||||
Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
|
||||
|
||||
This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
|
||||
to the compatibility table based on: https://github.com/pytorch/vision#installation.
|
||||
"""
|
||||
compatibility_table = {
|
||||
"2.9": ["0.24"],
|
||||
"2.8": ["0.23"],
|
||||
"2.7": ["0.22"],
|
||||
"2.6": ["0.21"],
|
||||
"2.5": ["0.20"],
|
||||
"2.4": ["0.19"],
|
||||
"2.3": ["0.18"],
|
||||
"2.2": ["0.17"],
|
||||
"2.1": ["0.16"],
|
||||
"2.0": ["0.15"],
|
||||
"1.13": ["0.14"],
|
||||
"1.12": ["0.13"],
|
||||
}
|
||||
|
||||
# Check major and minor versions
|
||||
v_torch = ".".join(TORCH_VERSION.split("+", 1)[0].split(".")[:2])
|
||||
if v_torch in compatibility_table:
|
||||
compatible_versions = compatibility_table[v_torch]
|
||||
v_torchvision = ".".join(TORCHVISION_VERSION.split("+", 1)[0].split(".")[:2])
|
||||
if all(v_torchvision != v for v in compatible_versions):
|
||||
LOGGER.warning(
|
||||
f"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
|
||||
f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
|
||||
"'pip install -U torch torchvision' to update both.\n"
|
||||
"For a full compatibility table see https://github.com/pytorch/vision#installation"
|
||||
)
|
||||
|
||||
|
||||
def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
|
||||
"""
|
||||
Check file(s) for acceptable suffix.
|
||||
|
||||
Args:
|
||||
file (str | list[str]): File or list of files to check.
|
||||
suffix (str | tuple): Acceptable suffix or tuple of suffixes.
|
||||
msg (str): Additional message to display in case of error.
|
||||
"""
|
||||
if file and suffix:
|
||||
if isinstance(suffix, str):
|
||||
suffix = {suffix}
|
||||
for f in file if isinstance(file, (list, tuple)) else [file]:
|
||||
if s := str(f).rpartition(".")[-1].lower().strip(): # file suffix
|
||||
assert f".{s}" in suffix, f"{msg}{f} acceptable suffix is {suffix}, not .{s}"
|
||||
|
||||
|
||||
def check_yolov5u_filename(file: str, verbose: bool = True):
|
||||
"""
|
||||
Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
|
||||
|
||||
Args:
|
||||
file (str): Filename to check and potentially update.
|
||||
verbose (bool): Whether to print information about the replacement.
|
||||
|
||||
Returns:
|
||||
(str): Updated filename.
|
||||
"""
|
||||
if "yolov3" in file or "yolov5" in file:
|
||||
if "u.yaml" in file:
|
||||
file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
|
||||
elif ".pt" in file and "u" not in file:
|
||||
original_file = file
|
||||
file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
|
||||
file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
|
||||
file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
|
||||
if file != original_file and verbose:
|
||||
LOGGER.info(
|
||||
f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
|
||||
f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
|
||||
f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
|
||||
)
|
||||
return file
|
||||
|
||||
|
||||
def check_model_file_from_stem(model="yolo11n"):
|
||||
"""
|
||||
Return a model filename from a valid model stem.
|
||||
|
||||
Args:
|
||||
model (str): Model stem to check.
|
||||
|
||||
Returns:
|
||||
(str | Path): Model filename with appropriate suffix.
|
||||
"""
|
||||
path = Path(model)
|
||||
if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS:
|
||||
return path.with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt
|
||||
return model
|
||||
|
||||
|
||||
def check_file(file, suffix="", download=True, download_dir=".", hard=True):
|
||||
"""
|
||||
Search/download file (if necessary), check suffix (if provided), and return path.
|
||||
|
||||
Args:
|
||||
file (str): File name or path.
|
||||
suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.
|
||||
download (bool): Whether to download the file if it doesn't exist locally.
|
||||
download_dir (str): Directory to download the file to.
|
||||
hard (bool): Whether to raise an error if the file is not found.
|
||||
|
||||
Returns:
|
||||
(str): Path to the file.
|
||||
"""
|
||||
check_suffix(file, suffix) # optional
|
||||
file = str(file).strip() # convert to string and strip spaces
|
||||
file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
|
||||
if (
|
||||
not file
|
||||
or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
|
||||
or file.lower().startswith("grpc://")
|
||||
): # file exists or gRPC Triton images
|
||||
return file
|
||||
elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
|
||||
url = file # warning: Pathlib turns :// -> :/
|
||||
file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
|
||||
if file.exists():
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
downloads.safe_download(url=url, file=file, unzip=False)
|
||||
return str(file)
|
||||
else: # search
|
||||
files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
|
||||
if not files and hard:
|
||||
raise FileNotFoundError(f"'{file}' does not exist")
|
||||
elif len(files) > 1 and hard:
|
||||
raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
|
||||
return files[0] if len(files) else [] # return file
|
||||
|
||||
|
||||
def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
|
||||
"""
|
||||
Search/download YAML file (if necessary) and return path, checking suffix.
|
||||
|
||||
Args:
|
||||
file (str | Path): File name or path.
|
||||
suffix (tuple): Tuple of acceptable YAML file suffixes.
|
||||
hard (bool): Whether to raise an error if the file is not found or multiple files are found.
|
||||
|
||||
Returns:
|
||||
(str): Path to the YAML file.
|
||||
"""
|
||||
return check_file(file, suffix, hard=hard)
|
||||
|
||||
|
||||
def check_is_path_safe(basedir, path):
|
||||
"""
|
||||
Check if the resolved path is under the intended directory to prevent path traversal.
|
||||
|
||||
Args:
|
||||
basedir (Path | str): The intended directory.
|
||||
path (Path | str): The path to check.
|
||||
|
||||
Returns:
|
||||
(bool): True if the path is safe, False otherwise.
|
||||
"""
|
||||
base_dir_resolved = Path(basedir).resolve()
|
||||
path_resolved = Path(path).resolve()
|
||||
|
||||
return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def check_imshow(warn=False):
|
||||
"""
|
||||
Check if environment supports image displays.
|
||||
|
||||
Args:
|
||||
warn (bool): Whether to warn if environment doesn't support image displays.
|
||||
|
||||
Returns:
|
||||
(bool): True if environment supports image displays, False otherwise.
|
||||
"""
|
||||
try:
|
||||
if LINUX:
|
||||
assert not IS_COLAB and not IS_KAGGLE
|
||||
assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
|
||||
cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
|
||||
cv2.waitKey(1)
|
||||
cv2.destroyAllWindows()
|
||||
cv2.waitKey(1)
|
||||
return True
|
||||
except Exception as e:
|
||||
if warn:
|
||||
LOGGER.warning(f"Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_yolo(verbose=True, device=""):
|
||||
"""
|
||||
Return a human-readable YOLO software and hardware summary.
|
||||
|
||||
Args:
|
||||
verbose (bool): Whether to print verbose information.
|
||||
device (str | torch.device): Device to use for YOLO.
|
||||
"""
|
||||
import psutil # scoped as slow import
|
||||
|
||||
from ultralytics.utils.torch_utils import select_device
|
||||
|
||||
if IS_COLAB:
|
||||
shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
|
||||
|
||||
if verbose:
|
||||
# System info
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
ram = psutil.virtual_memory().total
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
|
||||
try:
|
||||
from IPython import display
|
||||
|
||||
display.clear_output() # clear display if notebook
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
s = ""
|
||||
|
||||
if GIT.is_repo:
|
||||
check_multiple_install() # check conflicting installation if using local clone
|
||||
|
||||
select_device(device=device, newline=False)
|
||||
LOGGER.info(f"Setup complete ✅ {s}")
|
||||
|
||||
|
||||
def collect_system_info():
|
||||
"""
|
||||
Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary containing system information.
|
||||
"""
|
||||
import psutil # scoped as slow import
|
||||
|
||||
from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
|
||||
from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
|
||||
|
||||
gib = 1 << 30 # bytes per GiB
|
||||
cuda = torch.cuda.is_available()
|
||||
check_yolo()
|
||||
total, used, free = shutil.disk_usage("/")
|
||||
|
||||
info_dict = {
|
||||
"OS": platform.platform(),
|
||||
"Environment": ENVIRONMENT,
|
||||
"Python": PYTHON_VERSION,
|
||||
"Install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
|
||||
"Path": str(ROOT),
|
||||
"RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
|
||||
"Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
|
||||
"CPU": get_cpu_info(),
|
||||
"CPU count": os.cpu_count(),
|
||||
"GPU": get_gpu_info(index=0) if cuda else None,
|
||||
"GPU count": torch.cuda.device_count() if cuda else None,
|
||||
"CUDA": torch.version.cuda if cuda else None,
|
||||
}
|
||||
LOGGER.info("\n" + "\n".join(f"{k:<23}{v}" for k, v in info_dict.items()) + "\n")
|
||||
|
||||
package_info = {}
|
||||
for r in parse_requirements(package="ultralytics"):
|
||||
try:
|
||||
current = metadata.version(r.name)
|
||||
is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ "
|
||||
except metadata.PackageNotFoundError:
|
||||
current = "(not installed)"
|
||||
is_met = "❌ "
|
||||
package_info[r.name] = f"{is_met}{current}{r.specifier}"
|
||||
LOGGER.info(f"{r.name:<23}{package_info[r.name]}")
|
||||
|
||||
info_dict["Package Info"] = package_info
|
||||
|
||||
if is_github_action_running():
|
||||
github_info = {
|
||||
"RUNNER_OS": os.getenv("RUNNER_OS"),
|
||||
"GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
|
||||
"GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
|
||||
"GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
|
||||
"GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
|
||||
"GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
|
||||
}
|
||||
LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
|
||||
info_dict["GitHub Info"] = github_info
|
||||
|
||||
return info_dict
|
||||
|
||||
|
||||
def check_amp(model):
|
||||
"""
|
||||
Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
|
||||
|
||||
If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
|
||||
results, so AMP will be disabled during training.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): A YOLO model instance.
|
||||
|
||||
Returns:
|
||||
(bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> from ultralytics.utils.checks import check_amp
|
||||
>>> model = YOLO("yolo11n.pt").model.cuda()
|
||||
>>> check_amp(model)
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
device = next(model.parameters()).device # get model device
|
||||
prefix = colorstr("AMP: ")
|
||||
if device.type in {"cpu", "mps"}:
|
||||
return False # AMP only used on CUDA devices
|
||||
else:
|
||||
# GPUs that have issues with AMP
|
||||
pattern = re.compile(
|
||||
r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE
|
||||
)
|
||||
|
||||
gpu = torch.cuda.get_device_name(device)
|
||||
if bool(pattern.search(gpu)):
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause "
|
||||
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||
)
|
||||
return False
|
||||
|
||||
def amp_allclose(m, im):
|
||||
"""All close FP32 vs AMP results."""
|
||||
batch = [im] * 8
|
||||
imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64
|
||||
a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference
|
||||
with autocast(enabled=True):
|
||||
b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference
|
||||
del m
|
||||
return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
|
||||
|
||||
im = ASSETS / "bus.jpg" # image to check
|
||||
LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
|
||||
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
assert amp_allclose(YOLO("yolo11n.pt"), im)
|
||||
LOGGER.info(f"{prefix}checks passed ✅")
|
||||
except ConnectionError:
|
||||
LOGGER.warning(f"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}")
|
||||
except (AttributeError, ModuleNotFoundError):
|
||||
LOGGER.warning(
|
||||
f"{prefix}checks skipped. "
|
||||
f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}"
|
||||
)
|
||||
except AssertionError:
|
||||
LOGGER.error(
|
||||
f"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to "
|
||||
f"NaN losses or zero-mAP results, so AMP will be disabled during training."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def check_multiple_install():
|
||||
"""Check if there are multiple Ultralytics installations."""
|
||||
import sys
|
||||
|
||||
try:
|
||||
result = subprocess.run([sys.executable, "-m", "pip", "show", "ultralytics"], capture_output=True, text=True)
|
||||
install_msg = (
|
||||
f"Install your local copy in editable mode with 'pip install -e {ROOT.parent}' to avoid "
|
||||
"issues. See https://docs.ultralytics.com/quickstart/"
|
||||
)
|
||||
if result.returncode != 0:
|
||||
if "not found" in result.stderr.lower(): # Package not pip-installed but locally imported
|
||||
LOGGER.warning(f"Ultralytics not found via pip but importing from: {ROOT}. {install_msg}")
|
||||
return
|
||||
yolo_path = (Path(re.findall(r"location:\s+(.+)", result.stdout, flags=re.I)[-1]) / "ultralytics").resolve()
|
||||
if not yolo_path.samefile(ROOT.resolve()):
|
||||
LOGGER.warning(
|
||||
f"Multiple Ultralytics installations detected. The `yolo` command uses: {yolo_path}, "
|
||||
f"but current session imports from: {ROOT}. This may cause version conflicts. {install_msg}"
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def print_args(args: dict | None = None, show_file=True, show_func=False):
|
||||
"""
|
||||
Print function arguments (optional args dict).
|
||||
|
||||
Args:
|
||||
args (dict, optional): Arguments to print.
|
||||
show_file (bool): Whether to show the file name.
|
||||
show_func (bool): Whether to show the function name.
|
||||
"""
|
||||
|
||||
def strip_auth(v):
|
||||
"""Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
|
||||
return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
|
||||
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in sorted(args.items())))
|
||||
|
||||
|
||||
def cuda_device_count() -> int:
|
||||
"""
|
||||
Get the number of NVIDIA GPUs available in the environment.
|
||||
|
||||
Returns:
|
||||
(int): The number of NVIDIA GPUs available.
|
||||
"""
|
||||
if IS_JETSON:
|
||||
# NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead
|
||||
return torch.cuda.device_count()
|
||||
else:
|
||||
try:
|
||||
# Run the nvidia-smi command and capture its output
|
||||
output = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
|
||||
)
|
||||
|
||||
# Take the first line and strip any leading/trailing white space
|
||||
first_line = output.strip().split("\n", 1)[0]
|
||||
|
||||
return int(first_line)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
|
||||
# If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
|
||||
return 0
|
||||
|
||||
|
||||
def cuda_is_available() -> bool:
|
||||
"""
|
||||
Check if CUDA is available in the environment.
|
||||
|
||||
Returns:
|
||||
(bool): True if one or more NVIDIA GPUs are available, False otherwise.
|
||||
"""
|
||||
return cuda_device_count() > 0
|
||||
|
||||
|
||||
def is_rockchip():
|
||||
"""
|
||||
Check if the current environment is running on a Rockchip SoC.
|
||||
|
||||
Returns:
|
||||
(bool): True if running on a Rockchip SoC, False otherwise.
|
||||
"""
|
||||
if LINUX and ARM64:
|
||||
try:
|
||||
with open("/proc/device-tree/compatible") as f:
|
||||
dev_str = f.read()
|
||||
*_, soc = dev_str.split(",")
|
||||
if soc.replace("\x00", "") in RKNN_CHIPS:
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_intel():
|
||||
"""
|
||||
Check if the system has Intel hardware (CPU or GPU).
|
||||
|
||||
Returns:
|
||||
(bool): True if Intel hardware is detected, False otherwise.
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import get_cpu_info
|
||||
|
||||
# Check CPU
|
||||
if "intel" in get_cpu_info().lower():
|
||||
return True
|
||||
|
||||
# Check GPU via xpu-smi
|
||||
try:
|
||||
result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5)
|
||||
return "intel" in result.stdout.lower()
|
||||
except Exception: # broad clause to capture all Intel GPU exception types
|
||||
return False
|
||||
|
||||
|
||||
def is_sudo_available() -> bool:
|
||||
"""
|
||||
Check if the sudo command is available in the environment.
|
||||
|
||||
Returns:
|
||||
(bool): True if the sudo command is available, False otherwise.
|
||||
"""
|
||||
if WINDOWS:
|
||||
return False
|
||||
cmd = "sudo --version"
|
||||
return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
|
||||
|
||||
|
||||
# Run checks and define constants
|
||||
check_python("3.8", hard=False, verbose=True) # check python version
|
||||
check_torchvision() # check torch-torchvision compatibility
|
||||
|
||||
# Define constants
|
||||
IS_PYTHON_3_8 = PYTHON_VERSION.startswith("3.8")
|
||||
IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
|
||||
IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13")
|
||||
|
||||
IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
|
||||
IS_PYTHON_MINIMUM_3_12 = check_python("3.12", hard=False)
|
||||
90
ultralytics/utils/cpu.py
Normal file
90
ultralytics/utils/cpu.py
Normal file
@@ -0,0 +1,90 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class CPUInfo:
|
||||
"""
|
||||
Provide cross-platform CPU brand and model information.
|
||||
|
||||
Query platform-specific sources to retrieve a human-readable CPU descriptor and normalize it for consistent
|
||||
presentation across macOS, Linux, and Windows. If platform-specific probing fails, generic platform identifiers are
|
||||
used to ensure a stable string is always returned.
|
||||
|
||||
Methods:
|
||||
name: Return the normalized CPU name using platform-specific sources with robust fallbacks.
|
||||
_clean: Normalize and prettify common vendor brand strings and frequency patterns.
|
||||
__str__: Return the normalized CPU name for string contexts.
|
||||
|
||||
Examples:
|
||||
>>> CPUInfo.name()
|
||||
'Apple M4 Pro'
|
||||
>>> str(CPUInfo())
|
||||
'Intel Core i7-9750H 2.60GHz'
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return a normalized CPU model string from platform-specific sources."""
|
||||
try:
|
||||
if sys.platform == "darwin":
|
||||
# Query macOS sysctl for the CPU brand string
|
||||
s = subprocess.run(
|
||||
["sysctl", "-n", "machdep.cpu.brand_string"], capture_output=True, text=True
|
||||
).stdout.strip()
|
||||
if s:
|
||||
return CPUInfo._clean(s)
|
||||
elif sys.platform.startswith("linux"):
|
||||
# Parse /proc/cpuinfo for the first "model name" entry
|
||||
p = Path("/proc/cpuinfo")
|
||||
if p.exists():
|
||||
for line in p.read_text(errors="ignore").splitlines():
|
||||
if "model name" in line:
|
||||
return CPUInfo._clean(line.split(":", 1)[1])
|
||||
elif sys.platform.startswith("win"):
|
||||
try:
|
||||
import winreg as wr
|
||||
|
||||
with wr.OpenKey(wr.HKEY_LOCAL_MACHINE, r"HARDWARE\DESCRIPTION\System\CentralProcessor\0") as k:
|
||||
val, _ = wr.QueryValueEx(k, "ProcessorNameString")
|
||||
if val:
|
||||
return CPUInfo._clean(val)
|
||||
except Exception:
|
||||
# Fall through to generic platform fallbacks on Windows registry access failure
|
||||
pass
|
||||
# Generic platform fallbacks
|
||||
s = platform.processor() or getattr(platform.uname(), "processor", "") or platform.machine()
|
||||
return CPUInfo._clean(s or "Unknown CPU")
|
||||
except Exception:
|
||||
# Ensure a string is always returned even on unexpected failures
|
||||
s = platform.processor() or platform.machine() or ""
|
||||
return CPUInfo._clean(s or "Unknown CPU")
|
||||
|
||||
@staticmethod
|
||||
def _clean(s: str) -> str:
|
||||
"""Normalize and prettify a raw CPU descriptor string."""
|
||||
s = re.sub(r"\s+", " ", s.strip())
|
||||
s = s.replace("(TM)", "").replace("(tm)", "").replace("(R)", "").replace("(r)", "").strip()
|
||||
# Normalize common Intel pattern to 'Model Freq'
|
||||
m = re.search(r"(Intel.*?i\d[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
|
||||
if m:
|
||||
return f"{m.group(1)} {m.group(2)}"
|
||||
# Normalize common AMD Ryzen pattern to 'Model Freq'
|
||||
m = re.search(r"(AMD.*?Ryzen.*?[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
|
||||
if m:
|
||||
return f"{m.group(1)} {m.group(2)}"
|
||||
return s
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return the normalized CPU name."""
|
||||
return self.name()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(CPUInfo.name())
|
||||
127
ultralytics/utils/dist.py
Normal file
127
ultralytics/utils/dist.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from . import USER_CONFIG_DIR
|
||||
from .torch_utils import TORCH_1_9
|
||||
|
||||
|
||||
def find_free_network_port() -> int:
|
||||
"""
|
||||
Find a free port on localhost.
|
||||
|
||||
It is useful in single-node training when we don't want to connect to a real main node but have to set the
|
||||
`MASTER_PORT` environment variable.
|
||||
|
||||
Returns:
|
||||
(int): The available network port number.
|
||||
"""
|
||||
import socket
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1] # port
|
||||
|
||||
|
||||
def generate_ddp_file(trainer):
|
||||
"""
|
||||
Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
|
||||
|
||||
This function creates a temporary Python file that enables distributed training across multiple GPUs.
|
||||
The file contains the necessary configuration to initialize the trainer in a distributed environment.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.
|
||||
Must have args attribute and be a class instance.
|
||||
|
||||
Returns:
|
||||
(str): Path to the generated temporary DDP file.
|
||||
|
||||
Notes:
|
||||
The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:
|
||||
- Trainer class import
|
||||
- Configuration overrides from the trainer arguments
|
||||
- Model path configuration
|
||||
- Training initialization code
|
||||
"""
|
||||
module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
|
||||
|
||||
content = f"""
|
||||
# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
|
||||
overrides = {vars(trainer.args)}
|
||||
|
||||
if __name__ == "__main__":
|
||||
from {module} import {name}
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT
|
||||
|
||||
cfg = DEFAULT_CFG_DICT.copy()
|
||||
cfg.update(save_dir='') # handle the extra key 'save_dir'
|
||||
trainer = {name}(cfg=cfg, overrides=overrides)
|
||||
trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
|
||||
results = trainer.train()
|
||||
"""
|
||||
(USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(
|
||||
prefix="_temp_",
|
||||
suffix=f"{id(trainer)}.py",
|
||||
mode="w+",
|
||||
encoding="utf-8",
|
||||
dir=USER_CONFIG_DIR / "DDP",
|
||||
delete=False,
|
||||
) as file:
|
||||
file.write(content)
|
||||
return file.name
|
||||
|
||||
|
||||
def generate_ddp_command(trainer):
|
||||
"""
|
||||
Generate command for distributed training.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.
|
||||
|
||||
Returns:
|
||||
cmd (list[str]): The command to execute for distributed training.
|
||||
file (str): Path to the temporary file created for DDP training.
|
||||
"""
|
||||
import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
|
||||
|
||||
if not trainer.resume:
|
||||
shutil.rmtree(trainer.save_dir) # remove the save_dir
|
||||
file = generate_ddp_file(trainer)
|
||||
dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
|
||||
port = find_free_network_port()
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
dist_cmd,
|
||||
"--nproc_per_node",
|
||||
f"{trainer.world_size}",
|
||||
"--master_port",
|
||||
f"{port}",
|
||||
file,
|
||||
]
|
||||
return cmd, file
|
||||
|
||||
|
||||
def ddp_cleanup(trainer, file):
|
||||
"""
|
||||
Delete temporary file if created during distributed data parallel (DDP) training.
|
||||
|
||||
This function checks if the provided file contains the trainer's ID in its name, indicating it was created
|
||||
as a temporary file for DDP training, and deletes it if so.
|
||||
|
||||
Args:
|
||||
trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.
|
||||
file (str): Path to the file that might need to be deleted.
|
||||
|
||||
Examples:
|
||||
>>> trainer = YOLOTrainer()
|
||||
>>> file = "/tmp/ddp_temp_123456789.py"
|
||||
>>> ddp_cleanup(trainer, file)
|
||||
"""
|
||||
if f"{id(trainer)}.py" in file: # if temp_file suffix in file
|
||||
os.remove(file)
|
||||
541
ultralytics/utils/downloads.py
Normal file
541
ultralytics/utils/downloads.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
from itertools import repeat
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from urllib import parse, request
|
||||
|
||||
from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
|
||||
|
||||
# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
|
||||
GITHUB_ASSETS_REPO = "ultralytics/assets"
|
||||
GITHUB_ASSETS_NAMES = frozenset(
|
||||
[f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
|
||||
+ [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
|
||||
+ [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
|
||||
+ [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
|
||||
+ [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
|
||||
+ [f"yolov8{k}-world.pt" for k in "smlx"]
|
||||
+ [f"yolov8{k}-worldv2.pt" for k in "smlx"]
|
||||
+ [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
|
||||
+ [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
|
||||
+ [f"yolov9{k}.pt" for k in "tsmce"]
|
||||
+ [f"yolov10{k}.pt" for k in "nsmblx"]
|
||||
+ [f"yolo_nas_{k}.pt" for k in "sml"]
|
||||
+ [f"sam_{k}.pt" for k in "bl"]
|
||||
+ [f"sam2_{k}.pt" for k in "blst"]
|
||||
+ [f"sam2.1_{k}.pt" for k in "blst"]
|
||||
+ [f"FastSAM-{k}.pt" for k in "sx"]
|
||||
+ [f"rtdetr-{k}.pt" for k in "lx"]
|
||||
+ [
|
||||
"mobile_sam.pt",
|
||||
"mobileclip_blt.ts",
|
||||
"yolo11n-grayscale.pt",
|
||||
"calibration_image_sample_data_20x128x128x3_float32.npy.zip",
|
||||
]
|
||||
)
|
||||
GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES)
|
||||
|
||||
|
||||
def is_url(url: str | Path, check: bool = False) -> bool:
|
||||
"""
|
||||
Validate if the given string is a URL and optionally check if the URL exists online.
|
||||
|
||||
Args:
|
||||
url (str): The string to be validated as a URL.
|
||||
check (bool, optional): If True, performs an additional check to see if the URL exists online.
|
||||
|
||||
Returns:
|
||||
(bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.
|
||||
|
||||
Examples:
|
||||
>>> valid = is_url("https://www.example.com")
|
||||
>>> valid_and_exists = is_url("https://www.example.com", check=True)
|
||||
"""
|
||||
try:
|
||||
url = str(url)
|
||||
result = parse.urlparse(url)
|
||||
assert all([result.scheme, result.netloc]) # check if is url
|
||||
if check:
|
||||
with request.urlopen(url) as response:
|
||||
return response.getcode() == 200 # check if exists online
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_Store", "__MACOSX")) -> None:
|
||||
"""
|
||||
Delete all specified system files in a directory.
|
||||
|
||||
Args:
|
||||
path (str | Path): The directory path where the files should be deleted.
|
||||
files_to_delete (tuple): The files to be deleted.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.downloads import delete_dsstore
|
||||
>>> delete_dsstore("path/to/dir")
|
||||
|
||||
Notes:
|
||||
".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
|
||||
are hidden system files and can cause issues when transferring files between different operating systems.
|
||||
"""
|
||||
for file in files_to_delete:
|
||||
matches = list(Path(path).rglob(file))
|
||||
LOGGER.info(f"Deleting {file} files: {matches}")
|
||||
for f in matches:
|
||||
f.unlink()
|
||||
|
||||
|
||||
def zip_directory(
|
||||
directory: str | Path,
|
||||
compress: bool = True,
|
||||
exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
|
||||
progress: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Zip the contents of a directory, excluding specified files.
|
||||
|
||||
The resulting zip file is named after the directory and placed alongside it.
|
||||
|
||||
Args:
|
||||
directory (str | Path): The path to the directory to be zipped.
|
||||
compress (bool): Whether to compress the files while zipping.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded.
|
||||
progress (bool, optional): Whether to display a progress bar.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the resulting zip file.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.downloads import zip_directory
|
||||
>>> file = zip_directory("path/to/dir")
|
||||
"""
|
||||
from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
|
||||
|
||||
delete_dsstore(directory)
|
||||
directory = Path(directory)
|
||||
if not directory.is_dir():
|
||||
raise FileNotFoundError(f"Directory '{directory}' does not exist.")
|
||||
|
||||
# Zip with progress bar
|
||||
files = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] # files to zip
|
||||
zip_file = directory.with_suffix(".zip")
|
||||
compression = ZIP_DEFLATED if compress else ZIP_STORED
|
||||
with ZipFile(zip_file, "w", compression) as f:
|
||||
for file in TQDM(files, desc=f"Zipping {directory} to {zip_file}...", unit="files", disable=not progress):
|
||||
f.write(file, file.relative_to(directory))
|
||||
|
||||
return zip_file # return path to zip file
|
||||
|
||||
|
||||
def unzip_file(
|
||||
file: str | Path,
|
||||
path: str | Path | None = None,
|
||||
exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
|
||||
exist_ok: bool = False,
|
||||
progress: bool = True,
|
||||
) -> Path:
|
||||
"""
|
||||
Unzip a *.zip file to the specified path, excluding specified files.
|
||||
|
||||
If the zipfile does not contain a single top-level directory, the function will create a new
|
||||
directory with the same name as the zipfile (without the extension) to extract its contents.
|
||||
If a path is not provided, the function will use the parent directory of the zipfile as the default path.
|
||||
|
||||
Args:
|
||||
file (str | Path): The path to the zipfile to be extracted.
|
||||
path (str | Path, optional): The path to extract the zipfile to.
|
||||
exclude (tuple, optional): A tuple of filename strings to be excluded.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents if they exist.
|
||||
progress (bool, optional): Whether to display a progress bar.
|
||||
|
||||
Returns:
|
||||
(Path): The path to the directory where the zipfile was extracted.
|
||||
|
||||
Raises:
|
||||
BadZipFile: If the provided file does not exist or is not a valid zipfile.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.downloads import unzip_file
|
||||
>>> directory = unzip_file("path/to/file.zip")
|
||||
"""
|
||||
from zipfile import BadZipFile, ZipFile, is_zipfile
|
||||
|
||||
if not (Path(file).exists() and is_zipfile(file)):
|
||||
raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
|
||||
if path is None:
|
||||
path = Path(file).parent # default path
|
||||
|
||||
# Unzip the file contents
|
||||
with ZipFile(file) as zipObj:
|
||||
files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
|
||||
top_level_dirs = {Path(f).parts[0] for f in files}
|
||||
|
||||
# Decide to unzip directly or unzip into a directory
|
||||
unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/"))
|
||||
if unzip_as_dir:
|
||||
# Zip has 1 top-level directory
|
||||
extract_path = path # i.e. ../datasets
|
||||
path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/
|
||||
else:
|
||||
# Zip has multiple files at top level
|
||||
path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/
|
||||
|
||||
# Check if destination directory already exists and contains files
|
||||
if path.exists() and any(path.iterdir()) and not exist_ok:
|
||||
# If it exists and is not empty, return the path without unzipping
|
||||
LOGGER.warning(f"Skipping {file} unzip as destination directory {path} is not empty.")
|
||||
return path
|
||||
|
||||
for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="files", disable=not progress):
|
||||
# Ensure the file is within the extract_path to avoid path traversal security vulnerability
|
||||
if ".." in Path(f).parts:
|
||||
LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
|
||||
continue
|
||||
zipObj.extract(f, extract_path)
|
||||
|
||||
return path # return unzip dir
|
||||
|
||||
|
||||
def check_disk_space(
|
||||
file_bytes: int,
|
||||
path: str | Path = Path.cwd(),
|
||||
sf: float = 1.5,
|
||||
hard: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if there is sufficient disk space to download and store a file.
|
||||
|
||||
Args:
|
||||
file_bytes (int): The file size in bytes.
|
||||
path (str | Path, optional): The path or drive to check the available free space on.
|
||||
sf (float, optional): Safety factor, the multiplier for the required free space.
|
||||
hard (bool, optional): Whether to throw an error or not on insufficient disk space.
|
||||
|
||||
Returns:
|
||||
(bool): True if there is sufficient disk space, False otherwise.
|
||||
"""
|
||||
total, used, free = shutil.disk_usage(path) # bytes
|
||||
if file_bytes * sf < free:
|
||||
return True # sufficient space
|
||||
|
||||
# Insufficient space
|
||||
text = (
|
||||
f"Insufficient free disk space {free >> 30:.3f} GB < {int(file_bytes * sf) >> 30:.3f} GB required, "
|
||||
f"Please free {int(file_bytes * sf - free) >> 30:.3f} GB additional disk space and try again."
|
||||
)
|
||||
if hard:
|
||||
raise MemoryError(text)
|
||||
LOGGER.warning(text)
|
||||
return False
|
||||
|
||||
|
||||
def get_google_drive_file_info(link: str) -> tuple[str, str | None]:
|
||||
"""
|
||||
Retrieve the direct download link and filename for a shareable Google Drive file link.
|
||||
|
||||
Args:
|
||||
link (str): The shareable link of the Google Drive file.
|
||||
|
||||
Returns:
|
||||
url (str): Direct download URL for the Google Drive file.
|
||||
filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.downloads import get_google_drive_file_info
|
||||
>>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
|
||||
>>> url, filename = get_google_drive_file_info(link)
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
file_id = link.split("/d/")[1].split("/view", 1)[0]
|
||||
drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
|
||||
filename = None
|
||||
|
||||
# Start session
|
||||
with requests.Session() as session:
|
||||
response = session.get(drive_url, stream=True)
|
||||
if "quota exceeded" in str(response.content.lower()):
|
||||
raise ConnectionError(
|
||||
emojis(
|
||||
f"❌ Google Drive file download quota exceeded. "
|
||||
f"Please try again later or download this file manually at {link}."
|
||||
)
|
||||
)
|
||||
for k, v in response.cookies.items():
|
||||
if k.startswith("download_warning"):
|
||||
drive_url += f"&confirm={v}" # v is token
|
||||
if cd := response.headers.get("content-disposition"):
|
||||
filename = re.findall('filename="(.+)"', cd)[0]
|
||||
return drive_url, filename
|
||||
|
||||
|
||||
def safe_download(
|
||||
url: str | Path,
|
||||
file: str | Path | None = None,
|
||||
dir: str | Path | None = None,
|
||||
unzip: bool = True,
|
||||
delete: bool = False,
|
||||
curl: bool = False,
|
||||
retry: int = 3,
|
||||
min_bytes: float = 1e0,
|
||||
exist_ok: bool = False,
|
||||
progress: bool = True,
|
||||
) -> Path | str:
|
||||
"""
|
||||
Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
|
||||
robust partial download detection using Content-Length validation.
|
||||
|
||||
Args:
|
||||
url (str): The URL of the file to be downloaded.
|
||||
file (str, optional): The filename of the downloaded file.
|
||||
If not provided, the file will be saved with the same name as the URL.
|
||||
dir (str | Path, optional): The directory to save the downloaded file.
|
||||
If not provided, the file will be saved in the current working directory.
|
||||
unzip (bool, optional): Whether to unzip the downloaded file.
|
||||
delete (bool, optional): Whether to delete the downloaded file after unzipping.
|
||||
curl (bool, optional): Whether to use curl command line tool for downloading.
|
||||
retry (int, optional): The number of times to retry the download in case of failure.
|
||||
min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
|
||||
a successful download.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
|
||||
progress (bool, optional): Whether to display a progress bar during the download.
|
||||
|
||||
Returns:
|
||||
(Path | str): The path to the downloaded file or extracted directory.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils.downloads import safe_download
|
||||
>>> link = "https://ultralytics.com/assets/bus.jpg"
|
||||
>>> path = safe_download(link)
|
||||
"""
|
||||
gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
|
||||
if gdrive:
|
||||
url, file = get_google_drive_file_info(url)
|
||||
|
||||
f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
|
||||
if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
|
||||
f = Path(url) # filename
|
||||
elif not f.is_file(): # URL and file do not exist
|
||||
uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url
|
||||
"https://github.com/ultralytics/assets/releases/download/v0.0.0/",
|
||||
"https://ultralytics.com/assets/", # assets alias
|
||||
)
|
||||
desc = f"Downloading {uri} to '{f}'"
|
||||
f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
|
||||
curl_installed = shutil.which("curl")
|
||||
for i in range(retry + 1):
|
||||
try:
|
||||
if (curl or i > 0) and curl_installed: # curl download with retry, continue
|
||||
s = "sS" * (not progress) # silent
|
||||
r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
|
||||
assert r == 0, f"Curl return value {r}"
|
||||
expected_size = None # Can't get size with curl
|
||||
else: # urllib download
|
||||
with request.urlopen(url) as response:
|
||||
expected_size = int(response.getheader("Content-Length", 0))
|
||||
if i == 0 and expected_size > 1048576:
|
||||
check_disk_space(expected_size, path=f.parent)
|
||||
buffer_size = max(8192, min(1048576, expected_size // 1000)) if expected_size else 8192
|
||||
with TQDM(
|
||||
total=expected_size,
|
||||
desc=desc,
|
||||
disable=not progress,
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as pbar:
|
||||
with open(f, "wb") as f_opened:
|
||||
while True:
|
||||
data = response.read(buffer_size)
|
||||
if not data:
|
||||
break
|
||||
f_opened.write(data)
|
||||
pbar.update(len(data))
|
||||
|
||||
if f.exists():
|
||||
file_size = f.stat().st_size
|
||||
if file_size > min_bytes:
|
||||
# Check if download is complete (only if we have expected_size)
|
||||
if expected_size and file_size != expected_size:
|
||||
LOGGER.warning(
|
||||
f"Partial download: {file_size}/{expected_size} bytes ({file_size / expected_size * 100:.1f}%)"
|
||||
)
|
||||
else:
|
||||
break # success
|
||||
f.unlink() # remove partial downloads
|
||||
except MemoryError:
|
||||
raise # Re-raise immediately - no point retrying if insufficient disk space
|
||||
except Exception as e:
|
||||
if i == 0 and not is_online():
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e
|
||||
elif i >= retry:
|
||||
raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e
|
||||
LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}...")
|
||||
|
||||
if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
|
||||
from zipfile import is_zipfile
|
||||
|
||||
unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
|
||||
if is_zipfile(f):
|
||||
unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
|
||||
elif f.suffix in {".tar", ".gz"}:
|
||||
LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
|
||||
subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
|
||||
if delete:
|
||||
f.unlink() # remove zip
|
||||
return unzip_dir
|
||||
return f
|
||||
|
||||
|
||||
def get_github_assets(
|
||||
repo: str = "ultralytics/assets",
|
||||
version: str = "latest",
|
||||
retry: bool = False,
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Retrieve the specified version's tag and assets from a GitHub repository.
|
||||
|
||||
If the version is not specified, the function fetches the latest release assets.
|
||||
|
||||
Args:
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'.
|
||||
version (str, optional): The release version to fetch assets from.
|
||||
retry (bool, optional): Flag to retry the request in case of a failure.
|
||||
|
||||
Returns:
|
||||
tag (str): The release tag.
|
||||
assets (list[str]): A list of asset names.
|
||||
|
||||
Examples:
|
||||
>>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
|
||||
"""
|
||||
import requests # scoped as slow import
|
||||
|
||||
if version != "latest":
|
||||
version = f"tags/{version}" # i.e. tags/v6.2
|
||||
url = f"https://api.github.com/repos/{repo}/releases/{version}"
|
||||
r = requests.get(url) # github api
|
||||
if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
|
||||
r = requests.get(url) # try again
|
||||
if r.status_code != 200:
|
||||
LOGGER.warning(f"GitHub assets check failure for {url}: {r.status_code} {r.reason}")
|
||||
return "", []
|
||||
data = r.json()
|
||||
return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]
|
||||
|
||||
|
||||
def attempt_download_asset(
|
||||
file: str | Path,
|
||||
repo: str = "ultralytics/assets",
|
||||
release: str = "v8.3.0",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to download a file from GitHub release assets if it is not found locally.
|
||||
|
||||
Args:
|
||||
file (str | Path): The filename or file path to be downloaded.
|
||||
repo (str, optional): The GitHub repository in the format 'owner/repo'.
|
||||
release (str, optional): The specific release version to be downloaded.
|
||||
**kwargs (Any): Additional keyword arguments for the download process.
|
||||
|
||||
Returns:
|
||||
(str): The path to the downloaded file.
|
||||
|
||||
Examples:
|
||||
>>> file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest")
|
||||
"""
|
||||
from ultralytics.utils import SETTINGS # scoped for circular import
|
||||
|
||||
# YOLOv3/5u updates
|
||||
file = str(file)
|
||||
file = checks.check_yolov5u_filename(file)
|
||||
file = Path(file.strip().replace("'", ""))
|
||||
if file.exists():
|
||||
return str(file)
|
||||
elif (SETTINGS["weights_dir"] / file).exists():
|
||||
return str(SETTINGS["weights_dir"] / file)
|
||||
else:
|
||||
# URL specified
|
||||
name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
|
||||
download_url = f"https://github.com/{repo}/releases/download"
|
||||
if str(file).startswith(("http:/", "https:/")): # download
|
||||
url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
|
||||
file = url2file(name) # parse authentication https://url.com/file.txt?auth...
|
||||
if Path(file).is_file():
|
||||
LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
|
||||
else:
|
||||
safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
|
||||
safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
else:
|
||||
tag, assets = get_github_assets(repo, release)
|
||||
if not assets:
|
||||
tag, assets = get_github_assets(repo) # latest release
|
||||
if name in assets:
|
||||
safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
|
||||
|
||||
return str(file)
|
||||
|
||||
|
||||
def download(
|
||||
url: str | list[str] | Path,
|
||||
dir: Path = Path.cwd(),
|
||||
unzip: bool = True,
|
||||
delete: bool = False,
|
||||
curl: bool = False,
|
||||
threads: int = 1,
|
||||
retry: int = 3,
|
||||
exist_ok: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Download files from specified URLs to a given directory.
|
||||
|
||||
Supports concurrent downloads if multiple threads are specified.
|
||||
|
||||
Args:
|
||||
url (str | list[str]): The URL or list of URLs of the files to be downloaded.
|
||||
dir (Path, optional): The directory where the files will be saved.
|
||||
unzip (bool, optional): Flag to unzip the files after downloading.
|
||||
delete (bool, optional): Flag to delete the zip files after extraction.
|
||||
curl (bool, optional): Flag to use curl for downloading.
|
||||
threads (int, optional): Number of threads to use for concurrent downloads.
|
||||
retry (int, optional): Number of retries in case of download failure.
|
||||
exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
|
||||
|
||||
Examples:
|
||||
>>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
|
||||
"""
|
||||
dir = Path(dir)
|
||||
dir.mkdir(parents=True, exist_ok=True) # make directory
|
||||
urls = [url] if isinstance(url, (str, Path)) else url
|
||||
if threads > 1:
|
||||
LOGGER.info(f"Downloading {len(urls)} file(s) with {threads} threads to {dir}...")
|
||||
with ThreadPool(threads) as pool:
|
||||
pool.map(
|
||||
lambda x: safe_download(
|
||||
url=x[0],
|
||||
dir=x[1],
|
||||
unzip=unzip,
|
||||
delete=delete,
|
||||
curl=curl,
|
||||
retry=retry,
|
||||
exist_ok=exist_ok,
|
||||
progress=True,
|
||||
),
|
||||
zip(urls, repeat(dir)),
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
else:
|
||||
for u in urls:
|
||||
safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)
|
||||
43
ultralytics/utils/errors.py
Normal file
43
ultralytics/utils/errors.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.utils import emojis
|
||||
|
||||
|
||||
class HUBModelError(Exception):
|
||||
"""
|
||||
Exception raised when a model cannot be found or retrieved from ultralytics HUB.
|
||||
|
||||
This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.
|
||||
The error message is processed to include emojis for better user experience.
|
||||
|
||||
Attributes:
|
||||
message (str): The error message displayed when the exception is raised.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize the HUBModelError with a custom message.
|
||||
|
||||
Examples:
|
||||
>>> try:
|
||||
... # Code that might fail to find a model
|
||||
... raise HUBModelError("Custom model not found message")
|
||||
... except HUBModelError as e:
|
||||
... print(e) # Displays the emoji-enhanced error message
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Model not found. Please check model URL and try again."):
|
||||
"""
|
||||
Initialize a HUBModelError exception.
|
||||
|
||||
This exception is raised when a requested model is not found or cannot be retrieved from ultralytics HUB.
|
||||
The message is processed to include emojis for better user experience.
|
||||
|
||||
Args:
|
||||
message (str, optional): The error message to display when the exception is raised.
|
||||
|
||||
Examples:
|
||||
>>> try:
|
||||
... raise HUBModelError("Custom model error message")
|
||||
... except HUBModelError as e:
|
||||
... print(e)
|
||||
"""
|
||||
super().__init__(emojis(message))
|
||||
115
ultralytics/utils/events.py
Normal file
115
ultralytics/utils/events.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from ultralytics import SETTINGS, __version__
|
||||
from ultralytics.utils import ARGV, ENVIRONMENT, GIT, IS_PIP_PACKAGE, ONLINE, PYTHON_VERSION, RANK, TESTS_RUNNING
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
|
||||
from ultralytics.utils.torch_utils import get_cpu_info
|
||||
|
||||
|
||||
def _post(url: str, data: dict, timeout: float = 5.0) -> None:
|
||||
"""Send a one-shot JSON POST request."""
|
||||
try:
|
||||
body = json.dumps(data, separators=(",", ":")).encode() # compact JSON
|
||||
req = Request(url, data=body, headers={"Content-Type": "application/json"})
|
||||
urlopen(req, timeout=timeout).close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Events:
|
||||
"""
|
||||
Collect and send anonymous usage analytics with rate-limiting.
|
||||
|
||||
Event collection and transmission are enabled when sync is enabled in settings, the current process is rank -1 or 0,
|
||||
tests are not running, the environment is online, and the installation source is either pip or the official
|
||||
Ultralytics GitHub repository.
|
||||
|
||||
Attributes:
|
||||
url (str): Measurement Protocol endpoint for receiving anonymous events.
|
||||
events (list[dict]): In-memory queue of event payloads awaiting transmission.
|
||||
rate_limit (float): Minimum time in seconds between POST requests.
|
||||
t (float): Timestamp of the last transmission in seconds since the epoch.
|
||||
metadata (dict): Static metadata describing runtime, installation source, and environment.
|
||||
enabled (bool): Flag indicating whether analytics collection is active.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize the event queue, rate limiter, and runtime metadata.
|
||||
__call__: Queue an event and trigger a non-blocking send when the rate limit elapses.
|
||||
"""
|
||||
|
||||
url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the Events instance with queue, rate limiter, and environment metadata."""
|
||||
self.events = [] # pending events
|
||||
self.rate_limit = 30.0 # rate limit (seconds)
|
||||
self.t = 0.0 # last send timestamp (seconds)
|
||||
self.metadata = {
|
||||
"cli": Path(ARGV[0]).name == "yolo",
|
||||
"install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
|
||||
"python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
|
||||
"CPU": get_cpu_info(),
|
||||
# "GPU": get_gpu_info(index=0) if cuda else None,
|
||||
"version": __version__,
|
||||
"env": ENVIRONMENT,
|
||||
"session_id": round(random.random() * 1e15),
|
||||
"engagement_time_msec": 1000,
|
||||
}
|
||||
self.enabled = (
|
||||
SETTINGS["sync"]
|
||||
and RANK in {-1, 0}
|
||||
and not TESTS_RUNNING
|
||||
and ONLINE
|
||||
and (IS_PIP_PACKAGE or GIT.origin == "https://github.com/ultralytics/ultralytics.git")
|
||||
)
|
||||
|
||||
def __call__(self, cfg, device=None) -> None:
|
||||
"""
|
||||
Queue an event and flush the queue asynchronously when the rate limit elapses.
|
||||
|
||||
Args:
|
||||
cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
|
||||
device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
|
||||
"""
|
||||
if not self.enabled:
|
||||
# Events disabled, do nothing
|
||||
return
|
||||
|
||||
# Attempt to enqueue a new event
|
||||
if len(self.events) < 25: # Queue limited to 25 events to bound memory and traffic
|
||||
params = {
|
||||
**self.metadata,
|
||||
"task": cfg.task,
|
||||
"model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
|
||||
"device": str(device),
|
||||
}
|
||||
if cfg.mode == "export":
|
||||
params["format"] = cfg.format
|
||||
self.events.append({"name": cfg.mode, "params": params})
|
||||
|
||||
# Check rate limit and return early if under limit
|
||||
t = time.time()
|
||||
if (t - self.t) < self.rate_limit:
|
||||
return
|
||||
|
||||
# Overrate limit: send a snapshot of queued events in a background thread
|
||||
payload_events = list(self.events) # snapshot to avoid race with queue reset
|
||||
Thread(
|
||||
target=_post,
|
||||
args=(self.url, {"client_id": SETTINGS["uuid"], "events": payload_events}), # SHA-256 anonymized
|
||||
daemon=True,
|
||||
).start()
|
||||
|
||||
# Reset queue and rate limit timer
|
||||
self.events = []
|
||||
self.t = t
|
||||
|
||||
|
||||
events = Events()
|
||||
239
ultralytics/utils/export/__init__.py
Normal file
239
ultralytics/utils/export/__init__.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import IS_JETSON, LOGGER
|
||||
|
||||
from .imx import torch2imx # noqa
|
||||
|
||||
|
||||
def torch2onnx(
|
||||
torch_model: torch.nn.Module,
|
||||
im: torch.Tensor,
|
||||
onnx_file: str,
|
||||
opset: int = 14,
|
||||
input_names: list[str] = ["images"],
|
||||
output_names: list[str] = ["output0"],
|
||||
dynamic: bool | dict = False,
|
||||
) -> None:
|
||||
"""
|
||||
Export a PyTorch model to ONNX format.
|
||||
|
||||
Args:
|
||||
torch_model (torch.nn.Module): The PyTorch model to export.
|
||||
im (torch.Tensor): Example input tensor for the model.
|
||||
onnx_file (str): Path to save the exported ONNX file.
|
||||
opset (int): ONNX opset version to use for export.
|
||||
input_names (list[str]): List of input tensor names.
|
||||
output_names (list[str]): List of output tensor names.
|
||||
dynamic (bool | dict, optional): Whether to enable dynamic axes.
|
||||
|
||||
Notes:
|
||||
Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
|
||||
"""
|
||||
torch.onnx.export(
|
||||
torch_model,
|
||||
im,
|
||||
onnx_file,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic or None,
|
||||
)
|
||||
|
||||
|
||||
def onnx2engine(
|
||||
onnx_file: str,
|
||||
engine_file: str | None = None,
|
||||
workspace: int | None = None,
|
||||
half: bool = False,
|
||||
int8: bool = False,
|
||||
dynamic: bool = False,
|
||||
shape: tuple[int, int, int, int] = (1, 3, 640, 640),
|
||||
dla: int | None = None,
|
||||
dataset=None,
|
||||
metadata: dict | None = None,
|
||||
verbose: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Export a YOLO model to TensorRT engine format.
|
||||
|
||||
Args:
|
||||
onnx_file (str): Path to the ONNX file to be converted.
|
||||
engine_file (str, optional): Path to save the generated TensorRT engine file.
|
||||
workspace (int, optional): Workspace size in GB for TensorRT.
|
||||
half (bool, optional): Enable FP16 precision.
|
||||
int8 (bool, optional): Enable INT8 precision.
|
||||
dynamic (bool, optional): Enable dynamic input shapes.
|
||||
shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
|
||||
dla (int, optional): DLA core to use (Jetson devices only).
|
||||
dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
|
||||
metadata (dict, optional): Metadata to include in the engine file.
|
||||
verbose (bool, optional): Enable verbose logging.
|
||||
prefix (str, optional): Prefix for log messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
|
||||
RuntimeError: If the ONNX file cannot be parsed.
|
||||
|
||||
Notes:
|
||||
TensorRT version compatibility is handled for workspace size and engine building.
|
||||
INT8 calibration requires a dataset and generates a calibration cache.
|
||||
Metadata is serialized and written to the engine file if provided.
|
||||
"""
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
# Engine builder
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
workspace_bytes = int((workspace or 0) * (1 << 30))
|
||||
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
|
||||
if is_trt10 and workspace_bytes > 0:
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
elif workspace_bytes > 0: # TensorRT versions 7, 8
|
||||
config.max_workspace_size = workspace_bytes
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
half = builder.platform_has_fast_fp16 and half
|
||||
int8 = builder.platform_has_fast_int8 and int8
|
||||
|
||||
# Optionally switch to DLA if enabled
|
||||
if dla is not None:
|
||||
if not IS_JETSON:
|
||||
raise ValueError("DLA is only available on NVIDIA Jetson devices")
|
||||
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
|
||||
if not half and not int8:
|
||||
raise ValueError(
|
||||
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
|
||||
)
|
||||
config.default_device_type = trt.DeviceType.DLA
|
||||
config.DLA_core = int(dla)
|
||||
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
|
||||
|
||||
# Read ONNX file
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(onnx_file):
|
||||
raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
|
||||
|
||||
# Network inputs
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
for inp in inputs:
|
||||
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
||||
for out in outputs:
|
||||
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
||||
|
||||
if dynamic:
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
||||
max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
if int8:
|
||||
config.set_calibration_profile(profile)
|
||||
|
||||
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
|
||||
if int8:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
||||
|
||||
class EngineCalibrator(trt.IInt8Calibrator):
|
||||
"""
|
||||
Custom INT8 calibrator for TensorRT engine optimization.
|
||||
|
||||
This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
|
||||
using a dataset. It handles batch generation, caching, and calibration algorithm selection.
|
||||
|
||||
Attributes:
|
||||
dataset: Dataset for calibration.
|
||||
data_iter: Iterator over the calibration dataset.
|
||||
algo (trt.CalibrationAlgoType): Calibration algorithm type.
|
||||
batch (int): Batch size for calibration.
|
||||
cache (Path): Path to save the calibration cache.
|
||||
|
||||
Methods:
|
||||
get_algorithm: Get the calibration algorithm to use.
|
||||
get_batch_size: Get the batch size to use for calibration.
|
||||
get_batch: Get the next batch to use for calibration.
|
||||
read_calibration_cache: Use existing cache instead of calibrating again.
|
||||
write_calibration_cache: Write calibration cache to disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset, # ultralytics.data.build.InfiniteDataLoader
|
||||
cache: str = "",
|
||||
) -> None:
|
||||
"""Initialize the INT8 calibrator with dataset and cache path."""
|
||||
trt.IInt8Calibrator.__init__(self)
|
||||
self.dataset = dataset
|
||||
self.data_iter = iter(dataset)
|
||||
self.algo = (
|
||||
trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
|
||||
if dla is not None
|
||||
else trt.CalibrationAlgoType.MINMAX_CALIBRATION
|
||||
)
|
||||
self.batch = dataset.batch_size
|
||||
self.cache = Path(cache)
|
||||
|
||||
def get_algorithm(self) -> trt.CalibrationAlgoType:
|
||||
"""Get the calibration algorithm to use."""
|
||||
return self.algo
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""Get the batch size to use for calibration."""
|
||||
return self.batch or 1
|
||||
|
||||
def get_batch(self, names) -> list[int] | None:
|
||||
"""Get the next batch to use for calibration, as a list of device memory pointers."""
|
||||
try:
|
||||
im0s = next(self.data_iter)["img"] / 255.0
|
||||
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
|
||||
return [int(im0s.data_ptr())]
|
||||
except StopIteration:
|
||||
# Return None to signal to TensorRT there is no calibration data remaining
|
||||
return None
|
||||
|
||||
def read_calibration_cache(self) -> bytes | None:
|
||||
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
|
||||
if self.cache.exists() and self.cache.suffix == ".cache":
|
||||
return self.cache.read_bytes()
|
||||
|
||||
def write_calibration_cache(self, cache: bytes) -> None:
|
||||
"""Write calibration cache to disk."""
|
||||
_ = self.cache.write_bytes(cache)
|
||||
|
||||
# Load dataset w/ builder (for batching) and calibrate
|
||||
config.int8_calibrator = EngineCalibrator(
|
||||
dataset=dataset,
|
||||
cache=str(Path(onnx_file).with_suffix(".cache")),
|
||||
)
|
||||
|
||||
elif half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Write file
|
||||
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
||||
with build(network, config) as engine, open(engine_file, "wb") as t:
|
||||
# Metadata
|
||||
if metadata is not None:
|
||||
meta = json.dumps(metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine if is_trt10 else engine.serialize())
|
||||
BIN
ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/export/__pycache__/imx.cpython-310.pyc
Normal file
BIN
ultralytics/utils/export/__pycache__/imx.cpython-310.pyc
Normal file
Binary file not shown.
289
ultralytics/utils/export/imx.py
Normal file
289
ultralytics/utils/export/imx.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.modules import Detect, Pose
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.tal import make_anchors
|
||||
from ultralytics.utils.torch_utils import copy_attr
|
||||
|
||||
|
||||
class FXModel(torch.nn.Module):
|
||||
"""
|
||||
A custom model class for torch.fx compatibility.
|
||||
|
||||
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
|
||||
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
|
||||
copying.
|
||||
|
||||
Attributes:
|
||||
model (nn.Module): The original model's layers.
|
||||
"""
|
||||
|
||||
def __init__(self, model, imgsz=(640, 640)):
|
||||
"""
|
||||
Initialize the FXModel.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The original model to wrap for torch.fx compatibility.
|
||||
imgsz (tuple[int, int]): The input image size (height, width). Default is (640, 640).
|
||||
"""
|
||||
super().__init__()
|
||||
copy_attr(self, model)
|
||||
# Explicitly set `model` since `copy_attr` somehow does not copy it.
|
||||
self.model = model.model
|
||||
self.imgsz = imgsz
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
This method performs the forward pass through the model, handling the dependencies between layers and saving
|
||||
intermediate outputs.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output tensor from the model.
|
||||
"""
|
||||
y = [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
# from earlier layers
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
|
||||
if isinstance(m, Detect):
|
||||
m._inference = types.MethodType(_inference, m) # bind method to Detect
|
||||
m.anchors, m.strides = (
|
||||
x.transpose(0, 1)
|
||||
for x in make_anchors(
|
||||
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
|
||||
)
|
||||
)
|
||||
if type(m) is Pose:
|
||||
m.forward = types.MethodType(pose_forward, m) # bind method to Detect
|
||||
x = m(x) # run
|
||||
y.append(x) # save output
|
||||
return x
|
||||
|
||||
|
||||
def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
||||
"""Decode boxes and cls scores for imx object detection."""
|
||||
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
|
||||
|
||||
|
||||
def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for imx pose estimation, including keypoint decoding."""
|
||||
bs = x[0].shape[0] # batch size
|
||||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||||
x = Detect.forward(self, x)
|
||||
pred_kpt = self.kpts_decode(bs, kpt)
|
||||
return (*x, pred_kpt.permute(0, 2, 1))
|
||||
|
||||
|
||||
class NMSWrapper(torch.nn.Module):
|
||||
"""Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
score_threshold: float = 0.001,
|
||||
iou_threshold: float = 0.7,
|
||||
max_detections: int = 300,
|
||||
task: str = "detect",
|
||||
):
|
||||
"""
|
||||
Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model instance.
|
||||
score_threshold (float): Score threshold for non-maximum suppression.
|
||||
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
||||
max_detections (int): The number of detections to return.
|
||||
task (str): Task type, either 'detect' or 'pose'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.score_threshold = score_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.max_detections = max_detections
|
||||
self.task = task
|
||||
|
||||
def forward(self, images):
|
||||
"""Forward pass with model inference and NMS post-processing."""
|
||||
from sony_custom_layers.pytorch import multiclass_nms_with_indices
|
||||
|
||||
# model inference
|
||||
outputs = self.model(images)
|
||||
boxes, scores = outputs[0], outputs[1]
|
||||
nms_outputs = multiclass_nms_with_indices(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
score_threshold=self.score_threshold,
|
||||
iou_threshold=self.iou_threshold,
|
||||
max_detections=self.max_detections,
|
||||
)
|
||||
if self.task == "pose":
|
||||
kpts = outputs[2] # (bs, max_detections, kpts 17*3)
|
||||
out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
|
||||
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
|
||||
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
|
||||
|
||||
|
||||
def torch2imx(
|
||||
model: torch.nn.Module,
|
||||
file: Path | str,
|
||||
conf: float,
|
||||
iou: float,
|
||||
max_det: int,
|
||||
metadata: dict | None = None,
|
||||
gptq: bool = False,
|
||||
dataset=None,
|
||||
prefix: str = "",
|
||||
):
|
||||
"""
|
||||
Export YOLO model to IMX format for deployment on Sony IMX500 devices.
|
||||
|
||||
This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
|
||||
to IMX format compatible with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n
|
||||
models for detection and pose estimation tasks.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
|
||||
file (Path | str): Output file path for the exported model.
|
||||
conf (float): Confidence threshold for NMS post-processing.
|
||||
iou (float): IoU threshold for NMS post-processing.
|
||||
max_det (int): Maximum number of detections to return.
|
||||
metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
|
||||
gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
|
||||
If False, uses standard Post Training Quantization. Defaults to False.
|
||||
dataset (optional): Representative dataset for quantization calibration. Defaults to None.
|
||||
prefix (str, optional): Logging prefix string. Defaults to "".
|
||||
|
||||
Returns:
|
||||
f (Path): Path to the exported IMX model directory
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
|
||||
|
||||
Example:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
|
||||
|
||||
Note:
|
||||
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and sony_custom_layers packages
|
||||
- Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
|
||||
- Output includes quantized ONNX model, IMX binary, and labels.txt file
|
||||
"""
|
||||
import model_compression_toolkit as mct
|
||||
import onnx
|
||||
from edgemdt_tpc import get_target_platform_capabilities
|
||||
|
||||
LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
|
||||
|
||||
def representative_dataset_gen(dataloader=dataset):
|
||||
for batch in dataloader:
|
||||
img = batch["img"]
|
||||
img = img / 255.0
|
||||
yield [img]
|
||||
|
||||
tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
|
||||
|
||||
bit_cfg = mct.core.BitWidthConfig()
|
||||
if "C2PSA" in model.__str__(): # YOLO11
|
||||
if model.task == "detect":
|
||||
layer_names = ["sub", "mul_2", "add_14", "cat_21"]
|
||||
weights_memory = 2585350.2439
|
||||
n_layers = 238 # 238 layers for fused YOLO11n
|
||||
elif model.task == "pose":
|
||||
layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
|
||||
weights_memory = 2437771.67
|
||||
n_layers = 257 # 257 layers for fused YOLO11n-pose
|
||||
else: # YOLOv8
|
||||
if model.task == "detect":
|
||||
layer_names = ["sub", "mul", "add_6", "cat_17"]
|
||||
weights_memory = 2550540.8
|
||||
n_layers = 168 # 168 layers for fused YOLOv8n
|
||||
elif model.task == "pose":
|
||||
layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
|
||||
weights_memory = 2482451.85
|
||||
n_layers = 187 # 187 layers for fused YOLO11n-pose
|
||||
|
||||
# Check if the model has the expected number of layers
|
||||
if len(list(model.modules())) != n_layers:
|
||||
raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
|
||||
|
||||
for layer_name in layer_names:
|
||||
bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
|
||||
|
||||
config = mct.core.CoreConfig(
|
||||
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
|
||||
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
|
||||
bit_width_config=bit_cfg,
|
||||
)
|
||||
|
||||
resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
|
||||
|
||||
quant_model = (
|
||||
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
|
||||
model=model,
|
||||
representative_data_gen=representative_dataset_gen,
|
||||
target_resource_utilization=resource_utilization,
|
||||
gptq_config=mct.gptq.get_pytorch_gptq_config(
|
||||
n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
|
||||
),
|
||||
core_config=config,
|
||||
target_platform_capabilities=tpc,
|
||||
)[0]
|
||||
if gptq
|
||||
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
|
||||
in_module=model,
|
||||
representative_data_gen=representative_dataset_gen,
|
||||
target_resource_utilization=resource_utilization,
|
||||
core_config=config,
|
||||
target_platform_capabilities=tpc,
|
||||
)[0]
|
||||
)
|
||||
|
||||
quant_model = NMSWrapper(
|
||||
model=quant_model,
|
||||
score_threshold=conf or 0.001,
|
||||
iou_threshold=iou,
|
||||
max_detections=max_det,
|
||||
task=model.task,
|
||||
)
|
||||
|
||||
f = Path(str(file).replace(file.suffix, "_imx_model"))
|
||||
f.mkdir(exist_ok=True)
|
||||
onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
|
||||
mct.exporter.pytorch_export_model(
|
||||
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
|
||||
)
|
||||
|
||||
model_onnx = onnx.load(onnx_model) # load onnx model
|
||||
for k, v in metadata.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
|
||||
onnx.save(model_onnx, onnx_model)
|
||||
|
||||
subprocess.run(
|
||||
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Needed for imx models.
|
||||
with open(f / "labels.txt", "w", encoding="utf-8") as file:
|
||||
file.writelines([f"{name}\n" for _, name in model.names.items()])
|
||||
|
||||
return f
|
||||
223
ultralytics/utils/files.py
Normal file
223
ultralytics/utils/files.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class WorkingDirectory(contextlib.ContextDecorator):
|
||||
"""
|
||||
A context manager and decorator for temporarily changing the working directory.
|
||||
|
||||
This class allows for the temporary change of the working directory using a context manager or decorator.
|
||||
It ensures that the original working directory is restored after the context or decorated function completes.
|
||||
|
||||
Attributes:
|
||||
dir (Path | str): The new directory to switch to.
|
||||
cwd (Path): The original current working directory before the switch.
|
||||
|
||||
Methods:
|
||||
__enter__: Changes the current directory to the specified directory.
|
||||
__exit__: Restores the original working directory on context exit.
|
||||
|
||||
Examples:
|
||||
Using as a context manager:
|
||||
>>> with WorkingDirectory('/path/to/new/dir'):
|
||||
>>> # Perform operations in the new directory
|
||||
>>> pass
|
||||
|
||||
Using as a decorator:
|
||||
>>> @WorkingDirectory('/path/to/new/dir')
|
||||
>>> def some_function():
|
||||
>>> # Perform operations in the new directory
|
||||
>>> pass
|
||||
"""
|
||||
|
||||
def __init__(self, new_dir: str | Path):
|
||||
"""Initialize the WorkingDirectory context manager with the target directory."""
|
||||
self.dir = new_dir # new dir
|
||||
self.cwd = Path.cwd().resolve() # current dir
|
||||
|
||||
def __enter__(self):
|
||||
"""Change the current working directory to the specified directory upon entering the context."""
|
||||
os.chdir(self.dir)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb): # noqa
|
||||
"""Restore the original working directory when exiting the context."""
|
||||
os.chdir(self.cwd)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def spaces_in_path(path: str | Path):
|
||||
"""
|
||||
Context manager to handle paths with spaces in their names.
|
||||
|
||||
If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes
|
||||
the context code block, then copies the file/directory back to its original location.
|
||||
|
||||
Args:
|
||||
path (str | Path): The original path that may contain spaces.
|
||||
|
||||
Yields:
|
||||
(Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the
|
||||
original path.
|
||||
|
||||
Examples:
|
||||
>>> with spaces_in_path('/path/with spaces') as new_path:
|
||||
>>> # Your code here
|
||||
>>> pass
|
||||
"""
|
||||
# If path has spaces, replace them with underscores
|
||||
if " " in str(path):
|
||||
string = isinstance(path, str) # input type
|
||||
path = Path(path)
|
||||
|
||||
# Create a temporary directory and construct the new path
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
|
||||
|
||||
# Copy file/directory
|
||||
if path.is_dir():
|
||||
shutil.copytree(path, tmp_path)
|
||||
elif path.is_file():
|
||||
tmp_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(path, tmp_path)
|
||||
|
||||
try:
|
||||
# Yield the temporary path
|
||||
yield str(tmp_path) if string else tmp_path
|
||||
|
||||
finally:
|
||||
# Copy file/directory back
|
||||
if tmp_path.is_dir():
|
||||
shutil.copytree(tmp_path, path, dirs_exist_ok=True)
|
||||
elif tmp_path.is_file():
|
||||
shutil.copy2(tmp_path, path) # Copy back the file
|
||||
|
||||
else:
|
||||
# If there are no spaces, just yield the original path
|
||||
yield path
|
||||
|
||||
|
||||
def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
|
||||
"""
|
||||
Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
|
||||
If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
|
||||
the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
|
||||
number will be appended directly to the end of the path.
|
||||
|
||||
Args:
|
||||
path (str | Path): Path to increment.
|
||||
exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.
|
||||
sep (str, optional): Separator to use between the path and the incrementation number.
|
||||
mkdir (bool, optional): Create a directory if it does not exist.
|
||||
|
||||
Returns:
|
||||
(Path): Incremented path.
|
||||
|
||||
Examples:
|
||||
Increment a directory path:
|
||||
>>> from pathlib import Path
|
||||
>>> path = Path("runs/exp")
|
||||
>>> new_path = increment_path(path)
|
||||
>>> print(new_path)
|
||||
runs/exp2
|
||||
|
||||
Increment a file path:
|
||||
>>> path = Path("runs/exp/results.txt")
|
||||
>>> new_path = increment_path(path)
|
||||
>>> print(new_path)
|
||||
runs/exp/results2.txt
|
||||
"""
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
|
||||
# Method 1
|
||||
for n in range(2, 9999):
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def file_age(path: str | Path = __file__) -> int:
|
||||
"""Return days since the last modification of the specified file."""
|
||||
dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
|
||||
return dt.days # + dt.seconds / 86400 # fractional days
|
||||
|
||||
|
||||
def file_date(path: str | Path = __file__) -> str:
|
||||
"""Return the file modification date in 'YYYY-M-D' format."""
|
||||
t = datetime.fromtimestamp(Path(path).stat().st_mtime)
|
||||
return f"{t.year}-{t.month}-{t.day}"
|
||||
|
||||
|
||||
def file_size(path: str | Path) -> float:
|
||||
"""Return the size of a file or directory in megabytes (MB)."""
|
||||
if isinstance(path, (str, Path)):
|
||||
mb = 1 << 20 # bytes to MiB (1024 ** 2)
|
||||
path = Path(path)
|
||||
if path.is_file():
|
||||
return path.stat().st_size / mb
|
||||
elif path.is_dir():
|
||||
return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
|
||||
return 0.0
|
||||
|
||||
|
||||
def get_latest_run(search_dir: str = ".") -> str:
|
||||
"""Return the path to the most recent 'last.pt' file in the specified directory for resuming training."""
|
||||
last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
|
||||
return max(last_list, key=os.path.getctime) if last_list else ""
|
||||
|
||||
|
||||
def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
|
||||
"""
|
||||
Update and re-save specified YOLO models in an 'updated_models' subdirectory.
|
||||
|
||||
Args:
|
||||
model_names (tuple, optional): Model filenames to update.
|
||||
source_dir (Path, optional): Directory containing models and target subdirectory.
|
||||
update_names (bool, optional): Update model names from a data YAML.
|
||||
|
||||
Examples:
|
||||
Update specified YOLO models and save them in 'updated_models' subdirectory:
|
||||
>>> from ultralytics.utils.files import update_models
|
||||
>>> model_names = ("yolo11n.pt", "yolov8s.pt")
|
||||
>>> update_models(model_names, source_dir=Path("/models"), update_names=True)
|
||||
"""
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.nn.autobackend import default_class_names
|
||||
|
||||
target_dir = source_dir / "updated_models"
|
||||
target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
|
||||
|
||||
for model_name in model_names:
|
||||
model_path = source_dir / model_name
|
||||
print(f"Loading model from {model_path}")
|
||||
|
||||
# Load model
|
||||
model = YOLO(model_path)
|
||||
model.half()
|
||||
if update_names: # update model names from a dataset YAML
|
||||
model.model.names = default_class_names("coco8.yaml")
|
||||
|
||||
# Define new save path
|
||||
save_path = target_dir / model_name
|
||||
|
||||
# Save model using model.save()
|
||||
print(f"Re-saving {model_name} model to {save_path}")
|
||||
model.save(save_path)
|
||||
139
ultralytics/utils/git.py
Normal file
139
ultralytics/utils/git.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class GitRepo:
|
||||
"""
|
||||
Represent a local Git repository and expose branch, commit, and remote metadata.
|
||||
|
||||
This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
|
||||
actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
|
||||
invoke the git binary and therefore works in restricted environments. All metadata properties are resolved
|
||||
lazily and cached; construct a new instance to refresh state.
|
||||
|
||||
Attributes:
|
||||
root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
|
||||
gitdir (Path | None): Resolved .git directory path; handles worktrees; None if unresolved.
|
||||
head (str | None): Raw contents of HEAD; a SHA for detached HEAD or "ref: <refname>" for branch heads.
|
||||
is_repo (bool): Whether the provided path resides inside a Git repository.
|
||||
branch (str | None): Current branch name when HEAD points to a branch; None for detached HEAD or non-repo.
|
||||
commit (str | None): Current commit SHA for HEAD; None if not determinable.
|
||||
origin (str | None): URL of the "origin" remote as read from gitdir/config; None if unset or unavailable.
|
||||
|
||||
Examples:
|
||||
Initialize from the current working directory and read metadata
|
||||
>>> from pathlib import Path
|
||||
>>> repo = GitRepo(Path.cwd())
|
||||
>>> repo.is_repo
|
||||
True
|
||||
>>> repo.branch, repo.commit[:7], repo.origin
|
||||
('main', '1a2b3c4', 'https://example.com/owner/repo.git')
|
||||
|
||||
Notes:
|
||||
- Resolves metadata by reading files: HEAD, packed-refs, and config; no subprocess calls are used.
|
||||
- Caches properties on first access using cached_property; recreate the object to reflect repository changes.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path = Path(__file__).resolve()):
|
||||
"""
|
||||
Initialize a Git repository context by discovering the repository root from a starting path.
|
||||
|
||||
Args:
|
||||
path (Path, optional): File or directory path used as the starting point to locate the repository root.
|
||||
"""
|
||||
self.root = self._find_root(path)
|
||||
self.gitdir = self._gitdir(self.root) if self.root else None
|
||||
|
||||
@staticmethod
|
||||
def _find_root(p: Path) -> Path | None:
|
||||
"""Return repo root or None."""
|
||||
return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
|
||||
|
||||
@staticmethod
|
||||
def _gitdir(root: Path) -> Path | None:
|
||||
"""Resolve actual .git directory (handles worktrees)."""
|
||||
g = root / ".git"
|
||||
if g.is_dir():
|
||||
return g
|
||||
if g.is_file():
|
||||
t = g.read_text(errors="ignore").strip()
|
||||
if t.startswith("gitdir:"):
|
||||
return (root / t.split(":", 1)[1].strip()).resolve()
|
||||
return None
|
||||
|
||||
def _read(self, p: Path | None) -> str | None:
|
||||
"""Read and strip file if exists."""
|
||||
return p.read_text(errors="ignore").strip() if p and p.exists() else None
|
||||
|
||||
@cached_property
|
||||
def head(self) -> str | None:
|
||||
"""HEAD file contents."""
|
||||
return self._read(self.gitdir / "HEAD" if self.gitdir else None)
|
||||
|
||||
def _ref_commit(self, ref: str) -> str | None:
|
||||
"""Commit for ref (handles packed-refs)."""
|
||||
rf = self.gitdir / ref
|
||||
s = self._read(rf)
|
||||
if s:
|
||||
return s
|
||||
pf = self.gitdir / "packed-refs"
|
||||
b = pf.read_bytes().splitlines() if pf.exists() else []
|
||||
tgt = ref.encode()
|
||||
for line in b:
|
||||
if line[:1] in (b"#", b"^") or b" " not in line:
|
||||
continue
|
||||
sha, name = line.split(b" ", 1)
|
||||
if name.strip() == tgt:
|
||||
return sha.decode()
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_repo(self) -> bool:
|
||||
"""True if inside a git repo."""
|
||||
return self.gitdir is not None
|
||||
|
||||
@cached_property
|
||||
def branch(self) -> str | None:
|
||||
"""Current branch or None."""
|
||||
if not self.is_repo or not self.head or not self.head.startswith("ref: "):
|
||||
return None
|
||||
ref = self.head[5:].strip()
|
||||
return ref[len("refs/heads/") :] if ref.startswith("refs/heads/") else ref
|
||||
|
||||
@cached_property
|
||||
def commit(self) -> str | None:
|
||||
"""Current commit SHA or None."""
|
||||
if not self.is_repo or not self.head:
|
||||
return None
|
||||
return self._ref_commit(self.head[5:].strip()) if self.head.startswith("ref: ") else self.head
|
||||
|
||||
@cached_property
|
||||
def origin(self) -> str | None:
|
||||
"""Origin URL or None."""
|
||||
if not self.is_repo:
|
||||
return None
|
||||
cfg = self.gitdir / "config"
|
||||
remote, url = None, None
|
||||
for s in (self._read(cfg) or "").splitlines():
|
||||
t = s.strip()
|
||||
if t.startswith("[") and t.endswith("]"):
|
||||
remote = t.lower()
|
||||
elif t.lower().startswith("url =") and remote == '[remote "origin"]':
|
||||
url = t.split("=", 1)[1].strip()
|
||||
break
|
||||
return url
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
g = GitRepo()
|
||||
if g.is_repo:
|
||||
t0 = time.perf_counter()
|
||||
print(f"repo={g.root}\nbranch={g.branch}\ncommit={g.commit}\norigin={g.origin}")
|
||||
dt = (time.perf_counter() - t0) * 1000
|
||||
print(f"\n⏱️ Profiling: total {dt:.3f} ms")
|
||||
505
ultralytics/utils/instance.py
Normal file
505
ultralytics/utils/instance.py
Normal file
@@ -0,0 +1,505 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc
|
||||
from itertools import repeat
|
||||
from numbers import Number
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
"""Create a function that converts input to n-tuple by repeating singleton values."""
|
||||
|
||||
def parse(x):
|
||||
"""Parse input to return n-tuple by repeating singleton values n times."""
|
||||
return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
|
||||
|
||||
return parse
|
||||
|
||||
|
||||
to_2tuple = _ntuple(2)
|
||||
to_4tuple = _ntuple(4)
|
||||
|
||||
# `xyxy` means left top and right bottom
|
||||
# `xywh` means center x, center y and width, height(YOLO format)
|
||||
# `ltwh` means left top and width, height(COCO format)
|
||||
_formats = ["xyxy", "xywh", "ltwh"]
|
||||
|
||||
__all__ = ("Bboxes", "Instances") # tuple or list
|
||||
|
||||
|
||||
class Bboxes:
|
||||
"""
|
||||
A class for handling bounding boxes in multiple formats.
|
||||
|
||||
The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
|
||||
conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
|
||||
|
||||
Attributes:
|
||||
bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).
|
||||
format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
|
||||
|
||||
Methods:
|
||||
convert: Convert bounding box format from one type to another.
|
||||
areas: Calculate the area of bounding boxes.
|
||||
mul: Multiply bounding box coordinates by scale factor(s).
|
||||
add: Add offset to bounding box coordinates.
|
||||
concatenate: Concatenate multiple Bboxes objects.
|
||||
|
||||
Examples:
|
||||
Create bounding boxes in YOLO format
|
||||
>>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh")
|
||||
>>> bboxes.convert("xyxy")
|
||||
>>> print(bboxes.areas())
|
||||
|
||||
Notes:
|
||||
This class does not handle normalization or denormalization of bounding boxes.
|
||||
"""
|
||||
|
||||
def __init__(self, bboxes: np.ndarray, format: str = "xyxy") -> None:
|
||||
"""
|
||||
Initialize the Bboxes class with bounding box data in a specified format.
|
||||
|
||||
Args:
|
||||
bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).
|
||||
format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'.
|
||||
"""
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
|
||||
assert bboxes.ndim == 2
|
||||
assert bboxes.shape[1] == 4
|
||||
self.bboxes = bboxes
|
||||
self.format = format
|
||||
|
||||
def convert(self, format: str) -> None:
|
||||
"""
|
||||
Convert bounding box format from one type to another.
|
||||
|
||||
Args:
|
||||
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
|
||||
"""
|
||||
assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
|
||||
if self.format == format:
|
||||
return
|
||||
elif self.format == "xyxy":
|
||||
func = xyxy2xywh if format == "xywh" else xyxy2ltwh
|
||||
elif self.format == "xywh":
|
||||
func = xywh2xyxy if format == "xyxy" else xywh2ltwh
|
||||
else:
|
||||
func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
|
||||
self.bboxes = func(self.bboxes)
|
||||
self.format = format
|
||||
|
||||
def areas(self) -> np.ndarray:
|
||||
"""Calculate the area of bounding boxes."""
|
||||
return (
|
||||
(self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy
|
||||
if self.format == "xyxy"
|
||||
else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
|
||||
)
|
||||
|
||||
def mul(self, scale: int | tuple | list) -> None:
|
||||
"""
|
||||
Multiply bounding box coordinates by scale factor(s).
|
||||
|
||||
Args:
|
||||
scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
|
||||
all coordinates.
|
||||
"""
|
||||
if isinstance(scale, Number):
|
||||
scale = to_4tuple(scale)
|
||||
assert isinstance(scale, (tuple, list))
|
||||
assert len(scale) == 4
|
||||
self.bboxes[:, 0] *= scale[0]
|
||||
self.bboxes[:, 1] *= scale[1]
|
||||
self.bboxes[:, 2] *= scale[2]
|
||||
self.bboxes[:, 3] *= scale[3]
|
||||
|
||||
def add(self, offset: int | tuple | list) -> None:
|
||||
"""
|
||||
Add offset to bounding box coordinates.
|
||||
|
||||
Args:
|
||||
offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
|
||||
all coordinates.
|
||||
"""
|
||||
if isinstance(offset, Number):
|
||||
offset = to_4tuple(offset)
|
||||
assert isinstance(offset, (tuple, list))
|
||||
assert len(offset) == 4
|
||||
self.bboxes[:, 0] += offset[0]
|
||||
self.bboxes[:, 1] += offset[1]
|
||||
self.bboxes[:, 2] += offset[2]
|
||||
self.bboxes[:, 3] += offset[3]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of bounding boxes."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
|
||||
"""
|
||||
Concatenate a list of Bboxes objects into a single Bboxes object.
|
||||
|
||||
Args:
|
||||
boxes_list (list[Bboxes]): A list of Bboxes objects to concatenate.
|
||||
axis (int, optional): The axis along which to concatenate the bounding boxes.
|
||||
|
||||
Returns:
|
||||
(Bboxes): A new Bboxes object containing the concatenated bounding boxes.
|
||||
|
||||
Notes:
|
||||
The input should be a list or tuple of Bboxes objects.
|
||||
"""
|
||||
assert isinstance(boxes_list, (list, tuple))
|
||||
if not boxes_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(box, Bboxes) for box in boxes_list)
|
||||
|
||||
if len(boxes_list) == 1:
|
||||
return boxes_list[0]
|
||||
return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
|
||||
|
||||
def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
|
||||
"""
|
||||
Retrieve a specific bounding box or a set of bounding boxes using indexing.
|
||||
|
||||
Args:
|
||||
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
|
||||
|
||||
Returns:
|
||||
(Bboxes): A new Bboxes object containing the selected bounding boxes.
|
||||
|
||||
Notes:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
|
||||
bounding boxes.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
return Bboxes(self.bboxes[index].reshape(1, -1))
|
||||
b = self.bboxes[index]
|
||||
assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
|
||||
return Bboxes(b)
|
||||
|
||||
|
||||
class Instances:
|
||||
"""
|
||||
Container for bounding boxes, segments, and keypoints of detected objects in an image.
|
||||
|
||||
This class provides a unified interface for handling different types of object annotations including bounding
|
||||
boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,
|
||||
and format conversion.
|
||||
|
||||
Attributes:
|
||||
_bboxes (Bboxes): Internal object for handling bounding box operations.
|
||||
keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).
|
||||
normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
|
||||
segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.
|
||||
|
||||
Methods:
|
||||
convert_bbox: Convert bounding box format.
|
||||
scale: Scale coordinates by given factors.
|
||||
denormalize: Convert normalized coordinates to absolute coordinates.
|
||||
normalize: Convert absolute coordinates to normalized coordinates.
|
||||
add_padding: Add padding to coordinates.
|
||||
flipud: Flip coordinates vertically.
|
||||
fliplr: Flip coordinates horizontally.
|
||||
clip: Clip coordinates to stay within image boundaries.
|
||||
remove_zero_area_boxes: Remove boxes with zero area.
|
||||
update: Update instance variables.
|
||||
concatenate: Concatenate multiple Instances objects.
|
||||
|
||||
Examples:
|
||||
Create instances with bounding boxes and segments
|
||||
>>> instances = Instances(
|
||||
... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
|
||||
... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
|
||||
... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bboxes: np.ndarray,
|
||||
segments: np.ndarray = None,
|
||||
keypoints: np.ndarray = None,
|
||||
bbox_format: str = "xywh",
|
||||
normalized: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Instances object with bounding boxes, segments, and keypoints.
|
||||
|
||||
Args:
|
||||
bboxes (np.ndarray): Bounding boxes with shape (N, 4).
|
||||
segments (np.ndarray, optional): Segmentation masks.
|
||||
keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).
|
||||
bbox_format (str): Format of bboxes.
|
||||
normalized (bool): Whether the coordinates are normalized.
|
||||
"""
|
||||
self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
|
||||
self.keypoints = keypoints
|
||||
self.normalized = normalized
|
||||
self.segments = segments
|
||||
|
||||
def convert_bbox(self, format: str) -> None:
|
||||
"""
|
||||
Convert bounding box format.
|
||||
|
||||
Args:
|
||||
format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
|
||||
"""
|
||||
self._bboxes.convert(format=format)
|
||||
|
||||
@property
|
||||
def bbox_areas(self) -> np.ndarray:
|
||||
"""Calculate the area of bounding boxes."""
|
||||
return self._bboxes.areas()
|
||||
|
||||
def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):
|
||||
"""
|
||||
Scale coordinates by given factors.
|
||||
|
||||
Args:
|
||||
scale_w (float): Scale factor for width.
|
||||
scale_h (float): Scale factor for height.
|
||||
bbox_only (bool, optional): Whether to scale only bounding boxes.
|
||||
"""
|
||||
self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
|
||||
if bbox_only:
|
||||
return
|
||||
self.segments[..., 0] *= scale_w
|
||||
self.segments[..., 1] *= scale_h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= scale_w
|
||||
self.keypoints[..., 1] *= scale_h
|
||||
|
||||
def denormalize(self, w: int, h: int) -> None:
|
||||
"""
|
||||
Convert normalized coordinates to absolute coordinates.
|
||||
|
||||
Args:
|
||||
w (int): Image width.
|
||||
h (int): Image height.
|
||||
"""
|
||||
if not self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(w, h, w, h))
|
||||
self.segments[..., 0] *= w
|
||||
self.segments[..., 1] *= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] *= w
|
||||
self.keypoints[..., 1] *= h
|
||||
self.normalized = False
|
||||
|
||||
def normalize(self, w: int, h: int) -> None:
|
||||
"""
|
||||
Convert absolute coordinates to normalized coordinates.
|
||||
|
||||
Args:
|
||||
w (int): Image width.
|
||||
h (int): Image height.
|
||||
"""
|
||||
if self.normalized:
|
||||
return
|
||||
self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
|
||||
self.segments[..., 0] /= w
|
||||
self.segments[..., 1] /= h
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] /= w
|
||||
self.keypoints[..., 1] /= h
|
||||
self.normalized = True
|
||||
|
||||
def add_padding(self, padw: int, padh: int) -> None:
|
||||
"""
|
||||
Add padding to coordinates.
|
||||
|
||||
Args:
|
||||
padw (int): Padding width.
|
||||
padh (int): Padding height.
|
||||
"""
|
||||
assert not self.normalized, "you should add padding with absolute coordinates."
|
||||
self._bboxes.add(offset=(padw, padh, padw, padh))
|
||||
self.segments[..., 0] += padw
|
||||
self.segments[..., 1] += padh
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] += padw
|
||||
self.keypoints[..., 1] += padh
|
||||
|
||||
def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
|
||||
"""
|
||||
Retrieve a specific instance or a set of instances using indexing.
|
||||
|
||||
Args:
|
||||
index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.
|
||||
|
||||
Returns:
|
||||
(Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.
|
||||
|
||||
Notes:
|
||||
When using boolean indexing, make sure to provide a boolean array with the same length as the number of
|
||||
instances.
|
||||
"""
|
||||
segments = self.segments[index] if len(self.segments) else self.segments
|
||||
keypoints = self.keypoints[index] if self.keypoints is not None else None
|
||||
bboxes = self.bboxes[index]
|
||||
bbox_format = self._bboxes.format
|
||||
return Instances(
|
||||
bboxes=bboxes,
|
||||
segments=segments,
|
||||
keypoints=keypoints,
|
||||
bbox_format=bbox_format,
|
||||
normalized=self.normalized,
|
||||
)
|
||||
|
||||
def flipud(self, h: int) -> None:
|
||||
"""
|
||||
Flip coordinates vertically.
|
||||
|
||||
Args:
|
||||
h (int): Image height.
|
||||
"""
|
||||
if self._bboxes.format == "xyxy":
|
||||
y1 = self.bboxes[:, 1].copy()
|
||||
y2 = self.bboxes[:, 3].copy()
|
||||
self.bboxes[:, 1] = h - y2
|
||||
self.bboxes[:, 3] = h - y1
|
||||
else:
|
||||
self.bboxes[:, 1] = h - self.bboxes[:, 1]
|
||||
self.segments[..., 1] = h - self.segments[..., 1]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 1] = h - self.keypoints[..., 1]
|
||||
|
||||
def fliplr(self, w: int) -> None:
|
||||
"""
|
||||
Flip coordinates horizontally.
|
||||
|
||||
Args:
|
||||
w (int): Image width.
|
||||
"""
|
||||
if self._bboxes.format == "xyxy":
|
||||
x1 = self.bboxes[:, 0].copy()
|
||||
x2 = self.bboxes[:, 2].copy()
|
||||
self.bboxes[:, 0] = w - x2
|
||||
self.bboxes[:, 2] = w - x1
|
||||
else:
|
||||
self.bboxes[:, 0] = w - self.bboxes[:, 0]
|
||||
self.segments[..., 0] = w - self.segments[..., 0]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints[..., 0] = w - self.keypoints[..., 0]
|
||||
|
||||
def clip(self, w: int, h: int) -> None:
|
||||
"""
|
||||
Clip coordinates to stay within image boundaries.
|
||||
|
||||
Args:
|
||||
w (int): Image width.
|
||||
h (int): Image height.
|
||||
"""
|
||||
ori_format = self._bboxes.format
|
||||
self.convert_bbox(format="xyxy")
|
||||
self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
|
||||
self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
|
||||
if ori_format != "xyxy":
|
||||
self.convert_bbox(format=ori_format)
|
||||
self.segments[..., 0] = self.segments[..., 0].clip(0, w)
|
||||
self.segments[..., 1] = self.segments[..., 1].clip(0, h)
|
||||
if self.keypoints is not None:
|
||||
# Set out of bounds visibility to zero
|
||||
self.keypoints[..., 2][
|
||||
(self.keypoints[..., 0] < 0)
|
||||
| (self.keypoints[..., 0] > w)
|
||||
| (self.keypoints[..., 1] < 0)
|
||||
| (self.keypoints[..., 1] > h)
|
||||
] = 0.0
|
||||
self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
|
||||
self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
|
||||
|
||||
def remove_zero_area_boxes(self) -> np.ndarray:
|
||||
"""
|
||||
Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Boolean array indicating which boxes were kept.
|
||||
"""
|
||||
good = self.bbox_areas > 0
|
||||
if not all(good):
|
||||
self._bboxes = self._bboxes[good]
|
||||
if len(self.segments):
|
||||
self.segments = self.segments[good]
|
||||
if self.keypoints is not None:
|
||||
self.keypoints = self.keypoints[good]
|
||||
return good
|
||||
|
||||
def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):
|
||||
"""
|
||||
Update instance variables.
|
||||
|
||||
Args:
|
||||
bboxes (np.ndarray): New bounding boxes.
|
||||
segments (np.ndarray, optional): New segments.
|
||||
keypoints (np.ndarray, optional): New keypoints.
|
||||
"""
|
||||
self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
|
||||
if segments is not None:
|
||||
self.segments = segments
|
||||
if keypoints is not None:
|
||||
self.keypoints = keypoints
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of instances."""
|
||||
return len(self.bboxes)
|
||||
|
||||
@classmethod
|
||||
def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
|
||||
"""
|
||||
Concatenate a list of Instances objects into a single Instances object.
|
||||
|
||||
Args:
|
||||
instances_list (list[Instances]): A list of Instances objects to concatenate.
|
||||
axis (int, optional): The axis along which the arrays will be concatenated.
|
||||
|
||||
Returns:
|
||||
(Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
|
||||
if present.
|
||||
|
||||
Notes:
|
||||
The `Instances` objects in the list should have the same properties, such as the format of the bounding
|
||||
boxes, whether keypoints are present, and if the coordinates are normalized.
|
||||
"""
|
||||
assert isinstance(instances_list, (list, tuple))
|
||||
if not instances_list:
|
||||
return cls(np.empty(0))
|
||||
assert all(isinstance(instance, Instances) for instance in instances_list)
|
||||
|
||||
if len(instances_list) == 1:
|
||||
return instances_list[0]
|
||||
|
||||
use_keypoint = instances_list[0].keypoints is not None
|
||||
bbox_format = instances_list[0]._bboxes.format
|
||||
normalized = instances_list[0].normalized
|
||||
|
||||
cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
|
||||
seg_len = [b.segments.shape[1] for b in instances_list]
|
||||
if len(frozenset(seg_len)) > 1: # resample segments if there's different length
|
||||
max_len = max(seg_len)
|
||||
cat_segments = np.concatenate(
|
||||
[
|
||||
resample_segments(list(b.segments), max_len)
|
||||
if len(b.segments)
|
||||
else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments
|
||||
for b in instances_list
|
||||
],
|
||||
axis=axis,
|
||||
)
|
||||
else:
|
||||
cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
|
||||
cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
|
||||
return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
|
||||
|
||||
@property
|
||||
def bboxes(self) -> np.ndarray:
|
||||
"""Return bounding boxes."""
|
||||
return self._bboxes.bboxes
|
||||
408
ultralytics/utils/logger.py
Normal file
408
ultralytics/utils/logger.py
Normal file
@@ -0,0 +1,408 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.utils import MACOS, RANK
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
|
||||
# Initialize default log file
|
||||
DEFAULT_LOG_PATH = Path("train.log")
|
||||
if RANK in {-1, 0} and DEFAULT_LOG_PATH.exists():
|
||||
DEFAULT_LOG_PATH.unlink(missing_ok=True)
|
||||
|
||||
|
||||
class ConsoleLogger:
|
||||
"""
|
||||
Console output capture with API/file streaming and deduplication.
|
||||
|
||||
Captures stdout/stderr output and streams it to either an API endpoint or local file, with intelligent
|
||||
deduplication to reduce noise from repetitive console output.
|
||||
|
||||
Attributes:
|
||||
destination (str | Path): Target destination for streaming (URL or Path object).
|
||||
is_api (bool): Whether destination is an API endpoint (True) or local file (False).
|
||||
original_stdout: Reference to original sys.stdout for restoration.
|
||||
original_stderr: Reference to original sys.stderr for restoration.
|
||||
log_queue (queue.Queue): Thread-safe queue for buffering log messages.
|
||||
active (bool): Whether console capture is currently active.
|
||||
worker_thread (threading.Thread): Background thread for processing log queue.
|
||||
last_line (str): Last processed line for deduplication.
|
||||
last_time (float): Timestamp of last processed line.
|
||||
last_progress_line (str): Last progress bar line for progress deduplication.
|
||||
last_was_progress (bool): Whether the last line was a progress bar.
|
||||
|
||||
Examples:
|
||||
Basic file logging:
|
||||
>>> logger = ConsoleLogger("training.log")
|
||||
>>> logger.start_capture()
|
||||
>>> print("This will be logged")
|
||||
>>> logger.stop_capture()
|
||||
|
||||
API streaming:
|
||||
>>> logger = ConsoleLogger("https://api.example.com/logs")
|
||||
>>> logger.start_capture()
|
||||
>>> # All output streams to API
|
||||
>>> logger.stop_capture()
|
||||
"""
|
||||
|
||||
def __init__(self, destination):
|
||||
"""
|
||||
Initialize with API endpoint or local file path.
|
||||
|
||||
Args:
|
||||
destination (str | Path): API endpoint URL (http/https) or local file path for streaming output.
|
||||
"""
|
||||
self.destination = destination
|
||||
self.is_api = isinstance(destination, str) and destination.startswith(("http://", "https://"))
|
||||
if not self.is_api:
|
||||
self.destination = Path(destination)
|
||||
|
||||
# Console capture
|
||||
self.original_stdout = sys.stdout
|
||||
self.original_stderr = sys.stderr
|
||||
self.log_queue = queue.Queue(maxsize=1000)
|
||||
self.active = False
|
||||
self.worker_thread = None
|
||||
|
||||
# State tracking
|
||||
self.last_line = ""
|
||||
self.last_time = 0.0
|
||||
self.last_progress_line = "" # Track last progress line for deduplication
|
||||
self.last_was_progress = False # Track if last line was a progress bar
|
||||
|
||||
def start_capture(self):
|
||||
"""Start capturing console output and redirect stdout/stderr to custom capture objects."""
|
||||
if self.active:
|
||||
return
|
||||
|
||||
self.active = True
|
||||
sys.stdout = self._ConsoleCapture(self.original_stdout, self._queue_log)
|
||||
sys.stderr = self._ConsoleCapture(self.original_stderr, self._queue_log)
|
||||
|
||||
# Hook Ultralytics logger
|
||||
try:
|
||||
handler = self._LogHandler(self._queue_log)
|
||||
logging.getLogger("ultralytics").addHandler(handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self.worker_thread = threading.Thread(target=self._stream_worker, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
def stop_capture(self):
|
||||
"""Stop capturing console output and restore original stdout/stderr."""
|
||||
if not self.active:
|
||||
return
|
||||
|
||||
self.active = False
|
||||
sys.stdout = self.original_stdout
|
||||
sys.stderr = self.original_stderr
|
||||
self.log_queue.put(None)
|
||||
|
||||
def _queue_log(self, text):
|
||||
"""Queue console text with deduplication and timestamp processing."""
|
||||
if not self.active:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# Handle carriage returns and process lines
|
||||
if "\r" in text:
|
||||
text = text.split("\r")[-1]
|
||||
|
||||
lines = text.split("\n")
|
||||
if lines and lines[-1] == "":
|
||||
lines.pop()
|
||||
|
||||
for line in lines:
|
||||
line = line.rstrip()
|
||||
|
||||
# Skip lines with only thin progress bars (partial progress)
|
||||
if "─" in line: # Has thin lines but no thick lines
|
||||
continue
|
||||
|
||||
# Deduplicate completed progress bars only if they match the previous progress line
|
||||
if " ━━" in line:
|
||||
progress_core = line.split(" ━━")[0].strip()
|
||||
if progress_core == self.last_progress_line and self.last_was_progress:
|
||||
continue
|
||||
self.last_progress_line = progress_core
|
||||
self.last_was_progress = True
|
||||
else:
|
||||
# Skip empty line after progress bar
|
||||
if not line and self.last_was_progress:
|
||||
self.last_was_progress = False
|
||||
continue
|
||||
self.last_was_progress = False
|
||||
|
||||
# General deduplication
|
||||
if line == self.last_line and current_time - self.last_time < 0.1:
|
||||
continue
|
||||
|
||||
self.last_line = line
|
||||
self.last_time = current_time
|
||||
|
||||
# Add timestamp if needed
|
||||
if not line.startswith("[20"):
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
line = f"[{timestamp}] {line}"
|
||||
|
||||
# Queue with overflow protection
|
||||
if not self._safe_put(f"{line}\n"):
|
||||
continue # Skip if queue handling fails
|
||||
|
||||
def _safe_put(self, item):
|
||||
"""Safely put item in queue with overflow handling."""
|
||||
try:
|
||||
self.log_queue.put_nowait(item)
|
||||
return True
|
||||
except queue.Full:
|
||||
try:
|
||||
self.log_queue.get_nowait() # Drop oldest
|
||||
self.log_queue.put_nowait(item)
|
||||
return True
|
||||
except queue.Empty:
|
||||
return False
|
||||
|
||||
def _stream_worker(self):
|
||||
"""Background worker for streaming logs to destination."""
|
||||
while self.active:
|
||||
try:
|
||||
log_text = self.log_queue.get(timeout=1)
|
||||
if log_text is None:
|
||||
break
|
||||
self._write_log(log_text)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def _write_log(self, text):
|
||||
"""Write log to API endpoint or local file destination."""
|
||||
try:
|
||||
if self.is_api:
|
||||
import requests # scoped as slow import
|
||||
|
||||
payload = {"timestamp": datetime.now().isoformat(), "message": text.strip()}
|
||||
requests.post(str(self.destination), json=payload, timeout=5)
|
||||
else:
|
||||
self.destination.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.destination.open("a", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
except Exception as e:
|
||||
print(f"Platform logging error: {e}", file=self.original_stderr)
|
||||
|
||||
class _ConsoleCapture:
|
||||
"""Lightweight stdout/stderr capture."""
|
||||
|
||||
__slots__ = ("original", "callback")
|
||||
|
||||
def __init__(self, original, callback):
|
||||
self.original = original
|
||||
self.callback = callback
|
||||
|
||||
def write(self, text):
|
||||
self.original.write(text)
|
||||
self.callback(text)
|
||||
|
||||
def flush(self):
|
||||
self.original.flush()
|
||||
|
||||
class _LogHandler(logging.Handler):
|
||||
"""Lightweight logging handler."""
|
||||
|
||||
__slots__ = ("callback",)
|
||||
|
||||
def __init__(self, callback):
|
||||
super().__init__()
|
||||
self.callback = callback
|
||||
|
||||
def emit(self, record):
|
||||
self.callback(self.format(record) + "\n")
|
||||
|
||||
|
||||
class SystemLogger:
|
||||
"""
|
||||
Log dynamic system metrics for training monitoring.
|
||||
|
||||
Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for
|
||||
training performance monitoring and analysis.
|
||||
|
||||
Attributes:
|
||||
pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
|
||||
nvidia_initialized (bool): Whether NVIDIA GPU monitoring is available and initialized.
|
||||
net_start: Initial network I/O counters for calculating cumulative usage.
|
||||
disk_start: Initial disk I/O counters for calculating cumulative usage.
|
||||
|
||||
Examples:
|
||||
Basic usage:
|
||||
>>> logger = SystemLogger()
|
||||
>>> metrics = logger.get_metrics()
|
||||
>>> print(f"CPU: {metrics['cpu']}%, RAM: {metrics['ram']}%")
|
||||
>>> if metrics["gpus"]:
|
||||
... gpu0 = metrics["gpus"]["0"]
|
||||
... print(f"GPU0: {gpu0['usage']}% usage, {gpu0['temp']}°C")
|
||||
|
||||
Training loop integration:
|
||||
>>> system_logger = SystemLogger()
|
||||
>>> for epoch in range(epochs):
|
||||
... # Training code here
|
||||
... metrics = system_logger.get_metrics()
|
||||
... # Log to database/file
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the system logger."""
|
||||
import psutil # scoped as slow import
|
||||
|
||||
self.pynvml = None
|
||||
self.nvidia_initialized = self._init_nvidia()
|
||||
self.net_start = psutil.net_io_counters()
|
||||
self.disk_start = psutil.disk_io_counters()
|
||||
|
||||
def _init_nvidia(self):
|
||||
"""Initialize NVIDIA GPU monitoring with pynvml."""
|
||||
try:
|
||||
assert not MACOS
|
||||
check_requirements("nvidia-ml-py>=12.0.0")
|
||||
self.pynvml = __import__("pynvml")
|
||||
self.pynvml.nvmlInit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def get_metrics(self):
|
||||
"""
|
||||
Get current system metrics.
|
||||
|
||||
Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics,
|
||||
network I/O statistics, and GPU metrics (if available). Example output:
|
||||
|
||||
```python
|
||||
metrics = {
|
||||
"cpu": 45.2,
|
||||
"ram": 78.9,
|
||||
"disk": {"read_mb": 156.7, "write_mb": 89.3, "used_gb": 256.8},
|
||||
"network": {"recv_mb": 157.2, "sent_mb": 89.1},
|
||||
"gpus": {
|
||||
0: {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
|
||||
1: {"usage": 94.1, "memory": 82.7, "temp": 70, "power": 278},
|
||||
},
|
||||
}
|
||||
```
|
||||
|
||||
- cpu (float): CPU usage percentage (0-100%)
|
||||
- ram (float): RAM usage percentage (0-100%)
|
||||
- disk (dict):
|
||||
- read_mb (float): Cumulative disk read in MB since initialization
|
||||
- write_mb (float): Cumulative disk write in MB since initialization
|
||||
- used_gb (float): Total disk space used in GB
|
||||
- network (dict):
|
||||
- recv_mb (float): Cumulative network received in MB since initialization
|
||||
- sent_mb (float): Cumulative network sent in MB since initialization
|
||||
- gpus (dict): GPU metrics by device index (e.g., 0, 1) containing:
|
||||
- usage (int): GPU utilization percentage (0-100%)
|
||||
- memory (float): CUDA memory usage percentage (0-100%)
|
||||
- temp (int): GPU temperature in degrees Celsius
|
||||
- power (int): GPU power consumption in watts
|
||||
|
||||
Returns:
|
||||
metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with respective usage data.
|
||||
"""
|
||||
import psutil # scoped as slow import
|
||||
|
||||
net = psutil.net_io_counters()
|
||||
disk = psutil.disk_io_counters()
|
||||
memory = psutil.virtual_memory()
|
||||
disk_usage = shutil.disk_usage("/")
|
||||
|
||||
metrics = {
|
||||
"cpu": round(psutil.cpu_percent(), 3),
|
||||
"ram": round(memory.percent, 3),
|
||||
"disk": {
|
||||
"read_mb": round((disk.read_bytes - self.disk_start.read_bytes) / (1 << 20), 3),
|
||||
"write_mb": round((disk.write_bytes - self.disk_start.write_bytes) / (1 << 20), 3),
|
||||
"used_gb": round(disk_usage.used / (1 << 30), 3),
|
||||
},
|
||||
"network": {
|
||||
"recv_mb": round((net.bytes_recv - self.net_start.bytes_recv) / (1 << 20), 3),
|
||||
"sent_mb": round((net.bytes_sent - self.net_start.bytes_sent) / (1 << 20), 3),
|
||||
},
|
||||
"gpus": {},
|
||||
}
|
||||
|
||||
# Add GPU metrics (NVIDIA only)
|
||||
if self.nvidia_initialized:
|
||||
metrics["gpus"].update(self._get_nvidia_metrics())
|
||||
|
||||
return metrics
|
||||
|
||||
def _get_nvidia_metrics(self):
|
||||
"""Get NVIDIA GPU metrics including utilization, memory, temperature, and power."""
|
||||
gpus = {}
|
||||
if not self.nvidia_initialized or not self.pynvml:
|
||||
return gpus
|
||||
try:
|
||||
device_count = self.pynvml.nvmlDeviceGetCount()
|
||||
for i in range(device_count):
|
||||
handle = self.pynvml.nvmlDeviceGetHandleByIndex(i)
|
||||
util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
|
||||
memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
temp = self.pynvml.nvmlDeviceGetTemperature(handle, self.pynvml.NVML_TEMPERATURE_GPU)
|
||||
power = self.pynvml.nvmlDeviceGetPowerUsage(handle) // 1000
|
||||
|
||||
gpus[str(i)] = {
|
||||
"usage": round(util.gpu, 3),
|
||||
"memory": round((memory.used / memory.total) * 100, 3),
|
||||
"temp": temp,
|
||||
"power": power,
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
return gpus
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("SystemLogger Real-time Metrics Monitor")
|
||||
print("Press Ctrl+C to stop\n")
|
||||
|
||||
logger = SystemLogger()
|
||||
|
||||
try:
|
||||
while True:
|
||||
metrics = logger.get_metrics()
|
||||
|
||||
# Clear screen (works on most terminals)
|
||||
print("\033[H\033[J", end="")
|
||||
|
||||
# Display system metrics
|
||||
print(f"CPU: {metrics['cpu']:5.1f}%")
|
||||
print(f"RAM: {metrics['ram']:5.1f}%")
|
||||
print(f"Disk Read: {metrics['disk']['read_mb']:8.1f} MB")
|
||||
print(f"Disk Write: {metrics['disk']['write_mb']:7.1f} MB")
|
||||
print(f"Disk Used: {metrics['disk']['used_gb']:8.1f} GB")
|
||||
print(f"Net Recv: {metrics['network']['recv_mb']:9.1f} MB")
|
||||
print(f"Net Sent: {metrics['network']['sent_mb']:9.1f} MB")
|
||||
|
||||
# Display GPU metrics if available
|
||||
if metrics["gpus"]:
|
||||
print("\nGPU Metrics:")
|
||||
for gpu_id, gpu_data in metrics["gpus"].items():
|
||||
print(
|
||||
f" GPU {gpu_id}: {gpu_data['usage']:3}% | "
|
||||
f"Mem: {gpu_data['memory']:5.1f}% | "
|
||||
f"Temp: {gpu_data['temp']:2}°C | "
|
||||
f"Power: {gpu_data['power']:3}W"
|
||||
)
|
||||
else:
|
||||
print("\nGPU: No NVIDIA GPUs detected")
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nStopped monitoring.")
|
||||
857
ultralytics/utils/loss.py
Normal file
857
ultralytics/utils/loss.py
Normal file
@@ -0,0 +1,857 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.metrics import OKS_SIGMA
|
||||
from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
|
||||
from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
|
||||
from ultralytics.utils.torch_utils import autocast
|
||||
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .tal import bbox2dist
|
||||
|
||||
|
||||
class VarifocalLoss(nn.Module):
|
||||
"""
|
||||
Varifocal loss by Zhang et al.
|
||||
|
||||
Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
|
||||
hard-to-classify examples and balancing positive/negative samples.
|
||||
|
||||
Attributes:
|
||||
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
||||
alpha (float): The balancing factor used to address class imbalance.
|
||||
|
||||
References:
|
||||
https://arxiv.org/abs/2008.13367
|
||||
"""
|
||||
|
||||
def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
|
||||
"""Initialize the VarifocalLoss class with focusing and balancing parameters."""
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute varifocal loss between predictions and ground truth."""
|
||||
weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
|
||||
with autocast(enabled=False):
|
||||
loss = (
|
||||
(F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
|
||||
.mean(1)
|
||||
.sum()
|
||||
)
|
||||
return loss
|
||||
|
||||
|
||||
class FocalLoss(nn.Module):
|
||||
"""
|
||||
Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
|
||||
|
||||
Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
|
||||
on hard negatives during training.
|
||||
|
||||
Attributes:
|
||||
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
||||
alpha (torch.Tensor): The balancing factor used to address class imbalance.
|
||||
"""
|
||||
|
||||
def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
|
||||
"""Initialize FocalLoss class with focusing and balancing parameters."""
|
||||
super().__init__()
|
||||
self.gamma = gamma
|
||||
self.alpha = torch.tensor(alpha)
|
||||
|
||||
def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate focal loss with modulating factors for class imbalance."""
|
||||
loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
|
||||
# p_t = torch.exp(-loss)
|
||||
# loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
|
||||
|
||||
# TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
|
||||
pred_prob = pred.sigmoid() # prob from logits
|
||||
p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
|
||||
modulating_factor = (1.0 - p_t) ** self.gamma
|
||||
loss *= modulating_factor
|
||||
if (self.alpha > 0).any():
|
||||
self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
|
||||
alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
|
||||
loss *= alpha_factor
|
||||
return loss.mean(1).sum()
|
||||
|
||||
|
||||
class DFLoss(nn.Module):
|
||||
"""Criterion class for computing Distribution Focal Loss (DFL)."""
|
||||
|
||||
def __init__(self, reg_max: int = 16) -> None:
|
||||
"""Initialize the DFL module with regularization maximum."""
|
||||
super().__init__()
|
||||
self.reg_max = reg_max
|
||||
|
||||
def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
|
||||
target = target.clamp_(0, self.reg_max - 1 - 0.01)
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
return (
|
||||
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
|
||||
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
|
||||
).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class BboxLoss(nn.Module):
|
||||
"""Criterion class for computing training losses for bounding boxes."""
|
||||
|
||||
def __init__(self, reg_max: int = 16):
|
||||
"""Initialize the BboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__()
|
||||
self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_dist: torch.Tensor,
|
||||
pred_bboxes: torch.Tensor,
|
||||
anchor_points: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
target_scores: torch.Tensor,
|
||||
target_scores_sum: torch.Tensor,
|
||||
fg_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute IoU and DFL losses for bounding boxes."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class RotatedBboxLoss(BboxLoss):
|
||||
"""Criterion class for computing training losses for rotated bounding boxes."""
|
||||
|
||||
def __init__(self, reg_max: int):
|
||||
"""Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
|
||||
super().__init__(reg_max)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_dist: torch.Tensor,
|
||||
pred_bboxes: torch.Tensor,
|
||||
anchor_points: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
target_scores: torch.Tensor,
|
||||
target_scores_sum: torch.Tensor,
|
||||
fg_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute IoU and DFL losses for rotated bounding boxes."""
|
||||
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
|
||||
iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
if self.dfl_loss:
|
||||
target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
|
||||
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
|
||||
loss_dfl = loss_dfl.sum() / target_scores_sum
|
||||
else:
|
||||
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
|
||||
|
||||
return loss_iou, loss_dfl
|
||||
|
||||
|
||||
class KeypointLoss(nn.Module):
|
||||
"""Criterion class for computing keypoint losses."""
|
||||
|
||||
def __init__(self, sigmas: torch.Tensor) -> None:
|
||||
"""Initialize the KeypointLoss class with keypoint sigmas."""
|
||||
super().__init__()
|
||||
self.sigmas = sigmas
|
||||
|
||||
def forward(
|
||||
self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
|
||||
d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
|
||||
kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
|
||||
# e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
|
||||
e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
|
||||
return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
|
||||
|
||||
|
||||
class v8DetectionLoss:
|
||||
"""Criterion class for computing training losses for YOLOv8 object detection."""
|
||||
|
||||
def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
|
||||
"""Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
|
||||
device = next(model.parameters()).device # get model device
|
||||
h = model.args # hyperparameters
|
||||
|
||||
m = model.model[-1] # Detect() module
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.hyp = h
|
||||
self.stride = m.stride # model strides
|
||||
self.nc = m.nc # number of classes
|
||||
self.no = m.nc + m.reg_max * 4
|
||||
self.reg_max = m.reg_max
|
||||
self.device = device
|
||||
|
||||
self.use_dfl = m.reg_max > 1
|
||||
|
||||
self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = BboxLoss(m.reg_max).to(device)
|
||||
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
|
||||
|
||||
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Preprocess targets by converting to tensor format and scaling coordinates."""
|
||||
nl, ne = targets.shape
|
||||
if nl == 0:
|
||||
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
if n := matches.sum():
|
||||
out[j, :n] = targets[matches, 1:]
|
||||
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
|
||||
return out
|
||||
|
||||
def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode predicted object bounding box coordinates from anchor points and distribution."""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
# pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
|
||||
return dist2bbox(pred_dist, anchor_points, xywh=False)
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
# dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
|
||||
# dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
# pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes / stride_tensor,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
|
||||
class v8SegmentationLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses for YOLOv8 segmentation."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
|
||||
super().__init__(model)
|
||||
self.overlap = model.args.overlap_mask
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate and return the combined loss for detection and segmentation."""
|
||||
loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_masks = pred_masks.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
try:
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError(
|
||||
"ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
|
||||
"This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
if fg_mask.sum():
|
||||
# Bbox loss
|
||||
loss[0], loss[3] = self.bbox_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes / stride_tensor,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
# Masks loss
|
||||
masks = batch["masks"].to(self.device).float()
|
||||
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
|
||||
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
|
||||
|
||||
loss[1] = self.calculate_segmentation_loss(
|
||||
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
|
||||
)
|
||||
|
||||
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.box # seg gain
|
||||
loss[2] *= self.hyp.cls # cls gain
|
||||
loss[3] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
|
||||
|
||||
@staticmethod
|
||||
def single_mask_loss(
|
||||
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the instance segmentation loss for a single image.
|
||||
|
||||
Args:
|
||||
gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
|
||||
pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
|
||||
proto (torch.Tensor): Prototype masks of shape (32, H, W).
|
||||
xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
|
||||
area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated mask loss for a single image.
|
||||
|
||||
Notes:
|
||||
The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
|
||||
predicted masks from the prototype masks and predicted mask coefficients.
|
||||
"""
|
||||
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
|
||||
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
|
||||
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
|
||||
|
||||
def calculate_segmentation_loss(
|
||||
self,
|
||||
fg_mask: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
target_gt_idx: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
batch_idx: torch.Tensor,
|
||||
proto: torch.Tensor,
|
||||
pred_masks: torch.Tensor,
|
||||
imgsz: torch.Tensor,
|
||||
overlap: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the loss for instance segmentation.
|
||||
|
||||
Args:
|
||||
fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
|
||||
masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
|
||||
target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
|
||||
target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
|
||||
batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
|
||||
proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
|
||||
pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
|
||||
imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
|
||||
overlap (bool): Whether the masks in `masks` tensor overlap.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The calculated loss for instance segmentation.
|
||||
|
||||
Notes:
|
||||
The batch loss can be computed for improved speed at higher memory usage.
|
||||
For example, pred_mask can be computed as follows:
|
||||
pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
|
||||
"""
|
||||
_, _, mask_h, mask_w = proto.shape
|
||||
loss = 0
|
||||
|
||||
# Normalize to 0-1
|
||||
target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
|
||||
|
||||
# Areas of target bboxes
|
||||
marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
|
||||
|
||||
# Normalize to mask size
|
||||
mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
|
||||
|
||||
for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
|
||||
fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
|
||||
if fg_mask_i.any():
|
||||
mask_idx = target_gt_idx_i[fg_mask_i]
|
||||
if overlap:
|
||||
gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
|
||||
gt_mask = gt_mask.float()
|
||||
else:
|
||||
gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
|
||||
|
||||
loss += self.single_mask_loss(
|
||||
gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
|
||||
)
|
||||
|
||||
# WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
|
||||
else:
|
||||
loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
|
||||
|
||||
return loss / fg_mask.sum()
|
||||
|
||||
|
||||
class v8PoseLoss(v8DetectionLoss):
|
||||
"""Criterion class for computing training losses for YOLOv8 pose estimation."""
|
||||
|
||||
def __init__(self, model): # model must be de-paralleled
|
||||
"""Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
|
||||
super().__init__(model)
|
||||
self.kpt_shape = model.model[-1].kpt_shape
|
||||
self.bce_pose = nn.BCEWithLogitsLoss()
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0] # number of keypoints
|
||||
sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
|
||||
self.keypoint_loss = KeypointLoss(sigmas=sigmas)
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate the total loss and detach it for pose estimation."""
|
||||
loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
|
||||
feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# B, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# Targets
|
||||
batch_size = pred_scores.shape[0]
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
|
||||
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
|
||||
pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
|
||||
|
||||
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss[0], loss[4] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
keypoints = batch["keypoints"].to(self.device).float().clone()
|
||||
keypoints[..., 0] *= imgsz[1]
|
||||
keypoints[..., 1] *= imgsz[0]
|
||||
|
||||
loss[1], loss[2] = self.calculate_keypoints_loss(
|
||||
fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
|
||||
)
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.pose # pose gain
|
||||
loss[2] *= self.hyp.kobj # kobj gain
|
||||
loss[3] *= self.hyp.cls # cls gain
|
||||
loss[4] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
@staticmethod
|
||||
def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode predicted keypoints to image coordinates."""
|
||||
y = pred_kpts.clone()
|
||||
y[..., :2] *= 2.0
|
||||
y[..., 0] += anchor_points[:, [0]] - 0.5
|
||||
y[..., 1] += anchor_points[:, [1]] - 0.5
|
||||
return y
|
||||
|
||||
def calculate_keypoints_loss(
|
||||
self,
|
||||
masks: torch.Tensor,
|
||||
target_gt_idx: torch.Tensor,
|
||||
keypoints: torch.Tensor,
|
||||
batch_idx: torch.Tensor,
|
||||
stride_tensor: torch.Tensor,
|
||||
target_bboxes: torch.Tensor,
|
||||
pred_kpts: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Calculate the keypoints loss for the model.
|
||||
|
||||
This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
|
||||
based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
|
||||
a binary classification loss that classifies whether a keypoint is present or not.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
|
||||
target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
|
||||
keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
|
||||
batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
|
||||
stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
|
||||
target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
|
||||
pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
|
||||
|
||||
Returns:
|
||||
kpts_loss (torch.Tensor): The keypoints loss.
|
||||
kpts_obj_loss (torch.Tensor): The keypoints object loss.
|
||||
"""
|
||||
batch_idx = batch_idx.flatten()
|
||||
batch_size = len(masks)
|
||||
|
||||
# Find the maximum number of keypoints in a single image
|
||||
max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
|
||||
|
||||
# Create a tensor to hold batched keypoints
|
||||
batched_keypoints = torch.zeros(
|
||||
(batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
|
||||
)
|
||||
|
||||
# TODO: any idea how to vectorize this?
|
||||
# Fill batched_keypoints with keypoints based on batch_idx
|
||||
for i in range(batch_size):
|
||||
keypoints_i = keypoints[batch_idx == i]
|
||||
batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
|
||||
|
||||
# Expand dimensions of target_gt_idx to match the shape of batched_keypoints
|
||||
target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
# Use target_gt_idx_expanded to select keypoints from batched_keypoints
|
||||
selected_keypoints = batched_keypoints.gather(
|
||||
1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
|
||||
)
|
||||
|
||||
# Divide coordinates by stride
|
||||
selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
|
||||
|
||||
kpts_loss = 0
|
||||
kpts_obj_loss = 0
|
||||
|
||||
if masks.any():
|
||||
gt_kpt = selected_keypoints[masks]
|
||||
area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
|
||||
pred_kpt = pred_kpts[masks]
|
||||
kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
|
||||
kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
|
||||
|
||||
if pred_kpt.shape[-1] == 3:
|
||||
kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
|
||||
|
||||
return kpts_loss, kpts_obj_loss
|
||||
|
||||
|
||||
class v8ClassificationLoss:
|
||||
"""Criterion class for computing training losses for classification."""
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute the classification loss between predictions and true labels."""
|
||||
preds = preds[1] if isinstance(preds, (list, tuple)) else preds
|
||||
loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
|
||||
return loss, loss.detach()
|
||||
|
||||
|
||||
class v8OBBLoss(v8DetectionLoss):
|
||||
"""Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
|
||||
super().__init__(model)
|
||||
self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
|
||||
self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
|
||||
|
||||
def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Preprocess targets for oriented bounding box detection."""
|
||||
if targets.shape[0] == 0:
|
||||
out = torch.zeros(batch_size, 0, 6, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0] # image index
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
if n := matches.sum():
|
||||
bboxes = targets[matches, 2:]
|
||||
bboxes[..., :4].mul_(scale_tensor)
|
||||
out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
|
||||
return out
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate and return the loss for oriented bounding box detection."""
|
||||
loss = torch.zeros(3, device=self.device) # box, cls, dfl
|
||||
feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
|
||||
batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
|
||||
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
|
||||
(self.reg_max * 4, self.nc), 1
|
||||
)
|
||||
|
||||
# b, grids, ..
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
pred_angle = pred_angle.permute(0, 2, 1).contiguous()
|
||||
|
||||
dtype = pred_scores.dtype
|
||||
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
|
||||
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
|
||||
|
||||
# targets
|
||||
try:
|
||||
batch_idx = batch["batch_idx"].view(-1, 1)
|
||||
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
|
||||
rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
|
||||
targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
|
||||
targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
|
||||
gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
|
||||
except RuntimeError as e:
|
||||
raise TypeError(
|
||||
"ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
|
||||
"This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
|
||||
"i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
|
||||
"correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
|
||||
"as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
|
||||
) from e
|
||||
|
||||
# Pboxes
|
||||
pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
|
||||
|
||||
bboxes_for_assigner = pred_bboxes.clone().detach()
|
||||
# Only the first four elements need to be scaled
|
||||
bboxes_for_assigner[..., :4] *= stride_tensor
|
||||
_, target_bboxes, target_scores, fg_mask, _ = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
bboxes_for_assigner.type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
# Cls loss
|
||||
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
|
||||
loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Bbox loss
|
||||
if fg_mask.sum():
|
||||
target_bboxes[..., :4] /= stride_tensor
|
||||
loss[0], loss[2] = self.bbox_loss(
|
||||
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||
)
|
||||
else:
|
||||
loss[0] += (pred_angle * 0).sum()
|
||||
|
||||
loss[0] *= self.hyp.box # box gain
|
||||
loss[1] *= self.hyp.cls # cls gain
|
||||
loss[2] *= self.hyp.dfl # dfl gain
|
||||
|
||||
return loss * batch_size, loss.detach() # loss(box, cls, dfl)
|
||||
|
||||
def bbox_decode(
|
||||
self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decode predicted object bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
anchor_points (torch.Tensor): Anchor points, (h*w, 2).
|
||||
pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
|
||||
"""
|
||||
if self.use_dfl:
|
||||
b, a, c = pred_dist.shape # batch, anchors, channels
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
|
||||
return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
|
||||
|
||||
|
||||
class E2EDetectLoss:
|
||||
"""Criterion class for computing training losses for end-to-end detection."""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
|
||||
self.one2many = v8DetectionLoss(model, tal_topk=10)
|
||||
self.one2one = v8DetectionLoss(model, tal_topk=1)
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
|
||||
preds = preds[1] if isinstance(preds, tuple) else preds
|
||||
one2many = preds["one2many"]
|
||||
loss_one2many = self.one2many(one2many, batch)
|
||||
one2one = preds["one2one"]
|
||||
loss_one2one = self.one2one(one2one, batch)
|
||||
return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
|
||||
|
||||
|
||||
class TVPDetectLoss:
|
||||
"""Criterion class for computing training losses for text-visual prompt detection."""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
|
||||
self.vp_criterion = v8DetectionLoss(model)
|
||||
# NOTE: store following info as it's changeable in __call__
|
||||
self.ori_nc = self.vp_criterion.nc
|
||||
self.ori_no = self.vp_criterion.no
|
||||
self.ori_reg_max = self.vp_criterion.reg_max
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate the loss for text-visual prompt detection."""
|
||||
feats = preds[1] if isinstance(preds, tuple) else preds
|
||||
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
||||
|
||||
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
||||
loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
|
||||
return loss, loss.detach()
|
||||
|
||||
vp_feats = self._get_vp_features(feats)
|
||||
vp_loss = self.vp_criterion(vp_feats, batch)
|
||||
box_loss = vp_loss[0][1]
|
||||
return box_loss, vp_loss[1]
|
||||
|
||||
def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
|
||||
"""Extract visual-prompt features from the model output."""
|
||||
vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
|
||||
|
||||
self.vp_criterion.nc = vnc
|
||||
self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
|
||||
self.vp_criterion.assigner.num_classes = vnc
|
||||
|
||||
return [
|
||||
torch.cat((box, cls_vp), dim=1)
|
||||
for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
|
||||
]
|
||||
|
||||
|
||||
class TVPSegmentLoss(TVPDetectLoss):
|
||||
"""Criterion class for computing training losses for text-visual prompt segmentation."""
|
||||
|
||||
def __init__(self, model):
|
||||
"""Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
|
||||
super().__init__(model)
|
||||
self.vp_criterion = v8SegmentationLoss(model)
|
||||
|
||||
def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Calculate the loss for text-visual prompt segmentation."""
|
||||
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
|
||||
assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
|
||||
|
||||
if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
|
||||
loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
|
||||
return loss, loss.detach()
|
||||
|
||||
vp_feats = self._get_vp_features(feats)
|
||||
vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
|
||||
cls_loss = vp_loss[0][2]
|
||||
return cls_loss, vp_loss[1]
|
||||
1592
ultralytics/utils/metrics.py
Normal file
1592
ultralytics/utils/metrics.py
Normal file
File diff suppressed because it is too large
Load Diff
340
ultralytics/utils/nms.py
Normal file
340
ultralytics/utils/nms.py
Normal file
@@ -0,0 +1,340 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import batch_probiou, box_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
prediction,
|
||||
conf_thres: float = 0.25,
|
||||
iou_thres: float = 0.45,
|
||||
classes=None,
|
||||
agnostic: bool = False,
|
||||
multi_label: bool = False,
|
||||
labels=(),
|
||||
max_det: int = 300,
|
||||
nc: int = 0, # number of classes (optional)
|
||||
max_time_img: float = 0.05,
|
||||
max_nms: int = 30000,
|
||||
max_wh: int = 7680,
|
||||
rotated: bool = False,
|
||||
end2end: bool = False,
|
||||
return_idxs: bool = False,
|
||||
):
|
||||
"""
|
||||
Perform non-maximum suppression (NMS) on prediction results.
|
||||
|
||||
Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
|
||||
detection formats including standard boxes, rotated boxes, and masks.
|
||||
|
||||
Args:
|
||||
prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
|
||||
containing boxes, classes, and optional masks.
|
||||
conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
|
||||
iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
|
||||
classes (list[int], optional): List of class indices to consider. If None, all classes are considered.
|
||||
agnostic (bool): Whether to perform class-agnostic NMS.
|
||||
multi_label (bool): Whether each box can have multiple labels.
|
||||
labels (list[list[Union[int, float, torch.Tensor]]]): A priori labels for each image.
|
||||
max_det (int): Maximum number of detections to keep per image.
|
||||
nc (int): Number of classes. Indices after this are considered masks.
|
||||
max_time_img (float): Maximum time in seconds for processing one image.
|
||||
max_nms (int): Maximum number of boxes for NMS.
|
||||
max_wh (int): Maximum box width and height in pixels.
|
||||
rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
|
||||
end2end (bool): Whether the model is end-to-end and doesn't require NMS.
|
||||
return_idxs (bool): Whether to return the indices of kept detections.
|
||||
|
||||
Returns:
|
||||
output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
|
||||
containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
|
||||
keepi (list[torch.Tensor]): Indices of kept detections if return_idxs=True.
|
||||
"""
|
||||
# Checks
|
||||
assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
|
||||
assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
|
||||
if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
|
||||
prediction = prediction[0] # select only inference output
|
||||
if classes is not None:
|
||||
classes = torch.tensor(classes, device=prediction.device)
|
||||
|
||||
if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
|
||||
output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
|
||||
if classes is not None:
|
||||
output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
|
||||
return output
|
||||
|
||||
bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
|
||||
nc = nc or (prediction.shape[1] - 4) # number of classes
|
||||
extra = prediction.shape[1] - nc - 4 # number of extra info
|
||||
mi = 4 + nc # mask start index
|
||||
xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
|
||||
xinds = torch.arange(prediction.shape[-1], device=prediction.device).expand(bs, -1)[..., None] # to track idxs
|
||||
|
||||
# Settings
|
||||
# min_wh = 2 # (pixels) minimum box width and height
|
||||
time_limit = 2.0 + max_time_img * bs # seconds to quit after
|
||||
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
|
||||
|
||||
prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
|
||||
if not rotated:
|
||||
prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
|
||||
|
||||
t = time.time()
|
||||
output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
|
||||
keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
|
||||
for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
|
||||
# Apply constraints
|
||||
# x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
|
||||
filt = xc[xi] # confidence
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Cat apriori labels if autolabelling
|
||||
if labels and len(labels[xi]) and not rotated:
|
||||
lb = labels[xi]
|
||||
v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
|
||||
v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
|
||||
v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
|
||||
x = torch.cat((x, v), 0)
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# Detections matrix nx6 (xyxy, conf, cls)
|
||||
box, cls, mask = x.split((4, nc, extra), 1)
|
||||
|
||||
if multi_label:
|
||||
i, j = torch.where(cls > conf_thres)
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
|
||||
if return_idxs:
|
||||
xk = xk[i]
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
filt = conf.view(-1) > conf_thres
|
||||
x = torch.cat((box, conf, j.float(), mask), 1)[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Filter by class
|
||||
if classes is not None:
|
||||
filt = (x[:, 5:6] == classes).any(1)
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
if n > max_nms: # excess boxes
|
||||
filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
|
||||
x = x[filt]
|
||||
if return_idxs:
|
||||
xk = xk[filt]
|
||||
|
||||
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
|
||||
scores = x[:, 4] # scores
|
||||
if rotated:
|
||||
boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
|
||||
i = TorchNMS.fast_nms(boxes, scores, iou_thres, iou_func=batch_probiou)
|
||||
else:
|
||||
boxes = x[:, :4] + c # boxes (offset by class)
|
||||
# Speed strategy: torchvision for val or already loaded (faster), TorchNMS for predict (lower latency)
|
||||
if "torchvision" in sys.modules:
|
||||
import torchvision # scope as slow import
|
||||
|
||||
i = torchvision.ops.nms(boxes, scores, iou_thres)
|
||||
else:
|
||||
i = TorchNMS.nms(boxes, scores, iou_thres)
|
||||
i = i[:max_det] # limit detections
|
||||
|
||||
output[xi] = x[i]
|
||||
if return_idxs:
|
||||
keepi[xi] = xk[i].view(-1)
|
||||
if (time.time() - t) > time_limit:
|
||||
LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
|
||||
break # time limit exceeded
|
||||
|
||||
return (output, keepi) if return_idxs else output
|
||||
|
||||
|
||||
class TorchNMS:
|
||||
"""
|
||||
Ultralytics custom NMS implementation optimized for YOLO.
|
||||
|
||||
This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
|
||||
including both standard NMS and batched NMS for multi-class scenarios.
|
||||
|
||||
Methods:
|
||||
nms: Optimized NMS with early termination that matches torchvision behavior exactly.
|
||||
batched_nms: Batched NMS for class-aware suppression.
|
||||
|
||||
Examples:
|
||||
Perform standard NMS on boxes and scores
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def fast_nms(
|
||||
boxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
iou_threshold: float,
|
||||
use_triu: bool = True,
|
||||
iou_func=box_iou,
|
||||
exit_early: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
|
||||
iou_func (callable): Function to compute IoU between boxes.
|
||||
exit_early (bool): Whether to exit early if there are no boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply NMS to a set of boxes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0 and exit_early:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
sorted_idx = torch.argsort(scores, descending=True)
|
||||
boxes = boxes[sorted_idx]
|
||||
ious = iou_func(boxes, boxes)
|
||||
if use_triu:
|
||||
ious = ious.triu_(diagonal=1)
|
||||
# NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
|
||||
pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
|
||||
else:
|
||||
n = boxes.shape[0]
|
||||
row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
|
||||
col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
|
||||
upper_mask = row_idx < col_idx
|
||||
ious = ious * upper_mask
|
||||
# Zeroing these scores ensures the additional indices would not affect the final results
|
||||
scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
|
||||
# NOTE: return indices with fixed length to avoid TFLite reshape error
|
||||
pick = torch.topk(scores, scores.shape[0]).indices
|
||||
return sorted_idx[pick]
|
||||
|
||||
@staticmethod
|
||||
def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
|
||||
"""
|
||||
Optimized NMS with early termination that matches torchvision behavior exactly.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply NMS to a set of boxes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> keep = TorchNMS.nms(boxes, scores, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
# Pre-allocate and extract coordinates once
|
||||
x1, y1, x2, y2 = boxes.unbind(1)
|
||||
areas = (x2 - x1) * (y2 - y1)
|
||||
|
||||
# Sort by scores descending
|
||||
order = scores.argsort(0, descending=True)
|
||||
|
||||
# Pre-allocate keep list with maximum possible size
|
||||
keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
|
||||
keep_idx = 0
|
||||
while order.numel() > 0:
|
||||
i = order[0]
|
||||
keep[keep_idx] = i
|
||||
keep_idx += 1
|
||||
|
||||
if order.numel() == 1:
|
||||
break
|
||||
# Vectorized IoU calculation for remaining boxes
|
||||
rest = order[1:]
|
||||
xx1 = torch.maximum(x1[i], x1[rest])
|
||||
yy1 = torch.maximum(y1[i], y1[rest])
|
||||
xx2 = torch.minimum(x2[i], x2[rest])
|
||||
yy2 = torch.minimum(y2[i], y2[rest])
|
||||
|
||||
# Fast intersection and IoU
|
||||
w = (xx2 - xx1).clamp_(min=0)
|
||||
h = (yy2 - yy1).clamp_(min=0)
|
||||
inter = w * h
|
||||
# Early exit: skip IoU calculation if no intersection
|
||||
if inter.sum() == 0:
|
||||
# No overlaps with current box, keep all remaining boxes
|
||||
order = rest
|
||||
continue
|
||||
iou = inter / (areas[i] + areas[rest] - inter)
|
||||
# Keep boxes with IoU <= threshold
|
||||
order = rest[iou <= iou_threshold]
|
||||
|
||||
return keep[:keep_idx]
|
||||
|
||||
@staticmethod
|
||||
def batched_nms(
|
||||
boxes: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
idxs: torch.Tensor,
|
||||
iou_threshold: float,
|
||||
use_fast_nms: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Batched NMS for class-aware suppression.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
scores (torch.Tensor): Confidence scores with shape (N,).
|
||||
idxs (torch.Tensor): Class indices with shape (N,).
|
||||
iou_threshold (float): IoU threshold for suppression.
|
||||
use_fast_nms (bool): Whether to use the Fast-NMS implementation.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Indices of boxes to keep after NMS.
|
||||
|
||||
Examples:
|
||||
Apply batched NMS across multiple classes
|
||||
>>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
|
||||
>>> scores = torch.tensor([0.9, 0.8])
|
||||
>>> idxs = torch.tensor([0, 1])
|
||||
>>> keep = TorchNMS.batched_nms(boxes, scores, idxs, 0.5)
|
||||
"""
|
||||
if boxes.numel() == 0:
|
||||
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
|
||||
|
||||
# Strategy: offset boxes by class index to prevent cross-class suppression
|
||||
max_coordinate = boxes.max()
|
||||
offsets = idxs.to(boxes) * (max_coordinate + 1)
|
||||
boxes_for_nms = boxes + offsets[:, None]
|
||||
|
||||
return (
|
||||
TorchNMS.fast_nms(boxes_for_nms, scores, iou_threshold)
|
||||
if use_fast_nms
|
||||
else TorchNMS.nms(boxes_for_nms, scores, iou_threshold)
|
||||
)
|
||||
722
ultralytics/utils/ops.py
Normal file
722
ultralytics/utils/ops.py
Normal file
@@ -0,0 +1,722 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils import NOT_MACOS14
|
||||
|
||||
|
||||
class Profile(contextlib.ContextDecorator):
|
||||
"""
|
||||
Ultralytics Profile class for timing code execution.
|
||||
|
||||
Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
|
||||
measurements with CUDA synchronization support for GPU operations.
|
||||
|
||||
Attributes:
|
||||
t (float): Accumulated time in seconds.
|
||||
device (torch.device): Device used for model inference.
|
||||
cuda (bool): Whether CUDA is being used for timing synchronization.
|
||||
|
||||
Examples:
|
||||
Use as a context manager to time code execution
|
||||
>>> with Profile(device=device) as dt:
|
||||
... pass # slow operation here
|
||||
>>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
|
||||
|
||||
Use as a decorator to time function execution
|
||||
>>> @Profile()
|
||||
... def slow_function():
|
||||
... time.sleep(0.1)
|
||||
"""
|
||||
|
||||
def __init__(self, t: float = 0.0, device: torch.device | None = None):
|
||||
"""
|
||||
Initialize the Profile class.
|
||||
|
||||
Args:
|
||||
t (float): Initial accumulated time in seconds.
|
||||
device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
|
||||
"""
|
||||
self.t = t
|
||||
self.device = device
|
||||
self.cuda = bool(device and str(device).startswith("cuda"))
|
||||
|
||||
def __enter__(self):
|
||||
"""Start timing."""
|
||||
self.start = self.time()
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback): # noqa
|
||||
"""Stop timing."""
|
||||
self.dt = self.time() - self.start # delta-time
|
||||
self.t += self.dt # accumulate dt
|
||||
|
||||
def __str__(self):
|
||||
"""Return a human-readable string representing the accumulated elapsed time."""
|
||||
return f"Elapsed time is {self.t} s"
|
||||
|
||||
def time(self):
|
||||
"""Get current time with CUDA synchronization if applicable."""
|
||||
if self.cuda:
|
||||
torch.cuda.synchronize(self.device)
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def segment2box(segment, width: int = 640, height: int = 640):
|
||||
"""
|
||||
Convert segment coordinates to bounding box coordinates.
|
||||
|
||||
Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
|
||||
Applies inside-image constraint and clips coordinates when necessary.
|
||||
|
||||
Args:
|
||||
segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
|
||||
width (int): Width of the image in pixels.
|
||||
height (int): Height of the image in pixels.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
|
||||
"""
|
||||
x, y = segment.T # segment xy
|
||||
# Clip coordinates if 3 out of 4 sides are outside the image
|
||||
if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
|
||||
x = x.clip(0, width)
|
||||
y = y.clip(0, height)
|
||||
inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
|
||||
x = x[inside]
|
||||
y = y[inside]
|
||||
return (
|
||||
np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
|
||||
if any(x)
|
||||
else np.zeros(4, dtype=segment.dtype)
|
||||
) # xyxy
|
||||
|
||||
|
||||
def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
|
||||
"""
|
||||
Rescale bounding boxes from one image shape to another.
|
||||
|
||||
Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
|
||||
Supports both xyxy and xywh box formats.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): Shape of the source image (height, width).
|
||||
boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
|
||||
img0_shape (tuple): Shape of the target image (height, width).
|
||||
ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
|
||||
padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
|
||||
xywh (bool): Whether box format is xywh (True) or xyxy (False).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Rescaled bounding boxes in the same format as input.
|
||||
"""
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
|
||||
pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
|
||||
pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad_x, pad_y = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
boxes[..., 0] -= pad_x # x padding
|
||||
boxes[..., 1] -= pad_y # y padding
|
||||
if not xywh:
|
||||
boxes[..., 2] -= pad_x # x padding
|
||||
boxes[..., 3] -= pad_y # y padding
|
||||
boxes[..., :4] /= gain
|
||||
return boxes if xywh else clip_boxes(boxes, img0_shape)
|
||||
|
||||
|
||||
def make_divisible(x: int, divisor):
|
||||
"""
|
||||
Return the nearest number that is divisible by the given divisor.
|
||||
|
||||
Args:
|
||||
x (int): The number to make divisible.
|
||||
divisor (int | torch.Tensor): The divisor.
|
||||
|
||||
Returns:
|
||||
(int): The nearest number divisible by the divisor.
|
||||
"""
|
||||
if isinstance(divisor, torch.Tensor):
|
||||
divisor = int(divisor.max()) # to int
|
||||
return math.ceil(x / divisor) * divisor
|
||||
|
||||
|
||||
def clip_boxes(boxes, shape):
|
||||
"""
|
||||
Clip bounding boxes to image boundaries.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
|
||||
shape (tuple): Image shape as HWC or HW (supports both).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Clipped bounding boxes.
|
||||
"""
|
||||
h, w = shape[:2] # supports both HWC or HW shapes
|
||||
if isinstance(boxes, torch.Tensor): # faster individually
|
||||
if NOT_MACOS14:
|
||||
boxes[..., 0].clamp_(0, w) # x1
|
||||
boxes[..., 1].clamp_(0, h) # y1
|
||||
boxes[..., 2].clamp_(0, w) # x2
|
||||
boxes[..., 3].clamp_(0, h) # y2
|
||||
else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
|
||||
boxes[..., 0] = boxes[..., 0].clamp(0, w)
|
||||
boxes[..., 1] = boxes[..., 1].clamp(0, h)
|
||||
boxes[..., 2] = boxes[..., 2].clamp(0, w)
|
||||
boxes[..., 3] = boxes[..., 3].clamp(0, h)
|
||||
else: # np.array (faster grouped)
|
||||
boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, w) # x1, x2
|
||||
boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, h) # y1, y2
|
||||
return boxes
|
||||
|
||||
|
||||
def clip_coords(coords, shape):
|
||||
"""
|
||||
Clip line coordinates to image boundaries.
|
||||
|
||||
Args:
|
||||
coords (torch.Tensor | np.ndarray): Line coordinates to clip.
|
||||
shape (tuple): Image shape as HWC or HW (supports both).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor | np.ndarray): Clipped coordinates.
|
||||
"""
|
||||
h, w = shape[:2] # supports both HWC or HW shapes
|
||||
if isinstance(coords, torch.Tensor):
|
||||
if NOT_MACOS14:
|
||||
coords[..., 0].clamp_(0, w) # x
|
||||
coords[..., 1].clamp_(0, h) # y
|
||||
else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
|
||||
coords[..., 0] = coords[..., 0].clamp(0, w)
|
||||
coords[..., 1] = coords[..., 1].clamp(0, h)
|
||||
else: # np.array
|
||||
coords[..., 0] = coords[..., 0].clip(0, w) # x
|
||||
coords[..., 1] = coords[..., 1].clip(0, h) # y
|
||||
return coords
|
||||
|
||||
|
||||
def scale_image(masks, im0_shape, ratio_pad=None):
|
||||
"""
|
||||
Rescale masks to original image size.
|
||||
|
||||
Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
|
||||
that was applied during preprocessing.
|
||||
|
||||
Args:
|
||||
masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
|
||||
im0_shape (tuple): Original image shape as HWC or HW (supports both).
|
||||
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
|
||||
"""
|
||||
# Rescale coordinates (xyxy) from im1_shape to im0_shape
|
||||
im0_h, im0_w = im0_shape[:2] # supports both HWC or HW shapes
|
||||
im1_h, im1_w, _ = masks.shape
|
||||
if im1_h == im0_h and im1_w == im0_w:
|
||||
return masks
|
||||
|
||||
if ratio_pad is None: # calculate from im0_shape
|
||||
gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
|
||||
pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2 # wh padding
|
||||
else:
|
||||
pad = ratio_pad[1]
|
||||
|
||||
pad_w, pad_h = pad
|
||||
top = int(round(pad_h - 0.1))
|
||||
left = int(round(pad_w - 0.1))
|
||||
bottom = im1_h - int(round(pad_h + 0.1))
|
||||
right = im1_w - int(round(pad_w + 0.1))
|
||||
|
||||
if len(masks.shape) < 2:
|
||||
raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
|
||||
masks = masks[top:bottom, left:right]
|
||||
# handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
|
||||
masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
|
||||
masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
|
||||
if len(masks.shape) == 2:
|
||||
masks = masks[:, :, None]
|
||||
|
||||
return masks
|
||||
|
||||
|
||||
def xyxy2xywh(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = empty_like(x) # faster than clone/copy
|
||||
x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
|
||||
y[..., 0] = (x1 + x2) / 2 # x center
|
||||
y[..., 1] = (y1 + y2) / 2 # y center
|
||||
y[..., 2] = x2 - x1 # width
|
||||
y[..., 3] = y2 - y1 # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2xyxy(x):
|
||||
"""
|
||||
Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
|
||||
top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = empty_like(x) # faster than clone/copy
|
||||
xy = x[..., :2] # centers
|
||||
wh = x[..., 2:] / 2 # half width-height
|
||||
y[..., :2] = xy - wh # top left xy
|
||||
y[..., 2:] = xy + wh # bottom right xy
|
||||
return y
|
||||
|
||||
|
||||
def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
|
||||
"""
|
||||
Convert normalized bounding box coordinates to pixel coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
|
||||
w (int): Image width in pixels.
|
||||
h (int): Image height in pixels.
|
||||
padw (int): Padding width in pixels.
|
||||
padh (int): Padding height in pixels.
|
||||
|
||||
Returns:
|
||||
y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
|
||||
x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
|
||||
"""
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = empty_like(x) # faster than clone/copy
|
||||
xc, yc, xw, xh = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
|
||||
half_w, half_h = xw / 2, xh / 2
|
||||
y[..., 0] = w * (xc - half_w) + padw # top left x
|
||||
y[..., 1] = h * (yc - half_h) + padh # top left y
|
||||
y[..., 2] = w * (xc + half_w) + padw # bottom right x
|
||||
y[..., 3] = h * (yc + half_h) + padh # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
|
||||
"""
|
||||
Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
|
||||
width and height are normalized to image dimensions.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
|
||||
w (int): Image width in pixels.
|
||||
h (int): Image height in pixels.
|
||||
clip (bool): Whether to clip boxes to image boundaries.
|
||||
eps (float): Minimum value for box width and height.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
|
||||
"""
|
||||
if clip:
|
||||
x = clip_boxes(x, (h - eps, w - eps))
|
||||
assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
|
||||
y = empty_like(x) # faster than clone/copy
|
||||
x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
|
||||
y[..., 0] = ((x1 + x2) / 2) / w # x center
|
||||
y[..., 1] = ((y1 + y2) / 2) / h # y center
|
||||
y[..., 2] = (x2 - x1) / w # width
|
||||
y[..., 3] = (y2 - y1) / h # height
|
||||
return y
|
||||
|
||||
|
||||
def xywh2ltwh(x):
|
||||
"""
|
||||
Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
|
||||
y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
|
||||
return y
|
||||
|
||||
|
||||
def xyxy2ltwh(x):
|
||||
"""
|
||||
Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 2] = x[..., 2] - x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] - x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def ltwh2xywh(x):
|
||||
"""
|
||||
Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input bounding box coordinates.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
|
||||
y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
|
||||
return y
|
||||
|
||||
|
||||
def xyxyxyxy2xywhr(x):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
|
||||
Rotation values are in radians from 0 to pi/2.
|
||||
"""
|
||||
is_torch = isinstance(x, torch.Tensor)
|
||||
points = x.cpu().numpy() if is_torch else x
|
||||
points = points.reshape(len(x), -1, 2)
|
||||
rboxes = []
|
||||
for pts in points:
|
||||
# NOTE: Use cv2.minAreaRect to get accurate xywhr,
|
||||
# especially some objects are cut off by augmentations in dataloader.
|
||||
(cx, cy), (w, h), angle = cv2.minAreaRect(pts)
|
||||
rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
|
||||
return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
|
||||
|
||||
|
||||
def xywhr2xyxyxyxy(x):
|
||||
"""
|
||||
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
|
||||
Rotation values should be in radians from 0 to pi/2.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
|
||||
"""
|
||||
cos, sin, cat, stack = (
|
||||
(torch.cos, torch.sin, torch.cat, torch.stack)
|
||||
if isinstance(x, torch.Tensor)
|
||||
else (np.cos, np.sin, np.concatenate, np.stack)
|
||||
)
|
||||
|
||||
ctr = x[..., :2]
|
||||
w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
|
||||
cos_value, sin_value = cos(angle), sin(angle)
|
||||
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
|
||||
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
|
||||
vec1 = cat(vec1, -1)
|
||||
vec2 = cat(vec2, -1)
|
||||
pt1 = ctr + vec1 + vec2
|
||||
pt2 = ctr + vec1 - vec2
|
||||
pt3 = ctr - vec1 - vec2
|
||||
pt4 = ctr - vec1 + vec2
|
||||
return stack([pt1, pt2, pt3, pt4], -2)
|
||||
|
||||
|
||||
def ltwh2xyxy(x):
|
||||
"""
|
||||
Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
|
||||
|
||||
Args:
|
||||
x (np.ndarray | torch.Tensor): Input bounding box coordinates.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
|
||||
"""
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
|
||||
y[..., 2] = x[..., 2] + x[..., 0] # width
|
||||
y[..., 3] = x[..., 3] + x[..., 1] # height
|
||||
return y
|
||||
|
||||
|
||||
def segments2boxes(segments):
|
||||
"""
|
||||
Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
|
||||
|
||||
Args:
|
||||
segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Bounding box coordinates in xywh format.
|
||||
"""
|
||||
boxes = []
|
||||
for s in segments:
|
||||
x, y = s.T # segment xy
|
||||
boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
|
||||
return xyxy2xywh(np.array(boxes)) # cls, xywh
|
||||
|
||||
|
||||
def resample_segments(segments, n: int = 1000):
|
||||
"""
|
||||
Resample segments to n points each using linear interpolation.
|
||||
|
||||
Args:
|
||||
segments (list): List of (N, 2) arrays where N is the number of points in each segment.
|
||||
n (int): Number of points to resample each segment to.
|
||||
|
||||
Returns:
|
||||
(list): Resampled segments with n points each.
|
||||
"""
|
||||
for i, s in enumerate(segments):
|
||||
if len(s) == n:
|
||||
continue
|
||||
s = np.concatenate((s, s[0:1, :]), axis=0)
|
||||
x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
|
||||
xp = np.arange(len(s))
|
||||
x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
|
||||
segments[i] = (
|
||||
np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
|
||||
) # segment xy
|
||||
return segments
|
||||
|
||||
|
||||
def crop_mask(masks, boxes):
|
||||
"""
|
||||
Crop masks to bounding box regions.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks with shape (N, H, W).
|
||||
boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Cropped masks.
|
||||
"""
|
||||
_, h, w = masks.shape
|
||||
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
|
||||
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
|
||||
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
|
||||
|
||||
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
|
||||
|
||||
|
||||
def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
|
||||
"""
|
||||
Apply masks to bounding boxes using mask head output.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
|
||||
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
|
||||
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
|
||||
shape (tuple): Input image size as (height, width).
|
||||
upsample (bool): Whether to upsample masks to original image size.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
|
||||
are the height and width of the input image. The mask is applied to the bounding boxes.
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
ih, iw = shape
|
||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
|
||||
width_ratio = mw / iw
|
||||
height_ratio = mh / ih
|
||||
|
||||
downsampled_bboxes = bboxes.clone()
|
||||
downsampled_bboxes[:, 0] *= width_ratio
|
||||
downsampled_bboxes[:, 2] *= width_ratio
|
||||
downsampled_bboxes[:, 3] *= height_ratio
|
||||
downsampled_bboxes[:, 1] *= height_ratio
|
||||
|
||||
masks = crop_mask(masks, downsampled_bboxes) # CHW
|
||||
if upsample:
|
||||
masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
|
||||
return masks.gt_(0.0)
|
||||
|
||||
|
||||
def process_mask_native(protos, masks_in, bboxes, shape):
|
||||
"""
|
||||
Apply masks to bounding boxes using mask head output with native upsampling.
|
||||
|
||||
Args:
|
||||
protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
|
||||
masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
|
||||
bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
|
||||
shape (tuple): Input image size as (height, width).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Binary mask tensor with shape (H, W, N).
|
||||
"""
|
||||
c, mh, mw = protos.shape # CHW
|
||||
masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
|
||||
masks = scale_masks(masks[None], shape)[0] # CHW
|
||||
masks = crop_mask(masks, bboxes) # CHW
|
||||
return masks.gt_(0.0)
|
||||
|
||||
|
||||
def scale_masks(masks, shape, padding: bool = True):
|
||||
"""
|
||||
Rescale segment masks to target shape.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Masks with shape (N, C, H, W).
|
||||
shape (tuple): Target height and width as (height, width).
|
||||
padding (bool): Whether masks are based on YOLO-style augmented images with padding.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Rescaled masks.
|
||||
"""
|
||||
mh, mw = masks.shape[2:]
|
||||
gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
|
||||
pad_w = mw - shape[1] * gain
|
||||
pad_h = mh - shape[0] * gain
|
||||
if padding:
|
||||
pad_w /= 2
|
||||
pad_h /= 2
|
||||
top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
|
||||
bottom = mh - int(round(pad_h + 0.1))
|
||||
right = mw - int(round(pad_w + 0.1))
|
||||
return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear", align_corners=False) # NCHW masks
|
||||
|
||||
|
||||
def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
|
||||
"""
|
||||
Rescale segment coordinates from img1_shape to img0_shape.
|
||||
|
||||
Args:
|
||||
img1_shape (tuple): Source image shape as HWC or HW (supports both).
|
||||
coords (torch.Tensor): Coordinates to scale with shape (N, 2).
|
||||
img0_shape (tuple): Image 0 shape as HWC or HW (supports both).
|
||||
ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
|
||||
normalize (bool): Whether to normalize coordinates to range [0, 1].
|
||||
padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Scaled coordinates.
|
||||
"""
|
||||
img0_h, img0_w = img0_shape[:2] # supports both HWC or HW shapes
|
||||
if ratio_pad is None: # calculate from img0_shape
|
||||
img1_h, img1_w = img1_shape[:2] # supports both HWC or HW shapes
|
||||
gain = min(img1_h / img0_h, img1_w / img0_w) # gain = old / new
|
||||
pad = (img1_w - img0_w * gain) / 2, (img1_h - img0_h * gain) / 2 # wh padding
|
||||
else:
|
||||
gain = ratio_pad[0][0]
|
||||
pad = ratio_pad[1]
|
||||
|
||||
if padding:
|
||||
coords[..., 0] -= pad[0] # x padding
|
||||
coords[..., 1] -= pad[1] # y padding
|
||||
coords[..., 0] /= gain
|
||||
coords[..., 1] /= gain
|
||||
coords = clip_coords(coords, img0_shape)
|
||||
if normalize:
|
||||
coords[..., 0] /= img0_w # width
|
||||
coords[..., 1] /= img0_h # height
|
||||
return coords
|
||||
|
||||
|
||||
def regularize_rboxes(rboxes):
|
||||
"""
|
||||
Regularize rotated bounding boxes to range [0, pi/2].
|
||||
|
||||
Args:
|
||||
rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Regularized rotated boxes.
|
||||
"""
|
||||
x, y, w, h, t = rboxes.unbind(dim=-1)
|
||||
# Swap edge if t >= pi/2 while not being symmetrically opposite
|
||||
swap = t % math.pi >= math.pi / 2
|
||||
w_ = torch.where(swap, h, w)
|
||||
h_ = torch.where(swap, w, h)
|
||||
t = t % (math.pi / 2)
|
||||
return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
|
||||
|
||||
|
||||
def masks2segments(masks, strategy: str = "all"):
|
||||
"""
|
||||
Convert masks to segments using contour detection.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
|
||||
strategy (str): Segmentation strategy, either 'all' or 'largest'.
|
||||
|
||||
Returns:
|
||||
(list): List of segment masks as float32 arrays.
|
||||
"""
|
||||
from ultralytics.data.converter import merge_multi_segment
|
||||
|
||||
segments = []
|
||||
for x in masks.int().cpu().numpy().astype("uint8"):
|
||||
c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
|
||||
if c:
|
||||
if strategy == "all": # merge and concatenate all segments
|
||||
c = (
|
||||
np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
|
||||
if len(c) > 1
|
||||
else c[0].reshape(-1, 2)
|
||||
)
|
||||
elif strategy == "largest": # select largest segment
|
||||
c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
|
||||
else:
|
||||
c = np.zeros((0, 2)) # no segments found
|
||||
segments.append(c.astype("float32"))
|
||||
return segments
|
||||
|
||||
|
||||
def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
|
||||
"""
|
||||
Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
|
||||
"""
|
||||
return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
|
||||
|
||||
|
||||
def clean_str(s):
|
||||
"""
|
||||
Clean a string by replacing special characters with '_' character.
|
||||
|
||||
Args:
|
||||
s (str): A string needing special characters replaced.
|
||||
|
||||
Returns:
|
||||
(str): A string with special characters replaced by an underscore _.
|
||||
"""
|
||||
return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
|
||||
|
||||
|
||||
def empty_like(x):
|
||||
"""Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
|
||||
return (
|
||||
torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
|
||||
)
|
||||
189
ultralytics/utils/patches.py
Normal file
189
ultralytics/utils/patches.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""Monkey patches to update/extend functionality of existing functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
|
||||
_imshow = cv2.imshow # copy to avoid recursion errors
|
||||
|
||||
|
||||
def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
|
||||
"""
|
||||
Read an image from a file with multilanguage filename support.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to read.
|
||||
flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
|
||||
|
||||
Returns:
|
||||
(np.ndarray | None): The read image array, or None if reading fails.
|
||||
|
||||
Examples:
|
||||
>>> img = imread("path/to/image.jpg")
|
||||
>>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE)
|
||||
"""
|
||||
file_bytes = np.fromfile(filename, np.uint8)
|
||||
if filename.endswith((".tiff", ".tif")):
|
||||
success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
|
||||
if success:
|
||||
# Handle RGB images in tif/tiff format
|
||||
return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
|
||||
return None
|
||||
else:
|
||||
im = cv2.imdecode(file_bytes, flags)
|
||||
return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
|
||||
|
||||
|
||||
def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
|
||||
"""
|
||||
Write an image to a file with multilanguage filename support.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the file to write.
|
||||
img (np.ndarray): Image to write.
|
||||
params (list[int], optional): Additional parameters for image encoding.
|
||||
|
||||
Returns:
|
||||
(bool): True if the file was written successfully, False otherwise.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
|
||||
>>> success = imwrite("output.jpg", img) # Write image to file
|
||||
>>> print(success)
|
||||
True
|
||||
"""
|
||||
try:
|
||||
cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def imshow(winname: str, mat: np.ndarray) -> None:
|
||||
"""
|
||||
Display an image in the specified window with multilanguage window name support.
|
||||
|
||||
This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
|
||||
multilanguage window names by encoding them properly for OpenCV compatibility.
|
||||
|
||||
Args:
|
||||
winname (str): Name of the window where the image will be displayed. If a window with this name already
|
||||
exists, the image will be displayed in that window.
|
||||
mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image
|
||||
>>> img[:100, :100] = [255, 0, 0] # Add a blue square
|
||||
>>> imshow("Example Window", img) # Display the image
|
||||
"""
|
||||
_imshow(winname.encode("unicode_escape").decode(), mat)
|
||||
|
||||
|
||||
# PyTorch functions ----------------------------------------------------------------------------------------------------
|
||||
_torch_save = torch.save
|
||||
|
||||
|
||||
def torch_load(*args, **kwargs):
|
||||
"""
|
||||
Load a PyTorch model with updated arguments to avoid warnings.
|
||||
|
||||
This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
|
||||
|
||||
Args:
|
||||
*args (Any): Variable length argument list to pass to torch.load.
|
||||
**kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
|
||||
|
||||
Returns:
|
||||
(Any): The loaded PyTorch object.
|
||||
|
||||
Notes:
|
||||
For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
|
||||
if the argument is not provided, to avoid deprecation warnings.
|
||||
"""
|
||||
from ultralytics.utils.torch_utils import TORCH_1_13
|
||||
|
||||
if TORCH_1_13 and "weights_only" not in kwargs:
|
||||
kwargs["weights_only"] = False
|
||||
|
||||
return torch.load(*args, **kwargs)
|
||||
|
||||
|
||||
def torch_save(*args, **kwargs):
|
||||
"""
|
||||
Save PyTorch objects with retry mechanism for robustness.
|
||||
|
||||
This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
|
||||
due to device flushing delays or antivirus scanning.
|
||||
|
||||
Args:
|
||||
*args (Any): Positional arguments to pass to torch.save.
|
||||
**kwargs (Any): Keyword arguments to pass to torch.save.
|
||||
|
||||
Examples:
|
||||
>>> model = torch.nn.Linear(10, 1)
|
||||
>>> torch_save(model.state_dict(), "model.pt")
|
||||
"""
|
||||
for i in range(4): # 3 retries
|
||||
try:
|
||||
return _torch_save(*args, **kwargs)
|
||||
except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
|
||||
if i == 3:
|
||||
raise e
|
||||
time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
|
||||
|
||||
|
||||
@contextmanager
|
||||
def arange_patch(args):
|
||||
"""
|
||||
Workaround for ONNX torch.arange incompatibility with FP16.
|
||||
|
||||
https://github.com/pytorch/pytorch/issues/148041.
|
||||
"""
|
||||
if args.dynamic and args.half and args.format == "onnx":
|
||||
func = torch.arange
|
||||
|
||||
def arange(*args, dtype=None, **kwargs):
|
||||
"""Return a 1-D tensor of size with values from the interval and common difference."""
|
||||
return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
|
||||
|
||||
torch.arange = arange # patch
|
||||
yield
|
||||
torch.arange = func # unpatch
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_configs(args, overrides: dict[str, Any] | None = None):
|
||||
"""
|
||||
Context manager to temporarily override configurations in args.
|
||||
|
||||
Args:
|
||||
args (IterableSimpleNamespace): Original configuration arguments.
|
||||
overrides (dict[str, Any]): Dictionary of overrides to apply.
|
||||
|
||||
Yields:
|
||||
(IterableSimpleNamespace): Configuration arguments with overrides applied.
|
||||
"""
|
||||
if overrides:
|
||||
original_args = copy(args)
|
||||
for key, value in overrides.items():
|
||||
setattr(args, key, value)
|
||||
try:
|
||||
yield args
|
||||
finally:
|
||||
args.__dict__.update(original_args.__dict__)
|
||||
else:
|
||||
yield args
|
||||
1031
ultralytics/utils/plotting.py
Normal file
1031
ultralytics/utils/plotting.py
Normal file
File diff suppressed because it is too large
Load Diff
417
ultralytics/utils/tal.py
Normal file
417
ultralytics/utils/tal.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from . import LOGGER
|
||||
from .metrics import bbox_iou, probiou
|
||||
from .ops import xywhr2xyxyxyxy
|
||||
from .torch_utils import TORCH_1_11
|
||||
|
||||
|
||||
class TaskAlignedAssigner(nn.Module):
|
||||
"""
|
||||
A task-aligned assigner for object detection.
|
||||
|
||||
This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
|
||||
classification and localization information.
|
||||
|
||||
Attributes:
|
||||
topk (int): The number of top candidates to consider.
|
||||
num_classes (int): The number of object classes.
|
||||
alpha (float): The alpha parameter for the classification component of the task-aligned metric.
|
||||
beta (float): The beta parameter for the localization component of the task-aligned metric.
|
||||
eps (float): A small value to prevent division by zero.
|
||||
"""
|
||||
|
||||
def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
|
||||
"""
|
||||
Initialize a TaskAlignedAssigner object with customizable hyperparameters.
|
||||
|
||||
Args:
|
||||
topk (int, optional): The number of top candidates to consider.
|
||||
num_classes (int, optional): The number of object classes.
|
||||
alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
|
||||
beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
|
||||
eps (float, optional): A small value to prevent division by zero.
|
||||
"""
|
||||
super().__init__()
|
||||
self.topk = topk
|
||||
self.num_classes = num_classes
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment.
|
||||
|
||||
Args:
|
||||
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
||||
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
||||
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
||||
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
||||
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
||||
|
||||
Returns:
|
||||
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
||||
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
||||
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
||||
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
||||
|
||||
References:
|
||||
https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
|
||||
"""
|
||||
self.bs = pd_scores.shape[0]
|
||||
self.n_max_boxes = gt_bboxes.shape[1]
|
||||
device = gt_bboxes.device
|
||||
|
||||
if self.n_max_boxes == 0:
|
||||
return (
|
||||
torch.full_like(pd_scores[..., 0], self.num_classes),
|
||||
torch.zeros_like(pd_bboxes),
|
||||
torch.zeros_like(pd_scores),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
torch.zeros_like(pd_scores[..., 0]),
|
||||
)
|
||||
|
||||
try:
|
||||
return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
# Move tensors to CPU, compute, then move back to original device
|
||||
LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
|
||||
cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
|
||||
result = self._forward(*cpu_tensors)
|
||||
return tuple(t.to(device) for t in result)
|
||||
|
||||
def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute the task-aligned assignment.
|
||||
|
||||
Args:
|
||||
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
||||
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
||||
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
||||
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
||||
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
||||
|
||||
Returns:
|
||||
target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
|
||||
target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
|
||||
fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
|
||||
target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
|
||||
"""
|
||||
mask_pos, align_metric, overlaps = self.get_pos_mask(
|
||||
pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
|
||||
)
|
||||
|
||||
target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
|
||||
|
||||
# Assigned target
|
||||
target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
|
||||
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
|
||||
|
||||
def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
|
||||
"""
|
||||
Get positive mask for each ground truth box.
|
||||
|
||||
Args:
|
||||
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
||||
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
||||
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
||||
anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
|
||||
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
|
||||
|
||||
Returns:
|
||||
mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
|
||||
align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
|
||||
overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
|
||||
"""
|
||||
mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
|
||||
# Get anchor_align metric, (b, max_num_obj, h*w)
|
||||
align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
|
||||
# Get topk_metric mask, (b, max_num_obj, h*w)
|
||||
mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
|
||||
# Merge all mask to a final mask, (b, max_num_obj, h*w)
|
||||
mask_pos = mask_topk * mask_in_gts * mask_gt
|
||||
|
||||
return mask_pos, align_metric, overlaps
|
||||
|
||||
def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
|
||||
"""
|
||||
Compute alignment metric given predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
|
||||
pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
|
||||
gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
|
||||
gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
|
||||
mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
|
||||
|
||||
Returns:
|
||||
align_metric (torch.Tensor): Alignment metric combining classification and localization.
|
||||
overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
|
||||
"""
|
||||
na = pd_bboxes.shape[-2]
|
||||
mask_gt = mask_gt.bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||
bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||
|
||||
ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
||||
# Get the scores of each grid for each gt cls
|
||||
bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
|
||||
|
||||
# (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
|
||||
overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
return align_metric, overlaps
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""
|
||||
Calculate IoU for horizontal bounding boxes.
|
||||
|
||||
Args:
|
||||
gt_bboxes (torch.Tensor): Ground truth boxes.
|
||||
pd_bboxes (torch.Tensor): Predicted boxes.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): IoU values between each pair of boxes.
|
||||
"""
|
||||
return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
|
||||
|
||||
def select_topk_candidates(self, metrics, topk_mask=None):
|
||||
"""
|
||||
Select the top-k candidates based on the given metrics.
|
||||
|
||||
Args:
|
||||
metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
|
||||
the maximum number of objects, and h*w represents the total number of anchor points.
|
||||
topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
|
||||
topk is the number of top candidates to consider. If not provided, the top-k values are automatically
|
||||
computed based on the given metrics.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
|
||||
"""
|
||||
# (b, max_num_obj, topk)
|
||||
topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
|
||||
if topk_mask is None:
|
||||
topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
|
||||
# (b, max_num_obj, topk)
|
||||
topk_idxs.masked_fill_(~topk_mask, 0)
|
||||
|
||||
# (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
|
||||
count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
|
||||
ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
|
||||
for k in range(self.topk):
|
||||
# Expand topk_idxs for each value of k and add 1 at the specified positions
|
||||
count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
|
||||
# Filter invalid bboxes
|
||||
count_tensor.masked_fill_(count_tensor > 1, 0)
|
||||
|
||||
return count_tensor.to(metrics.dtype)
|
||||
|
||||
def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
|
||||
"""
|
||||
Compute target labels, target bounding boxes, and target scores for the positive anchor points.
|
||||
|
||||
Args:
|
||||
gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
|
||||
batch size and max_num_obj is the maximum number of objects.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
|
||||
target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
|
||||
anchor points, with shape (b, h*w), where h*w is the total
|
||||
number of anchor points.
|
||||
fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
|
||||
(foreground) anchor points.
|
||||
|
||||
Returns:
|
||||
target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
|
||||
target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
|
||||
target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
|
||||
"""
|
||||
# Assigned target labels, (b, 1)
|
||||
batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
|
||||
target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
|
||||
|
||||
# Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
||||
# 10x faster than F.one_hot()
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.num_classes),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
) # (b, h*w, 80)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
return target_labels, target_bboxes, target_scores
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
|
||||
"""
|
||||
Select positive anchor centers within ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
|
||||
eps (float, optional): Small value for numerical stability.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
|
||||
|
||||
Note:
|
||||
b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
|
||||
Bounding box format: [x_min, y_min, x_max, y_max].
|
||||
"""
|
||||
n_anchors = xy_centers.shape[0]
|
||||
bs, n_boxes, _ = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
|
||||
bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
|
||||
return bbox_deltas.amin(3).gt_(eps)
|
||||
|
||||
@staticmethod
|
||||
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
|
||||
"""
|
||||
Select anchor boxes with highest IoU when assigned to multiple ground truths.
|
||||
|
||||
Args:
|
||||
mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
|
||||
overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
|
||||
n_max_boxes (int): Maximum number of ground truth boxes.
|
||||
|
||||
Returns:
|
||||
target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
|
||||
fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
|
||||
mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
|
||||
"""
|
||||
# Convert (b, n_max_boxes, h*w) -> (b, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
|
||||
max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
# Find each grid serve which gt(index)
|
||||
target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
|
||||
return target_gt_idx, fg_mask, mask_pos
|
||||
|
||||
|
||||
class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
|
||||
"""Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
|
||||
|
||||
def iou_calculation(self, gt_bboxes, pd_bboxes):
|
||||
"""Calculate IoU for rotated bounding boxes."""
|
||||
return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
|
||||
|
||||
@staticmethod
|
||||
def select_candidates_in_gts(xy_centers, gt_bboxes):
|
||||
"""
|
||||
Select the positive anchor center in gt for rotated bounding boxes.
|
||||
|
||||
Args:
|
||||
xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
|
||||
"""
|
||||
# (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
|
||||
corners = xywhr2xyxyxyxy(gt_bboxes)
|
||||
# (b, n_boxes, 1, 2)
|
||||
a, b, _, d = corners.split(1, dim=-2)
|
||||
ab = b - a
|
||||
ad = d - a
|
||||
|
||||
# (b, n_boxes, h*w, 2)
|
||||
ap = xy_centers - a
|
||||
norm_ab = (ab * ab).sum(dim=-1)
|
||||
norm_ad = (ad * ad).sum(dim=-1)
|
||||
ap_dot_ab = (ap * ab).sum(dim=-1)
|
||||
ap_dot_ad = (ap * ad).sum(dim=-1)
|
||||
return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
|
||||
|
||||
|
||||
def make_anchors(feats, strides, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points, stride_tensor = [], []
|
||||
assert feats is not None
|
||||
dtype, device = feats[0].dtype, feats[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
|
||||
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_points), torch.cat(stride_tensor)
|
||||
|
||||
|
||||
def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
|
||||
"""Transform distance(ltrb) to box(xywh or xyxy)."""
|
||||
lt, rb = distance.chunk(2, dim)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
if xywh:
|
||||
c_xy = (x1y1 + x2y2) / 2
|
||||
wh = x2y2 - x1y1
|
||||
return torch.cat([c_xy, wh], dim) # xywh bbox
|
||||
return torch.cat((x1y1, x2y2), dim) # xyxy bbox
|
||||
|
||||
|
||||
def bbox2dist(anchor_points, bbox, reg_max):
|
||||
"""Transform bbox(xyxy) to dist(ltrb)."""
|
||||
x1y1, x2y2 = bbox.chunk(2, -1)
|
||||
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
|
||||
|
||||
|
||||
def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
|
||||
"""
|
||||
Decode predicted rotated bounding box coordinates from anchor points and distribution.
|
||||
|
||||
Args:
|
||||
pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
|
||||
pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
|
||||
anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
|
||||
dim (int, optional): Dimension along which to split.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
|
||||
"""
|
||||
lt, rb = pred_dist.split(2, dim=dim)
|
||||
cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
|
||||
# (bs, h*w, 1)
|
||||
xf, yf = ((rb - lt) / 2).split(1, dim=dim)
|
||||
x, y = xf * cos - yf * sin, xf * sin + yf * cos
|
||||
xy = torch.cat([x, y], dim=dim) + anchor_points
|
||||
return torch.cat([xy, lt + rb], dim=dim)
|
||||
1010
ultralytics/utils/torch_utils.py
Normal file
1010
ultralytics/utils/torch_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
440
ultralytics/utils/tqdm.py
Normal file
440
ultralytics/utils/tqdm.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from typing import IO, Any
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_noninteractive_console() -> bool:
|
||||
"""Check for known non-interactive console environments."""
|
||||
return "GITHUB_ACTIONS" in os.environ or "RUNPOD_POD_ID" in os.environ
|
||||
|
||||
|
||||
class TQDM:
|
||||
"""
|
||||
Lightweight zero-dependency progress bar for Ultralytics.
|
||||
|
||||
Provides clean, rich-style progress bars suitable for various environments including Weights & Biases,
|
||||
console outputs, and other logging systems. Features zero external dependencies, clean single-line output,
|
||||
rich-style progress bars with Unicode block characters, context manager support, iterator protocol support,
|
||||
and dynamic description updates.
|
||||
|
||||
Attributes:
|
||||
iterable (object): Iterable to wrap with progress bar.
|
||||
desc (str): Prefix description for the progress bar.
|
||||
total (int): Expected number of iterations.
|
||||
disable (bool): Whether to disable the progress bar.
|
||||
unit (str): String for units of iteration.
|
||||
unit_scale (bool): Auto-scale units flag.
|
||||
unit_divisor (int): Divisor for unit scaling.
|
||||
leave (bool): Whether to leave the progress bar after completion.
|
||||
mininterval (float): Minimum time interval between updates.
|
||||
initial (int): Initial counter value.
|
||||
n (int): Current iteration count.
|
||||
closed (bool): Whether the progress bar is closed.
|
||||
bar_format (str): Custom bar format string.
|
||||
file (object): Output file stream.
|
||||
|
||||
Methods:
|
||||
update: Update progress by n steps.
|
||||
set_description: Set or update the description.
|
||||
set_postfix: Set postfix for the progress bar.
|
||||
close: Close the progress bar and clean up.
|
||||
refresh: Refresh the progress bar display.
|
||||
clear: Clear the progress bar from display.
|
||||
write: Write a message without breaking the progress bar.
|
||||
|
||||
Examples:
|
||||
Basic usage with iterator:
|
||||
>>> for i in TQDM(range(100)):
|
||||
... time.sleep(0.01)
|
||||
|
||||
With custom description:
|
||||
>>> pbar = TQDM(range(100), desc="Processing")
|
||||
>>> for i in pbar:
|
||||
... pbar.set_description(f"Processing item {i}")
|
||||
|
||||
Context manager usage:
|
||||
>>> with TQDM(total=100, unit="B", unit_scale=True) as pbar:
|
||||
... for i in range(100):
|
||||
... pbar.update(1)
|
||||
|
||||
Manual updates:
|
||||
>>> pbar = TQDM(total=100, desc="Training")
|
||||
>>> for epoch in range(100):
|
||||
... # Do work
|
||||
... pbar.update(1)
|
||||
>>> pbar.close()
|
||||
"""
|
||||
|
||||
# Constants
|
||||
MIN_RATE_CALC_INTERVAL = 0.01 # Minimum time interval for rate calculation
|
||||
RATE_SMOOTHING_FACTOR = 0.3 # Factor for exponential smoothing of rates
|
||||
MAX_SMOOTHED_RATE = 1000000 # Maximum rate to apply smoothing to
|
||||
NONINTERACTIVE_MIN_INTERVAL = 60.0 # Minimum interval for non-interactive environments
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
iterable: Any = None,
|
||||
desc: str | None = None,
|
||||
total: int | None = None,
|
||||
leave: bool = True,
|
||||
file: IO[str] | None = None,
|
||||
mininterval: float = 0.1,
|
||||
disable: bool | None = None,
|
||||
unit: str = "it",
|
||||
unit_scale: bool = True,
|
||||
unit_divisor: int = 1000,
|
||||
bar_format: str | None = None, # kept for API compatibility; not used for formatting
|
||||
initial: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the TQDM progress bar with specified configuration options.
|
||||
|
||||
Args:
|
||||
iterable (object, optional): Iterable to wrap with progress bar.
|
||||
desc (str, optional): Prefix description for the progress bar.
|
||||
total (int, optional): Expected number of iterations.
|
||||
leave (bool, optional): Whether to leave the progress bar after completion.
|
||||
file (object, optional): Output file stream for progress display.
|
||||
mininterval (float, optional): Minimum time interval between updates (default 0.1s, 60s in GitHub Actions).
|
||||
disable (bool, optional): Whether to disable the progress bar. Auto-detected if None.
|
||||
unit (str, optional): String for units of iteration (default "it" for items).
|
||||
unit_scale (bool, optional): Auto-scale units for bytes/data units.
|
||||
unit_divisor (int, optional): Divisor for unit scaling (default 1000).
|
||||
bar_format (str, optional): Custom bar format string.
|
||||
initial (int, optional): Initial counter value.
|
||||
**kwargs (Any): Additional keyword arguments for compatibility (ignored).
|
||||
|
||||
Examples:
|
||||
>>> pbar = TQDM(range(100), desc="Processing")
|
||||
>>> with TQDM(total=1000, unit="B", unit_scale=True) as pbar:
|
||||
... pbar.update(1024) # Updates by 1KB
|
||||
"""
|
||||
# Disable if not verbose
|
||||
if disable is None:
|
||||
try:
|
||||
from ultralytics.utils import LOGGER, VERBOSE
|
||||
|
||||
disable = not VERBOSE or LOGGER.getEffectiveLevel() > 20
|
||||
except ImportError:
|
||||
disable = False
|
||||
|
||||
self.iterable = iterable
|
||||
self.desc = desc or ""
|
||||
self.total = total or (len(iterable) if hasattr(iterable, "__len__") else None) or None # prevent total=0
|
||||
self.disable = disable
|
||||
self.unit = unit
|
||||
self.unit_scale = unit_scale
|
||||
self.unit_divisor = unit_divisor
|
||||
self.leave = leave
|
||||
self.noninteractive = is_noninteractive_console()
|
||||
self.mininterval = max(mininterval, self.NONINTERACTIVE_MIN_INTERVAL) if self.noninteractive else mininterval
|
||||
self.initial = initial
|
||||
|
||||
# Kept for API compatibility (unused for f-string formatting)
|
||||
self.bar_format = bar_format
|
||||
|
||||
self.file = file or sys.stdout
|
||||
|
||||
# Internal state
|
||||
self.n = self.initial
|
||||
self.last_print_n = self.initial
|
||||
self.last_print_t = time.time()
|
||||
self.start_t = time.time()
|
||||
self.last_rate = 0.0
|
||||
self.closed = False
|
||||
self.is_bytes = unit_scale and unit in ("B", "bytes")
|
||||
self.scales = (
|
||||
[(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]
|
||||
if self.is_bytes
|
||||
else [(1e9, f"G{self.unit}/s"), (1e6, f"M{self.unit}/s"), (1e3, f"K{self.unit}/s")]
|
||||
)
|
||||
|
||||
if not self.disable and self.total and not self.noninteractive:
|
||||
self._display()
|
||||
|
||||
def _format_rate(self, rate: float) -> str:
|
||||
"""Format rate with units."""
|
||||
if rate <= 0:
|
||||
return ""
|
||||
fallback = f"{rate:.1f}B/s" if self.is_bytes else f"{rate:.1f}{self.unit}/s"
|
||||
return next((f"{rate / t:.1f}{u}" for t, u in self.scales if rate >= t), fallback)
|
||||
|
||||
def _format_num(self, num: int | float) -> str:
|
||||
"""Format number with optional unit scaling."""
|
||||
if not self.unit_scale or not self.is_bytes:
|
||||
return str(num)
|
||||
|
||||
for unit in ("", "K", "M", "G", "T"):
|
||||
if abs(num) < self.unit_divisor:
|
||||
return f"{num:3.1f}{unit}B" if unit else f"{num:.0f}B"
|
||||
num /= self.unit_divisor
|
||||
return f"{num:.1f}PB"
|
||||
|
||||
def _format_time(self, seconds: float) -> str:
|
||||
"""Format time duration."""
|
||||
if seconds < 60:
|
||||
return f"{seconds:.1f}s"
|
||||
elif seconds < 3600:
|
||||
return f"{int(seconds // 60)}:{seconds % 60:02.0f}"
|
||||
else:
|
||||
h, m = int(seconds // 3600), int((seconds % 3600) // 60)
|
||||
return f"{h}:{m:02d}:{seconds % 60:02.0f}"
|
||||
|
||||
def _generate_bar(self, width: int = 12) -> str:
|
||||
"""Generate progress bar."""
|
||||
if self.total is None:
|
||||
return "━" * width if self.closed else "─" * width
|
||||
|
||||
frac = min(1.0, self.n / self.total)
|
||||
filled = int(frac * width)
|
||||
bar = "━" * filled + "─" * (width - filled)
|
||||
if filled < width and frac * width - filled > 0.5:
|
||||
bar = f"{bar[:filled]}╸{bar[filled + 1 :]}"
|
||||
return bar
|
||||
|
||||
def _should_update(self, dt: float, dn: int) -> bool:
|
||||
"""Check if display should update."""
|
||||
if self.noninteractive:
|
||||
return False
|
||||
return (self.total is not None and self.n >= self.total) or (dt >= self.mininterval)
|
||||
|
||||
def _display(self, final: bool = False) -> None:
|
||||
"""Display progress bar."""
|
||||
if self.disable or (self.closed and not final):
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
dt = current_time - self.last_print_t
|
||||
dn = self.n - self.last_print_n
|
||||
|
||||
if not final and not self._should_update(dt, dn):
|
||||
return
|
||||
|
||||
# Calculate rate (avoid crazy numbers)
|
||||
if dt > self.MIN_RATE_CALC_INTERVAL:
|
||||
rate = dn / dt if dt else 0.0
|
||||
# Smooth rate for reasonable values, use raw rate for very high values
|
||||
if rate < self.MAX_SMOOTHED_RATE:
|
||||
self.last_rate = self.RATE_SMOOTHING_FACTOR * rate + (1 - self.RATE_SMOOTHING_FACTOR) * self.last_rate
|
||||
rate = self.last_rate
|
||||
else:
|
||||
rate = self.last_rate
|
||||
|
||||
# At completion, use overall rate
|
||||
if self.total and self.n >= self.total:
|
||||
overall_elapsed = current_time - self.start_t
|
||||
if overall_elapsed > 0:
|
||||
rate = self.n / overall_elapsed
|
||||
|
||||
# Update counters
|
||||
self.last_print_n = self.n
|
||||
self.last_print_t = current_time
|
||||
elapsed = current_time - self.start_t
|
||||
|
||||
# Remaining time
|
||||
remaining_str = ""
|
||||
if self.total and 0 < self.n < self.total and elapsed > 0:
|
||||
est_rate = rate or (self.n / elapsed)
|
||||
remaining_str = f"<{self._format_time((self.total - self.n) / est_rate)}"
|
||||
|
||||
# Numbers and percent
|
||||
if self.total:
|
||||
percent = (self.n / self.total) * 100
|
||||
n_str = self._format_num(self.n)
|
||||
t_str = self._format_num(self.total)
|
||||
if self.is_bytes:
|
||||
# Collapse suffix only when identical (e.g. "5.4/5.4MB")
|
||||
if n_str[-2] == t_str[-2]:
|
||||
n_str = n_str.rstrip("KMGTPB") # Remove unit suffix from current if different than total
|
||||
else:
|
||||
percent = 0.0
|
||||
n_str, t_str = self._format_num(self.n), "?"
|
||||
|
||||
elapsed_str = self._format_time(elapsed)
|
||||
rate_str = self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else "")
|
||||
|
||||
bar = self._generate_bar()
|
||||
|
||||
# Compose progress line via f-strings (two shapes: with/without total)
|
||||
if self.total:
|
||||
if self.is_bytes and self.n >= self.total:
|
||||
# Completed bytes: show only final size
|
||||
progress_str = f"{self.desc}: {percent:.0f}% {bar} {t_str} {rate_str} {elapsed_str}"
|
||||
else:
|
||||
progress_str = (
|
||||
f"{self.desc}: {percent:.0f}% {bar} {n_str}/{t_str} {rate_str} {elapsed_str}{remaining_str}"
|
||||
)
|
||||
else:
|
||||
progress_str = f"{self.desc}: {bar} {n_str} {rate_str} {elapsed_str}"
|
||||
|
||||
# Write to output
|
||||
try:
|
||||
if self.noninteractive:
|
||||
# In non-interactive environments, avoid carriage return which creates empty lines
|
||||
self.file.write(progress_str)
|
||||
else:
|
||||
# In interactive terminals, use carriage return and clear line for updating display
|
||||
self.file.write(f"\r\033[K{progress_str}")
|
||||
self.file.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update(self, n: int = 1) -> None:
|
||||
"""Update progress by n steps."""
|
||||
if not self.disable and not self.closed:
|
||||
self.n += n
|
||||
self._display()
|
||||
|
||||
def set_description(self, desc: str | None) -> None:
|
||||
"""Set description."""
|
||||
self.desc = desc or ""
|
||||
if not self.disable:
|
||||
self._display()
|
||||
|
||||
def set_postfix(self, **kwargs: Any) -> None:
|
||||
"""Set postfix (appends to description)."""
|
||||
if kwargs:
|
||||
postfix = ", ".join(f"{k}={v}" for k, v in kwargs.items())
|
||||
base_desc = self.desc.split(" | ")[0] if " | " in self.desc else self.desc
|
||||
self.set_description(f"{base_desc} | {postfix}")
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close progress bar."""
|
||||
if self.closed:
|
||||
return
|
||||
|
||||
self.closed = True
|
||||
|
||||
if not self.disable:
|
||||
# Final display
|
||||
if self.total and self.n >= self.total:
|
||||
self.n = self.total
|
||||
self._display(final=True)
|
||||
|
||||
# Cleanup
|
||||
if self.leave:
|
||||
self.file.write("\n")
|
||||
else:
|
||||
self.file.write("\r\033[K")
|
||||
|
||||
try:
|
||||
self.file.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def __enter__(self) -> TQDM:
|
||||
"""Enter context manager."""
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit context manager and close progress bar."""
|
||||
self.close()
|
||||
|
||||
def __iter__(self) -> Any:
|
||||
"""Iterate over the wrapped iterable with progress updates."""
|
||||
if self.iterable is None:
|
||||
raise TypeError("'NoneType' object is not iterable")
|
||||
|
||||
try:
|
||||
for item in self.iterable:
|
||||
yield item
|
||||
self.update(1)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Destructor to ensure cleanup."""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def refresh(self) -> None:
|
||||
"""Refresh display."""
|
||||
if not self.disable:
|
||||
self._display()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear progress bar."""
|
||||
if not self.disable:
|
||||
try:
|
||||
self.file.write("\r\033[K")
|
||||
self.file.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def write(s: str, file: IO[str] | None = None, end: str = "\n") -> None:
|
||||
"""Static method to write without breaking progress bar."""
|
||||
file = file or sys.stdout
|
||||
try:
|
||||
file.write(s + end)
|
||||
file.flush()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
print("1. Basic progress bar with known total:")
|
||||
for i in TQDM(range(3), desc="Known total"):
|
||||
time.sleep(0.05)
|
||||
|
||||
print("\n2. Manual updates with known total:")
|
||||
pbar = TQDM(total=300, desc="Manual updates", unit="files")
|
||||
for i in range(300):
|
||||
time.sleep(0.03)
|
||||
pbar.update(1)
|
||||
if i % 10 == 9:
|
||||
pbar.set_description(f"Processing batch {i // 10 + 1}")
|
||||
pbar.close()
|
||||
|
||||
print("\n3. Progress bar with unknown total:")
|
||||
pbar = TQDM(desc="Unknown total", unit="items")
|
||||
for i in range(25):
|
||||
time.sleep(0.08)
|
||||
pbar.update(1)
|
||||
if i % 5 == 4:
|
||||
pbar.set_postfix(processed=i + 1, status="OK")
|
||||
pbar.close()
|
||||
|
||||
print("\n4. Context manager with unknown total:")
|
||||
with TQDM(desc="Processing stream", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
||||
for i in range(30):
|
||||
time.sleep(0.1)
|
||||
pbar.update(1024 * 1024 * i) # Simulate processing MB of data
|
||||
|
||||
print("\n5. Iterator with unknown length:")
|
||||
|
||||
def data_stream():
|
||||
"""Simulate a data stream of unknown length."""
|
||||
import random
|
||||
|
||||
for i in range(random.randint(10, 20)):
|
||||
yield f"data_chunk_{i}"
|
||||
|
||||
for chunk in TQDM(data_stream(), desc="Stream processing", unit="chunks"):
|
||||
time.sleep(0.1)
|
||||
|
||||
print("\n6. File processing simulation (unknown size):")
|
||||
|
||||
def process_files():
|
||||
"""Simulate processing files of unknown count."""
|
||||
return [f"file_{i}.txt" for i in range(18)]
|
||||
|
||||
pbar = TQDM(desc="Scanning files", unit="files")
|
||||
files = process_files()
|
||||
for i, filename in enumerate(files):
|
||||
time.sleep(0.06)
|
||||
pbar.update(1)
|
||||
pbar.set_description(f"Processing {filename}")
|
||||
pbar.close()
|
||||
118
ultralytics/utils/triton.py
Normal file
118
ultralytics/utils/triton.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TritonRemoteModel:
|
||||
"""
|
||||
Client for interacting with a remote Triton Inference Server model.
|
||||
|
||||
This class provides a convenient interface for sending inference requests to a Triton Inference Server
|
||||
and processing the responses. Supports both HTTP and gRPC communication protocols.
|
||||
|
||||
Attributes:
|
||||
endpoint (str): The name of the model on the Triton server.
|
||||
url (str): The URL of the Triton server.
|
||||
triton_client: The Triton client (either HTTP or gRPC).
|
||||
InferInput: The input class for the Triton client.
|
||||
InferRequestedOutput: The output request class for the Triton client.
|
||||
input_formats (list[str]): The data types of the model inputs.
|
||||
np_input_formats (list[type]): The numpy data types of the model inputs.
|
||||
input_names (list[str]): The names of the model inputs.
|
||||
output_names (list[str]): The names of the model outputs.
|
||||
metadata: The metadata associated with the model.
|
||||
|
||||
Methods:
|
||||
__call__: Call the model with the given inputs and return the outputs.
|
||||
|
||||
Examples:
|
||||
Initialize a Triton client with HTTP
|
||||
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
||||
|
||||
Make inference with numpy arrays
|
||||
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
||||
"""
|
||||
Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
|
||||
|
||||
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
||||
<scheme>://<netloc>/<endpoint>/<task_name>
|
||||
|
||||
Args:
|
||||
url (str): The URL of the Triton server.
|
||||
endpoint (str, optional): The name of the model on the Triton server.
|
||||
scheme (str, optional): The communication scheme ('http' or 'grpc').
|
||||
|
||||
Examples:
|
||||
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
||||
>>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
|
||||
"""
|
||||
if not endpoint and not scheme: # Parse all args from URL string
|
||||
splits = urlsplit(url)
|
||||
endpoint = splits.path.strip("/").split("/", 1)[0]
|
||||
scheme = splits.scheme
|
||||
url = splits.netloc
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.url = url
|
||||
|
||||
# Choose the Triton client based on the communication scheme
|
||||
if scheme == "http":
|
||||
import tritonclient.http as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint)
|
||||
else:
|
||||
import tritonclient.grpc as client # noqa
|
||||
|
||||
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
||||
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
|
||||
|
||||
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
||||
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
|
||||
|
||||
# Define model attributes
|
||||
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
|
||||
self.InferRequestedOutput = client.InferRequestedOutput
|
||||
self.InferInput = client.InferInput
|
||||
self.input_formats = [x["data_type"] for x in config["input"]]
|
||||
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
||||
self.input_names = [x["name"] for x in config["input"]]
|
||||
self.output_names = [x["name"] for x in config["output"]]
|
||||
self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
|
||||
|
||||
def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]:
|
||||
"""
|
||||
Call the model with the given inputs and return inference results.
|
||||
|
||||
Args:
|
||||
*inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
|
||||
for the corresponding model input.
|
||||
|
||||
Returns:
|
||||
(list[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
|
||||
corresponds to one of the model's output tensors.
|
||||
|
||||
Examples:
|
||||
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
||||
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
|
||||
"""
|
||||
infer_inputs = []
|
||||
input_format = inputs[0].dtype
|
||||
for i, x in enumerate(inputs):
|
||||
if x.dtype != self.np_input_formats[i]:
|
||||
x = x.astype(self.np_input_formats[i])
|
||||
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
|
||||
infer_input.set_data_from_numpy(x)
|
||||
infer_inputs.append(infer_input)
|
||||
|
||||
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
|
||||
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
|
||||
|
||||
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
|
||||
159
ultralytics/utils/tuner.py
Normal file
159
ultralytics/utils/tuner.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
|
||||
|
||||
|
||||
def run_ray_tune(
|
||||
model,
|
||||
space: dict = None,
|
||||
grace_period: int = 10,
|
||||
gpu_per_trial: int = None,
|
||||
max_samples: int = 10,
|
||||
**train_args,
|
||||
):
|
||||
"""
|
||||
Run hyperparameter tuning using Ray Tune.
|
||||
|
||||
Args:
|
||||
model (YOLO): Model to run the tuner on.
|
||||
space (dict, optional): The hyperparameter search space. If not provided, uses default space.
|
||||
grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
|
||||
gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
|
||||
max_samples (int, optional): The maximum number of trials to run.
|
||||
**train_args (Any): Additional arguments to pass to the `train()` method.
|
||||
|
||||
Returns:
|
||||
(ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt") # Load a YOLO11n model
|
||||
|
||||
Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
|
||||
>>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
|
||||
"""
|
||||
LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
|
||||
if train_args is None:
|
||||
train_args = {}
|
||||
|
||||
try:
|
||||
checks.check_requirements("ray[tune]")
|
||||
|
||||
import ray
|
||||
from ray import tune
|
||||
from ray.air import RunConfig
|
||||
from ray.air.integrations.wandb import WandbLoggerCallback
|
||||
from ray.tune.schedulers import ASHAScheduler
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
assert hasattr(wandb, "__version__")
|
||||
except (ImportError, AssertionError):
|
||||
wandb = False
|
||||
|
||||
checks.check_version(ray.__version__, ">=2.0.0", "ray")
|
||||
default_space = {
|
||||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
|
||||
"lr0": tune.uniform(1e-5, 1e-1),
|
||||
"lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
|
||||
"momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
|
||||
"weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay
|
||||
"warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
|
||||
"box": tune.uniform(0.02, 0.2), # box loss gain
|
||||
"cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
"hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
|
||||
"degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
|
||||
"translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
|
||||
"scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
|
||||
"shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
|
||||
"perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
|
||||
"flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
|
||||
"fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
|
||||
"bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
|
||||
"mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability)
|
||||
"mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
|
||||
"cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability)
|
||||
"copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
|
||||
}
|
||||
|
||||
# Put the model in ray store
|
||||
task = model.task
|
||||
model_in_store = ray.put(model)
|
||||
|
||||
def _tune(config):
|
||||
"""Train the YOLO model with the specified hyperparameters and return results."""
|
||||
model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
|
||||
model_to_train.reset_callbacks()
|
||||
config.update(train_args)
|
||||
results = model_to_train.train(**config)
|
||||
return results.results_dict
|
||||
|
||||
# Get search space
|
||||
if not space and not train_args.get("resume"):
|
||||
space = default_space
|
||||
LOGGER.warning("Search space not provided, using default search space.")
|
||||
|
||||
# Get dataset
|
||||
data = train_args.get("data", TASK2DATA[task])
|
||||
space["data"] = data
|
||||
if "data" not in train_args:
|
||||
LOGGER.warning(f'Data not provided, using default "data={data}".')
|
||||
|
||||
# Define the trainable function with allocated resources
|
||||
trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
|
||||
|
||||
# Define the ASHA scheduler for hyperparameter search
|
||||
asha_scheduler = ASHAScheduler(
|
||||
time_attr="epoch",
|
||||
metric=TASK2METRIC[task],
|
||||
mode="max",
|
||||
max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
|
||||
grace_period=grace_period,
|
||||
reduction_factor=3,
|
||||
)
|
||||
|
||||
# Define the callbacks for the hyperparameter search
|
||||
tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
|
||||
|
||||
# Create the Ray Tune hyperparameter search tuner
|
||||
tune_dir = get_save_dir(
|
||||
get_cfg(
|
||||
DEFAULT_CFG,
|
||||
{**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
|
||||
),
|
||||
name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
|
||||
) # must be absolute dir
|
||||
tune_dir.mkdir(parents=True, exist_ok=True)
|
||||
if tune.Tuner.can_restore(tune_dir):
|
||||
LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
|
||||
tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
|
||||
else:
|
||||
tuner = tune.Tuner(
|
||||
trainable_with_resources,
|
||||
param_space=space,
|
||||
tune_config=tune.TuneConfig(
|
||||
scheduler=asha_scheduler,
|
||||
num_samples=max_samples,
|
||||
trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
|
||||
trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
|
||||
),
|
||||
run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
|
||||
)
|
||||
|
||||
# Run the hyperparameter search
|
||||
tuner.fit()
|
||||
|
||||
# Get the results of the hyperparameter search
|
||||
results = tuner.get_results()
|
||||
|
||||
# Shut down Ray to clean up workers
|
||||
ray.shutdown()
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user