", self.on_canvas_click)
-
- self.rg_data.clear(), self.current_box.clear()
-
- def on_canvas_click(self, event) -> None:
- """Handle mouse clicks to add points for bounding boxes on the canvas."""
- self.current_box.append((event.x, event.y))
- self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red")
- if len(self.current_box) == 4:
- self.rg_data.append(self.current_box.copy())
- self.draw_box(self.current_box)
- self.current_box.clear()
-
- def draw_box(self, box: list[tuple[int, int]]) -> None:
- """Draw a bounding box on the canvas using the provided coordinates."""
- for i in range(4):
- self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2)
-
- def remove_last_bounding_box(self) -> None:
- """Remove the last bounding box from the list and redraw the canvas."""
- if not self.rg_data:
- self.messagebox.showwarning("Warning", "No bounding boxes to remove.")
- return
- self.rg_data.pop()
- self.redraw_canvas()
-
- def redraw_canvas(self) -> None:
- """Redraw the canvas with the image and all bounding boxes."""
- self.canvas.delete("all")
- self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)
- for box in self.rg_data:
- self.draw_box(box)
-
- def save_to_json(self) -> None:
- """Save the selected parking zone points to a JSON file with scaled coordinates."""
- scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()
- data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]
-
- from io import StringIO # Function level import, as it's only required to store coordinates
-
- write_buffer = StringIO()
- json.dump(data, write_buffer, indent=4)
- with open("bounding_boxes.json", "w", encoding="utf-8") as f:
- f.write(write_buffer.getvalue())
- self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json")
-
-
-class ParkingManagement(BaseSolution):
- """
- Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.
-
- This class extends BaseSolution to provide functionality for parking lot management, including detection of
- occupied spaces, visualization of parking regions, and display of occupancy statistics.
-
- Attributes:
- json_file (str): Path to the JSON file containing parking region details.
- json (list[dict]): Loaded JSON data containing parking region information.
- pr_info (dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).
- arc (tuple[int, int, int]): RGB color tuple for available region visualization.
- occ (tuple[int, int, int]): RGB color tuple for occupied region visualization.
- dc (tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.
-
- Methods:
- process: Process the input image for parking lot management and visualization.
-
- Examples:
- >>> from ultralytics.solutions import ParkingManagement
- >>> parking_manager = ParkingManagement(model="yolo11n.pt", json_file="parking_regions.json")
- >>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}")
- >>> print(f"Available spaces: {parking_manager.pr_info['Available']}")
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the parking management system with a YOLO model and visualization settings."""
- super().__init__(**kwargs)
-
- self.json_file = self.CFG["json_file"] # Load parking regions JSON data
- if self.json_file is None:
- LOGGER.warning("json_file argument missing. Parking region details required.")
- raise ValueError("❌ Json file path can not be empty")
-
- with open(self.json_file) as f:
- self.json = json.load(f)
-
- self.pr_info = {"Occupancy": 0, "Available": 0} # Dictionary for parking information
-
- self.arc = (0, 0, 255) # Available region color
- self.occ = (0, 255, 0) # Occupied region color
- self.dc = (255, 0, 189) # Centroid color for each box
-
- def process(self, im0: np.ndarray) -> SolutionResults:
- """
- Process the input image for parking lot management and visualization.
-
- This function analyzes the input image, extracts tracks, and determines the occupancy status of parking
- regions defined in the JSON file. It annotates the image with occupied and available parking spots,
- and updates the parking information.
-
- Args:
- im0 (np.ndarray): The input inference image.
-
- Returns:
- (SolutionResults): Contains processed image `plot_im`, 'filled_slots' (number of occupied parking slots),
- 'available_slots' (number of available parking slots), and 'total_tracks' (total number of tracked objects).
-
- Examples:
- >>> parking_manager = ParkingManagement(json_file="parking_regions.json")
- >>> image = cv2.imread("parking_lot.jpg")
- >>> results = parking_manager.process(image)
- """
- self.extract_tracks(im0) # Extract tracks from im0
- es, fs = len(self.json), 0 # Empty slots, filled slots
- annotator = SolutionAnnotator(im0, self.line_width) # Initialize annotator
-
- for region in self.json:
- # Convert points to a NumPy array with the correct dtype and reshape properly
- pts_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2))
- rg_occupied = False # Occupied region initialization
- for box, cls in zip(self.boxes, self.clss):
- xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
- dist = cv2.pointPolygonTest(pts_array, (xc, yc), False)
- if dist >= 0:
- # cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1)
- annotator.display_objects_labels(
- im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10
- )
- rg_occupied = True
- break
- fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es)
- # Plot regions
- cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2)
-
- self.pr_info["Occupancy"], self.pr_info["Available"] = fs, es
-
- annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10)
-
- plot_im = annotator.result()
- self.display_output(plot_im) # Display output with base class function
-
- # Return SolutionResults
- return SolutionResults(
- plot_im=plot_im,
- filled_slots=self.pr_info["Occupancy"],
- available_slots=self.pr_info["Available"],
- total_tracks=len(self.track_ids),
- )
diff --git a/ultralytics/solutions/queue_management.py b/ultralytics/solutions/queue_management.py
deleted file mode 100644
index 4cdbfa8..0000000
--- a/ultralytics/solutions/queue_management.py
+++ /dev/null
@@ -1,95 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from typing import Any
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils.plotting import colors
-
-
-class QueueManager(BaseSolution):
- """
- Manages queue counting in real-time video streams based on object tracks.
-
- This class extends BaseSolution to provide functionality for tracking and counting objects within a specified
- region in video frames.
-
- Attributes:
- counts (int): The current count of objects in the queue.
- rect_color (tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.
- region_length (int): The number of points defining the queue region.
- track_line (list[tuple[int, int]]): List of track line coordinates.
- track_history (dict[int, list[tuple[int, int]]]): Dictionary storing tracking history for each object.
-
- Methods:
- initialize_region: Initialize the queue region.
- process: Process a single frame for queue management.
- extract_tracks: Extract object tracks from the current frame.
- store_tracking_history: Store the tracking history for an object.
- display_output: Display the processed output.
-
- Examples:
- >>> cap = cv2.VideoCapture("path/to/video.mp4")
- >>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])
- >>> while cap.isOpened():
- >>> success, im0 = cap.read()
- >>> if not success:
- >>> break
- >>> results = queue_manager.process(im0)
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the QueueManager with parameters for tracking and counting objects in a video stream."""
- super().__init__(**kwargs)
- self.initialize_region()
- self.counts = 0 # Queue counts information
- self.rect_color = (255, 255, 255) # Rectangle color for visualization
- self.region_length = len(self.region) # Store region length for further usage
-
- def process(self, im0) -> SolutionResults:
- """
- Process queue management for a single frame of video.
-
- Args:
- im0 (np.ndarray): Input image for processing, typically a frame from a video stream.
-
- Returns:
- (SolutionResults): Contains processed image `im0`, 'queue_count' (int, number of objects in the queue) and
- 'total_tracks' (int, total number of tracked objects).
-
- Examples:
- >>> queue_manager = QueueManager()
- >>> frame = cv2.imread("frame.jpg")
- >>> results = queue_manager.process(frame)
- """
- self.counts = 0 # Reset counts every frame
- self.extract_tracks(im0) # Extract tracks from the current frame
- annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
- annotator.draw_region(reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2) # Draw region
-
- for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
- # Draw bounding box and counting region
- annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))
- self.store_tracking_history(track_id, box) # Store track history
-
- # Cache frequently accessed attributes
- track_history = self.track_history.get(track_id, [])
-
- # Store previous position of track and check if the object is inside the counting region
- prev_position = None
- if len(track_history) > 1:
- prev_position = track_history[-2]
- if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):
- self.counts += 1
-
- # Display queue counts
- annotator.queue_counts_display(
- f"Queue Counts : {str(self.counts)}",
- points=self.region,
- region_color=self.rect_color,
- txt_color=(104, 31, 17),
- )
- plot_im = annotator.result()
- self.display_output(plot_im) # Display output with base class function
-
- # Return a SolutionResults object with processed data
- return SolutionResults(plot_im=plot_im, queue_count=self.counts, total_tracks=len(self.track_ids))
diff --git a/ultralytics/solutions/region_counter.py b/ultralytics/solutions/region_counter.py
deleted file mode 100644
index 2f4d6fa..0000000
--- a/ultralytics/solutions/region_counter.py
+++ /dev/null
@@ -1,136 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils.plotting import colors
-
-
-class RegionCounter(BaseSolution):
- """
- A class for real-time counting of objects within user-defined regions in a video stream.
-
- This class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,
- track objects, and count those objects that pass through each defined region. Useful for applications requiring
- counting in specified areas, such as monitoring zones or segmented sections.
-
- Attributes:
- region_template (dict): Template for creating new counting regions with default attributes including name,
- polygon coordinates, and display colors.
- counting_regions (list): List storing all defined regions, where each entry is based on `region_template`
- and includes specific region settings like name, coordinates, and color.
- region_counts (dict): Dictionary storing the count of objects for each named region.
-
- Methods:
- add_region: Add a new counting region with specified attributes.
- process: Process video frames to count objects in each region.
- initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.
-
- Examples:
- Initialize a RegionCounter and add a counting region
- >>> counter = RegionCounter()
- >>> counter.add_region("Zone1", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))
- >>> results = counter.process(frame)
- >>> print(f"Total tracks: {results.total_tracks}")
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the RegionCounter for real-time object counting in user-defined regions."""
- super().__init__(**kwargs)
- self.region_template = {
- "name": "Default Region",
- "polygon": None,
- "counts": 0,
- "region_color": (255, 255, 255),
- "text_color": (0, 0, 0),
- }
- self.region_counts = {}
- self.counting_regions = []
- self.initialize_regions()
-
- def add_region(
- self,
- name: str,
- polygon_points: list[tuple],
- region_color: tuple[int, int, int],
- text_color: tuple[int, int, int],
- ) -> dict[str, Any]:
- """
- Add a new region to the counting list based on the provided template with specific attributes.
-
- Args:
- name (str): Name assigned to the new region.
- polygon_points (list[tuple]): List of (x, y) coordinates defining the region's polygon.
- region_color (tuple[int, int, int]): BGR color for region visualization.
- text_color (tuple[int, int, int]): BGR color for the text within the region.
-
- Returns:
- (dict[str, any]): Returns a dictionary including the region information i.e. name, region_color etc.
- """
- region = self.region_template.copy()
- region.update(
- {
- "name": name,
- "polygon": self.Polygon(polygon_points),
- "region_color": region_color,
- "text_color": text_color,
- }
- )
- self.counting_regions.append(region)
- return region
-
- def initialize_regions(self):
- """Initialize regions only once."""
- if self.region is None:
- self.initialize_region()
- if not isinstance(self.region, dict): # Ensure self.region is initialized and structured as a dictionary
- self.region = {"Region#01": self.region}
- for i, (name, pts) in enumerate(self.region.items()):
- region = self.add_region(name, pts, colors(i, True), (255, 255, 255))
- region["prepared_polygon"] = self.prep(region["polygon"])
-
- def process(self, im0: np.ndarray) -> SolutionResults:
- """
- Process the input frame to detect and count objects within each defined region.
-
- Args:
- im0 (np.ndarray): Input image frame where objects and regions are annotated.
-
- Returns:
- (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects),
- and 'region_counts' (dict, counts of objects per region).
- """
- self.extract_tracks(im0)
- annotator = SolutionAnnotator(im0, line_width=self.line_width)
-
- for box, cls, track_id, conf in zip(self.boxes, self.clss, self.track_ids, self.confs):
- annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))
- center = self.Point(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
- for region in self.counting_regions:
- if region["prepared_polygon"].contains(center):
- region["counts"] += 1
- self.region_counts[region["name"]] = region["counts"]
-
- # Display region counts
- for region in self.counting_regions:
- poly = region["polygon"]
- pts = list(map(tuple, np.array(poly.exterior.coords, dtype=np.int32)))
- (x1, y1), (x2, y2) = [(int(poly.centroid.x), int(poly.centroid.y))] * 2
- annotator.draw_region(pts, region["region_color"], self.line_width * 2)
- annotator.adaptive_label(
- [x1, y1, x2, y2],
- label=str(region["counts"]),
- color=region["region_color"],
- txt_color=region["text_color"],
- margin=self.line_width * 4,
- shape="rect",
- )
- region["counts"] = 0 # Reset for next frame
- plot_im = annotator.result()
- self.display_output(plot_im)
-
- return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), region_counts=self.region_counts)
diff --git a/ultralytics/solutions/security_alarm.py b/ultralytics/solutions/security_alarm.py
deleted file mode 100644
index d34f78d..0000000
--- a/ultralytics/solutions/security_alarm.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from typing import Any
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils import LOGGER
-from ultralytics.utils.plotting import colors
-
-
-class SecurityAlarm(BaseSolution):
- """
- A class to manage security alarm functionalities for real-time monitoring.
-
- This class extends the BaseSolution class and provides features to monitor objects in a frame, send email
- notifications when specific thresholds are exceeded for total detections, and annotate the output frame for
- visualization.
-
- Attributes:
- email_sent (bool): Flag to track if an email has already been sent for the current event.
- records (int): Threshold for the number of detected objects to trigger an alert.
- server (smtplib.SMTP): SMTP server connection for sending email alerts.
- to_email (str): Recipient's email address for alerts.
- from_email (str): Sender's email address for alerts.
-
- Methods:
- authenticate: Set up email server authentication for sending alerts.
- send_email: Send an email notification with details and an image attachment.
- process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.
-
- Examples:
- >>> security = SecurityAlarm()
- >>> security.authenticate("abc@gmail.com", "1111222233334444", "xyz@gmail.com")
- >>> frame = cv2.imread("frame.jpg")
- >>> results = security.process(frame)
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """
- Initialize the SecurityAlarm class with parameters for real-time object monitoring.
-
- Args:
- **kwargs (Any): Additional keyword arguments passed to the parent class.
- """
- super().__init__(**kwargs)
- self.email_sent = False
- self.records = self.CFG["records"]
- self.server = None
- self.to_email = ""
- self.from_email = ""
-
- def authenticate(self, from_email: str, password: str, to_email: str) -> None:
- """
- Authenticate the email server for sending alert notifications.
-
- Args:
- from_email (str): Sender's email address.
- password (str): Password for the sender's email account.
- to_email (str): Recipient's email address.
-
- This method initializes a secure connection with the SMTP server and logs in using the provided credentials.
-
- Examples:
- >>> alarm = SecurityAlarm()
- >>> alarm.authenticate("sender@example.com", "password123", "recipient@example.com")
- """
- import smtplib
-
- self.server = smtplib.SMTP("smtp.gmail.com: 587")
- self.server.starttls()
- self.server.login(from_email, password)
- self.to_email = to_email
- self.from_email = from_email
-
- def send_email(self, im0, records: int = 5) -> None:
- """
- Send an email notification with an image attachment indicating the number of objects detected.
-
- Args:
- im0 (np.ndarray): The input image or frame to be attached to the email.
- records (int, optional): The number of detected objects to be included in the email message.
-
- This method encodes the input image, composes the email message with details about the detection, and sends it
- to the specified recipient.
-
- Examples:
- >>> alarm = SecurityAlarm()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> alarm.send_email(frame, records=10)
- """
- from email.mime.image import MIMEImage
- from email.mime.multipart import MIMEMultipart
- from email.mime.text import MIMEText
-
- import cv2
-
- img_bytes = cv2.imencode(".jpg", im0)[1].tobytes() # Encode the image as JPEG
-
- # Create the email
- message = MIMEMultipart()
- message["From"] = self.from_email
- message["To"] = self.to_email
- message["Subject"] = "Security Alert"
-
- # Add the text message body
- message_body = f"Ultralytics ALERT!!! {records} objects have been detected!!"
- message.attach(MIMEText(message_body))
-
- # Attach the image
- image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg")
- message.attach(image_attachment)
-
- # Send the email
- try:
- self.server.send_message(message)
- LOGGER.info("Email sent successfully!")
- except Exception as e:
- LOGGER.error(f"Failed to send email: {e}")
-
- def process(self, im0) -> SolutionResults:
- """
- Monitor the frame, process object detections, and trigger alerts if thresholds are exceeded.
-
- Args:
- im0 (np.ndarray): The input image or frame to be processed and annotated.
-
- Returns:
- (SolutionResults): Contains processed image `plot_im`, 'total_tracks' (total number of tracked objects) and
- 'email_sent' (whether an email alert was triggered).
-
- This method processes the input frame, extracts detections, annotates the frame with bounding boxes, and sends
- an email notification if the number of detected objects surpasses the specified threshold and an alert has not
- already been sent.
-
- Examples:
- >>> alarm = SecurityAlarm()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> results = alarm.process(frame)
- """
- self.extract_tracks(im0) # Extract tracks
- annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
-
- # Iterate over bounding boxes and classes index
- for box, cls in zip(self.boxes, self.clss):
- # Draw bounding box
- annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
-
- total_det = len(self.clss)
- if total_det >= self.records and not self.email_sent: # Only send email if not sent before
- self.send_email(im0, total_det)
- self.email_sent = True
-
- plot_im = annotator.result()
- self.display_output(plot_im) # Display output with base class function
-
- # Return a SolutionResults
- return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), email_sent=self.email_sent)
diff --git a/ultralytics/solutions/similarity_search.py b/ultralytics/solutions/similarity_search.py
deleted file mode 100644
index 37a13ec..0000000
--- a/ultralytics/solutions/similarity_search.py
+++ /dev/null
@@ -1,224 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import os
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-from PIL import Image
-
-from ultralytics.data.utils import IMG_FORMATS
-from ultralytics.utils import LOGGER, TORCH_VERSION
-from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.torch_utils import TORCH_2_4, select_device
-
-os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Avoid OpenMP conflict on some systems
-
-
-class VisualAISearch:
- """
- A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and
- FAISS for fast similarity-based retrieval.
-
- This class aligns image and text embeddings in a shared semantic space, enabling users to search large collections
- of images using natural language queries with high accuracy and speed.
-
- Attributes:
- data (str): Directory containing images.
- device (str): Computation device, e.g., 'cpu' or 'cuda'.
- faiss_index (str): Path to the FAISS index file.
- data_path_npy (str): Path to the numpy file storing image paths.
- data_dir (Path): Path object for the data directory.
- model: Loaded CLIP model.
- index: FAISS index for similarity search.
- image_paths (list[str]): List of image file paths.
-
- Methods:
- extract_image_feature: Extract CLIP embedding from an image.
- extract_text_feature: Extract CLIP embedding from text.
- load_or_build_index: Load existing FAISS index or build new one.
- search: Perform semantic search for similar images.
-
- Examples:
- Initialize and search for images
- >>> searcher = VisualAISearch(data="path/to/images", device="cuda")
- >>> results = searcher.search("a cat sitting on a chair", k=10)
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """Initialize the VisualAISearch class with FAISS index and CLIP model."""
- assert TORCH_2_4, f"VisualAISearch requires torch>=2.4 (found torch=={TORCH_VERSION})"
- from ultralytics.nn.text_model import build_text_model
-
- check_requirements("faiss-cpu")
-
- self.faiss = __import__("faiss")
- self.faiss_index = "faiss.index"
- self.data_path_npy = "paths.npy"
- self.data_dir = Path(kwargs.get("data", "images"))
- self.device = select_device(kwargs.get("device", "cpu"))
-
- if not self.data_dir.exists():
- from ultralytics.utils import ASSETS_URL
-
- LOGGER.warning(f"{self.data_dir} not found. Downloading images.zip from {ASSETS_URL}/images.zip")
- from ultralytics.utils.downloads import safe_download
-
- safe_download(url=f"{ASSETS_URL}/images.zip", unzip=True, retry=3)
- self.data_dir = Path("images")
-
- self.model = build_text_model("clip:ViT-B/32", device=self.device)
-
- self.index = None
- self.image_paths = []
-
- self.load_or_build_index()
-
- def extract_image_feature(self, path: Path) -> np.ndarray:
- """Extract CLIP image embedding from the given image path."""
- return self.model.encode_image(Image.open(path)).cpu().numpy()
-
- def extract_text_feature(self, text: str) -> np.ndarray:
- """Extract CLIP text embedding from the given text query."""
- return self.model.encode_text(self.model.tokenize([text])).cpu().numpy()
-
- def load_or_build_index(self) -> None:
- """
- Load existing FAISS index or build a new one from image features.
-
- Checks if FAISS index and image paths exist on disk. If found, loads them directly. Otherwise, builds a new
- index by extracting features from all images in the data directory, normalizes the features, and saves both the
- index and image paths for future use.
- """
- # Check if the FAISS index and corresponding image paths already exist
- if Path(self.faiss_index).exists() and Path(self.data_path_npy).exists():
- LOGGER.info("Loading existing FAISS index...")
- self.index = self.faiss.read_index(self.faiss_index) # Load the FAISS index from disk
- self.image_paths = np.load(self.data_path_npy) # Load the saved image path list
- return # Exit the function as the index is successfully loaded
-
- # If the index doesn't exist, start building it from scratch
- LOGGER.info("Building FAISS index from images...")
- vectors = [] # List to store feature vectors of images
-
- # Iterate over all image files in the data directory
- for file in self.data_dir.iterdir():
- # Skip files that are not valid image formats
- if file.suffix.lower().lstrip(".") not in IMG_FORMATS:
- continue
- try:
- # Extract feature vector for the image and add to the list
- vectors.append(self.extract_image_feature(file))
- self.image_paths.append(file.name) # Store the corresponding image name
- except Exception as e:
- LOGGER.warning(f"Skipping {file.name}: {e}")
-
- # If no vectors were successfully created, raise an error
- if not vectors:
- raise RuntimeError("No image embeddings could be generated.")
-
- vectors = np.vstack(vectors).astype("float32") # Stack all vectors into a NumPy array and convert to float32
- self.faiss.normalize_L2(vectors) # Normalize vectors to unit length for cosine similarity
-
- self.index = self.faiss.IndexFlatIP(vectors.shape[1]) # Create a new FAISS index using inner product
- self.index.add(vectors) # Add the normalized vectors to the FAISS index
- self.faiss.write_index(self.index, self.faiss_index) # Save the newly built FAISS index to disk
- np.save(self.data_path_npy, np.array(self.image_paths)) # Save the list of image paths to disk
-
- LOGGER.info(f"Indexed {len(self.image_paths)} images.")
-
- def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) -> list[str]:
- """
- Return top-k semantically similar images to the given query.
-
- Args:
- query (str): Natural language text query to search for.
- k (int, optional): Maximum number of results to return.
- similarity_thresh (float, optional): Minimum similarity threshold for filtering results.
-
- Returns:
- (list[str]): List of image filenames ranked by similarity score.
-
- Examples:
- Search for images matching a query
- >>> searcher = VisualAISearch(data="images")
- >>> results = searcher.search("red car", k=5, similarity_thresh=0.2)
- """
- text_feat = self.extract_text_feature(query).astype("float32")
- self.faiss.normalize_L2(text_feat)
-
- D, index = self.index.search(text_feat, k)
- results = [
- (self.image_paths[i], float(D[0][idx])) for idx, i in enumerate(index[0]) if D[0][idx] >= similarity_thresh
- ]
- results.sort(key=lambda x: x[1], reverse=True)
-
- LOGGER.info("\nRanked Results:")
- for name, score in results:
- LOGGER.info(f" - {name} | Similarity: {score:.4f}")
-
- return [r[0] for r in results]
-
- def __call__(self, query: str) -> list[str]:
- """Direct call interface for the search function."""
- return self.search(query)
-
-
-class SearchApp:
- """
- A Flask-based web interface for semantic image search with natural language queries.
-
- This class provides a clean, responsive frontend that enables users to input natural language queries and
- instantly view the most relevant images retrieved from the indexed database.
-
- Attributes:
- render_template: Flask template rendering function.
- request: Flask request object.
- searcher (VisualAISearch): Instance of the VisualAISearch class.
- app (Flask): Flask application instance.
-
- Methods:
- index: Process user queries and display search results.
- run: Start the Flask web application.
-
- Examples:
- Start a search application
- >>> app = SearchApp(data="path/to/images", device="cuda")
- >>> app.run(debug=True)
- """
-
- def __init__(self, data: str = "images", device: str = None) -> None:
- """
- Initialize the SearchApp with VisualAISearch backend.
-
- Args:
- data (str, optional): Path to directory containing images to index and search.
- device (str, optional): Device to run inference on (e.g. 'cpu', 'cuda').
- """
- check_requirements("flask>=3.0.1")
- from flask import Flask, render_template, request
-
- self.render_template = render_template
- self.request = request
- self.searcher = VisualAISearch(data=data, device=device)
- self.app = Flask(
- __name__,
- template_folder="templates",
- static_folder=Path(data).resolve(), # Absolute path to serve images
- static_url_path="/images", # URL prefix for images
- )
- self.app.add_url_rule("/", view_func=self.index, methods=["GET", "POST"])
-
- def index(self) -> str:
- """Process user query and display search results in the web interface."""
- results = []
- if self.request.method == "POST":
- query = self.request.form.get("query", "").strip()
- results = self.searcher(query)
- return self.render_template("similarity-search.html", results=results)
-
- def run(self, debug: bool = False) -> None:
- """Start the Flask web application server."""
- self.app.run(debug=debug)
diff --git a/ultralytics/solutions/solutions.py b/ultralytics/solutions/solutions.py
deleted file mode 100644
index b7da9e6..0000000
--- a/ultralytics/solutions/solutions.py
+++ /dev/null
@@ -1,827 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import math
-from collections import Counter, defaultdict
-from functools import lru_cache
-from typing import Any
-
-import cv2
-import numpy as np
-
-from ultralytics import YOLO
-from ultralytics.solutions.config import SolutionConfig
-from ultralytics.utils import ASSETS_URL, LOGGER, ops
-from ultralytics.utils.checks import check_imshow, check_requirements
-from ultralytics.utils.plotting import Annotator
-
-
-class BaseSolution:
- """
- A base class for managing Ultralytics Solutions.
-
- This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
- and region initialization. It serves as the foundation for implementing specific computer vision solutions such as
- object counting, pose estimation, and analytics.
-
- Attributes:
- LineString: Class for creating line string geometries from shapely.
- Polygon: Class for creating polygon geometries from shapely.
- Point: Class for creating point geometries from shapely.
- prep: Prepared geometry function from shapely for optimized spatial operations.
- CFG (dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.
- LOGGER: Logger instance for solution-specific logging.
- annotator: Annotator instance for drawing on images.
- tracks: YOLO tracking results from the latest inference.
- track_data: Extracted tracking data (boxes or OBB) from tracks.
- boxes (list): Bounding box coordinates from tracking results.
- clss (list[int]): Class indices from tracking results.
- track_ids (list[int]): Track IDs from tracking results.
- confs (list[float]): Confidence scores from tracking results.
- track_line: Current track line for storing tracking history.
- masks: Segmentation masks from tracking results.
- r_s: Region or line geometry object for spatial operations.
- frame_no (int): Current frame number for logging purposes.
- region (list[tuple[int, int]]): List of coordinate tuples defining region of interest.
- line_width (int): Width of lines used in visualizations.
- model (YOLO): Loaded YOLO model instance.
- names (dict[int, str]): Dictionary mapping class indices to class names.
- classes (list[int]): List of class indices to track.
- show_conf (bool): Flag to show confidence scores in annotations.
- show_labels (bool): Flag to show class labels in annotations.
- device (str): Device for model inference.
- track_add_args (dict[str, Any]): Additional arguments for tracking configuration.
- env_check (bool): Flag indicating whether environment supports image display.
- track_history (defaultdict): Dictionary storing tracking history for each object.
- profilers (tuple): Profiler instances for performance monitoring.
-
- Methods:
- adjust_box_label: Generate formatted label for bounding box.
- extract_tracks: Apply object tracking and extract tracks from input image.
- store_tracking_history: Store object tracking history for given track ID and bounding box.
- initialize_region: Initialize counting region and line segment based on configuration.
- display_output: Display processing results including frames or saved results.
- process: Process method to be implemented by each Solution subclass.
-
- Examples:
- >>> solution = BaseSolution(model="yolo11n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
- >>> solution.initialize_region()
- >>> image = cv2.imread("image.jpg")
- >>> solution.extract_tracks(image)
- >>> solution.display_output(image)
- """
-
- def __init__(self, is_cli: bool = False, **kwargs: Any) -> None:
- """
- Initialize the BaseSolution class with configuration settings and YOLO model.
-
- Args:
- is_cli (bool): Enable CLI mode if set to True.
- **kwargs (Any): Additional configuration parameters that override defaults.
- """
- self.CFG = vars(SolutionConfig().update(**kwargs))
- self.LOGGER = LOGGER # Store logger object to be used in multiple solution classes
-
- check_requirements("shapely>=2.0.0")
- from shapely.geometry import LineString, Point, Polygon
- from shapely.prepared import prep
-
- self.LineString = LineString
- self.Polygon = Polygon
- self.Point = Point
- self.prep = prep
- self.annotator = None # Initialize annotator
- self.tracks = None
- self.track_data = None
- self.boxes = []
- self.clss = []
- self.track_ids = []
- self.track_line = None
- self.masks = None
- self.r_s = None
- self.frame_no = -1 # Only for logging
-
- self.LOGGER.info(f"Ultralytics Solutions: ✅ {self.CFG}")
- self.region = self.CFG["region"] # Store region data for other classes usage
- self.line_width = self.CFG["line_width"]
-
- # Load Model and store additional information (classes, show_conf, show_label)
- if self.CFG["model"] is None:
- self.CFG["model"] = "yolo11n.pt"
- self.model = YOLO(self.CFG["model"])
- self.names = self.model.names
- self.classes = self.CFG["classes"]
- self.show_conf = self.CFG["show_conf"]
- self.show_labels = self.CFG["show_labels"]
- self.device = self.CFG["device"]
-
- self.track_add_args = { # Tracker additional arguments for advance configuration
- k: self.CFG[k] for k in {"iou", "conf", "device", "max_det", "half", "tracker"}
- } # verbose must be passed to track method; setting it False in YOLO still logs the track information.
-
- if is_cli and self.CFG["source"] is None:
- d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
- self.LOGGER.warning(f"source not provided. using default source {ASSETS_URL}/{d_s}")
- from ultralytics.utils.downloads import safe_download
-
- safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
- self.CFG["source"] = d_s # set default source
-
- # Initialize environment and region setup
- self.env_check = check_imshow(warn=True)
- self.track_history = defaultdict(list)
-
- self.profilers = (
- ops.Profile(device=self.device), # track
- ops.Profile(device=self.device), # solution
- )
-
- def adjust_box_label(self, cls: int, conf: float, track_id: int | None = None) -> str | None:
- """
- Generate a formatted label for a bounding box.
-
- This method constructs a label string for a bounding box using the class index and confidence score.
- Optionally includes the track ID if provided. The label format adapts based on the display settings
- defined in `self.show_conf` and `self.show_labels`.
-
- Args:
- cls (int): The class index of the detected object.
- conf (float): The confidence score of the detection.
- track_id (int, optional): The unique identifier for the tracked object.
-
- Returns:
- (str | None): The formatted label string if `self.show_labels` is True; otherwise, None.
- """
- name = ("" if track_id is None else f"{track_id} ") + self.names[cls]
- return (f"{name} {conf:.2f}" if self.show_conf else name) if self.show_labels else None
-
- def extract_tracks(self, im0: np.ndarray) -> None:
- """
- Apply object tracking and extract tracks from an input image or frame.
-
- Args:
- im0 (np.ndarray): The input image or frame.
-
- Examples:
- >>> solution = BaseSolution()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> solution.extract_tracks(frame)
- """
- with self.profilers[0]:
- self.tracks = self.model.track(
- source=im0, persist=True, classes=self.classes, verbose=False, **self.track_add_args
- )[0]
- is_obb = self.tracks.obb is not None
- self.track_data = self.tracks.obb if is_obb else self.tracks.boxes # Extract tracks for OBB or object detection
-
- if self.track_data and self.track_data.is_track:
- self.boxes = (self.track_data.xyxyxyxy if is_obb else self.track_data.xyxy).cpu()
- self.clss = self.track_data.cls.cpu().tolist()
- self.track_ids = self.track_data.id.int().cpu().tolist()
- self.confs = self.track_data.conf.cpu().tolist()
- else:
- self.LOGGER.warning("no tracks found!")
- self.boxes, self.clss, self.track_ids, self.confs = [], [], [], []
-
- def store_tracking_history(self, track_id: int, box) -> None:
- """
- Store the tracking history of an object.
-
- This method updates the tracking history for a given object by appending the center point of its
- bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
-
- Args:
- track_id (int): The unique identifier for the tracked object.
- box (list[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
-
- Examples:
- >>> solution = BaseSolution()
- >>> solution.store_tracking_history(1, [100, 200, 300, 400])
- """
- # Store tracking history
- self.track_line = self.track_history[track_id]
- self.track_line.append(tuple(box.mean(dim=0)) if box.numel() > 4 else (box[:4:2].mean(), box[1:4:2].mean()))
- if len(self.track_line) > 30:
- self.track_line.pop(0)
-
- def initialize_region(self) -> None:
- """Initialize the counting region and line segment based on configuration settings."""
- if self.region is None:
- self.region = [(10, 200), (540, 200), (540, 180), (10, 180)]
- self.r_s = (
- self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
- ) # region or line
-
- def display_output(self, plot_im: np.ndarray) -> None:
- """
- Display the results of the processing, which could involve showing frames, printing counts, or saving results.
-
- This method is responsible for visualizing the output of the object detection and tracking process. It displays
- the processed frame with annotations, and allows for user interaction to close the display.
-
- Args:
- plot_im (np.ndarray): The image or frame that has been processed and annotated.
-
- Examples:
- >>> solution = BaseSolution()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> solution.display_output(frame)
-
- Notes:
- - This method will only display output if the 'show' configuration is set to True and the environment
- supports image display.
- - The display can be closed by pressing the 'q' key.
- """
- if self.CFG.get("show") and self.env_check:
- cv2.imshow("Ultralytics Solutions", plot_im)
- if cv2.waitKey(1) & 0xFF == ord("q"):
- cv2.destroyAllWindows() # Closes current frame window
- return
-
- def process(self, *args: Any, **kwargs: Any):
- """Process method should be implemented by each Solution subclass."""
-
- def __call__(self, *args: Any, **kwargs: Any):
- """Allow instances to be called like a function with flexible arguments."""
- with self.profilers[1]:
- result = self.process(*args, **kwargs) # Call the subclass-specific process method
- track_or_predict = "predict" if type(self).__name__ == "ObjectCropper" else "track"
- track_or_predict_speed = self.profilers[0].dt * 1e3
- solution_speed = (self.profilers[1].dt - self.profilers[0].dt) * 1e3 # solution time = process - track
- result.speed = {track_or_predict: track_or_predict_speed, "solution": solution_speed}
- if self.CFG["verbose"]:
- self.frame_no += 1
- counts = Counter(self.clss) # Only for logging.
- LOGGER.info(
- f"{self.frame_no}: {result.plot_im.shape[0]}x{result.plot_im.shape[1]} {solution_speed:.1f}ms,"
- f" {', '.join([f'{v} {self.names[k]}' for k, v in counts.items()])}\n"
- f"Speed: {track_or_predict_speed:.1f}ms {track_or_predict}, "
- f"{solution_speed:.1f}ms solution per image at shape "
- f"(1, {getattr(self.model, 'ch', 3)}, {result.plot_im.shape[0]}, {result.plot_im.shape[1]})\n"
- )
- return result
-
-
-class SolutionAnnotator(Annotator):
- """
- A specialized annotator class for visualizing and analyzing computer vision tasks.
-
- This class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking
- trails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for
- various computer vision applications including object detection, tracking, pose estimation, and analytics.
-
- Attributes:
- im (np.ndarray): The image being annotated.
- line_width (int): Thickness of lines used in annotations.
- font_size (int): Size of the font used for text annotations.
- font (str): Path to the font file used for text rendering.
- pil (bool): Whether to use PIL for text rendering.
- example (str): An example attribute for demonstration purposes.
-
- Methods:
- draw_region: Draw a region using specified points, colors, and thickness.
- queue_counts_display: Display queue counts in the specified region.
- display_analytics: Display overall statistics for parking lot management.
- estimate_pose_angle: Calculate the angle between three points in an object pose.
- draw_specific_kpts: Draw specific keypoints on the image.
- plot_workout_information: Draw a labeled text box on the image.
- plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.
- plot_distance_and_line: Display the distance between centroids and connect them with a line.
- display_objects_labels: Annotate bounding boxes with object class labels.
- sweep_annotator: Visualize a vertical sweep line and optional label.
- visioneye: Map and connect object centroids to a visual "eye" point.
- adaptive_label: Draw a circular or rectangle background shape label in center of a bounding box.
-
- Examples:
- >>> annotator = SolutionAnnotator(image)
- >>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)
- >>> annotator.display_analytics(
- ... image, text={"Available Spots": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10
- ... )
- """
-
- def __init__(
- self,
- im: np.ndarray,
- line_width: int | None = None,
- font_size: int | None = None,
- font: str = "Arial.ttf",
- pil: bool = False,
- example: str = "abc",
- ):
- """
- Initialize the SolutionAnnotator class with an image for annotation.
-
- Args:
- im (np.ndarray): The image to be annotated.
- line_width (int, optional): Line thickness for drawing on the image.
- font_size (int, optional): Font size for text annotations.
- font (str): Path to the font file.
- pil (bool): Indicates whether to use PIL for rendering text.
- example (str): An example parameter for demonstration purposes.
- """
- super().__init__(im, line_width, font_size, font, pil, example)
-
- def draw_region(
- self,
- reg_pts: list[tuple[int, int]] | None = None,
- color: tuple[int, int, int] = (0, 255, 0),
- thickness: int = 5,
- ):
- """
- Draw a region or line on the image.
-
- Args:
- reg_pts (list[tuple[int, int]], optional): Region points (for line 2 points, for region 4+ points).
- color (tuple[int, int, int]): RGB color value for the region.
- thickness (int): Line thickness for drawing the region.
- """
- cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
-
- # Draw small circles at the corner points
- for point in reg_pts:
- cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
-
- def queue_counts_display(
- self,
- label: str,
- points: list[tuple[int, int]] | None = None,
- region_color: tuple[int, int, int] = (255, 255, 255),
- txt_color: tuple[int, int, int] = (0, 0, 0),
- ):
- """
- Display queue counts on an image centered at the points with customizable font size and colors.
-
- Args:
- label (str): Queue counts label.
- points (list[tuple[int, int]], optional): Region points for center point calculation to display text.
- region_color (tuple[int, int, int]): RGB queue region color.
- txt_color (tuple[int, int, int]): RGB text display color.
- """
- x_values = [point[0] for point in points]
- y_values = [point[1] for point in points]
- center_x = sum(x_values) // len(points)
- center_y = sum(y_values) // len(points)
-
- text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
- text_width = text_size[0]
- text_height = text_size[1]
-
- rect_width = text_width + 20
- rect_height = text_height + 20
- rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
- rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
- cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
-
- text_x = center_x - text_width // 2
- text_y = center_y + text_height // 2
-
- # Draw text
- cv2.putText(
- self.im,
- label,
- (text_x, text_y),
- 0,
- fontScale=self.sf,
- color=txt_color,
- thickness=self.tf,
- lineType=cv2.LINE_AA,
- )
-
- def display_analytics(
- self,
- im0: np.ndarray,
- text: dict[str, Any],
- txt_color: tuple[int, int, int],
- bg_color: tuple[int, int, int],
- margin: int,
- ):
- """
- Display the overall statistics for parking lots, object counter etc.
-
- Args:
- im0 (np.ndarray): Inference image.
- text (dict[str, Any]): Labels dictionary.
- txt_color (tuple[int, int, int]): Display color for text foreground.
- bg_color (tuple[int, int, int]): Display color for text background.
- margin (int): Gap between text and rectangle for better display.
- """
- horizontal_gap = int(im0.shape[1] * 0.02)
- vertical_gap = int(im0.shape[0] * 0.01)
- text_y_offset = 0
- for label, value in text.items():
- txt = f"{label}: {value}"
- text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
- if text_size[0] < 5 or text_size[1] < 5:
- text_size = (5, 5)
- text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
- text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
- rect_x1 = text_x - margin * 2
- rect_y1 = text_y - text_size[1] - margin * 2
- rect_x2 = text_x + text_size[0] + margin * 2
- rect_y2 = text_y + margin * 2
- cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
- cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
- text_y_offset = rect_y2
-
- @staticmethod
- @lru_cache(maxsize=256)
- def estimate_pose_angle(a: list[float], b: list[float], c: list[float]) -> float:
- """
- Calculate the angle between three points for workout monitoring.
-
- Args:
- a (list[float]): The coordinates of the first point.
- b (list[float]): The coordinates of the second point (vertex).
- c (list[float]): The coordinates of the third point.
-
- Returns:
- (float): The angle in degrees between the three points.
- """
- radians = math.atan2(c[1] - b[1], c[0] - b[0]) - math.atan2(a[1] - b[1], a[0] - b[0])
- angle = abs(radians * 180.0 / math.pi)
- return angle if angle <= 180.0 else (360 - angle)
-
- def draw_specific_kpts(
- self,
- keypoints: list[list[float]],
- indices: list[int] | None = None,
- radius: int = 2,
- conf_thresh: float = 0.25,
- ) -> np.ndarray:
- """
- Draw specific keypoints for gym steps counting.
-
- Args:
- keypoints (list[list[float]]): Keypoints data to be plotted, each in format [x, y, confidence].
- indices (list[int], optional): Keypoint indices to be plotted.
- radius (int): Keypoint radius.
- conf_thresh (float): Confidence threshold for keypoints.
-
- Returns:
- (np.ndarray): Image with drawn keypoints.
-
- Notes:
- Keypoint format: [x, y] or [x, y, confidence].
- Modifies self.im in-place.
- """
- indices = indices or [2, 5, 7]
- points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thresh]
-
- # Draw lines between consecutive points
- for start, end in zip(points[:-1], points[1:]):
- cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
-
- # Draw circles for keypoints
- for pt in points:
- cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
-
- return self.im
-
- def plot_workout_information(
- self,
- display_text: str,
- position: tuple[int, int],
- color: tuple[int, int, int] = (104, 31, 17),
- txt_color: tuple[int, int, int] = (255, 255, 255),
- ) -> int:
- """
- Draw workout text with a background on the image.
-
- Args:
- display_text (str): The text to be displayed.
- position (tuple[int, int]): Coordinates (x, y) on the image where the text will be placed.
- color (tuple[int, int, int]): Text background color.
- txt_color (tuple[int, int, int]): Text foreground color.
-
- Returns:
- (int): The height of the text.
- """
- (text_width, text_height), _ = cv2.getTextSize(display_text, 0, fontScale=self.sf, thickness=self.tf)
-
- # Draw background rectangle
- cv2.rectangle(
- self.im,
- (position[0], position[1] - text_height - 5),
- (position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
- color,
- -1,
- )
- # Draw text
- cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
-
- return text_height
-
- def plot_angle_and_count_and_stage(
- self,
- angle_text: str,
- count_text: str,
- stage_text: str,
- center_kpt: list[int],
- color: tuple[int, int, int] = (104, 31, 17),
- txt_color: tuple[int, int, int] = (255, 255, 255),
- ):
- """
- Plot the pose angle, count value, and step stage for workout monitoring.
-
- Args:
- angle_text (str): Angle value for workout monitoring.
- count_text (str): Counts value for workout monitoring.
- stage_text (str): Stage decision for workout monitoring.
- center_kpt (list[int]): Centroid pose index for workout monitoring.
- color (tuple[int, int, int]): Text background color.
- txt_color (tuple[int, int, int]): Text foreground color.
- """
- # Format text
- angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
-
- # Draw angle, count and stage text
- angle_height = self.plot_workout_information(
- angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
- )
- count_height = self.plot_workout_information(
- count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
- )
- self.plot_workout_information(
- stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
- )
-
- def plot_distance_and_line(
- self,
- pixels_distance: float,
- centroids: list[tuple[int, int]],
- line_color: tuple[int, int, int] = (104, 31, 17),
- centroid_color: tuple[int, int, int] = (255, 0, 255),
- ):
- """
- Plot the distance and line between two centroids on the frame.
-
- Args:
- pixels_distance (float): Pixels distance between two bbox centroids.
- centroids (list[tuple[int, int]]): Bounding box centroids data.
- line_color (tuple[int, int, int]): Distance line color.
- centroid_color (tuple[int, int, int]): Bounding box centroid color.
- """
- # Get the text size
- text = f"Pixels Distance: {pixels_distance:.2f}"
- (text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
-
- # Define corners with 10-pixel margin and draw rectangle
- cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
-
- # Calculate the position for the text with a 10-pixel margin and draw text
- text_position = (25, 25 + text_height_m + 10)
- cv2.putText(
- self.im,
- text,
- text_position,
- 0,
- self.sf,
- (255, 255, 255),
- self.tf,
- cv2.LINE_AA,
- )
-
- cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
- cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
- cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
-
- def display_objects_labels(
- self,
- im0: np.ndarray,
- text: str,
- txt_color: tuple[int, int, int],
- bg_color: tuple[int, int, int],
- x_center: float,
- y_center: float,
- margin: int,
- ):
- """
- Display the bounding boxes labels in parking management app.
-
- Args:
- im0 (np.ndarray): Inference image.
- text (str): Object/class name.
- txt_color (tuple[int, int, int]): Display color for text foreground.
- bg_color (tuple[int, int, int]): Display color for text background.
- x_center (float): The x position center point for bounding box.
- y_center (float): The y position center point for bounding box.
- margin (int): The gap between text and rectangle for better display.
- """
- text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
- text_x = x_center - text_size[0] // 2
- text_y = y_center + text_size[1] // 2
-
- rect_x1 = text_x - margin
- rect_y1 = text_y - text_size[1] - margin
- rect_x2 = text_x + text_size[0] + margin
- rect_y2 = text_y + margin
- cv2.rectangle(
- im0,
- (int(rect_x1), int(rect_y1)),
- (int(rect_x2), int(rect_y2)),
- tuple(map(int, bg_color)), # Ensure color values are int
- -1,
- )
-
- cv2.putText(
- im0,
- text,
- (int(text_x), int(text_y)),
- 0,
- self.sf,
- tuple(map(int, txt_color)), # Ensure color values are int
- self.tf,
- lineType=cv2.LINE_AA,
- )
-
- def sweep_annotator(
- self,
- line_x: int = 0,
- line_y: int = 0,
- label: str | None = None,
- color: tuple[int, int, int] = (221, 0, 186),
- txt_color: tuple[int, int, int] = (255, 255, 255),
- ):
- """
- Draw a sweep annotation line and an optional label.
-
- Args:
- line_x (int): The x-coordinate of the sweep line.
- line_y (int): The y-coordinate limit of the sweep line.
- label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
- color (tuple[int, int, int]): RGB color for the line and label background.
- txt_color (tuple[int, int, int]): RGB color for the label text.
- """
- # Draw the sweep line
- cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
-
- # Draw label, if provided
- if label:
- (text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
- cv2.rectangle(
- self.im,
- (line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
- (line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
- color,
- -1,
- )
- cv2.putText(
- self.im,
- label,
- (line_x - text_width // 2, line_y // 2 + text_height // 2),
- cv2.FONT_HERSHEY_SIMPLEX,
- self.sf,
- txt_color,
- self.tf,
- )
-
- def visioneye(
- self,
- box: list[float],
- center_point: tuple[int, int],
- color: tuple[int, int, int] = (235, 219, 11),
- pin_color: tuple[int, int, int] = (255, 0, 255),
- ):
- """
- Perform pinpoint human-vision eye mapping and plotting.
-
- Args:
- box (list[float]): Bounding box coordinates in format [x1, y1, x2, y2].
- center_point (tuple[int, int]): Center point for vision eye view.
- color (tuple[int, int, int]): Object centroid and line color.
- pin_color (tuple[int, int, int]): Visioneye point color.
- """
- center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
- cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
- cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
- cv2.line(self.im, center_point, center_bbox, color, self.tf)
-
- def adaptive_label(
- self,
- box: tuple[float, float, float, float],
- label: str = "",
- color: tuple[int, int, int] = (128, 128, 128),
- txt_color: tuple[int, int, int] = (255, 255, 255),
- shape: str = "rect",
- margin: int = 5,
- ):
- """
- Draw a label with a background rectangle or circle centered within a given bounding box.
-
- Args:
- box (tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
- label (str): The text label to be displayed.
- color (tuple[int, int, int]): The background color of the rectangle (B, G, R).
- txt_color (tuple[int, int, int]): The color of the text (R, G, B).
- shape (str): The shape of the label i.e "circle" or "rect"
- margin (int): The margin between the text and the rectangle border.
- """
- if shape == "circle" and len(label) > 3:
- LOGGER.warning(f"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.")
- label = label[:3]
-
- x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) # Calculate center of the bbox
- text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] # Get size of the text
- text_x, text_y = x_center - text_size[0] // 2, y_center + text_size[1] // 2 # Calculate top-left corner of text
-
- if shape == "circle":
- cv2.circle(
- self.im,
- (x_center, y_center),
- int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin, # Calculate the radius
- color,
- -1,
- )
- else:
- cv2.rectangle(
- self.im,
- (text_x - margin, text_y - text_size[1] - margin), # Calculate coordinates of the rectangle
- (text_x + text_size[0] + margin, text_y + margin), # Calculate coordinates of the rectangle
- color,
- -1,
- )
-
- # Draw the text on top of the rectangle
- cv2.putText(
- self.im,
- label,
- (text_x, text_y), # Calculate top-left corner of the text
- cv2.FONT_HERSHEY_SIMPLEX,
- self.sf - 0.15,
- self.get_txt_color(color, txt_color),
- self.tf,
- lineType=cv2.LINE_AA,
- )
-
-
-class SolutionResults:
- """
- A class to encapsulate the results of Ultralytics Solutions.
-
- This class is designed to store and manage various outputs generated by the solution pipeline, including counts,
- angles, workout stages, and other analytics data. It provides a structured way to access and manipulate results
- from different computer vision solutions such as object counting, pose estimation, and tracking analytics.
-
- Attributes:
- plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.
- in_count (int): The total number of "in" counts in a video stream.
- out_count (int): The total number of "out" counts in a video stream.
- classwise_count (dict[str, int]): A dictionary containing counts of objects categorized by class.
- queue_count (int): The count of objects in a queue or waiting area.
- workout_count (int): The count of workout repetitions.
- workout_angle (float): The angle calculated during a workout exercise.
- workout_stage (str): The current stage of the workout.
- pixels_distance (float): The calculated distance in pixels between two points or objects.
- available_slots (int): The number of available slots in a monitored area.
- filled_slots (int): The number of filled slots in a monitored area.
- email_sent (bool): A flag indicating whether an email notification was sent.
- total_tracks (int): The total number of tracked objects.
- region_counts (dict[str, int]): The count of objects within a specific region.
- speed_dict (dict[str, float]): A dictionary containing speed information for tracked objects.
- total_crop_objects (int): Total number of cropped objects using ObjectCropper class.
- speed (dict[str, float]): Performance timing information for tracking and solution processing.
- """
-
- def __init__(self, **kwargs):
- """
- Initialize a SolutionResults object with default or user-specified values.
-
- Args:
- **kwargs (Any): Optional arguments to override default attribute values.
- """
- self.plot_im = None
- self.in_count = 0
- self.out_count = 0
- self.classwise_count = {}
- self.queue_count = 0
- self.workout_count = 0
- self.workout_angle = 0.0
- self.workout_stage = None
- self.pixels_distance = 0.0
- self.available_slots = 0
- self.filled_slots = 0
- self.email_sent = False
- self.total_tracks = 0
- self.region_counts = {}
- self.speed_dict = {} # for speed estimation
- self.total_crop_objects = 0
- self.speed = {}
-
- # Override with user-defined values
- self.__dict__.update(kwargs)
-
- def __str__(self) -> str:
- """
- Return a formatted string representation of the SolutionResults object.
-
- Returns:
- (str): A string representation listing non-null attributes.
- """
- attrs = {
- k: v
- for k, v in self.__dict__.items()
- if k != "plot_im" and v not in [None, {}, 0, 0.0, False] # Exclude `plot_im` explicitly
- }
- return ", ".join(f"{k}={v}" for k, v in attrs.items())
diff --git a/ultralytics/solutions/speed_estimation.py b/ultralytics/solutions/speed_estimation.py
deleted file mode 100644
index 0da4223..0000000
--- a/ultralytics/solutions/speed_estimation.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from collections import deque
-from math import sqrt
-from typing import Any
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils.plotting import colors
-
-
-class SpeedEstimator(BaseSolution):
- """
- A class to estimate the speed of objects in a real-time video stream based on their tracks.
-
- This class extends the BaseSolution class and provides functionality for estimating object speeds using
- tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
- to real-world units using a configurable meters-per-pixel scale factor.
-
- Attributes:
- fps (float): Video frame rate for time calculations.
- frame_count (int): Global frame counter for tracking temporal information.
- trk_frame_ids (dict): Maps track IDs to their first frame index.
- spd (dict): Final speed per object in km/h once locked.
- trk_hist (dict): Maps track IDs to deque of position history.
- locked_ids (set): Track IDs whose speed has been finalized.
- max_hist (int): Required frame history before computing speed.
- meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
- max_speed (int): Maximum allowed object speed; values above this will be capped.
-
- Methods:
- process: Process input frames to estimate object speeds based on tracking data.
- store_tracking_history: Store the tracking history for an object.
- extract_tracks: Extract tracks from the current frame.
- display_output: Display the output with annotations.
-
- Examples:
- Initialize speed estimator and process a frame
- >>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
- >>> frame = cv2.imread("frame.jpg")
- >>> results = estimator.process(frame)
- >>> cv2.imshow("Speed Estimation", results.plot_im)
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """
- Initialize the SpeedEstimator object with speed estimation parameters and data structures.
-
- Args:
- **kwargs (Any): Additional keyword arguments passed to the parent class.
- """
- super().__init__(**kwargs)
-
- self.fps = self.CFG["fps"] # Video frame rate for time calculations
- self.frame_count = 0 # Global frame counter
- self.trk_frame_ids = {} # Track ID → first frame index
- self.spd = {} # Final speed per object (km/h), once locked
- self.trk_hist = {} # Track ID → deque of (time, position)
- self.locked_ids = set() # Track IDs whose speed has been finalized
- self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
- self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
- self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
-
- def process(self, im0) -> SolutionResults:
- """
- Process an input frame to estimate object speeds based on tracking data.
-
- Args:
- im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.
-
- Returns:
- (SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
-
- Examples:
- Process a frame for speed estimation
- >>> estimator = SpeedEstimator()
- >>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
- >>> results = estimator.process(image)
- """
- self.frame_count += 1
- self.extract_tracks(im0)
- annotator = SolutionAnnotator(im0, line_width=self.line_width)
-
- for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):
- self.store_tracking_history(track_id, box)
-
- if track_id not in self.trk_hist: # Initialize history if new track found
- self.trk_hist[track_id] = deque(maxlen=self.max_hist)
- self.trk_frame_ids[track_id] = self.frame_count
-
- if track_id not in self.locked_ids: # Update history until speed is locked
- trk_hist = self.trk_hist[track_id]
- trk_hist.append(self.track_line[-1])
-
- # Compute and lock speed once enough history is collected
- if len(trk_hist) == self.max_hist:
- p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
- dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
- if dt > 0:
- dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
- pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
- meters = pixel_distance * self.meter_per_pixel # Convert to meters
- self.spd[track_id] = int(
- min((meters / dt) * 3.6, self.max_speed)
- ) # Convert to km/h and store final speed
- self.locked_ids.add(track_id) # Prevent further updates
- self.trk_hist.pop(track_id, None) # Free memory
- self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
-
- if track_id in self.spd:
- speed_label = f"{self.spd[track_id]} km/h"
- annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
-
- plot_im = annotator.result()
- self.display_output(plot_im) # Display output with base class function
-
- # Return results with processed image and tracking summary
- return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
diff --git a/ultralytics/solutions/streamlit_inference.py b/ultralytics/solutions/streamlit_inference.py
deleted file mode 100644
index 44e2029..0000000
--- a/ultralytics/solutions/streamlit_inference.py
+++ /dev/null
@@ -1,262 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import io
-import os
-from typing import Any
-
-import cv2
-import torch
-
-from ultralytics import YOLO
-from ultralytics.utils import LOGGER
-from ultralytics.utils.checks import check_requirements
-from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
-
-torch.classes.__path__ = [] # Torch module __path__._path issue: https://github.com/datalab-to/marker/issues/442
-
-
-class Inference:
- """
- A class to perform object detection, image classification, image segmentation and pose estimation inference.
-
- This class provides functionalities for loading models, configuring settings, uploading video files, and performing
- real-time inference using Streamlit and Ultralytics YOLO models.
-
- Attributes:
- st (module): Streamlit module for UI creation.
- temp_dict (dict): Temporary dictionary to store the model path and other configuration.
- model_path (str): Path to the loaded model.
- model (YOLO): The YOLO model instance.
- source (str): Selected video source (webcam or video file).
- enable_trk (bool): Enable tracking option.
- conf (float): Confidence threshold for detection.
- iou (float): IoU threshold for non-maximum suppression.
- org_frame (Any): Container for the original frame to be displayed.
- ann_frame (Any): Container for the annotated frame to be displayed.
- vid_file_name (str | int): Name of the uploaded video file or webcam index.
- selected_ind (list[int]): List of selected class indices for detection.
-
- Methods:
- web_ui: Set up the Streamlit web interface with custom HTML elements.
- sidebar: Configure the Streamlit sidebar for model and inference settings.
- source_upload: Handle video file uploads through the Streamlit interface.
- configure: Configure the model and load selected classes for inference.
- inference: Perform real-time object detection inference.
-
- Examples:
- Create an Inference instance with a custom model
- >>> inf = Inference(model="path/to/model.pt")
- >>> inf.inference()
-
- Create an Inference instance with default settings
- >>> inf = Inference()
- >>> inf.inference()
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """
- Initialize the Inference class, checking Streamlit requirements and setting up the model path.
-
- Args:
- **kwargs (Any): Additional keyword arguments for model configuration.
- """
- check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
- import streamlit as st
-
- self.st = st # Reference to the Streamlit module
- self.source = None # Video source selection (webcam or video file)
- self.img_file_names = [] # List of image file names
- self.enable_trk = False # Flag to toggle object tracking
- self.conf = 0.25 # Confidence threshold for detection
- self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
- self.org_frame = None # Container for the original frame display
- self.ann_frame = None # Container for the annotated frame display
- self.vid_file_name = None # Video file name or webcam index
- self.selected_ind: list[int] = [] # List of selected class indices for detection
- self.model = None # YOLO model instance
-
- self.temp_dict = {"model": None, **kwargs}
- self.model_path = None # Model file path
- if self.temp_dict["model"] is not None:
- self.model_path = self.temp_dict["model"]
-
- LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
-
- def web_ui(self) -> None:
- """Set up the Streamlit web interface with custom HTML elements."""
- menu_style_cfg = """""" # Hide main menu style
-
- # Main title of streamlit application
- main_title_cfg = """Ultralytics YOLO Streamlit Application
"""
-
- # Subtitle of streamlit application
- sub_title_cfg = """Experience real-time object detection on your webcam, videos, and images
- with the power of Ultralytics YOLO! 🚀
"""
-
- # Set html page configuration and append custom HTML
- self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
- self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
- self.st.markdown(main_title_cfg, unsafe_allow_html=True)
- self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
-
- def sidebar(self) -> None:
- """Configure the Streamlit sidebar for model and inference settings."""
- with self.st.sidebar: # Add Ultralytics LOGO
- logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
- self.st.image(logo, width=250)
-
- self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
- self.source = self.st.sidebar.selectbox(
- "Source",
- ("webcam", "video", "image"),
- ) # Add source selection dropdown
- if self.source in ["webcam", "video"]:
- self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
- self.conf = float(
- self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
- ) # Slider for confidence
- self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
-
- if self.source != "image": # Only create columns for video/webcam
- col1, col2 = self.st.columns(2) # Create two columns for displaying frames
- self.org_frame = col1.empty() # Container for original frame
- self.ann_frame = col2.empty() # Container for annotated frame
-
- def source_upload(self) -> None:
- """Handle video file uploads through the Streamlit interface."""
- from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
-
- self.vid_file_name = ""
- if self.source == "video":
- vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
- if vid_file is not None:
- g = io.BytesIO(vid_file.read()) # BytesIO Object
- with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
- out.write(g.read()) # Read bytes into file
- self.vid_file_name = "ultralytics.mp4"
- elif self.source == "webcam":
- self.vid_file_name = 0 # Use webcam index 0
- elif self.source == "image":
- import tempfile # scope import
-
- if imgfiles := self.st.sidebar.file_uploader(
- "Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
- ):
- for imgfile in imgfiles: # Save each uploaded image to a temporary file
- with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
- tf.write(imgfile.read())
- self.img_file_names.append({"path": tf.name, "name": imgfile.name})
-
- def configure(self) -> None:
- """Configure the model and load selected classes for inference."""
- # Add dropdown menu for model selection
- M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"]
- available_models = sorted(
- [
- x.replace("yolo", "YOLO")
- for x in GITHUB_ASSETS_STEMS
- if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x
- ],
- key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
- )
- if self.model_path: # Insert user provided custom model in available_models
- available_models.insert(0, self.model_path)
- selected_model = self.st.sidebar.selectbox("Model", available_models)
-
- with self.st.spinner("Model is downloading..."):
- if selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine")) or any(
- fmt in selected_model for fmt in ("openvino_model", "rknn_model")
- ):
- model_path = selected_model
- else:
- model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
- self.model = YOLO(model_path) # Load the YOLO model
- class_names = list(self.model.names.values()) # Convert dictionary to list of class names
- self.st.success("Model loaded successfully!")
-
- # Multiselect box with class names and get indices of selected classes
- selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
- self.selected_ind = [class_names.index(option) for option in selected_classes]
-
- if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
- self.selected_ind = list(self.selected_ind)
-
- def image_inference(self) -> None:
- """Perform inference on uploaded images."""
- for img_info in self.img_file_names:
- img_path = img_info["path"]
- image = cv2.imread(img_path) # Load and display the original image
- if image is not None:
- self.st.markdown(f"#### Processed: {img_info['name']}")
- col1, col2 = self.st.columns(2)
- with col1:
- self.st.image(image, channels="BGR", caption="Original Image")
- results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
- annotated_image = results[0].plot()
- with col2:
- self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
- try: # Clean up temporary file
- os.unlink(img_path)
- except FileNotFoundError:
- pass # File doesn't exist, ignore
- else:
- self.st.error("Could not load the uploaded image.")
-
- def inference(self) -> None:
- """Perform real-time object detection inference on video or webcam feed."""
- self.web_ui() # Initialize the web interface
- self.sidebar() # Create the sidebar
- self.source_upload() # Upload the video source
- self.configure() # Configure the app
-
- if self.st.sidebar.button("Start"):
- if self.source == "image":
- if self.img_file_names:
- self.image_inference()
- else:
- self.st.info("Please upload an image file to perform inference.")
- return
-
- stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
- cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
- if not cap.isOpened():
- self.st.error("Could not open webcam or video source.")
- return
-
- while cap.isOpened():
- success, frame = cap.read()
- if not success:
- self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
- break
-
- # Process frame with model
- if self.enable_trk:
- results = self.model.track(
- frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
- )
- else:
- results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
-
- annotated_frame = results[0].plot() # Add annotations on frame
-
- if stop_button:
- cap.release() # Release the capture
- self.st.stop() # Stop streamlit app
-
- self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
- self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
-
- cap.release() # Release the capture
- cv2.destroyAllWindows() # Destroy all OpenCV windows
-
-
-if __name__ == "__main__":
- import sys # Import the sys module for accessing command-line arguments
-
- # Check if a model name is provided as a command-line argument
- args = len(sys.argv)
- model = sys.argv[1] if args > 1 else None # Assign first argument as the model name if provided
- # Create an instance of the Inference class and run inference
- Inference(model=model).inference()
diff --git a/ultralytics/solutions/templates/similarity-search.html b/ultralytics/solutions/templates/similarity-search.html
deleted file mode 100644
index 6a24179..0000000
--- a/ultralytics/solutions/templates/similarity-search.html
+++ /dev/null
@@ -1,167 +0,0 @@
-
-
-
-
-
-
-
-
- Semantic Image Search
-
-
-
-
-
-
-

-
- Semantic Image Search with AI
-
-
-
-
-
-
- {% for img in results %}
-
-
 }})
-
- {% endfor %}
-
-
-
diff --git a/ultralytics/solutions/trackzone.py b/ultralytics/solutions/trackzone.py
deleted file mode 100644
index 5505317..0000000
--- a/ultralytics/solutions/trackzone.py
+++ /dev/null
@@ -1,91 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from typing import Any
-
-import cv2
-import numpy as np
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils.plotting import colors
-
-
-class TrackZone(BaseSolution):
- """
- A class to manage region-based object tracking in a video stream.
-
- This class extends the BaseSolution class and provides functionality for tracking objects within a specific region
- defined by a polygonal area. Objects outside the region are excluded from tracking.
-
- Attributes:
- region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.
- line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.
- names (list[str]): List of class names that the model can detect.
- boxes (list[np.ndarray]): Bounding boxes of tracked objects.
- track_ids (list[int]): Unique identifiers for each tracked object.
- clss (list[int]): Class indices of tracked objects.
-
- Methods:
- process: Process each frame of the video, applying region-based tracking.
- extract_tracks: Extract tracking information from the input frame.
- display_output: Display the processed output.
-
- Examples:
- >>> tracker = TrackZone()
- >>> frame = cv2.imread("frame.jpg")
- >>> results = tracker.process(frame)
- >>> cv2.imshow("Tracked Frame", results.plot_im)
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """
- Initialize the TrackZone class for tracking objects within a defined region in video streams.
-
- Args:
- **kwargs (Any): Additional keyword arguments passed to the parent class.
- """
- super().__init__(**kwargs)
- default_region = [(75, 75), (565, 75), (565, 285), (75, 285)]
- self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32))
- self.mask = None
-
- def process(self, im0: np.ndarray) -> SolutionResults:
- """
- Process the input frame to track objects within a defined region.
-
- This method initializes the annotator, creates a mask for the specified region, extracts tracks
- only from the masked area, and updates tracking information. Objects outside the region are ignored.
-
- Args:
- im0 (np.ndarray): The input image or frame to be processed.
-
- Returns:
- (SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the
- total number of tracked objects within the defined region.
-
- Examples:
- >>> tracker = TrackZone()
- >>> frame = cv2.imread("path/to/image.jpg")
- >>> results = tracker.process(frame)
- """
- annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
-
- if self.mask is None: # Create a mask for the region
- self.mask = np.zeros_like(im0[:, :, 0])
- cv2.fillPoly(self.mask, [self.region], 255)
- masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask)
- self.extract_tracks(masked_frame)
-
- # Draw the region boundary
- cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2)
-
- # Iterate over boxes, track ids, classes indexes list and draw bounding boxes
- for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
- annotator.box_label(
- box, label=self.adjust_box_label(cls, conf, track_id=track_id), color=colors(track_id, True)
- )
-
- plot_im = annotator.result()
- self.display_output(plot_im) # Display output with base class function
-
- # Return a SolutionResults
- return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
diff --git a/ultralytics/solutions/vision_eye.py b/ultralytics/solutions/vision_eye.py
deleted file mode 100644
index 7732345..0000000
--- a/ultralytics/solutions/vision_eye.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from typing import Any
-
-from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
-from ultralytics.utils.plotting import colors
-
-
-class VisionEye(BaseSolution):
- """
- A class to manage object detection and vision mapping in images or video streams.
-
- This class extends the BaseSolution class and provides functionality for detecting objects,
- mapping vision points, and annotating results with bounding boxes and labels.
-
- Attributes:
- vision_point (tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.
-
- Methods:
- process: Process the input image to detect objects, annotate them, and apply vision mapping.
-
- Examples:
- >>> vision_eye = VisionEye()
- >>> frame = cv2.imread("frame.jpg")
- >>> results = vision_eye.process(frame)
- >>> print(f"Total detected instances: {results.total_tracks}")
- """
-
- def __init__(self, **kwargs: Any) -> None:
- """
- Initialize the VisionEye class for detecting objects and applying vision mapping.
-
- Args:
- **kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point.
- """
- super().__init__(**kwargs)
- # Set the vision point where the system will view objects and draw tracks
- self.vision_point = self.CFG["vision_point"]
-
- def process(self, im0) -> SolutionResults:
- """
- Perform object detection, vision mapping, and annotation on the input image.
-
- Args:
- im0 (np.ndarray): The input image for detection and annotation.
-
- Returns:
- (SolutionResults): Object containing the annotated image and tracking statistics.
- - plot_im: Annotated output image with bounding boxes and vision mapping
- - total_tracks: Number of tracked objects in the frame
-
- Examples:
- >>> vision_eye = VisionEye()
- >>> frame = cv2.imread("image.jpg")
- >>> results = vision_eye.process(frame)
- >>> print(f"Detected {results.total_tracks} objects")
- """
- self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
- annotator = SolutionAnnotator(im0, self.line_width)
-
- for cls, t_id, box, conf in zip(self.clss, self.track_ids, self.boxes, self.confs):
- # Annotate the image with bounding boxes, labels, and vision mapping
- annotator.box_label(box, label=self.adjust_box_label(cls, conf, t_id), color=colors(int(t_id), True))
- annotator.visioneye(box, self.vision_point)
-
- plot_im = annotator.result()
- self.display_output(plot_im) # Display the annotated output using the base class function
-
- # Return a SolutionResults object with the annotated image and tracking statistics
- return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
diff --git a/ultralytics/trackers/README.md b/ultralytics/trackers/README.md
deleted file mode 100644
index de6acbd..0000000
--- a/ultralytics/trackers/README.md
+++ /dev/null
@@ -1,295 +0,0 @@
-
-
-# Multi-Object Tracking with Ultralytics YOLO
-
-
-
-[Object tracking](https://www.ultralytics.com/glossary/object-tracking), a key aspect of [video analytics](https://en.wikipedia.org/wiki/Video_content_analysis), involves identifying the location and class of objects within video frames and assigning a unique ID to each detected object as it moves. This capability enables a wide range of applications, from surveillance and security systems to [real-time](https://www.ultralytics.com/glossary/real-time-inference) sports analysis and autonomous vehicle navigation. Learn more about tracking on our [tracking documentation page](https://docs.ultralytics.com/modes/track/).
-
-## 🎯 Why Choose Ultralytics YOLO for Object Tracking?
-
-Ultralytics YOLO trackers provide output consistent with standard [object detection](https://docs.ultralytics.com/tasks/detect/) but add persistent object IDs. This simplifies the process of tracking objects in video streams and performing subsequent analyses. Here’s why Ultralytics YOLO is an excellent choice for your object tracking needs:
-
-- **Efficiency:** Process video streams in real-time without sacrificing accuracy.
-- **Flexibility:** Supports multiple robust tracking algorithms and configurations.
-- **Ease of Use:** Offers straightforward [Python API](https://docs.ultralytics.com/usage/python/) and [CLI](https://docs.ultralytics.com/usage/cli/) options for rapid integration and deployment.
-- **Customizability:** Easily integrates with [custom-trained YOLO models](https://docs.ultralytics.com/modes/train/), enabling deployment in specialized, domain-specific applications.
-
-**Watch:** Object Detection and Tracking with Ultralytics YOLOv8.
-
-[](https://www.youtube.com/watch?v=hHyHmOtmEgs)
-
-## ✨ Features at a Glance
-
-Ultralytics YOLO extends its powerful object detection features to deliver robust and versatile object tracking:
-
-- **Real-Time Tracking:** Seamlessly track objects in high-frame-rate videos.
-- **Multiple Tracker Support:** Choose from a selection of established tracking algorithms.
-- **Customizable Tracker Configurations:** Adapt the tracking algorithm to specific requirements by adjusting various parameters.
-
-## 🛠️ Available Trackers
-
-Ultralytics YOLO supports the following tracking algorithms. Enable them by passing the relevant YAML configuration file, such as `tracker=tracker_type.yaml`:
-
-- **BoT-SORT:** Use [`botsort.yaml`](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/trackers/botsort.yaml) to enable this tracker. Based on the [BoT-SORT paper](https://arxiv.org/abs/2206.14651) and its official [code implementation](https://github.com/NirAharon/BoT-SORT).
-- **ByteTrack:** Use [`bytetrack.yaml`](https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/trackers/bytetrack.yaml) to enable this tracker. Based on the [ByteTrack paper](https://arxiv.org/abs/2110.06864) and its official [code implementation](https://github.com/FoundationVision/ByteTrack).
-
-The default tracker is **BoT-SORT**.
-
-## ⚙️ Usage
-
-To run the tracker on video streams, use a trained Detect, Segment, or Pose model like [Ultralytics YOLO11n](https://docs.ultralytics.com/models/yolo11/), YOLO11n-seg, or YOLO11n-pose.
-
-```python
-# Python
-from ultralytics import YOLO
-
-# Load an official or custom model
-model = YOLO("yolo11n.pt") # Load an official Detect model
-# model = YOLO("yolo11n-seg.pt") # Load an official Segment model
-# model = YOLO("yolo11n-pose.pt") # Load an official Pose model
-# model = YOLO("path/to/best.pt") # Load a custom trained model
-
-# Perform tracking with the model
-results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True) # Tracking with default tracker
-# results = model.track(source="https://youtu.be/LNwODJXcvt4", show=True, tracker="bytetrack.yaml") # Tracking with ByteTrack tracker
-```
-
-```bash
-# CLI
-# Perform tracking with various models using the command line interface
-yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" # Official Detect model
-# yolo track model=yolo11n-seg.pt source="https://youtu.be/LNwODJXcvt4" # Official Segment model
-# yolo track model=yolo11n-pose.pt source="https://youtu.be/LNwODJXcvt4" # Official Pose model
-# yolo track model=path/to/best.pt source="https://youtu.be/LNwODJXcvt4" # Custom trained model
-
-# Track using ByteTrack tracker
-# yolo track model=path/to/best.pt tracker="bytetrack.yaml"
-```
-
-As shown above, tracking is available for all [Detect](https://docs.ultralytics.com/tasks/detect/), [Segment](https://docs.ultralytics.com/tasks/segment/), and [Pose](https://docs.ultralytics.com/tasks/pose/) models when run on videos or streaming sources.
-
-## 🔧 Configuration
-
-### Tracking Arguments
-
-Tracking configuration shares properties with the Predict mode, such as `conf` (confidence threshold), `iou` ([Intersection over Union](https://www.ultralytics.com/glossary/intersection-over-union-iou) threshold), and `show` (display results). For additional configurations, refer to the [Predict mode documentation](https://docs.ultralytics.com/modes/predict/).
-
-```python
-# Python
-from ultralytics import YOLO
-
-# Configure the tracking parameters and run the tracker
-model = YOLO("yolo11n.pt")
-results = model.track(source="https://youtu.be/LNwODJXcvt4", conf=0.3, iou=0.5, show=True)
-```
-
-```bash
-# CLI
-# Configure tracking parameters and run the tracker using the command line interface
-yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" conf=0.3 iou=0.5 show
-```
-
-### Tracker Selection
-
-Ultralytics allows you to use a modified tracker configuration file. Create a copy of a tracker config file (e.g., `custom_tracker.yaml`) from [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) and adjust any configurations (except `tracker_type`) according to your needs.
-
-```python
-# Python
-from ultralytics import YOLO
-
-# Load the model and run the tracker with a custom configuration file
-model = YOLO("yolo11n.pt")
-results = model.track(source="https://youtu.be/LNwODJXcvt4", tracker="custom_tracker.yaml")
-```
-
-```bash
-# CLI
-# Load the model and run the tracker with a custom configuration file using the command line interface
-yolo track model=yolo11n.pt source="https://youtu.be/LNwODJXcvt4" tracker='custom_tracker.yaml'
-```
-
-For a comprehensive list of tracking arguments, consult the [Tracking Configuration files](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers) in the repository.
-
-## 🐍 Python Examples
-
-### Persisting Tracks Loop
-
-This Python script uses [OpenCV (`cv2`)](https://opencv.org/) and Ultralytics YOLO11 to perform object tracking on video frames. Ensure you have installed the necessary packages (`opencv-python` and `ultralytics`). The [`persist=True`](https://docs.ultralytics.com/modes/predict/#tracking) argument indicates that the current frame is the next in a sequence, allowing the tracker to maintain track continuity from the previous frame.
-
-```python
-# Python
-import cv2
-
-from ultralytics import YOLO
-
-# Load the YOLO11 model
-model = YOLO("yolo11n.pt")
-
-# Open the video file
-video_path = "path/to/video.mp4"
-cap = cv2.VideoCapture(video_path)
-
-# Loop through the video frames
-while cap.isOpened():
- # Read a frame from the video
- success, frame = cap.read()
-
- if success:
- # Run YOLO11 tracking on the frame, persisting tracks between frames
- results = model.track(frame, persist=True)
-
- # Visualize the results on the frame
- annotated_frame = results[0].plot()
-
- # Display the annotated frame
- cv2.imshow("YOLO11 Tracking", annotated_frame)
-
- # Break the loop if 'q' is pressed
- if cv2.waitKey(1) & 0xFF == ord("q"):
- break
- else:
- # Break the loop if the end of the video is reached
- break
-
-# Release the video capture object and close the display window
-cap.release()
-cv2.destroyAllWindows()
-```
-
-Note the use of `model.track(frame)` instead of `model(frame)`, which specifically enables object tracking. This script processes each video frame, visualizes the tracking results, and displays them. Press 'q' to exit the loop.
-
-### Plotting Tracks Over Time
-
-Visualizing object tracks across consecutive frames offers valuable insights into movement patterns within a video. Ultralytics YOLO11 makes plotting these tracks efficient.
-
-The following example demonstrates how to use YOLO11's tracking capabilities to plot the movement of detected objects. The script opens a video, reads it frame by frame, and uses the YOLO model built on [PyTorch](https://pytorch.org/) to identify and track objects. By storing the center points of the detected [bounding boxes](https://www.ultralytics.com/glossary/bounding-box) and connecting them, we can draw lines representing the paths of tracked objects using [NumPy](https://numpy.org/) for numerical operations.
-
-```python
-# Python
-from collections import defaultdict
-
-import cv2
-import numpy as np
-
-from ultralytics import YOLO
-
-# Load the YOLO11 model
-model = YOLO("yolo11n.pt")
-
-# Open the video file
-video_path = "path/to/video.mp4"
-cap = cv2.VideoCapture(video_path)
-
-# Store the track history
-track_history = defaultdict(lambda: [])
-
-# Loop through the video frames
-while cap.isOpened():
- # Read a frame from the video
- success, frame = cap.read()
-
- if success:
- # Run YOLO11 tracking on the frame, persisting tracks between frames
- result = model.track(frame, persist=True)[0]
-
- # Get the boxes and track IDs
- if result.boxes and result.boxes.is_track:
- boxes = result.boxes.xywh.cpu()
- track_ids = result.boxes.id.int().cpu().tolist()
-
- # Visualize the result on the frame
- frame = result.plot()
-
- # Plot the tracks
- for box, track_id in zip(boxes, track_ids):
- x, y, w, h = box
- track = track_history[track_id]
- track.append((float(x), float(y))) # x, y center point
- if len(track) > 30: # retain 30 tracks for 30 frames
- track.pop(0)
-
- # Draw the tracking lines
- points = np.hstack(track).astype(np.int32).reshape((-1, 1, 2))
- cv2.polylines(frame, [points], isClosed=False, color=(230, 230, 230), thickness=10)
-
- # Display the annotated frame
- cv2.imshow("YOLO11 Tracking", frame)
-
- # Break the loop if 'q' is pressed
- if cv2.waitKey(1) & 0xFF == ord("q"):
- break
- else:
- # Break the loop if the end of the video is reached
- break
-
-# Release the video capture object and close the display window
-cap.release()
-cv2.destroyAllWindows()
-```
-
-### Multithreaded Tracking
-
-Multithreaded tracking allows running object tracking on multiple video streams simultaneously, which is highly beneficial for systems handling inputs from several cameras, improving efficiency through concurrent processing.
-
-This Python script utilizes Python's [`threading`](https://docs.python.org/3/library/threading.html) module for concurrent tracker execution. Each thread manages tracking for a single video file.
-
-The `run_tracker_in_thread` function accepts parameters like the video file path, model, and a unique window index. It contains the main tracking loop, reading frames, running the tracker, and displaying results in a dedicated window.
-
-This example uses two models, `yolo11n.pt` and `yolo11n-seg.pt`, tracking objects in `video_file1` and `video_file2`, respectively.
-
-Setting `daemon=True` in `threading.Thread` ensures threads exit when the main program finishes. Threads are started with `start()` and the main thread waits for their completion using `join()`.
-
-Finally, `cv2.destroyAllWindows()` closes all OpenCV windows after the threads finish.
-
-```python
-# Python
-import threading
-
-import cv2
-
-from ultralytics import YOLO
-
-# Define model names and video sources
-MODEL_NAMES = ["yolo11n.pt", "yolo11n-seg.pt"]
-SOURCES = ["path/to/video.mp4", "0"] # local video, 0 for webcam
-
-
-def run_tracker_in_thread(model_name, filename):
- """
- Run YOLO tracker in its own thread for concurrent processing.
-
- Args:
- model_name (str): The YOLO11 model object.
- filename (str): The path to the video file or the identifier for the webcam/external camera source.
- """
- model = YOLO(model_name)
- results = model.track(filename, save=True, stream=True)
- for r in results:
- pass
-
-
-# Create and start tracker threads using a for loop
-tracker_threads = []
-for video_file, model_name in zip(SOURCES, MODEL_NAMES):
- thread = threading.Thread(target=run_tracker_in_thread, args=(model_name, video_file), daemon=True)
- tracker_threads.append(thread)
- thread.start()
-
-# Wait for all tracker threads to finish
-for thread in tracker_threads:
- thread.join()
-
-# Clean up and close windows
-cv2.destroyAllWindows()
-```
-
-This setup can be easily scaled to handle more video streams by creating additional threads following the same pattern. Explore more applications in our [blog post on object tracking](https://www.ultralytics.com/blog/object-detection-and-tracking-with-ultralytics-yolov8).
-
-## 🤝 Contribute New Trackers
-
-Are you experienced in multi-object tracking and have implemented or adapted an algorithm with Ultralytics YOLO? We encourage you to contribute to our Trackers section in [ultralytics/cfg/trackers](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/trackers)! Your contributions can help expand the tracking solutions available within the Ultralytics [ecosystem](https://docs.ultralytics.com/).
-
-To contribute, please review our [Contributing Guide](https://docs.ultralytics.com/help/contributing/) for instructions on submitting a [Pull Request (PR)](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/about-pull-requests) 🛠️. We look forward to your contributions!
-
-Let's work together to enhance the tracking capabilities of Ultralytics YOLO and provide more powerful tools for the [computer vision](https://www.ultralytics.com/glossary/computer-vision-cv) and [deep learning](https://www.ultralytics.com/glossary/deep-learning-dl) community 🙏!
diff --git a/ultralytics/trackers/__init__.py b/ultralytics/trackers/__init__.py
deleted file mode 100644
index 2919511..0000000
--- a/ultralytics/trackers/__init__.py
+++ /dev/null
@@ -1,7 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from .bot_sort import BOTSORT
-from .byte_tracker import BYTETracker
-from .track import register_tracker
-
-__all__ = "register_tracker", "BOTSORT", "BYTETracker" # allow simpler import
diff --git a/ultralytics/trackers/basetrack.py b/ultralytics/trackers/basetrack.py
deleted file mode 100644
index d254883..0000000
--- a/ultralytics/trackers/basetrack.py
+++ /dev/null
@@ -1,117 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""Module defines the base classes and structures for object tracking in YOLO."""
-
-from collections import OrderedDict
-from typing import Any
-
-import numpy as np
-
-
-class TrackState:
- """
- Enumeration class representing the possible states of an object being tracked.
-
- Attributes:
- New (int): State when the object is newly detected.
- Tracked (int): State when the object is successfully tracked in subsequent frames.
- Lost (int): State when the object is no longer tracked.
- Removed (int): State when the object is removed from tracking.
-
- Examples:
- >>> state = TrackState.New
- >>> if state == TrackState.New:
- >>> print("Object is newly detected.")
- """
-
- New = 0
- Tracked = 1
- Lost = 2
- Removed = 3
-
-
-class BaseTrack:
- """
- Base class for object tracking, providing foundational attributes and methods.
-
- Attributes:
- _count (int): Class-level counter for unique track IDs.
- track_id (int): Unique identifier for the track.
- is_activated (bool): Flag indicating whether the track is currently active.
- state (TrackState): Current state of the track.
- history (OrderedDict): Ordered history of the track's states.
- features (list): List of features extracted from the object for tracking.
- curr_feature (Any): The current feature of the object being tracked.
- score (float): The confidence score of the tracking.
- start_frame (int): The frame number where tracking started.
- frame_id (int): The most recent frame ID processed by the track.
- time_since_update (int): Frames passed since the last update.
- location (tuple): The location of the object in the context of multi-camera tracking.
-
- Methods:
- end_frame: Returns the ID of the last frame where the object was tracked.
- next_id: Increments and returns the next global track ID.
- activate: Abstract method to activate the track.
- predict: Abstract method to predict the next state of the track.
- update: Abstract method to update the track with new data.
- mark_lost: Marks the track as lost.
- mark_removed: Marks the track as removed.
- reset_id: Resets the global track ID counter.
-
- Examples:
- Initialize a new track and mark it as lost:
- >>> track = BaseTrack()
- >>> track.mark_lost()
- >>> print(track.state) # Output: 2 (TrackState.Lost)
- """
-
- _count = 0
-
- def __init__(self):
- """Initialize a new track with a unique ID and foundational tracking attributes."""
- self.track_id = 0
- self.is_activated = False
- self.state = TrackState.New
- self.history = OrderedDict()
- self.features = []
- self.curr_feature = None
- self.score = 0
- self.start_frame = 0
- self.frame_id = 0
- self.time_since_update = 0
- self.location = (np.inf, np.inf)
-
- @property
- def end_frame(self) -> int:
- """Return the ID of the most recent frame where the object was tracked."""
- return self.frame_id
-
- @staticmethod
- def next_id() -> int:
- """Increment and return the next unique global track ID for object tracking."""
- BaseTrack._count += 1
- return BaseTrack._count
-
- def activate(self, *args: Any) -> None:
- """Activate the track with provided arguments, initializing necessary attributes for tracking."""
- raise NotImplementedError
-
- def predict(self) -> None:
- """Predict the next state of the track based on the current state and tracking model."""
- raise NotImplementedError
-
- def update(self, *args: Any, **kwargs: Any) -> None:
- """Update the track with new observations and data, modifying its state and attributes accordingly."""
- raise NotImplementedError
-
- def mark_lost(self) -> None:
- """Mark the track as lost by updating its state to TrackState.Lost."""
- self.state = TrackState.Lost
-
- def mark_removed(self) -> None:
- """Mark the track as removed by setting its state to TrackState.Removed."""
- self.state = TrackState.Removed
-
- @staticmethod
- def reset_id() -> None:
- """Reset the global track ID counter to its initial value."""
- BaseTrack._count = 0
diff --git a/ultralytics/trackers/bot_sort.py b/ultralytics/trackers/bot_sort.py
deleted file mode 100644
index 30f9463..0000000
--- a/ultralytics/trackers/bot_sort.py
+++ /dev/null
@@ -1,274 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from collections import deque
-from typing import Any
-
-import numpy as np
-import torch
-
-from ultralytics.utils.ops import xywh2xyxy
-from ultralytics.utils.plotting import save_one_box
-
-from .basetrack import TrackState
-from .byte_tracker import BYTETracker, STrack
-from .utils import matching
-from .utils.gmc import GMC
-from .utils.kalman_filter import KalmanFilterXYWH
-
-
-class BOTrack(STrack):
- """
- An extended version of the STrack class for YOLO, adding object tracking features.
-
- This class extends the STrack class to include additional functionalities for object tracking, such as feature
- smoothing, Kalman filter prediction, and reactivation of tracks.
-
- Attributes:
- shared_kalman (KalmanFilterXYWH): A shared Kalman filter for all instances of BOTrack.
- smooth_feat (np.ndarray): Smoothed feature vector.
- curr_feat (np.ndarray): Current feature vector.
- features (deque): A deque to store feature vectors with a maximum length defined by `feat_history`.
- alpha (float): Smoothing factor for the exponential moving average of features.
- mean (np.ndarray): The mean state of the Kalman filter.
- covariance (np.ndarray): The covariance matrix of the Kalman filter.
-
- Methods:
- update_features: Update features vector and smooth it using exponential moving average.
- predict: Predict the mean and covariance using Kalman filter.
- re_activate: Reactivate a track with updated features and optionally new ID.
- update: Update the track with new detection and frame ID.
- tlwh: Property that gets the current position in tlwh format `(top left x, top left y, width, height)`.
- multi_predict: Predict the mean and covariance of multiple object tracks using shared Kalman filter.
- convert_coords: Convert tlwh bounding box coordinates to xywh format.
- tlwh_to_xywh: Convert bounding box to xywh format `(center x, center y, width, height)`.
-
- Examples:
- Create a BOTrack instance and update its features
- >>> bo_track = BOTrack(tlwh=[100, 50, 80, 40], score=0.9, cls=1, feat=np.random.rand(128))
- >>> bo_track.predict()
- >>> new_track = BOTrack(tlwh=[110, 60, 80, 40], score=0.85, cls=1, feat=np.random.rand(128))
- >>> bo_track.update(new_track, frame_id=2)
- """
-
- shared_kalman = KalmanFilterXYWH()
-
- def __init__(
- self, xywh: np.ndarray, score: float, cls: int, feat: np.ndarray | None = None, feat_history: int = 50
- ):
- """
- Initialize a BOTrack object with temporal parameters, such as feature history, alpha, and current features.
-
- Args:
- xywh (np.ndarray): Bounding box coordinates in xywh format (center x, center y, width, height).
- score (float): Confidence score of the detection.
- cls (int): Class ID of the detected object.
- feat (np.ndarray, optional): Feature vector associated with the detection.
- feat_history (int): Maximum length of the feature history deque.
-
- Examples:
- Initialize a BOTrack object with bounding box, score, class ID, and feature vector
- >>> xywh = np.array([100, 150, 60, 50])
- >>> score = 0.9
- >>> cls = 1
- >>> feat = np.random.rand(128)
- >>> bo_track = BOTrack(xywh, score, cls, feat)
- """
- super().__init__(xywh, score, cls)
-
- self.smooth_feat = None
- self.curr_feat = None
- if feat is not None:
- self.update_features(feat)
- self.features = deque([], maxlen=feat_history)
- self.alpha = 0.9
-
- def update_features(self, feat: np.ndarray) -> None:
- """Update the feature vector and apply exponential moving average smoothing."""
- feat /= np.linalg.norm(feat)
- self.curr_feat = feat
- if self.smooth_feat is None:
- self.smooth_feat = feat
- else:
- self.smooth_feat = self.alpha * self.smooth_feat + (1 - self.alpha) * feat
- self.features.append(feat)
- self.smooth_feat /= np.linalg.norm(self.smooth_feat)
-
- def predict(self) -> None:
- """Predict the object's future state using the Kalman filter to update its mean and covariance."""
- mean_state = self.mean.copy()
- if self.state != TrackState.Tracked:
- mean_state[6] = 0
- mean_state[7] = 0
-
- self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
-
- def re_activate(self, new_track: BOTrack, frame_id: int, new_id: bool = False) -> None:
- """Reactivate a track with updated features and optionally assign a new ID."""
- if new_track.curr_feat is not None:
- self.update_features(new_track.curr_feat)
- super().re_activate(new_track, frame_id, new_id)
-
- def update(self, new_track: BOTrack, frame_id: int) -> None:
- """Update the track with new detection information and the current frame ID."""
- if new_track.curr_feat is not None:
- self.update_features(new_track.curr_feat)
- super().update(new_track, frame_id)
-
- @property
- def tlwh(self) -> np.ndarray:
- """Return the current bounding box position in `(top left x, top left y, width, height)` format."""
- if self.mean is None:
- return self._tlwh.copy()
- ret = self.mean[:4].copy()
- ret[:2] -= ret[2:] / 2
- return ret
-
- @staticmethod
- def multi_predict(stracks: list[BOTrack]) -> None:
- """Predict the mean and covariance for multiple object tracks using a shared Kalman filter."""
- if len(stracks) <= 0:
- return
- multi_mean = np.asarray([st.mean.copy() for st in stracks])
- multi_covariance = np.asarray([st.covariance for st in stracks])
- for i, st in enumerate(stracks):
- if st.state != TrackState.Tracked:
- multi_mean[i][6] = 0
- multi_mean[i][7] = 0
- multi_mean, multi_covariance = BOTrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
- for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
- stracks[i].mean = mean
- stracks[i].covariance = cov
-
- def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
- """Convert tlwh bounding box coordinates to xywh format."""
- return self.tlwh_to_xywh(tlwh)
-
- @staticmethod
- def tlwh_to_xywh(tlwh: np.ndarray) -> np.ndarray:
- """Convert bounding box from tlwh (top-left-width-height) to xywh (center-x-center-y-width-height) format."""
- ret = np.asarray(tlwh).copy()
- ret[:2] += ret[2:] / 2
- return ret
-
-
-class BOTSORT(BYTETracker):
- """
- An extended version of the BYTETracker class for YOLO, designed for object tracking with ReID and GMC algorithm.
-
- Attributes:
- proximity_thresh (float): Threshold for spatial proximity (IoU) between tracks and detections.
- appearance_thresh (float): Threshold for appearance similarity (ReID embeddings) between tracks and detections.
- encoder (Any): Object to handle ReID embeddings, set to None if ReID is not enabled.
- gmc (GMC): An instance of the GMC algorithm for data association.
- args (Any): Parsed command-line arguments containing tracking parameters.
-
- Methods:
- get_kalmanfilter: Return an instance of KalmanFilterXYWH for object tracking.
- init_track: Initialize track with detections, scores, and classes.
- get_dists: Get distances between tracks and detections using IoU and (optionally) ReID.
- multi_predict: Predict and track multiple objects with a YOLO model.
- reset: Reset the BOTSORT tracker to its initial state.
-
- Examples:
- Initialize BOTSORT and process detections
- >>> bot_sort = BOTSORT(args, frame_rate=30)
- >>> bot_sort.init_track(dets, scores, cls, img)
- >>> bot_sort.multi_predict(tracks)
-
- Note:
- The class is designed to work with a YOLO object detection model and supports ReID only if enabled via args.
- """
-
- def __init__(self, args: Any, frame_rate: int = 30):
- """
- Initialize BOTSORT object with ReID module and GMC algorithm.
-
- Args:
- args (Any): Parsed command-line arguments containing tracking parameters.
- frame_rate (int): Frame rate of the video being processed.
-
- Examples:
- Initialize BOTSORT with command-line arguments and a specified frame rate:
- >>> args = parse_args()
- >>> bot_sort = BOTSORT(args, frame_rate=30)
- """
- super().__init__(args, frame_rate)
- self.gmc = GMC(method=args.gmc_method)
-
- # ReID module
- self.proximity_thresh = args.proximity_thresh
- self.appearance_thresh = args.appearance_thresh
- self.encoder = (
- (lambda feats, s: [f.cpu().numpy() for f in feats]) # native features do not require any model
- if args.with_reid and self.args.model == "auto"
- else ReID(args.model)
- if args.with_reid
- else None
- )
-
- def get_kalmanfilter(self) -> KalmanFilterXYWH:
- """Return an instance of KalmanFilterXYWH for predicting and updating object states in the tracking process."""
- return KalmanFilterXYWH()
-
- def init_track(self, results, img: np.ndarray | None = None) -> list[BOTrack]:
- """Initialize object tracks using detection bounding boxes, scores, class labels, and optional ReID features."""
- if len(results) == 0:
- return []
- bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
- bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
- if self.args.with_reid and self.encoder is not None:
- features_keep = self.encoder(img, bboxes)
- return [BOTrack(xywh, s, c, f) for (xywh, s, c, f) in zip(bboxes, results.conf, results.cls, features_keep)]
- else:
- return [BOTrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
-
- def get_dists(self, tracks: list[BOTrack], detections: list[BOTrack]) -> np.ndarray:
- """Calculate distances between tracks and detections using IoU and optionally ReID embeddings."""
- dists = matching.iou_distance(tracks, detections)
- dists_mask = dists > (1 - self.proximity_thresh)
-
- if self.args.fuse_score:
- dists = matching.fuse_score(dists, detections)
-
- if self.args.with_reid and self.encoder is not None:
- emb_dists = matching.embedding_distance(tracks, detections) / 2.0
- emb_dists[emb_dists > (1 - self.appearance_thresh)] = 1.0
- emb_dists[dists_mask] = 1.0
- dists = np.minimum(dists, emb_dists)
- return dists
-
- def multi_predict(self, tracks: list[BOTrack]) -> None:
- """Predict the mean and covariance of multiple object tracks using a shared Kalman filter."""
- BOTrack.multi_predict(tracks)
-
- def reset(self) -> None:
- """Reset the BOTSORT tracker to its initial state, clearing all tracked objects and internal states."""
- super().reset()
- self.gmc.reset_params()
-
-
-class ReID:
- """YOLO model as encoder for re-identification."""
-
- def __init__(self, model: str):
- """
- Initialize encoder for re-identification.
-
- Args:
- model (str): Path to the YOLO model for re-identification.
- """
- from ultralytics import YOLO
-
- self.model = YOLO(model)
- self.model(embed=[len(self.model.model.model) - 2 if ".pt" in model else -1], verbose=False, save=False) # init
-
- def __call__(self, img: np.ndarray, dets: np.ndarray) -> list[np.ndarray]:
- """Extract embeddings for detected objects."""
- feats = self.model.predictor(
- [save_one_box(det, img, save=False) for det in xywh2xyxy(torch.from_numpy(dets[:, :4]))]
- )
- if len(feats) != dets.shape[0] and feats[0].shape[0] == dets.shape[0]:
- feats = feats[0] # batched prediction with non-PyTorch backend
- return [f.cpu().numpy() for f in feats]
diff --git a/ultralytics/trackers/byte_tracker.py b/ultralytics/trackers/byte_tracker.py
deleted file mode 100644
index cdc7dcb..0000000
--- a/ultralytics/trackers/byte_tracker.py
+++ /dev/null
@@ -1,485 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from typing import Any
-
-import numpy as np
-
-from ..utils import LOGGER
-from ..utils.ops import xywh2ltwh
-from .basetrack import BaseTrack, TrackState
-from .utils import matching
-from .utils.kalman_filter import KalmanFilterXYAH
-
-
-class STrack(BaseTrack):
- """
- Single object tracking representation that uses Kalman filtering for state estimation.
-
- This class is responsible for storing all the information regarding individual tracklets and performs state updates
- and predictions based on Kalman filter.
-
- Attributes:
- shared_kalman (KalmanFilterXYAH): Shared Kalman filter used across all STrack instances for prediction.
- _tlwh (np.ndarray): Private attribute to store top-left corner coordinates and width and height of bounding box.
- kalman_filter (KalmanFilterXYAH): Instance of Kalman filter used for this particular object track.
- mean (np.ndarray): Mean state estimate vector.
- covariance (np.ndarray): Covariance of state estimate.
- is_activated (bool): Boolean flag indicating if the track has been activated.
- score (float): Confidence score of the track.
- tracklet_len (int): Length of the tracklet.
- cls (Any): Class label for the object.
- idx (int): Index or identifier for the object.
- frame_id (int): Current frame ID.
- start_frame (int): Frame where the object was first detected.
- angle (float | None): Optional angle information for oriented bounding boxes.
-
- Methods:
- predict: Predict the next state of the object using Kalman filter.
- multi_predict: Predict the next states for multiple tracks.
- multi_gmc: Update multiple track states using a homography matrix.
- activate: Activate a new tracklet.
- re_activate: Reactivate a previously lost tracklet.
- update: Update the state of a matched track.
- convert_coords: Convert bounding box to x-y-aspect-height format.
- tlwh_to_xyah: Convert tlwh bounding box to xyah format.
-
- Examples:
- Initialize and activate a new track
- >>> track = STrack(xywh=[100, 200, 50, 80, 0], score=0.9, cls="person")
- >>> track.activate(kalman_filter=KalmanFilterXYAH(), frame_id=1)
- """
-
- shared_kalman = KalmanFilterXYAH()
-
- def __init__(self, xywh: list[float], score: float, cls: Any):
- """
- Initialize a new STrack instance.
-
- Args:
- xywh (list[float]): Bounding box coordinates and dimensions in the format (x, y, w, h, [a], idx), where
- (x, y) is the center, (w, h) are width and height, [a] is optional aspect ratio, and idx is the id.
- score (float): Confidence score of the detection.
- cls (Any): Class label for the detected object.
-
- Examples:
- >>> xywh = [100.0, 150.0, 50.0, 75.0, 1]
- >>> score = 0.9
- >>> cls = "person"
- >>> track = STrack(xywh, score, cls)
- """
- super().__init__()
- # xywh+idx or xywha+idx
- assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
- self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
- self.kalman_filter = None
- self.mean, self.covariance = None, None
- self.is_activated = False
-
- self.score = score
- self.tracklet_len = 0
- self.cls = cls
- self.idx = xywh[-1]
- self.angle = xywh[4] if len(xywh) == 6 else None
-
- def predict(self):
- """Predict the next state (mean and covariance) of the object using the Kalman filter."""
- mean_state = self.mean.copy()
- if self.state != TrackState.Tracked:
- mean_state[7] = 0
- self.mean, self.covariance = self.kalman_filter.predict(mean_state, self.covariance)
-
- @staticmethod
- def multi_predict(stracks: list[STrack]):
- """Perform multi-object predictive tracking using Kalman filter for the provided list of STrack instances."""
- if len(stracks) <= 0:
- return
- multi_mean = np.asarray([st.mean.copy() for st in stracks])
- multi_covariance = np.asarray([st.covariance for st in stracks])
- for i, st in enumerate(stracks):
- if st.state != TrackState.Tracked:
- multi_mean[i][7] = 0
- multi_mean, multi_covariance = STrack.shared_kalman.multi_predict(multi_mean, multi_covariance)
- for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
- stracks[i].mean = mean
- stracks[i].covariance = cov
-
- @staticmethod
- def multi_gmc(stracks: list[STrack], H: np.ndarray = np.eye(2, 3)):
- """Update state tracks positions and covariances using a homography matrix for multiple tracks."""
- if stracks:
- multi_mean = np.asarray([st.mean.copy() for st in stracks])
- multi_covariance = np.asarray([st.covariance for st in stracks])
-
- R = H[:2, :2]
- R8x8 = np.kron(np.eye(4, dtype=float), R)
- t = H[:2, 2]
-
- for i, (mean, cov) in enumerate(zip(multi_mean, multi_covariance)):
- mean = R8x8.dot(mean)
- mean[:2] += t
- cov = R8x8.dot(cov).dot(R8x8.transpose())
-
- stracks[i].mean = mean
- stracks[i].covariance = cov
-
- def activate(self, kalman_filter: KalmanFilterXYAH, frame_id: int):
- """Activate a new tracklet using the provided Kalman filter and initialize its state and covariance."""
- self.kalman_filter = kalman_filter
- self.track_id = self.next_id()
- self.mean, self.covariance = self.kalman_filter.initiate(self.convert_coords(self._tlwh))
-
- self.tracklet_len = 0
- self.state = TrackState.Tracked
- if frame_id == 1:
- self.is_activated = True
- self.frame_id = frame_id
- self.start_frame = frame_id
-
- def re_activate(self, new_track: STrack, frame_id: int, new_id: bool = False):
- """Reactivate a previously lost track using new detection data and update its state and attributes."""
- self.mean, self.covariance = self.kalman_filter.update(
- self.mean, self.covariance, self.convert_coords(new_track.tlwh)
- )
- self.tracklet_len = 0
- self.state = TrackState.Tracked
- self.is_activated = True
- self.frame_id = frame_id
- if new_id:
- self.track_id = self.next_id()
- self.score = new_track.score
- self.cls = new_track.cls
- self.angle = new_track.angle
- self.idx = new_track.idx
-
- def update(self, new_track: STrack, frame_id: int):
- """
- Update the state of a matched track.
-
- Args:
- new_track (STrack): The new track containing updated information.
- frame_id (int): The ID of the current frame.
-
- Examples:
- Update the state of a track with new detection information
- >>> track = STrack([100, 200, 50, 80, 0.9, 1])
- >>> new_track = STrack([105, 205, 55, 85, 0.95, 1])
- >>> track.update(new_track, 2)
- """
- self.frame_id = frame_id
- self.tracklet_len += 1
-
- new_tlwh = new_track.tlwh
- self.mean, self.covariance = self.kalman_filter.update(
- self.mean, self.covariance, self.convert_coords(new_tlwh)
- )
- self.state = TrackState.Tracked
- self.is_activated = True
-
- self.score = new_track.score
- self.cls = new_track.cls
- self.angle = new_track.angle
- self.idx = new_track.idx
-
- def convert_coords(self, tlwh: np.ndarray) -> np.ndarray:
- """Convert a bounding box's top-left-width-height format to its x-y-aspect-height equivalent."""
- return self.tlwh_to_xyah(tlwh)
-
- @property
- def tlwh(self) -> np.ndarray:
- """Get the bounding box in top-left-width-height format from the current state estimate."""
- if self.mean is None:
- return self._tlwh.copy()
- ret = self.mean[:4].copy()
- ret[2] *= ret[3]
- ret[:2] -= ret[2:] / 2
- return ret
-
- @property
- def xyxy(self) -> np.ndarray:
- """Convert bounding box from (top left x, top left y, width, height) to (min x, min y, max x, max y) format."""
- ret = self.tlwh.copy()
- ret[2:] += ret[:2]
- return ret
-
- @staticmethod
- def tlwh_to_xyah(tlwh: np.ndarray) -> np.ndarray:
- """Convert bounding box from tlwh format to center-x-center-y-aspect-height (xyah) format."""
- ret = np.asarray(tlwh).copy()
- ret[:2] += ret[2:] / 2
- ret[2] /= ret[3]
- return ret
-
- @property
- def xywh(self) -> np.ndarray:
- """Get the current position of the bounding box in (center x, center y, width, height) format."""
- ret = np.asarray(self.tlwh).copy()
- ret[:2] += ret[2:] / 2
- return ret
-
- @property
- def xywha(self) -> np.ndarray:
- """Get position in (center x, center y, width, height, angle) format, warning if angle is missing."""
- if self.angle is None:
- LOGGER.warning("`angle` attr not found, returning `xywh` instead.")
- return self.xywh
- return np.concatenate([self.xywh, self.angle[None]])
-
- @property
- def result(self) -> list[float]:
- """Get the current tracking results in the appropriate bounding box format."""
- coords = self.xyxy if self.angle is None else self.xywha
- return coords.tolist() + [self.track_id, self.score, self.cls, self.idx]
-
- def __repr__(self) -> str:
- """Return a string representation of the STrack object including start frame, end frame, and track ID."""
- return f"OT_{self.track_id}_({self.start_frame}-{self.end_frame})"
-
-
-class BYTETracker:
- """
- BYTETracker: A tracking algorithm built on top of YOLOv8 for object detection and tracking.
-
- This class encapsulates the functionality for initializing, updating, and managing the tracks for detected objects in a
- video sequence. It maintains the state of tracked, lost, and removed tracks over frames, utilizes Kalman filtering for
- predicting the new object locations, and performs data association.
-
- Attributes:
- tracked_stracks (list[STrack]): List of successfully activated tracks.
- lost_stracks (list[STrack]): List of lost tracks.
- removed_stracks (list[STrack]): List of removed tracks.
- frame_id (int): The current frame ID.
- args (Namespace): Command-line arguments.
- max_time_lost (int): The maximum frames for a track to be considered as 'lost'.
- kalman_filter (KalmanFilterXYAH): Kalman Filter object.
-
- Methods:
- update: Update object tracker with new detections.
- get_kalmanfilter: Return a Kalman filter object for tracking bounding boxes.
- init_track: Initialize object tracking with detections.
- get_dists: Calculate the distance between tracks and detections.
- multi_predict: Predict the location of tracks.
- reset_id: Reset the ID counter of STrack.
- reset: Reset the tracker by clearing all tracks.
- joint_stracks: Combine two lists of stracks.
- sub_stracks: Filter out the stracks present in the second list from the first list.
- remove_duplicate_stracks: Remove duplicate stracks based on IoU.
-
- Examples:
- Initialize BYTETracker and update with detection results
- >>> tracker = BYTETracker(args, frame_rate=30)
- >>> results = yolo_model.detect(image)
- >>> tracked_objects = tracker.update(results)
- """
-
- def __init__(self, args, frame_rate: int = 30):
- """
- Initialize a BYTETracker instance for object tracking.
-
- Args:
- args (Namespace): Command-line arguments containing tracking parameters.
- frame_rate (int): Frame rate of the video sequence.
-
- Examples:
- Initialize BYTETracker with command-line arguments and a frame rate of 30
- >>> args = Namespace(track_buffer=30)
- >>> tracker = BYTETracker(args, frame_rate=30)
- """
- self.tracked_stracks = [] # type: list[STrack]
- self.lost_stracks = [] # type: list[STrack]
- self.removed_stracks = [] # type: list[STrack]
-
- self.frame_id = 0
- self.args = args
- self.max_time_lost = int(frame_rate / 30.0 * args.track_buffer)
- self.kalman_filter = self.get_kalmanfilter()
- self.reset_id()
-
- def update(self, results, img: np.ndarray | None = None, feats: np.ndarray | None = None) -> np.ndarray:
- """Update the tracker with new detections and return the current list of tracked objects."""
- self.frame_id += 1
- activated_stracks = []
- refind_stracks = []
- lost_stracks = []
- removed_stracks = []
-
- scores = results.conf
- remain_inds = scores >= self.args.track_high_thresh
- inds_low = scores > self.args.track_low_thresh
- inds_high = scores < self.args.track_high_thresh
-
- inds_second = inds_low & inds_high
- results_second = results[inds_second]
- results = results[remain_inds]
- feats_keep = feats_second = img
- if feats is not None and len(feats):
- feats_keep = feats[remain_inds]
- feats_second = feats[inds_second]
-
- detections = self.init_track(results, feats_keep)
- # Add newly detected tracklets to tracked_stracks
- unconfirmed = []
- tracked_stracks = [] # type: list[STrack]
- for track in self.tracked_stracks:
- if not track.is_activated:
- unconfirmed.append(track)
- else:
- tracked_stracks.append(track)
- # Step 2: First association, with high score detection boxes
- strack_pool = self.joint_stracks(tracked_stracks, self.lost_stracks)
- # Predict the current location with KF
- self.multi_predict(strack_pool)
- if hasattr(self, "gmc") and img is not None:
- # use try-except here to bypass errors from gmc module
- try:
- warp = self.gmc.apply(img, results.xyxy)
- except Exception:
- warp = np.eye(2, 3)
- STrack.multi_gmc(strack_pool, warp)
- STrack.multi_gmc(unconfirmed, warp)
-
- dists = self.get_dists(strack_pool, detections)
- matches, u_track, u_detection = matching.linear_assignment(dists, thresh=self.args.match_thresh)
-
- for itracked, idet in matches:
- track = strack_pool[itracked]
- det = detections[idet]
- if track.state == TrackState.Tracked:
- track.update(det, self.frame_id)
- activated_stracks.append(track)
- else:
- track.re_activate(det, self.frame_id, new_id=False)
- refind_stracks.append(track)
- # Step 3: Second association, with low score detection boxes association the untrack to the low score detections
- detections_second = self.init_track(results_second, feats_second)
- r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
- # TODO
- dists = matching.iou_distance(r_tracked_stracks, detections_second)
- matches, u_track, u_detection_second = matching.linear_assignment(dists, thresh=0.5)
- for itracked, idet in matches:
- track = r_tracked_stracks[itracked]
- det = detections_second[idet]
- if track.state == TrackState.Tracked:
- track.update(det, self.frame_id)
- activated_stracks.append(track)
- else:
- track.re_activate(det, self.frame_id, new_id=False)
- refind_stracks.append(track)
-
- for it in u_track:
- track = r_tracked_stracks[it]
- if track.state != TrackState.Lost:
- track.mark_lost()
- lost_stracks.append(track)
- # Deal with unconfirmed tracks, usually tracks with only one beginning frame
- detections = [detections[i] for i in u_detection]
- dists = self.get_dists(unconfirmed, detections)
- matches, u_unconfirmed, u_detection = matching.linear_assignment(dists, thresh=0.7)
- for itracked, idet in matches:
- unconfirmed[itracked].update(detections[idet], self.frame_id)
- activated_stracks.append(unconfirmed[itracked])
- for it in u_unconfirmed:
- track = unconfirmed[it]
- track.mark_removed()
- removed_stracks.append(track)
- # Step 4: Init new stracks
- for inew in u_detection:
- track = detections[inew]
- if track.score < self.args.new_track_thresh:
- continue
- track.activate(self.kalman_filter, self.frame_id)
- activated_stracks.append(track)
- # Step 5: Update state
- for track in self.lost_stracks:
- if self.frame_id - track.end_frame > self.max_time_lost:
- track.mark_removed()
- removed_stracks.append(track)
-
- self.tracked_stracks = [t for t in self.tracked_stracks if t.state == TrackState.Tracked]
- self.tracked_stracks = self.joint_stracks(self.tracked_stracks, activated_stracks)
- self.tracked_stracks = self.joint_stracks(self.tracked_stracks, refind_stracks)
- self.lost_stracks = self.sub_stracks(self.lost_stracks, self.tracked_stracks)
- self.lost_stracks.extend(lost_stracks)
- self.lost_stracks = self.sub_stracks(self.lost_stracks, self.removed_stracks)
- self.tracked_stracks, self.lost_stracks = self.remove_duplicate_stracks(self.tracked_stracks, self.lost_stracks)
- self.removed_stracks.extend(removed_stracks)
- if len(self.removed_stracks) > 1000:
- self.removed_stracks = self.removed_stracks[-999:] # clip remove stracks to 1000 maximum
-
- return np.asarray([x.result for x in self.tracked_stracks if x.is_activated], dtype=np.float32)
-
- def get_kalmanfilter(self) -> KalmanFilterXYAH:
- """Return a Kalman filter object for tracking bounding boxes using KalmanFilterXYAH."""
- return KalmanFilterXYAH()
-
- def init_track(self, results, img: np.ndarray | None = None) -> list[STrack]:
- """Initialize object tracking with given detections, scores, and class labels using the STrack algorithm."""
- if len(results) == 0:
- return []
- bboxes = results.xywhr if hasattr(results, "xywhr") else results.xywh
- bboxes = np.concatenate([bboxes, np.arange(len(bboxes)).reshape(-1, 1)], axis=-1)
- return [STrack(xywh, s, c) for (xywh, s, c) in zip(bboxes, results.conf, results.cls)]
-
- def get_dists(self, tracks: list[STrack], detections: list[STrack]) -> np.ndarray:
- """Calculate the distance between tracks and detections using IoU and optionally fuse scores."""
- dists = matching.iou_distance(tracks, detections)
- if self.args.fuse_score:
- dists = matching.fuse_score(dists, detections)
- return dists
-
- def multi_predict(self, tracks: list[STrack]):
- """Predict the next states for multiple tracks using Kalman filter."""
- STrack.multi_predict(tracks)
-
- @staticmethod
- def reset_id():
- """Reset the ID counter for STrack instances to ensure unique track IDs across tracking sessions."""
- STrack.reset_id()
-
- def reset(self):
- """Reset the tracker by clearing all tracked, lost, and removed tracks and reinitializing the Kalman filter."""
- self.tracked_stracks = [] # type: list[STrack]
- self.lost_stracks = [] # type: list[STrack]
- self.removed_stracks = [] # type: list[STrack]
- self.frame_id = 0
- self.kalman_filter = self.get_kalmanfilter()
- self.reset_id()
-
- @staticmethod
- def joint_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
- """Combine two lists of STrack objects into a single list, ensuring no duplicates based on track IDs."""
- exists = {}
- res = []
- for t in tlista:
- exists[t.track_id] = 1
- res.append(t)
- for t in tlistb:
- tid = t.track_id
- if not exists.get(tid, 0):
- exists[tid] = 1
- res.append(t)
- return res
-
- @staticmethod
- def sub_stracks(tlista: list[STrack], tlistb: list[STrack]) -> list[STrack]:
- """Filter out the stracks present in the second list from the first list."""
- track_ids_b = {t.track_id for t in tlistb}
- return [t for t in tlista if t.track_id not in track_ids_b]
-
- @staticmethod
- def remove_duplicate_stracks(stracksa: list[STrack], stracksb: list[STrack]) -> tuple[list[STrack], list[STrack]]:
- """Remove duplicate stracks from two lists based on Intersection over Union (IoU) distance."""
- pdist = matching.iou_distance(stracksa, stracksb)
- pairs = np.where(pdist < 0.15)
- dupa, dupb = [], []
- for p, q in zip(*pairs):
- timep = stracksa[p].frame_id - stracksa[p].start_frame
- timeq = stracksb[q].frame_id - stracksb[q].start_frame
- if timep > timeq:
- dupb.append(q)
- else:
- dupa.append(p)
- resa = [t for i, t in enumerate(stracksa) if i not in dupa]
- resb = [t for i, t in enumerate(stracksb) if i not in dupb]
- return resa, resb
diff --git a/ultralytics/trackers/track.py b/ultralytics/trackers/track.py
deleted file mode 100644
index 8720f73..0000000
--- a/ultralytics/trackers/track.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from functools import partial
-from pathlib import Path
-
-import torch
-
-from ultralytics.utils import YAML, IterableSimpleNamespace
-from ultralytics.utils.checks import check_yaml
-
-from .bot_sort import BOTSORT
-from .byte_tracker import BYTETracker
-
-# A mapping of tracker types to corresponding tracker classes
-TRACKER_MAP = {"bytetrack": BYTETracker, "botsort": BOTSORT}
-
-
-def on_predict_start(predictor: object, persist: bool = False) -> None:
- """
- Initialize trackers for object tracking during prediction.
-
- Args:
- predictor (ultralytics.engine.predictor.BasePredictor): The predictor object to initialize trackers for.
- persist (bool, optional): Whether to persist the trackers if they already exist.
-
- Examples:
- Initialize trackers for a predictor object
- >>> predictor = SomePredictorClass()
- >>> on_predict_start(predictor, persist=True)
- """
- if predictor.args.task == "classify":
- raise ValueError("❌ Classification doesn't support 'mode=track'")
-
- if hasattr(predictor, "trackers") and persist:
- return
-
- tracker = check_yaml(predictor.args.tracker)
- cfg = IterableSimpleNamespace(**YAML.load(tracker))
-
- if cfg.tracker_type not in {"bytetrack", "botsort"}:
- raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
-
- predictor._feats = None # reset in case used earlier
- if hasattr(predictor, "_hook"):
- predictor._hook.remove()
- if cfg.tracker_type == "botsort" and cfg.with_reid and cfg.model == "auto":
- from ultralytics.nn.modules.head import Detect
-
- if not (
- isinstance(predictor.model.model, torch.nn.Module)
- and isinstance(predictor.model.model.model[-1], Detect)
- and not predictor.model.model.model[-1].end2end
- ):
- cfg.model = "yolo11n-cls.pt"
- else:
- # Register hook to extract input of Detect layer
- def pre_hook(module, input):
- predictor._feats = list(input[0]) # unroll to new list to avoid mutation in forward
-
- predictor._hook = predictor.model.model.model[-1].register_forward_pre_hook(pre_hook)
-
- trackers = []
- for _ in range(predictor.dataset.bs):
- tracker = TRACKER_MAP[cfg.tracker_type](args=cfg, frame_rate=30)
- trackers.append(tracker)
- if predictor.dataset.mode != "stream": # only need one tracker for other modes
- break
- predictor.trackers = trackers
- predictor.vid_path = [None] * predictor.dataset.bs # for determining when to reset tracker on new video
-
-
-def on_predict_postprocess_end(predictor: object, persist: bool = False) -> None:
- """
- Postprocess detected boxes and update with object tracking.
-
- Args:
- predictor (object): The predictor object containing the predictions.
- persist (bool, optional): Whether to persist the trackers if they already exist.
-
- Examples:
- Postprocess predictions and update with tracking
- >>> predictor = YourPredictorClass()
- >>> on_predict_postprocess_end(predictor, persist=True)
- """
- is_obb = predictor.args.task == "obb"
- is_stream = predictor.dataset.mode == "stream"
- for i, result in enumerate(predictor.results):
- tracker = predictor.trackers[i if is_stream else 0]
- vid_path = predictor.save_dir / Path(result.path).name
- if not persist and predictor.vid_path[i if is_stream else 0] != vid_path:
- tracker.reset()
- predictor.vid_path[i if is_stream else 0] = vid_path
-
- det = (result.obb if is_obb else result.boxes).cpu().numpy()
- tracks = tracker.update(det, result.orig_img, getattr(result, "feats", None))
- if len(tracks) == 0:
- continue
- idx = tracks[:, -1].astype(int)
- predictor.results[i] = result[idx]
-
- update_args = {"obb" if is_obb else "boxes": torch.as_tensor(tracks[:, :-1])}
- predictor.results[i].update(**update_args)
-
-
-def register_tracker(model: object, persist: bool) -> None:
- """
- Register tracking callbacks to the model for object tracking during prediction.
-
- Args:
- model (object): The model object to register tracking callbacks for.
- persist (bool): Whether to persist the trackers if they already exist.
-
- Examples:
- Register tracking callbacks to a YOLO model
- >>> model = YOLOModel()
- >>> register_tracker(model, persist=True)
- """
- model.add_callback("on_predict_start", partial(on_predict_start, persist=persist))
- model.add_callback("on_predict_postprocess_end", partial(on_predict_postprocess_end, persist=persist))
diff --git a/ultralytics/trackers/utils/__init__.py b/ultralytics/trackers/utils/__init__.py
deleted file mode 100644
index 77a19dc..0000000
--- a/ultralytics/trackers/utils/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
diff --git a/ultralytics/trackers/utils/gmc.py b/ultralytics/trackers/utils/gmc.py
deleted file mode 100644
index 0eab5f2..0000000
--- a/ultralytics/trackers/utils/gmc.py
+++ /dev/null
@@ -1,350 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import copy
-
-import cv2
-import numpy as np
-
-from ultralytics.utils import LOGGER
-
-
-class GMC:
- """
- Generalized Motion Compensation (GMC) class for tracking and object detection in video frames.
-
- This class provides methods for tracking and detecting objects based on several tracking algorithms including ORB,
- SIFT, ECC, and Sparse Optical Flow. It also supports downscaling of frames for computational efficiency.
-
- Attributes:
- method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
- downscale (int): Factor by which to downscale the frames for processing.
- prevFrame (np.ndarray): Previous frame for tracking.
- prevKeyPoints (list): Keypoints from the previous frame.
- prevDescriptors (np.ndarray): Descriptors from the previous frame.
- initializedFirstFrame (bool): Flag indicating if the first frame has been processed.
-
- Methods:
- apply: Apply the chosen method to a raw frame and optionally use provided detections.
- apply_ecc: Apply the ECC algorithm to a raw frame.
- apply_features: Apply feature-based methods like ORB or SIFT to a raw frame.
- apply_sparseoptflow: Apply the Sparse Optical Flow method to a raw frame.
- reset_params: Reset the internal parameters of the GMC object.
-
- Examples:
- Create a GMC object and apply it to a frame
- >>> gmc = GMC(method="sparseOptFlow", downscale=2)
- >>> frame = np.array([[1, 2, 3], [4, 5, 6]])
- >>> processed_frame = gmc.apply(frame)
- >>> print(processed_frame)
- array([[1, 2, 3],
- [4, 5, 6]])
- """
-
- def __init__(self, method: str = "sparseOptFlow", downscale: int = 2) -> None:
- """
- Initialize a Generalized Motion Compensation (GMC) object with tracking method and downscale factor.
-
- Args:
- method (str): The tracking method to use. Options include 'orb', 'sift', 'ecc', 'sparseOptFlow', 'none'.
- downscale (int): Downscale factor for processing frames.
-
- Examples:
- Initialize a GMC object with the 'sparseOptFlow' method and a downscale factor of 2
- >>> gmc = GMC(method="sparseOptFlow", downscale=2)
- """
- super().__init__()
-
- self.method = method
- self.downscale = max(1, downscale)
-
- if self.method == "orb":
- self.detector = cv2.FastFeatureDetector_create(20)
- self.extractor = cv2.ORB_create()
- self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
-
- elif self.method == "sift":
- self.detector = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
- self.extractor = cv2.SIFT_create(nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
- self.matcher = cv2.BFMatcher(cv2.NORM_L2)
-
- elif self.method == "ecc":
- number_of_iterations = 5000
- termination_eps = 1e-6
- self.warp_mode = cv2.MOTION_EUCLIDEAN
- self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, number_of_iterations, termination_eps)
-
- elif self.method == "sparseOptFlow":
- self.feature_params = dict(
- maxCorners=1000, qualityLevel=0.01, minDistance=1, blockSize=3, useHarrisDetector=False, k=0.04
- )
-
- elif self.method in {"none", "None", None}:
- self.method = None
- else:
- raise ValueError(f"Unknown GMC method: {method}")
-
- self.prevFrame = None
- self.prevKeyPoints = None
- self.prevDescriptors = None
- self.initializedFirstFrame = False
-
- def apply(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
- """
- Apply object detection on a raw frame using the specified method.
-
- Args:
- raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
- detections (list, optional): List of detections to be used in the processing.
-
- Returns:
- (np.ndarray): Transformation matrix with shape (2, 3).
-
- Examples:
- >>> gmc = GMC(method="sparseOptFlow")
- >>> raw_frame = np.random.rand(480, 640, 3)
- >>> transformation_matrix = gmc.apply(raw_frame)
- >>> print(transformation_matrix.shape)
- (2, 3)
- """
- if self.method in {"orb", "sift"}:
- return self.apply_features(raw_frame, detections)
- elif self.method == "ecc":
- return self.apply_ecc(raw_frame)
- elif self.method == "sparseOptFlow":
- return self.apply_sparseoptflow(raw_frame)
- else:
- return np.eye(2, 3)
-
- def apply_ecc(self, raw_frame: np.ndarray) -> np.ndarray:
- """
- Apply the ECC (Enhanced Correlation Coefficient) algorithm to a raw frame for motion compensation.
-
- Args:
- raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
-
- Returns:
- (np.ndarray): Transformation matrix with shape (2, 3).
-
- Examples:
- >>> gmc = GMC(method="ecc")
- >>> processed_frame = gmc.apply_ecc(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
- >>> print(processed_frame)
- [[1. 0. 0.]
- [0. 1. 0.]]
- """
- height, width, c = raw_frame.shape
- frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
- H = np.eye(2, 3, dtype=np.float32)
-
- # Downscale image for computational efficiency
- if self.downscale > 1.0:
- frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
- frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
-
- # Handle first frame initialization
- if not self.initializedFirstFrame:
- self.prevFrame = frame.copy()
- self.initializedFirstFrame = True
- return H
-
- # Run the ECC algorithm to find transformation matrix
- try:
- (_, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria, None, 1)
- except Exception as e:
- LOGGER.warning(f"find transform failed. Set warp as identity {e}")
-
- return H
-
- def apply_features(self, raw_frame: np.ndarray, detections: list | None = None) -> np.ndarray:
- """
- Apply feature-based methods like ORB or SIFT to a raw frame.
-
- Args:
- raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
- detections (list, optional): List of detections to be used in the processing.
-
- Returns:
- (np.ndarray): Transformation matrix with shape (2, 3).
-
- Examples:
- >>> gmc = GMC(method="orb")
- >>> raw_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
- >>> transformation_matrix = gmc.apply_features(raw_frame)
- >>> print(transformation_matrix.shape)
- (2, 3)
- """
- height, width, c = raw_frame.shape
- frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
- H = np.eye(2, 3)
-
- # Downscale image for computational efficiency
- if self.downscale > 1.0:
- frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
- width = width // self.downscale
- height = height // self.downscale
-
- # Create mask for keypoint detection, excluding border regions
- mask = np.zeros_like(frame)
- mask[int(0.02 * height) : int(0.98 * height), int(0.02 * width) : int(0.98 * width)] = 255
-
- # Exclude detection regions from mask to avoid tracking detected objects
- if detections is not None:
- for det in detections:
- tlbr = (det[:4] / self.downscale).astype(np.int_)
- mask[tlbr[1] : tlbr[3], tlbr[0] : tlbr[2]] = 0
-
- # Find keypoints and compute descriptors
- keypoints = self.detector.detect(frame, mask)
- keypoints, descriptors = self.extractor.compute(frame, keypoints)
-
- # Handle first frame initialization
- if not self.initializedFirstFrame:
- self.prevFrame = frame.copy()
- self.prevKeyPoints = copy.copy(keypoints)
- self.prevDescriptors = copy.copy(descriptors)
- self.initializedFirstFrame = True
- return H
-
- # Match descriptors between previous and current frame
- knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
-
- # Filter matches based on spatial distance constraints
- matches = []
- spatialDistances = []
- maxSpatialDistance = 0.25 * np.array([width, height])
-
- # Handle empty matches case
- if len(knnMatches) == 0:
- self.prevFrame = frame.copy()
- self.prevKeyPoints = copy.copy(keypoints)
- self.prevDescriptors = copy.copy(descriptors)
- return H
-
- # Apply Lowe's ratio test and spatial distance filtering
- for m, n in knnMatches:
- if m.distance < 0.9 * n.distance:
- prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
- currKeyPointLocation = keypoints[m.trainIdx].pt
-
- spatialDistance = (
- prevKeyPointLocation[0] - currKeyPointLocation[0],
- prevKeyPointLocation[1] - currKeyPointLocation[1],
- )
-
- if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and (
- np.abs(spatialDistance[1]) < maxSpatialDistance[1]
- ):
- spatialDistances.append(spatialDistance)
- matches.append(m)
-
- # Filter outliers using statistical analysis
- meanSpatialDistances = np.mean(spatialDistances, 0)
- stdSpatialDistances = np.std(spatialDistances, 0)
- inliers = (spatialDistances - meanSpatialDistances) < 2.5 * stdSpatialDistances
-
- # Extract good matches and corresponding points
- goodMatches = []
- prevPoints = []
- currPoints = []
- for i in range(len(matches)):
- if inliers[i, 0] and inliers[i, 1]:
- goodMatches.append(matches[i])
- prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
- currPoints.append(keypoints[matches[i].trainIdx].pt)
-
- prevPoints = np.array(prevPoints)
- currPoints = np.array(currPoints)
-
- # Estimate transformation matrix using RANSAC
- if prevPoints.shape[0] > 4:
- H, inliers = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
-
- # Scale translation components back to original resolution
- if self.downscale > 1.0:
- H[0, 2] *= self.downscale
- H[1, 2] *= self.downscale
- else:
- LOGGER.warning("not enough matching points")
-
- # Store current frame data for next iteration
- self.prevFrame = frame.copy()
- self.prevKeyPoints = copy.copy(keypoints)
- self.prevDescriptors = copy.copy(descriptors)
-
- return H
-
- def apply_sparseoptflow(self, raw_frame: np.ndarray) -> np.ndarray:
- """
- Apply Sparse Optical Flow method to a raw frame.
-
- Args:
- raw_frame (np.ndarray): The raw frame to be processed, with shape (H, W, C).
-
- Returns:
- (np.ndarray): Transformation matrix with shape (2, 3).
-
- Examples:
- >>> gmc = GMC()
- >>> result = gmc.apply_sparseoptflow(np.array([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]))
- >>> print(result)
- [[1. 0. 0.]
- [0. 1. 0.]]
- """
- height, width, c = raw_frame.shape
- frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY) if c == 3 else raw_frame
- H = np.eye(2, 3)
-
- # Downscale image for computational efficiency
- if self.downscale > 1.0:
- frame = cv2.resize(frame, (width // self.downscale, height // self.downscale))
-
- # Find good features to track
- keypoints = cv2.goodFeaturesToTrack(frame, mask=None, **self.feature_params)
-
- # Handle first frame initialization
- if not self.initializedFirstFrame or self.prevKeyPoints is None:
- self.prevFrame = frame.copy()
- self.prevKeyPoints = copy.copy(keypoints)
- self.initializedFirstFrame = True
- return H
-
- # Calculate optical flow using Lucas-Kanade method
- matchedKeypoints, status, _ = cv2.calcOpticalFlowPyrLK(self.prevFrame, frame, self.prevKeyPoints, None)
-
- # Extract successfully tracked points
- prevPoints = []
- currPoints = []
-
- for i in range(len(status)):
- if status[i]:
- prevPoints.append(self.prevKeyPoints[i])
- currPoints.append(matchedKeypoints[i])
-
- prevPoints = np.array(prevPoints)
- currPoints = np.array(currPoints)
-
- # Estimate transformation matrix using RANSAC
- if (prevPoints.shape[0] > 4) and (prevPoints.shape[0] == currPoints.shape[0]):
- H, _ = cv2.estimateAffinePartial2D(prevPoints, currPoints, cv2.RANSAC)
-
- # Scale translation components back to original resolution
- if self.downscale > 1.0:
- H[0, 2] *= self.downscale
- H[1, 2] *= self.downscale
- else:
- LOGGER.warning("not enough matching points")
-
- # Store current frame data for next iteration
- self.prevFrame = frame.copy()
- self.prevKeyPoints = copy.copy(keypoints)
-
- return H
-
- def reset_params(self) -> None:
- """Reset the internal parameters including previous frame, keypoints, and descriptors."""
- self.prevFrame = None
- self.prevKeyPoints = None
- self.prevDescriptors = None
- self.initializedFirstFrame = False
diff --git a/ultralytics/trackers/utils/kalman_filter.py b/ultralytics/trackers/utils/kalman_filter.py
deleted file mode 100644
index 82fd515..0000000
--- a/ultralytics/trackers/utils/kalman_filter.py
+++ /dev/null
@@ -1,493 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import numpy as np
-import scipy.linalg
-
-
-class KalmanFilterXYAH:
- """
- A KalmanFilterXYAH class for tracking bounding boxes in image space using a Kalman filter.
-
- Implements a simple Kalman filter for tracking bounding boxes in image space. The 8-dimensional state space
- (x, y, a, h, vx, vy, va, vh) contains the bounding box center position (x, y), aspect ratio a, height h, and their
- respective velocities. Object motion follows a constant velocity model, and bounding box location (x, y, a, h) is
- taken as a direct observation of the state space (linear observation model).
-
- Attributes:
- _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
- _update_mat (np.ndarray): The update matrix for the Kalman filter.
- _std_weight_position (float): Standard deviation weight for position.
- _std_weight_velocity (float): Standard deviation weight for velocity.
-
- Methods:
- initiate: Create a track from an unassociated measurement.
- predict: Run the Kalman filter prediction step.
- project: Project the state distribution to measurement space.
- multi_predict: Run the Kalman filter prediction step (vectorized version).
- update: Run the Kalman filter correction step.
- gating_distance: Compute the gating distance between state distribution and measurements.
-
- Examples:
- Initialize the Kalman filter and create a track from a measurement
- >>> kf = KalmanFilterXYAH()
- >>> measurement = np.array([100, 200, 1.5, 50])
- >>> mean, covariance = kf.initiate(measurement)
- >>> print(mean)
- >>> print(covariance)
- """
-
- def __init__(self):
- """
- Initialize Kalman filter model matrices with motion and observation uncertainty weights.
-
- The Kalman filter is initialized with an 8-dimensional state space (x, y, a, h, vx, vy, va, vh), where (x, y)
- represents the bounding box center position, 'a' is the aspect ratio, 'h' is the height, and their respective
- velocities are (vx, vy, va, vh). The filter uses a constant velocity model for object motion and a linear
- observation model for bounding box location.
-
- Examples:
- Initialize a Kalman filter for tracking:
- >>> kf = KalmanFilterXYAH()
- """
- ndim, dt = 4, 1.0
-
- # Create Kalman filter model matrices
- self._motion_mat = np.eye(2 * ndim, 2 * ndim)
- for i in range(ndim):
- self._motion_mat[i, ndim + i] = dt
- self._update_mat = np.eye(ndim, 2 * ndim)
-
- # Motion and observation uncertainty are chosen relative to the current state estimate
- self._std_weight_position = 1.0 / 20
- self._std_weight_velocity = 1.0 / 160
-
- def initiate(self, measurement: np.ndarray):
- """
- Create a track from an unassociated measurement.
-
- Args:
- measurement (np.ndarray): Bounding box coordinates (x, y, a, h) with center position (x, y), aspect ratio a,
- and height h.
-
- Returns:
- mean (np.ndarray): Mean vector (8-dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
- covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
-
- Examples:
- >>> kf = KalmanFilterXYAH()
- >>> measurement = np.array([100, 50, 1.5, 200])
- >>> mean, covariance = kf.initiate(measurement)
- """
- mean_pos = measurement
- mean_vel = np.zeros_like(mean_pos)
- mean = np.r_[mean_pos, mean_vel]
-
- std = [
- 2 * self._std_weight_position * measurement[3],
- 2 * self._std_weight_position * measurement[3],
- 1e-2,
- 2 * self._std_weight_position * measurement[3],
- 10 * self._std_weight_velocity * measurement[3],
- 10 * self._std_weight_velocity * measurement[3],
- 1e-5,
- 10 * self._std_weight_velocity * measurement[3],
- ]
- covariance = np.diag(np.square(std))
- return mean, covariance
-
- def predict(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Run Kalman filter prediction step.
-
- Args:
- mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
-
- Returns:
- mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
- covariance (np.ndarray): Covariance matrix of the predicted state.
-
- Examples:
- >>> kf = KalmanFilterXYAH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
- """
- std_pos = [
- self._std_weight_position * mean[3],
- self._std_weight_position * mean[3],
- 1e-2,
- self._std_weight_position * mean[3],
- ]
- std_vel = [
- self._std_weight_velocity * mean[3],
- self._std_weight_velocity * mean[3],
- 1e-5,
- self._std_weight_velocity * mean[3],
- ]
- motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
-
- mean = np.dot(mean, self._motion_mat.T)
- covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
-
- return mean, covariance
-
- def project(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Project state distribution to measurement space.
-
- Args:
- mean (np.ndarray): The state's mean vector (8 dimensional array).
- covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
-
- Returns:
- mean (np.ndarray): Projected mean of the given state estimate.
- covariance (np.ndarray): Projected covariance matrix of the given state estimate.
-
- Examples:
- >>> kf = KalmanFilterXYAH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> projected_mean, projected_covariance = kf.project(mean, covariance)
- """
- std = [
- self._std_weight_position * mean[3],
- self._std_weight_position * mean[3],
- 1e-1,
- self._std_weight_position * mean[3],
- ]
- innovation_cov = np.diag(np.square(std))
-
- mean = np.dot(self._update_mat, mean)
- covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
- return mean, covariance + innovation_cov
-
- def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Run Kalman filter prediction step for multiple object states (Vectorized version).
-
- Args:
- mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
- covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
-
- Returns:
- mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
- covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
-
- Examples:
- >>> mean = np.random.rand(10, 8) # 10 object states
- >>> covariance = np.random.rand(10, 8, 8) # Covariance matrices for 10 object states
- >>> predicted_mean, predicted_covariance = kalman_filter.multi_predict(mean, covariance)
- """
- std_pos = [
- self._std_weight_position * mean[:, 3],
- self._std_weight_position * mean[:, 3],
- 1e-2 * np.ones_like(mean[:, 3]),
- self._std_weight_position * mean[:, 3],
- ]
- std_vel = [
- self._std_weight_velocity * mean[:, 3],
- self._std_weight_velocity * mean[:, 3],
- 1e-5 * np.ones_like(mean[:, 3]),
- self._std_weight_velocity * mean[:, 3],
- ]
- sqr = np.square(np.r_[std_pos, std_vel]).T
-
- motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
- motion_cov = np.asarray(motion_cov)
-
- mean = np.dot(mean, self._motion_mat.T)
- left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
- covariance = np.dot(left, self._motion_mat.T) + motion_cov
-
- return mean, covariance
-
- def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
- """
- Run Kalman filter correction step.
-
- Args:
- mean (np.ndarray): The predicted state's mean vector (8 dimensional).
- covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
- measurement (np.ndarray): The 4 dimensional measurement vector (x, y, a, h), where (x, y) is the center
- position, a the aspect ratio, and h the height of the bounding box.
-
- Returns:
- new_mean (np.ndarray): Measurement-corrected state mean.
- new_covariance (np.ndarray): Measurement-corrected state covariance.
-
- Examples:
- >>> kf = KalmanFilterXYAH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> measurement = np.array([1, 1, 1, 1])
- >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
- """
- projected_mean, projected_cov = self.project(mean, covariance)
-
- chol_factor, lower = scipy.linalg.cho_factor(projected_cov, lower=True, check_finite=False)
- kalman_gain = scipy.linalg.cho_solve(
- (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, check_finite=False
- ).T
- innovation = measurement - projected_mean
-
- new_mean = mean + np.dot(innovation, kalman_gain.T)
- new_covariance = covariance - np.linalg.multi_dot((kalman_gain, projected_cov, kalman_gain.T))
- return new_mean, new_covariance
-
- def gating_distance(
- self,
- mean: np.ndarray,
- covariance: np.ndarray,
- measurements: np.ndarray,
- only_position: bool = False,
- metric: str = "maha",
- ) -> np.ndarray:
- """
- Compute gating distance between state distribution and measurements.
-
- A suitable distance threshold can be obtained from `chi2inv95`. If `only_position` is False, the chi-square
- distribution has 4 degrees of freedom, otherwise 2.
-
- Args:
- mean (np.ndarray): Mean vector over the state distribution (8 dimensional).
- covariance (np.ndarray): Covariance of the state distribution (8x8 dimensional).
- measurements (np.ndarray): An (N, 4) matrix of N measurements, each in format (x, y, a, h) where (x, y) is the
- bounding box center position, a the aspect ratio, and h the height.
- only_position (bool, optional): If True, distance computation is done with respect to box center position only.
- metric (str, optional): The metric to use for calculating the distance. Options are 'gaussian' for the squared
- Euclidean distance and 'maha' for the squared Mahalanobis distance.
-
- Returns:
- (np.ndarray): Returns an array of length N, where the i-th element contains the squared distance between
- (mean, covariance) and `measurements[i]`.
-
- Examples:
- Compute gating distance using Mahalanobis metric:
- >>> kf = KalmanFilterXYAH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> measurements = np.array([[1, 1, 1, 1], [2, 2, 1, 1]])
- >>> distances = kf.gating_distance(mean, covariance, measurements, only_position=False, metric="maha")
- """
- mean, covariance = self.project(mean, covariance)
- if only_position:
- mean, covariance = mean[:2], covariance[:2, :2]
- measurements = measurements[:, :2]
-
- d = measurements - mean
- if metric == "gaussian":
- return np.sum(d * d, axis=1)
- elif metric == "maha":
- cholesky_factor = np.linalg.cholesky(covariance)
- z = scipy.linalg.solve_triangular(cholesky_factor, d.T, lower=True, check_finite=False, overwrite_b=True)
- return np.sum(z * z, axis=0) # square maha
- else:
- raise ValueError("Invalid distance metric")
-
-
-class KalmanFilterXYWH(KalmanFilterXYAH):
- """
- A KalmanFilterXYWH class for tracking bounding boxes in image space using a Kalman filter.
-
- Implements a Kalman filter for tracking bounding boxes with state space (x, y, w, h, vx, vy, vw, vh), where
- (x, y) is the center position, w is the width, h is the height, and vx, vy, vw, vh are their respective velocities.
- The object motion follows a constant velocity model, and the bounding box location (x, y, w, h) is taken as a direct
- observation of the state space (linear observation model).
-
- Attributes:
- _motion_mat (np.ndarray): The motion matrix for the Kalman filter.
- _update_mat (np.ndarray): The update matrix for the Kalman filter.
- _std_weight_position (float): Standard deviation weight for position.
- _std_weight_velocity (float): Standard deviation weight for velocity.
-
- Methods:
- initiate: Create a track from an unassociated measurement.
- predict: Run the Kalman filter prediction step.
- project: Project the state distribution to measurement space.
- multi_predict: Run the Kalman filter prediction step in a vectorized manner.
- update: Run the Kalman filter correction step.
-
- Examples:
- Create a Kalman filter and initialize a track
- >>> kf = KalmanFilterXYWH()
- >>> measurement = np.array([100, 50, 20, 40])
- >>> mean, covariance = kf.initiate(measurement)
- >>> print(mean)
- >>> print(covariance)
- """
-
- def initiate(self, measurement: np.ndarray):
- """
- Create track from unassociated measurement.
-
- Args:
- measurement (np.ndarray): Bounding box coordinates (x, y, w, h) with center position (x, y), width, and height.
-
- Returns:
- mean (np.ndarray): Mean vector (8 dimensional) of the new track. Unobserved velocities are initialized to 0 mean.
- covariance (np.ndarray): Covariance matrix (8x8 dimensional) of the new track.
-
- Examples:
- >>> kf = KalmanFilterXYWH()
- >>> measurement = np.array([100, 50, 20, 40])
- >>> mean, covariance = kf.initiate(measurement)
- >>> print(mean)
- [100. 50. 20. 40. 0. 0. 0. 0.]
- >>> print(covariance)
- [[ 4. 0. 0. 0. 0. 0. 0. 0.]
- [ 0. 4. 0. 0. 0. 0. 0. 0.]
- [ 0. 0. 4. 0. 0. 0. 0. 0.]
- [ 0. 0. 0. 4. 0. 0. 0. 0.]
- [ 0. 0. 0. 0. 0.25 0. 0. 0.]
- [ 0. 0. 0. 0. 0. 0.25 0. 0.]
- [ 0. 0. 0. 0. 0. 0. 0.25 0.]
- [ 0. 0. 0. 0. 0. 0. 0. 0.25]]
- """
- mean_pos = measurement
- mean_vel = np.zeros_like(mean_pos)
- mean = np.r_[mean_pos, mean_vel]
-
- std = [
- 2 * self._std_weight_position * measurement[2],
- 2 * self._std_weight_position * measurement[3],
- 2 * self._std_weight_position * measurement[2],
- 2 * self._std_weight_position * measurement[3],
- 10 * self._std_weight_velocity * measurement[2],
- 10 * self._std_weight_velocity * measurement[3],
- 10 * self._std_weight_velocity * measurement[2],
- 10 * self._std_weight_velocity * measurement[3],
- ]
- covariance = np.diag(np.square(std))
- return mean, covariance
-
- def predict(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Run Kalman filter prediction step.
-
- Args:
- mean (np.ndarray): The 8-dimensional mean vector of the object state at the previous time step.
- covariance (np.ndarray): The 8x8-dimensional covariance matrix of the object state at the previous time step.
-
- Returns:
- mean (np.ndarray): Mean vector of the predicted state. Unobserved velocities are initialized to 0 mean.
- covariance (np.ndarray): Covariance matrix of the predicted state.
-
- Examples:
- >>> kf = KalmanFilterXYWH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> predicted_mean, predicted_covariance = kf.predict(mean, covariance)
- """
- std_pos = [
- self._std_weight_position * mean[2],
- self._std_weight_position * mean[3],
- self._std_weight_position * mean[2],
- self._std_weight_position * mean[3],
- ]
- std_vel = [
- self._std_weight_velocity * mean[2],
- self._std_weight_velocity * mean[3],
- self._std_weight_velocity * mean[2],
- self._std_weight_velocity * mean[3],
- ]
- motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
-
- mean = np.dot(mean, self._motion_mat.T)
- covariance = np.linalg.multi_dot((self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
-
- return mean, covariance
-
- def project(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Project state distribution to measurement space.
-
- Args:
- mean (np.ndarray): The state's mean vector (8 dimensional array).
- covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
-
- Returns:
- mean (np.ndarray): Projected mean of the given state estimate.
- covariance (np.ndarray): Projected covariance matrix of the given state estimate.
-
- Examples:
- >>> kf = KalmanFilterXYWH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> projected_mean, projected_cov = kf.project(mean, covariance)
- """
- std = [
- self._std_weight_position * mean[2],
- self._std_weight_position * mean[3],
- self._std_weight_position * mean[2],
- self._std_weight_position * mean[3],
- ]
- innovation_cov = np.diag(np.square(std))
-
- mean = np.dot(self._update_mat, mean)
- covariance = np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))
- return mean, covariance + innovation_cov
-
- def multi_predict(self, mean: np.ndarray, covariance: np.ndarray):
- """
- Run Kalman filter prediction step (Vectorized version).
-
- Args:
- mean (np.ndarray): The Nx8 dimensional mean matrix of the object states at the previous time step.
- covariance (np.ndarray): The Nx8x8 covariance matrix of the object states at the previous time step.
-
- Returns:
- mean (np.ndarray): Mean matrix of the predicted states with shape (N, 8).
- covariance (np.ndarray): Covariance matrix of the predicted states with shape (N, 8, 8).
-
- Examples:
- >>> mean = np.random.rand(5, 8) # 5 objects with 8-dimensional state vectors
- >>> covariance = np.random.rand(5, 8, 8) # 5 objects with 8x8 covariance matrices
- >>> kf = KalmanFilterXYWH()
- >>> predicted_mean, predicted_covariance = kf.multi_predict(mean, covariance)
- """
- std_pos = [
- self._std_weight_position * mean[:, 2],
- self._std_weight_position * mean[:, 3],
- self._std_weight_position * mean[:, 2],
- self._std_weight_position * mean[:, 3],
- ]
- std_vel = [
- self._std_weight_velocity * mean[:, 2],
- self._std_weight_velocity * mean[:, 3],
- self._std_weight_velocity * mean[:, 2],
- self._std_weight_velocity * mean[:, 3],
- ]
- sqr = np.square(np.r_[std_pos, std_vel]).T
-
- motion_cov = [np.diag(sqr[i]) for i in range(len(mean))]
- motion_cov = np.asarray(motion_cov)
-
- mean = np.dot(mean, self._motion_mat.T)
- left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
- covariance = np.dot(left, self._motion_mat.T) + motion_cov
-
- return mean, covariance
-
- def update(self, mean: np.ndarray, covariance: np.ndarray, measurement: np.ndarray):
- """
- Run Kalman filter correction step.
-
- Args:
- mean (np.ndarray): The predicted state's mean vector (8 dimensional).
- covariance (np.ndarray): The state's covariance matrix (8x8 dimensional).
- measurement (np.ndarray): The 4 dimensional measurement vector (x, y, w, h), where (x, y) is the center
- position, w the width, and h the height of the bounding box.
-
- Returns:
- new_mean (np.ndarray): Measurement-corrected state mean.
- new_covariance (np.ndarray): Measurement-corrected state covariance.
-
- Examples:
- >>> kf = KalmanFilterXYWH()
- >>> mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
- >>> covariance = np.eye(8)
- >>> measurement = np.array([0.5, 0.5, 1.2, 1.2])
- >>> new_mean, new_covariance = kf.update(mean, covariance, measurement)
- """
- return super().update(mean, covariance, measurement)
diff --git a/ultralytics/trackers/utils/matching.py b/ultralytics/trackers/utils/matching.py
deleted file mode 100644
index e85a78c..0000000
--- a/ultralytics/trackers/utils/matching.py
+++ /dev/null
@@ -1,157 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import numpy as np
-import scipy
-from scipy.spatial.distance import cdist
-
-from ultralytics.utils.metrics import batch_probiou, bbox_ioa
-
-try:
- import lap # for linear_assignment
-
- assert lap.__version__ # verify package is not directory
-except (ImportError, AssertionError, AttributeError):
- from ultralytics.utils.checks import check_requirements
-
- check_requirements("lap>=0.5.12") # https://github.com/gatagat/lap
- import lap
-
-
-def linear_assignment(cost_matrix: np.ndarray, thresh: float, use_lap: bool = True):
- """
- Perform linear assignment using either the scipy or lap.lapjv method.
-
- Args:
- cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
- thresh (float): Threshold for considering an assignment valid.
- use_lap (bool): Use lap.lapjv for the assignment. If False, scipy.optimize.linear_sum_assignment is used.
-
- Returns:
- matched_indices (np.ndarray): Array of matched indices of shape (K, 2), where K is the number of matches.
- unmatched_a (np.ndarray): Array of unmatched indices from the first set, with shape (L,).
- unmatched_b (np.ndarray): Array of unmatched indices from the second set, with shape (M,).
-
- Examples:
- >>> cost_matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
- >>> thresh = 5.0
- >>> matched_indices, unmatched_a, unmatched_b = linear_assignment(cost_matrix, thresh, use_lap=True)
- """
- if cost_matrix.size == 0:
- return np.empty((0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(range(cost_matrix.shape[1]))
-
- if use_lap:
- # Use lap.lapjv
- # https://github.com/gatagat/lap
- _, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
- matches = [[ix, mx] for ix, mx in enumerate(x) if mx >= 0]
- unmatched_a = np.where(x < 0)[0]
- unmatched_b = np.where(y < 0)[0]
- else:
- # Use scipy.optimize.linear_sum_assignment
- # https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
- x, y = scipy.optimize.linear_sum_assignment(cost_matrix) # row x, col y
- matches = np.asarray([[x[i], y[i]] for i in range(len(x)) if cost_matrix[x[i], y[i]] <= thresh])
- if len(matches) == 0:
- unmatched_a = list(np.arange(cost_matrix.shape[0]))
- unmatched_b = list(np.arange(cost_matrix.shape[1]))
- else:
- unmatched_a = list(frozenset(np.arange(cost_matrix.shape[0])) - frozenset(matches[:, 0]))
- unmatched_b = list(frozenset(np.arange(cost_matrix.shape[1])) - frozenset(matches[:, 1]))
-
- return matches, unmatched_a, unmatched_b
-
-
-def iou_distance(atracks: list, btracks: list) -> np.ndarray:
- """
- Compute cost based on Intersection over Union (IoU) between tracks.
-
- Args:
- atracks (list[STrack] | list[np.ndarray]): List of tracks 'a' or bounding boxes.
- btracks (list[STrack] | list[np.ndarray]): List of tracks 'b' or bounding boxes.
-
- Returns:
- (np.ndarray): Cost matrix computed based on IoU with shape (len(atracks), len(btracks)).
-
- Examples:
- Compute IoU distance between two sets of tracks
- >>> atracks = [np.array([0, 0, 10, 10]), np.array([20, 20, 30, 30])]
- >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
- >>> cost_matrix = iou_distance(atracks, btracks)
- """
- if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
- atlbrs = atracks
- btlbrs = btracks
- else:
- atlbrs = [track.xywha if track.angle is not None else track.xyxy for track in atracks]
- btlbrs = [track.xywha if track.angle is not None else track.xyxy for track in btracks]
-
- ious = np.zeros((len(atlbrs), len(btlbrs)), dtype=np.float32)
- if len(atlbrs) and len(btlbrs):
- if len(atlbrs[0]) == 5 and len(btlbrs[0]) == 5:
- ious = batch_probiou(
- np.ascontiguousarray(atlbrs, dtype=np.float32),
- np.ascontiguousarray(btlbrs, dtype=np.float32),
- ).numpy()
- else:
- ious = bbox_ioa(
- np.ascontiguousarray(atlbrs, dtype=np.float32),
- np.ascontiguousarray(btlbrs, dtype=np.float32),
- iou=True,
- )
- return 1 - ious # cost matrix
-
-
-def embedding_distance(tracks: list, detections: list, metric: str = "cosine") -> np.ndarray:
- """
- Compute distance between tracks and detections based on embeddings.
-
- Args:
- tracks (list[STrack]): List of tracks, where each track contains embedding features.
- detections (list[BaseTrack]): List of detections, where each detection contains embedding features.
- metric (str): Metric for distance computation. Supported metrics include 'cosine', 'euclidean', etc.
-
- Returns:
- (np.ndarray): Cost matrix computed based on embeddings with shape (N, M), where N is the number of tracks
- and M is the number of detections.
-
- Examples:
- Compute the embedding distance between tracks and detections using cosine metric
- >>> tracks = [STrack(...), STrack(...)] # List of track objects with embedding features
- >>> detections = [BaseTrack(...), BaseTrack(...)] # List of detection objects with embedding features
- >>> cost_matrix = embedding_distance(tracks, detections, metric="cosine")
- """
- cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
- if cost_matrix.size == 0:
- return cost_matrix
- det_features = np.asarray([track.curr_feat for track in detections], dtype=np.float32)
- # for i, track in enumerate(tracks):
- # cost_matrix[i, :] = np.maximum(0.0, cdist(track.smooth_feat.reshape(1,-1), det_features, metric))
- track_features = np.asarray([track.smooth_feat for track in tracks], dtype=np.float32)
- cost_matrix = np.maximum(0.0, cdist(track_features, det_features, metric)) # Normalized features
- return cost_matrix
-
-
-def fuse_score(cost_matrix: np.ndarray, detections: list) -> np.ndarray:
- """
- Fuse cost matrix with detection scores to produce a single similarity matrix.
-
- Args:
- cost_matrix (np.ndarray): The matrix containing cost values for assignments, with shape (N, M).
- detections (list[BaseTrack]): List of detections, each containing a score attribute.
-
- Returns:
- (np.ndarray): Fused similarity matrix with shape (N, M).
-
- Examples:
- Fuse a cost matrix with detection scores
- >>> cost_matrix = np.random.rand(5, 10) # 5 tracks and 10 detections
- >>> detections = [BaseTrack(score=np.random.rand()) for _ in range(10)]
- >>> fused_matrix = fuse_score(cost_matrix, detections)
- """
- if cost_matrix.size == 0:
- return cost_matrix
- iou_sim = 1 - cost_matrix
- det_scores = np.array([det.score for det in detections])
- det_scores = np.expand_dims(det_scores, axis=0).repeat(cost_matrix.shape[0], axis=0)
- fuse_sim = iou_sim * det_scores
- return 1 - fuse_sim # fuse_cost
diff --git a/ultralytics/utils/__init__.py b/ultralytics/utils/__init__.py
deleted file mode 100644
index f97d5b9..0000000
--- a/ultralytics/utils/__init__.py
+++ /dev/null
@@ -1,1450 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import contextlib
-import importlib.metadata
-import inspect
-import json
-import logging
-import os
-import platform
-import re
-import socket
-import sys
-import threading
-import time
-from functools import lru_cache
-from pathlib import Path
-from threading import Lock
-from types import SimpleNamespace
-from urllib.parse import unquote
-
-import cv2
-import numpy as np
-import torch
-
-from ultralytics import __version__
-from ultralytics.utils.git import GitRepo
-from ultralytics.utils.patches import imread, imshow, imwrite, torch_save # for patches
-from ultralytics.utils.tqdm import TQDM # noqa
-
-# PyTorch Multi-GPU DDP Constants
-RANK = int(os.getenv("RANK", -1))
-LOCAL_RANK = int(os.getenv("LOCAL_RANK", -1)) # https://pytorch.org/docs/stable/elastic/run.html
-
-# Other Constants
-ARGV = sys.argv or ["", ""] # sometimes sys.argv = []
-FILE = Path(__file__).resolve()
-ROOT = FILE.parents[1] # YOLO
-ASSETS = ROOT / "assets" # default images
-ASSETS_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0" # assets GitHub URL
-DEFAULT_CFG_PATH = ROOT / "cfg/default.yaml"
-NUM_THREADS = min(8, max(1, os.cpu_count() - 1)) # number of YOLO multiprocessing threads
-AUTOINSTALL = str(os.getenv("YOLO_AUTOINSTALL", True)).lower() == "true" # global auto-install mode
-VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbose mode
-LOGGING_NAME = "ultralytics"
-MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
-MACOS_VERSION = platform.mac_ver()[0] if MACOS else None
-NOT_MACOS14 = not (MACOS and MACOS_VERSION.startswith("14."))
-ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
-PYTHON_VERSION = platform.python_version()
-TORCH_VERSION = str(torch.__version__) # Normalize torch.__version__ (PyTorch>1.9 returns TorchVersion objects)
-TORCHVISION_VERSION = importlib.metadata.version("torchvision") # faster than importing torchvision
-IS_VSCODE = os.environ.get("TERM_PROGRAM", False) == "vscode"
-RKNN_CHIPS = frozenset(
- {
- "rk3588",
- "rk3576",
- "rk3566",
- "rk3568",
- "rk3562",
- "rv1103",
- "rv1106",
- "rv1103b",
- "rv1106b",
- "rk2118",
- }
-) # Rockchip processors available for export
-HELP_MSG = """
- Examples for running Ultralytics:
-
- 1. Install the ultralytics package:
-
- pip install ultralytics
-
- 2. Use the Python SDK:
-
- from ultralytics import YOLO
-
- # Load a model
- model = YOLO("yolo11n.yaml") # build a new model from scratch
- model = YOLO("yolo11n.pt") # load a pretrained model (recommended for training)
-
- # Use the model
- results = model.train(data="coco8.yaml", epochs=3) # train the model
- results = model.val() # evaluate model performance on the validation set
- results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
- success = model.export(format="onnx") # export the model to ONNX format
-
- 3. Use the command line interface (CLI):
-
- Ultralytics 'yolo' CLI commands use the following syntax:
-
- yolo TASK MODE ARGS
-
- Where TASK (optional) is one of [detect, segment, classify, pose, obb]
- MODE (required) is one of [train, val, predict, export, track, benchmark]
- ARGS (optional) are any number of custom "arg=value" pairs like "imgsz=320" that override defaults.
- See all ARGS at https://docs.ultralytics.com/usage/cfg or with "yolo cfg"
-
- - Train a detection model for 10 epochs with an initial learning_rate of 0.01
- yolo detect train data=coco8.yaml model=yolo11n.pt epochs=10 lr0=0.01
-
- - Predict a YouTube video using a pretrained segmentation model at image size 320:
- yolo segment predict model=yolo11n-seg.pt source='https://youtu.be/LNwODJXcvt4' imgsz=320
-
- - Val a pretrained detection model at batch-size 1 and image size 640:
- yolo detect val model=yolo11n.pt data=coco8.yaml batch=1 imgsz=640
-
- - Export a YOLO11n classification model to ONNX format at image size 224 by 128 (no TASK required)
- yolo export model=yolo11n-cls.pt format=onnx imgsz=224,128
-
- - Run special commands:
- yolo help
- yolo checks
- yolo version
- yolo settings
- yolo copy-cfg
- yolo cfg
-
- Docs: https://docs.ultralytics.com
- Community: https://community.ultralytics.com
- GitHub: https://github.com/ultralytics/ultralytics
- """
-
-# Settings and Environment Variables
-torch.set_printoptions(linewidth=320, precision=4, profile="default")
-np.set_printoptions(linewidth=320, formatter=dict(float_kind="{:11.5g}".format)) # format short g, %precision=5
-cv2.setNumThreads(0) # prevent OpenCV from multithreading (incompatible with PyTorch DataLoader)
-os.environ["NUMEXPR_MAX_THREADS"] = str(NUM_THREADS) # NumExpr max threads
-os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # suppress verbose TF compiler warnings in Colab
-os.environ["TORCH_CPP_LOG_LEVEL"] = "ERROR" # suppress "NNPACK.cpp could not initialize NNPACK" warnings
-os.environ["KINETO_LOG_LEVEL"] = "5" # suppress verbose PyTorch profiler output when computing FLOPs
-
-# Precompiled type tuples for faster isinstance() checks
-FLOAT_OR_INT = (float, int)
-STR_OR_PATH = (str, Path)
-
-
-class DataExportMixin:
- """
- Mixin class for exporting validation metrics or prediction results in various formats.
-
- This class provides utilities to export performance metrics (e.g., mAP, precision, recall) or prediction results
- from classification, object detection, segmentation, or pose estimation tasks into various formats: Polars
- DataFrame, CSV and JSON.
-
- Methods:
- to_df: Convert summary to a Polars DataFrame.
- to_csv: Export results as a CSV string.
- to_json: Export results as a JSON string.
- tojson: Deprecated alias for `to_json()`.
-
- Examples:
- >>> model = YOLO("yolo11n.pt")
- >>> results = model("image.jpg")
- >>> df = results.to_df()
- >>> print(df)
- >>> csv_data = results.to_csv()
- """
-
- def to_df(self, normalize=False, decimals=5):
- """
- Create a polars DataFrame from the prediction results summary or validation metrics.
-
- Args:
- normalize (bool, optional): Normalize numerical values for easier comparison.
- decimals (int, optional): Decimal places to round floats.
-
- Returns:
- (DataFrame): DataFrame containing the summary data.
- """
- import polars as pl # scope for faster 'import ultralytics'
-
- return pl.DataFrame(self.summary(normalize=normalize, decimals=decimals))
-
- def to_csv(self, normalize=False, decimals=5):
- """
- Export results or metrics to CSV string format.
-
- Args:
- normalize (bool, optional): Normalize numeric values.
- decimals (int, optional): Decimal precision.
-
- Returns:
- (str): CSV content as string.
- """
- import polars as pl
-
- df = self.to_df(normalize=normalize, decimals=decimals)
-
- try:
- return df.write_csv()
- except Exception:
- # Minimal string conversion for any remaining complex types
- def _to_str_simple(v):
- if v is None:
- return ""
- elif isinstance(v, (dict, list, tuple, set)):
- return repr(v)
- else:
- return str(v)
-
- df_str = df.select(
- [pl.col(c).map_elements(_to_str_simple, return_dtype=pl.String).alias(c) for c in df.columns]
- )
- return df_str.write_csv()
-
- def to_json(self, normalize=False, decimals=5):
- """
- Export results to JSON format.
-
- Args:
- normalize (bool, optional): Normalize numeric values.
- decimals (int, optional): Decimal precision.
-
- Returns:
- (str): JSON-formatted string of the results.
- """
- return self.to_df(normalize=normalize, decimals=decimals).write_json()
-
-
-class SimpleClass:
- """
- A simple base class for creating objects with string representations of their attributes.
-
- This class provides a foundation for creating objects that can be easily printed or represented as strings,
- showing all their non-callable attributes. It's useful for debugging and introspection of object states.
-
- Methods:
- __str__: Return a human-readable string representation of the object.
- __repr__: Return a machine-readable string representation of the object.
- __getattr__: Provide a custom attribute access error message with helpful information.
-
- Examples:
- >>> class MyClass(SimpleClass):
- ... def __init__(self):
- ... self.x = 10
- ... self.y = "hello"
- >>> obj = MyClass()
- >>> print(obj)
- __main__.MyClass object with attributes:
-
- x: 10
- y: 'hello'
-
- Notes:
- - This class is designed to be subclassed. It provides a convenient way to inspect object attributes.
- - The string representation includes the module and class name of the object.
- - Callable attributes and attributes starting with an underscore are excluded from the string representation.
- """
-
- def __str__(self):
- """Return a human-readable string representation of the object."""
- attr = []
- for a in dir(self):
- v = getattr(self, a)
- if not callable(v) and not a.startswith("_"):
- if isinstance(v, SimpleClass):
- # Display only the module and class name for subclasses
- s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
- else:
- s = f"{a}: {repr(v)}"
- attr.append(s)
- return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
-
- def __repr__(self):
- """Return a machine-readable string representation of the object."""
- return self.__str__()
-
- def __getattr__(self, attr):
- """Provide a custom attribute access error message with helpful information."""
- name = self.__class__.__name__
- raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}")
-
-
-class IterableSimpleNamespace(SimpleNamespace):
- """
- An iterable SimpleNamespace class that provides enhanced functionality for attribute access and iteration.
-
- This class extends the SimpleNamespace class with additional methods for iteration, string representation,
- and attribute access. It is designed to be used as a convenient container for storing and accessing
- configuration parameters.
-
- Methods:
- __iter__: Return an iterator of key-value pairs from the namespace's attributes.
- __str__: Return a human-readable string representation of the object.
- __getattr__: Provide a custom attribute access error message with helpful information.
- get: Retrieve the value of a specified key, or a default value if the key doesn't exist.
-
- Examples:
- >>> cfg = IterableSimpleNamespace(a=1, b=2, c=3)
- >>> for k, v in cfg:
- ... print(f"{k}: {v}")
- a: 1
- b: 2
- c: 3
- >>> print(cfg)
- a=1
- b=2
- c=3
- >>> cfg.get("b")
- 2
- >>> cfg.get("d", "default")
- 'default'
-
- Notes:
- This class is particularly useful for storing configuration parameters in a more accessible
- and iterable format compared to a standard dictionary.
- """
-
- def __iter__(self):
- """Return an iterator of key-value pairs from the namespace's attributes."""
- return iter(vars(self).items())
-
- def __str__(self):
- """Return a human-readable string representation of the object."""
- return "\n".join(f"{k}={v}" for k, v in vars(self).items())
-
- def __getattr__(self, attr):
- """Provide a custom attribute access error message with helpful information."""
- name = self.__class__.__name__
- raise AttributeError(
- f"""
- '{name}' object has no attribute '{attr}'. This may be caused by a modified or out of date ultralytics
- 'default.yaml' file.\nPlease update your code with 'pip install -U ultralytics' and if necessary replace
- {DEFAULT_CFG_PATH} with the latest version from
- https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/default.yaml
- """
- )
-
- def get(self, key, default=None):
- """Return the value of the specified key if it exists; otherwise, return the default value."""
- return getattr(self, key, default)
-
-
-def plt_settings(rcparams=None, backend="Agg"):
- """
- Decorator to temporarily set rc parameters and the backend for a plotting function.
-
- Args:
- rcparams (dict, optional): Dictionary of rc parameters to set.
- backend (str, optional): Name of the backend to use.
-
- Returns:
- (Callable): Decorated function with temporarily set rc parameters and backend.
-
- Examples:
- >>> @plt_settings({"font.size": 12})
- >>> def plot_function():
- ... plt.figure()
- ... plt.plot([1, 2, 3])
- ... plt.show()
-
- >>> with plt_settings({"font.size": 12}):
- ... plt.figure()
- ... plt.plot([1, 2, 3])
- ... plt.show()
- """
- if rcparams is None:
- rcparams = {"font.size": 11}
-
- def decorator(func):
- """Decorator to apply temporary rc parameters and backend to a function."""
-
- def wrapper(*args, **kwargs):
- """Set rc parameters and backend, call the original function, and restore the settings."""
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- original_backend = plt.get_backend()
- switch = backend.lower() != original_backend.lower()
- if switch:
- plt.close("all") # auto-close()ing of figures upon backend switching is deprecated since 3.8
- plt.switch_backend(backend)
-
- # Plot with backend and always revert to original backend
- try:
- with plt.rc_context(rcparams):
- result = func(*args, **kwargs)
- finally:
- if switch:
- plt.close("all")
- plt.switch_backend(original_backend)
- return result
-
- return wrapper
-
- return decorator
-
-
-def set_logging(name="LOGGING_NAME", verbose=True):
- """
- Set up logging with UTF-8 encoding and configurable verbosity.
-
- This function configures logging for the Ultralytics library, setting the appropriate logging level and
- formatter based on the verbosity flag and the current process rank. It handles special cases for Windows
- environments where UTF-8 encoding might not be the default.
-
- Args:
- name (str): Name of the logger.
- verbose (bool): Flag to set logging level to INFO if True, ERROR otherwise.
-
- Returns:
- (logging.Logger): Configured logger object.
-
- Examples:
- >>> set_logging(name="ultralytics", verbose=True)
- >>> logger = logging.getLogger("ultralytics")
- >>> logger.info("This is an info message")
-
- Notes:
- - On Windows, this function attempts to reconfigure stdout to use UTF-8 encoding if possible.
- - If reconfiguration is not possible, it falls back to a custom formatter that handles non-UTF-8 environments.
- - The function sets up a StreamHandler with the appropriate formatter and level.
- - The logger's propagate flag is set to False to prevent duplicate logging in parent loggers.
- """
- level = logging.INFO if verbose and RANK in {-1, 0} else logging.ERROR # rank in world for Multi-GPU trainings
-
- class PrefixFormatter(logging.Formatter):
- def format(self, record):
- """Format log records with prefixes based on level."""
- # Apply prefixes based on log level
- if record.levelno == logging.WARNING:
- prefix = "WARNING" if WINDOWS else "WARNING ⚠️"
- record.msg = f"{prefix} {record.msg}"
- elif record.levelno == logging.ERROR:
- prefix = "ERROR" if WINDOWS else "ERROR ❌"
- record.msg = f"{prefix} {record.msg}"
-
- # Handle emojis in message based on platform
- formatted_message = super().format(record)
- return emojis(formatted_message)
-
- formatter = PrefixFormatter("%(message)s")
-
- # Handle Windows UTF-8 encoding issues
- if WINDOWS and hasattr(sys.stdout, "encoding") and sys.stdout.encoding != "utf-8":
- with contextlib.suppress(Exception):
- # Attempt to reconfigure stdout to use UTF-8 encoding if possible
- if hasattr(sys.stdout, "reconfigure"):
- sys.stdout.reconfigure(encoding="utf-8")
- # For environments where reconfigure is not available, wrap stdout in a TextIOWrapper
- elif hasattr(sys.stdout, "buffer"):
- import io
-
- sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8")
-
- # Create and configure the StreamHandler with the appropriate formatter and level
- stream_handler = logging.StreamHandler(sys.stdout)
- stream_handler.setFormatter(formatter)
- stream_handler.setLevel(level)
-
- # Set up the logger
- logger = logging.getLogger(name)
- logger.setLevel(level)
- logger.addHandler(stream_handler)
- logger.propagate = False
- return logger
-
-
-# Set logger
-LOGGER = set_logging(LOGGING_NAME, verbose=VERBOSE) # define globally (used in train.py, val.py, predict.py, etc.)
-logging.getLogger("sentry_sdk").setLevel(logging.CRITICAL + 1)
-
-
-def emojis(string=""):
- """Return platform-dependent emoji-safe version of string."""
- return string.encode().decode("ascii", "ignore") if WINDOWS else string
-
-
-class ThreadingLocked:
- """
- A decorator class for ensuring thread-safe execution of a function or method.
-
- This class can be used as a decorator to make sure that if the decorated function is called from multiple threads,
- only one thread at a time will be able to execute the function.
-
- Attributes:
- lock (threading.Lock): A lock object used to manage access to the decorated function.
-
- Examples:
- >>> from ultralytics.utils import ThreadingLocked
- >>> @ThreadingLocked()
- >>> def my_function():
- ... # Your code here
- """
-
- def __init__(self):
- """Initialize the decorator class with a threading lock."""
- self.lock = threading.Lock()
-
- def __call__(self, f):
- """Run thread-safe execution of function or method."""
- from functools import wraps
-
- @wraps(f)
- def decorated(*args, **kwargs):
- """Apply thread-safety to the decorated function or method."""
- with self.lock:
- return f(*args, **kwargs)
-
- return decorated
-
-
-class YAML:
- """
- YAML utility class for efficient file operations with automatic C-implementation detection.
-
- This class provides optimized YAML loading and saving operations using PyYAML's fastest available implementation
- (C-based when possible). It implements a singleton pattern with lazy initialization, allowing direct class method
- usage without explicit instantiation. The class handles file path creation, validation, and character encoding
- issues automatically.
-
- The implementation prioritizes performance through:
- - Automatic C-based loader/dumper selection when available
- - Singleton pattern to reuse the same instance
- - Lazy initialization to defer import costs until needed
- - Fallback mechanisms for handling problematic YAML content
-
- Attributes:
- _instance: Internal singleton instance storage.
- yaml: Reference to the PyYAML module.
- SafeLoader: Best available YAML loader (CSafeLoader if available).
- SafeDumper: Best available YAML dumper (CSafeDumper if available).
-
- Examples:
- >>> data = YAML.load("config.yaml")
- >>> data["new_value"] = 123
- >>> YAML.save("updated_config.yaml", data)
- >>> YAML.print(data)
- """
-
- _instance = None
-
- @classmethod
- def _get_instance(cls):
- """Initialize singleton instance on first use."""
- if cls._instance is None:
- cls._instance = cls()
- return cls._instance
-
- def __init__(self):
- """Initialize with optimal YAML implementation (C-based when available)."""
- import yaml
-
- self.yaml = yaml
- # Use C-based implementation if available for better performance
- try:
- self.SafeLoader = yaml.CSafeLoader
- self.SafeDumper = yaml.CSafeDumper
- except (AttributeError, ImportError):
- self.SafeLoader = yaml.SafeLoader
- self.SafeDumper = yaml.SafeDumper
-
- @classmethod
- def save(cls, file="data.yaml", data=None, header=""):
- """
- Save Python object as YAML file.
-
- Args:
- file (str | Path): Path to save YAML file.
- data (dict | None): Dict or compatible object to save.
- header (str): Optional string to add at file beginning.
- """
- instance = cls._get_instance()
- if data is None:
- data = {}
-
- # Create parent directories if needed
- file = Path(file)
- file.parent.mkdir(parents=True, exist_ok=True)
-
- # Convert non-serializable objects to strings
- valid_types = int, float, str, bool, list, tuple, dict, type(None)
- for k, v in data.items():
- if not isinstance(v, valid_types):
- data[k] = str(v)
-
- # Write YAML file
- with open(file, "w", errors="ignore", encoding="utf-8") as f:
- if header:
- f.write(header)
- instance.yaml.dump(data, f, sort_keys=False, allow_unicode=True, Dumper=instance.SafeDumper)
-
- @classmethod
- def load(cls, file="data.yaml", append_filename=False):
- """
- Load YAML file to Python object with robust error handling.
-
- Args:
- file (str | Path): Path to YAML file.
- append_filename (bool): Whether to add filename to returned dict.
-
- Returns:
- (dict): Loaded YAML content.
- """
- instance = cls._get_instance()
- assert str(file).endswith((".yaml", ".yml")), f"Not a YAML file: {file}"
-
- # Read file content
- with open(file, errors="ignore", encoding="utf-8") as f:
- s = f.read()
-
- # Try loading YAML with fallback for problematic characters
- try:
- data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}
- except Exception:
- # Remove problematic characters and retry
- s = re.sub(r"[^\x09\x0A\x0D\x20-\x7E\x85\xA0-\uD7FF\uE000-\uFFFD\U00010000-\U0010ffff]+", "", s)
- data = instance.yaml.load(s, Loader=instance.SafeLoader) or {}
-
- # Check for accidental user-error None strings (should be 'null' in YAML)
- if "None" in data.values():
- data = {k: None if v == "None" else v for k, v in data.items()}
-
- if append_filename:
- data["yaml_file"] = str(file)
- return data
-
- @classmethod
- def print(cls, yaml_file):
- """
- Pretty print YAML file or object to console.
-
- Args:
- yaml_file (str | Path | dict): Path to YAML file or dict to print.
- """
- instance = cls._get_instance()
-
- # Load file if path provided
- yaml_dict = cls.load(yaml_file) if isinstance(yaml_file, (str, Path)) else yaml_file
-
- # Use -1 for unlimited width in C implementation
- dump = instance.yaml.dump(yaml_dict, sort_keys=False, allow_unicode=True, width=-1, Dumper=instance.SafeDumper)
-
- LOGGER.info(f"Printing '{colorstr('bold', 'black', yaml_file)}'\n\n{dump}")
-
-
-# Default configuration
-DEFAULT_CFG_DICT = YAML.load(DEFAULT_CFG_PATH)
-DEFAULT_CFG_KEYS = DEFAULT_CFG_DICT.keys()
-DEFAULT_CFG = IterableSimpleNamespace(**DEFAULT_CFG_DICT)
-
-
-def read_device_model() -> str:
- """
- Read the device model information from the system and cache it for quick access.
-
- Returns:
- (str): Kernel release information.
- """
- return platform.release().lower()
-
-
-def is_ubuntu() -> bool:
- """
- Check if the OS is Ubuntu.
-
- Returns:
- (bool): True if OS is Ubuntu, False otherwise.
- """
- try:
- with open("/etc/os-release") as f:
- return "ID=ubuntu" in f.read()
- except FileNotFoundError:
- return False
-
-
-def is_colab():
- """
- Check if the current script is running inside a Google Colab notebook.
-
- Returns:
- (bool): True if running inside a Colab notebook, False otherwise.
- """
- return "COLAB_RELEASE_TAG" in os.environ or "COLAB_BACKEND_VERSION" in os.environ
-
-
-def is_kaggle():
- """
- Check if the current script is running inside a Kaggle kernel.
-
- Returns:
- (bool): True if running inside a Kaggle kernel, False otherwise.
- """
- return os.environ.get("PWD") == "/kaggle/working" and os.environ.get("KAGGLE_URL_BASE") == "https://www.kaggle.com"
-
-
-def is_jupyter():
- """
- Check if the current script is running inside a Jupyter Notebook.
-
- Returns:
- (bool): True if running inside a Jupyter Notebook, False otherwise.
-
- Notes:
- - Only works on Colab and Kaggle, other environments like Jupyterlab and Paperspace are not reliably detectable.
- - "get_ipython" in globals() method suffers false positives when IPython package installed manually.
- """
- return IS_COLAB or IS_KAGGLE
-
-
-def is_runpod():
- """
- Check if the current script is running inside a RunPod container.
-
- Returns:
- (bool): True if running in RunPod, False otherwise.
- """
- return "RUNPOD_POD_ID" in os.environ
-
-
-def is_docker() -> bool:
- """
- Determine if the script is running inside a Docker container.
-
- Returns:
- (bool): True if the script is running inside a Docker container, False otherwise.
- """
- try:
- return os.path.exists("/.dockerenv")
- except Exception:
- return False
-
-
-def is_raspberrypi() -> bool:
- """
- Determine if the Python environment is running on a Raspberry Pi.
-
- Returns:
- (bool): True if running on a Raspberry Pi, False otherwise.
- """
- return "rpi" in DEVICE_MODEL
-
-
-@lru_cache(maxsize=3)
-def is_jetson(jetpack=None) -> bool:
- """
- Determine if the Python environment is running on an NVIDIA Jetson device.
-
- Args:
- jetpack (int | None): If specified, check for specific JetPack version (4, 5, 6).
-
- Returns:
- (bool): True if running on an NVIDIA Jetson device, False otherwise.
- """
- if jetson := ("tegra" in DEVICE_MODEL):
- if jetpack:
- try:
- content = open("/etc/nv_tegra_release").read()
- version_map = {4: "R32", 5: "R35", 6: "R36"} # JetPack to L4T major version mapping
- return jetpack in version_map and version_map[jetpack] in content
- except Exception:
- return False
- return jetson
-
-
-def is_online() -> bool:
- """
- Fast online check using DNS (v4/v6) resolution (Cloudflare + Google).
-
- Returns:
- (bool): True if connection is successful, False otherwise.
- """
- if str(os.getenv("YOLO_OFFLINE", "")).lower() == "true":
- return False
-
- for host in ("one.one.one.one", "dns.google"):
- try:
- socket.getaddrinfo(host, 0, socket.AF_UNSPEC, 0, 0, socket.AI_ADDRCONFIG)
- return True
- except OSError:
- continue
- return False
-
-
-def is_pip_package(filepath: str = __name__) -> bool:
- """
- Determine if the file at the given filepath is part of a pip package.
-
- Args:
- filepath (str): The filepath to check.
-
- Returns:
- (bool): True if the file is part of a pip package, False otherwise.
- """
- import importlib.util
-
- # Get the spec for the module
- spec = importlib.util.find_spec(filepath)
-
- # Return whether the spec is not None and the origin is not None (indicating it is a package)
- return spec is not None and spec.origin is not None
-
-
-def is_dir_writeable(dir_path: str | Path) -> bool:
- """
- Check if a directory is writeable.
-
- Args:
- dir_path (str | Path): The path to the directory.
-
- Returns:
- (bool): True if the directory is writeable, False otherwise.
- """
- return os.access(str(dir_path), os.W_OK)
-
-
-def is_pytest_running():
- """
- Determine whether pytest is currently running or not.
-
- Returns:
- (bool): True if pytest is running, False otherwise.
- """
- return ("PYTEST_CURRENT_TEST" in os.environ) or ("pytest" in sys.modules) or ("pytest" in Path(ARGV[0]).stem)
-
-
-def is_github_action_running() -> bool:
- """
- Determine if the current environment is a GitHub Actions runner.
-
- Returns:
- (bool): True if the current environment is a GitHub Actions runner, False otherwise.
- """
- return "GITHUB_ACTIONS" in os.environ and "GITHUB_WORKFLOW" in os.environ and "RUNNER_OS" in os.environ
-
-
-def get_default_args(func):
- """
- Return a dictionary of default arguments for a function.
-
- Args:
- func (callable): The function to inspect.
-
- Returns:
- (dict): A dictionary where each key is a parameter name, and each value is the default value of that parameter.
- """
- signature = inspect.signature(func)
- return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
-
-
-def get_ubuntu_version():
- """
- Retrieve the Ubuntu version if the OS is Ubuntu.
-
- Returns:
- (str): Ubuntu version or None if not an Ubuntu OS.
- """
- if is_ubuntu():
- try:
- with open("/etc/os-release") as f:
- return re.search(r'VERSION_ID="(\d+\.\d+)"', f.read())[1]
- except (FileNotFoundError, AttributeError):
- return None
-
-
-def get_user_config_dir(sub_dir="Ultralytics"):
- """
- Return a writable config dir, preferring YOLO_CONFIG_DIR and being OS-aware.
-
- Args:
- sub_dir (str): The name of the subdirectory to create.
-
- Returns:
- (Path): The path to the user config directory.
- """
- if env_dir := os.getenv("YOLO_CONFIG_DIR"):
- p = Path(env_dir).expanduser() / sub_dir
- elif LINUX:
- p = Path(os.getenv("XDG_CONFIG_HOME", Path.home() / ".config")) / sub_dir
- elif WINDOWS:
- p = Path.home() / "AppData" / "Roaming" / sub_dir
- elif MACOS:
- p = Path.home() / "Library" / "Application Support" / sub_dir
- else:
- raise ValueError(f"Unsupported operating system: {platform.system()}")
-
- if p.exists(): # already created → trust it
- return p
- if is_dir_writeable(p.parent): # create if possible
- p.mkdir(parents=True, exist_ok=True)
- return p
-
- # Fallbacks for Docker, GCP/AWS functions where only /tmp is writeable
- for alt in [Path("/tmp") / sub_dir, Path.cwd() / sub_dir]:
- if alt.exists():
- return alt
- if is_dir_writeable(alt.parent):
- alt.mkdir(parents=True, exist_ok=True)
- LOGGER.warning(
- f"user config directory '{p}' is not writeable, using '{alt}'. Set YOLO_CONFIG_DIR to override."
- )
- return alt
-
- # Last fallback → CWD
- p = Path.cwd() / sub_dir
- p.mkdir(parents=True, exist_ok=True)
- return p
-
-
-# Define constants (required below)
-DEVICE_MODEL = read_device_model() # is_jetson() and is_raspberrypi() depend on this constant
-ONLINE = is_online()
-IS_COLAB = is_colab()
-IS_KAGGLE = is_kaggle()
-IS_DOCKER = is_docker()
-IS_JETSON = is_jetson()
-IS_JUPYTER = is_jupyter()
-IS_PIP_PACKAGE = is_pip_package()
-IS_RASPBERRYPI = is_raspberrypi()
-GIT = GitRepo()
-USER_CONFIG_DIR = get_user_config_dir() # Ultralytics settings dir
-SETTINGS_FILE = USER_CONFIG_DIR / "settings.json"
-
-
-def colorstr(*input):
- r"""
- Color a string based on the provided color and style arguments using ANSI escape codes.
-
- This function can be called in two ways:
- - colorstr('color', 'style', 'your string')
- - colorstr('your string')
-
- In the second form, 'blue' and 'bold' will be applied by default.
-
- Args:
- *input (str | Path): A sequence of strings where the first n-1 strings are color and style arguments,
- and the last string is the one to be colored.
-
- Returns:
- (str): The input string wrapped with ANSI escape codes for the specified color and style.
-
- Notes:
- Supported Colors and Styles:
- - Basic Colors: 'black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'
- - Bright Colors: 'bright_black', 'bright_red', 'bright_green', 'bright_yellow',
- 'bright_blue', 'bright_magenta', 'bright_cyan', 'bright_white'
- - Misc: 'end', 'bold', 'underline'
-
- Examples:
- >>> colorstr("blue", "bold", "hello world")
- >>> "\033[34m\033[1mhello world\033[0m"
-
- References:
- https://en.wikipedia.org/wiki/ANSI_escape_code
- """
- *args, string = input if len(input) > 1 else ("blue", "bold", input[0]) # color arguments, string
- colors = {
- "black": "\033[30m", # basic colors
- "red": "\033[31m",
- "green": "\033[32m",
- "yellow": "\033[33m",
- "blue": "\033[34m",
- "magenta": "\033[35m",
- "cyan": "\033[36m",
- "white": "\033[37m",
- "bright_black": "\033[90m", # bright colors
- "bright_red": "\033[91m",
- "bright_green": "\033[92m",
- "bright_yellow": "\033[93m",
- "bright_blue": "\033[94m",
- "bright_magenta": "\033[95m",
- "bright_cyan": "\033[96m",
- "bright_white": "\033[97m",
- "end": "\033[0m", # misc
- "bold": "\033[1m",
- "underline": "\033[4m",
- }
- return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
-
-
-def remove_colorstr(input_string):
- """
- Remove ANSI escape codes from a string, effectively un-coloring it.
-
- Args:
- input_string (str): The string to remove color and style from.
-
- Returns:
- (str): A new string with all ANSI escape codes removed.
-
- Examples:
- >>> remove_colorstr(colorstr("blue", "bold", "hello world"))
- >>> "hello world"
- """
- ansi_escape = re.compile(r"\x1B\[[0-9;]*[A-Za-z]")
- return ansi_escape.sub("", input_string)
-
-
-class TryExcept(contextlib.ContextDecorator):
- """
- Ultralytics TryExcept class for handling exceptions gracefully.
-
- This class can be used as a decorator or context manager to catch exceptions and optionally print warning messages.
- It allows code to continue execution even when exceptions occur, which is useful for non-critical operations.
-
- Attributes:
- msg (str): Optional message to display when an exception occurs.
- verbose (bool): Whether to print the exception message.
-
- Examples:
- As a decorator:
- >>> @TryExcept(msg="Error occurred in func", verbose=True)
- >>> def func():
- >>> # Function logic here
- >>> pass
-
- As a context manager:
- >>> with TryExcept(msg="Error occurred in block", verbose=True):
- >>> # Code block here
- >>> pass
- """
-
- def __init__(self, msg="", verbose=True):
- """Initialize TryExcept class with optional message and verbosity settings."""
- self.msg = msg
- self.verbose = verbose
-
- def __enter__(self):
- """Execute when entering TryExcept context, initialize instance."""
- pass
-
- def __exit__(self, exc_type, value, traceback):
- """Define behavior when exiting a 'with' block, print error message if necessary."""
- if self.verbose and value:
- LOGGER.warning(f"{self.msg}{': ' if self.msg else ''}{value}")
- return True
-
-
-class Retry(contextlib.ContextDecorator):
- """
- Retry class for function execution with exponential backoff.
-
- This decorator can be used to retry a function on exceptions, up to a specified number of times with an
- exponentially increasing delay between retries. It's useful for handling transient failures in network
- operations or other unreliable processes.
-
- Attributes:
- times (int): Maximum number of retry attempts.
- delay (int): Initial delay between retries in seconds.
-
- Examples:
- Example usage as a decorator:
- >>> @Retry(times=3, delay=2)
- >>> def test_func():
- >>> # Replace with function logic that may raise exceptions
- >>> return True
- """
-
- def __init__(self, times=3, delay=2):
- """Initialize Retry class with specified number of retries and delay."""
- self.times = times
- self.delay = delay
- self._attempts = 0
-
- def __call__(self, func):
- """Decorator implementation for Retry with exponential backoff."""
-
- def wrapped_func(*args, **kwargs):
- """Apply retries to the decorated function or method."""
- self._attempts = 0
- while self._attempts < self.times:
- try:
- return func(*args, **kwargs)
- except Exception as e:
- self._attempts += 1
- LOGGER.warning(f"Retry {self._attempts}/{self.times} failed: {e}")
- if self._attempts >= self.times:
- raise e
- time.sleep(self.delay * (2**self._attempts)) # exponential backoff delay
-
- return wrapped_func
-
-
-def threaded(func):
- """
- Multi-thread a target function by default and return the thread or function result.
-
- This decorator provides flexible execution of the target function, either in a separate thread or synchronously.
- By default, the function runs in a thread, but this can be controlled via the 'threaded=False' keyword argument
- which is removed from kwargs before calling the function.
-
- Args:
- func (callable): The function to be potentially executed in a separate thread.
-
- Returns:
- (callable): A wrapper function that either returns a daemon thread or the direct function result.
-
- Examples:
- >>> @threaded
- ... def process_data(data):
- ... return data
- >>>
- >>> thread = process_data(my_data) # Runs in background thread
- >>> result = process_data(my_data, threaded=False) # Runs synchronously, returns function result
- """
-
- def wrapper(*args, **kwargs):
- """Multi-thread a given function based on 'threaded' kwarg and return the thread or function result."""
- if kwargs.pop("threaded", True): # run in thread
- thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
- thread.start()
- return thread
- else:
- return func(*args, **kwargs)
-
- return wrapper
-
-
-def set_sentry():
- """
- Initialize the Sentry SDK for error tracking and reporting.
-
- Only used if sentry_sdk package is installed and sync=True in settings. Run 'yolo settings' to see and update
- settings.
-
- Conditions required to send errors (ALL conditions must be met or no errors will be reported):
- - sentry_sdk package is installed
- - sync=True in YOLO settings
- - pytest is not running
- - running in a pip package installation
- - running in a non-git directory
- - running with rank -1 or 0
- - online environment
- - CLI used to run package (checked with 'yolo' as the name of the main CLI command)
- """
- if (
- not SETTINGS["sync"]
- or RANK not in {-1, 0}
- or Path(ARGV[0]).name != "yolo"
- or TESTS_RUNNING
- or not ONLINE
- or not IS_PIP_PACKAGE
- or GIT.is_repo
- ):
- return
- # If sentry_sdk package is not installed then return and do not use Sentry
- try:
- import sentry_sdk # noqa
- except ImportError:
- return
-
- def before_send(event, hint):
- """
- Modify the event before sending it to Sentry based on specific exception types and messages.
-
- Args:
- event (dict): The event dictionary containing information about the error.
- hint (dict): A dictionary containing additional information about the error.
-
- Returns:
- (dict | None): The modified event or None if the event should not be sent to Sentry.
- """
- if "exc_info" in hint:
- exc_type, exc_value, _ = hint["exc_info"]
- if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
- return None # do not send event
-
- event["tags"] = {
- "sys_argv": ARGV[0],
- "sys_argv_name": Path(ARGV[0]).name,
- "install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
- "os": ENVIRONMENT,
- }
- return event
-
- sentry_sdk.init(
- dsn="https://888e5a0778212e1d0314c37d4b9aae5d@o4504521589325824.ingest.us.sentry.io/4504521592406016",
- debug=False,
- auto_enabling_integrations=False,
- traces_sample_rate=1.0,
- release=__version__,
- environment="runpod" if is_runpod() else "production",
- before_send=before_send,
- ignore_errors=[KeyboardInterrupt, FileNotFoundError],
- )
- sentry_sdk.set_user({"id": SETTINGS["uuid"]}) # SHA-256 anonymized UUID hash
-
-
-class JSONDict(dict):
- """
- A dictionary-like class that provides JSON persistence for its contents.
-
- This class extends the built-in dictionary to automatically save its contents to a JSON file whenever they are
- modified. It ensures thread-safe operations using a lock and handles JSON serialization of Path objects.
-
- Attributes:
- file_path (Path): The path to the JSON file used for persistence.
- lock (threading.Lock): A lock object to ensure thread-safe operations.
-
- Methods:
- _load: Load the data from the JSON file into the dictionary.
- _save: Save the current state of the dictionary to the JSON file.
- __setitem__: Store a key-value pair and persist it to disk.
- __delitem__: Remove an item and update the persistent storage.
- update: Update the dictionary and persist changes.
- clear: Clear all entries and update the persistent storage.
-
- Examples:
- >>> json_dict = JSONDict("data.json")
- >>> json_dict["key"] = "value"
- >>> print(json_dict["key"])
- value
- >>> del json_dict["key"]
- >>> json_dict.update({"new_key": "new_value"})
- >>> json_dict.clear()
- """
-
- def __init__(self, file_path: str | Path = "data.json"):
- """Initialize a JSONDict object with a specified file path for JSON persistence."""
- super().__init__()
- self.file_path = Path(file_path)
- self.lock = Lock()
- self._load()
-
- def _load(self):
- """Load the data from the JSON file into the dictionary."""
- try:
- if self.file_path.exists():
- with open(self.file_path) as f:
- self.update(json.load(f))
- except json.JSONDecodeError:
- LOGGER.warning(f"Error decoding JSON from {self.file_path}. Starting with an empty dictionary.")
- except Exception as e:
- LOGGER.error(f"Error reading from {self.file_path}: {e}")
-
- def _save(self):
- """Save the current state of the dictionary to the JSON file."""
- try:
- self.file_path.parent.mkdir(parents=True, exist_ok=True)
- with open(self.file_path, "w", encoding="utf-8") as f:
- json.dump(dict(self), f, indent=2, default=self._json_default)
- except Exception as e:
- LOGGER.error(f"Error writing to {self.file_path}: {e}")
-
- @staticmethod
- def _json_default(obj):
- """Handle JSON serialization of Path objects."""
- if isinstance(obj, Path):
- return str(obj)
- raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
-
- def __setitem__(self, key, value):
- """Store a key-value pair and persist to disk."""
- with self.lock:
- super().__setitem__(key, value)
- self._save()
-
- def __delitem__(self, key):
- """Remove an item and update the persistent storage."""
- with self.lock:
- super().__delitem__(key)
- self._save()
-
- def __str__(self):
- """Return a pretty-printed JSON string representation of the dictionary."""
- contents = json.dumps(dict(self), indent=2, ensure_ascii=False, default=self._json_default)
- return f'JSONDict("{self.file_path}"):\n{contents}'
-
- def update(self, *args, **kwargs):
- """Update the dictionary and persist changes."""
- with self.lock:
- super().update(*args, **kwargs)
- self._save()
-
- def clear(self):
- """Clear all entries and update the persistent storage."""
- with self.lock:
- super().clear()
- self._save()
-
-
-class SettingsManager(JSONDict):
- """
- SettingsManager class for managing and persisting Ultralytics settings.
-
- This class extends JSONDict to provide JSON persistence for settings, ensuring thread-safe operations and default
- values. It validates settings on initialization and provides methods to update or reset settings. The settings
- include directories for datasets, weights, and runs, as well as various integration flags.
-
- Attributes:
- file (Path): The path to the JSON file used for persistence.
- version (str): The version of the settings schema.
- defaults (dict): A dictionary containing default settings.
- help_msg (str): A help message for users on how to view and update settings.
-
- Methods:
- _validate_settings: Validate the current settings and reset if necessary.
- update: Update settings, validating keys and types.
- reset: Reset the settings to default and save them.
-
- Examples:
- Initialize and update settings:
- >>> settings = SettingsManager()
- >>> settings.update(runs_dir="/new/runs/dir")
- >>> print(settings["runs_dir"])
- /new/runs/dir
- """
-
- def __init__(self, file=SETTINGS_FILE, version="0.0.6"):
- """Initialize the SettingsManager with default settings and load user settings."""
- import hashlib
- import uuid
-
- from ultralytics.utils.torch_utils import torch_distributed_zero_first
-
- root = GIT.root or Path()
- datasets_root = (root.parent if GIT.root and is_dir_writeable(root.parent) else root).resolve()
-
- self.file = Path(file)
- self.version = version
- self.defaults = {
- "settings_version": version, # Settings schema version
- "datasets_dir": str(datasets_root / "datasets"), # Datasets directory
- "weights_dir": str(root / "weights"), # Model weights directory
- "runs_dir": str(root / "runs"), # Experiment runs directory
- "uuid": hashlib.sha256(str(uuid.getnode()).encode()).hexdigest(), # SHA-256 anonymized UUID hash
- "sync": True, # Enable synchronization
- "api_key": "", # Ultralytics API Key
- "openai_api_key": "", # OpenAI API Key
- "clearml": True, # ClearML integration
- "comet": True, # Comet integration
- "dvc": True, # DVC integration
- "hub": True, # Ultralytics HUB integration
- "mlflow": True, # MLflow integration
- "neptune": True, # Neptune integration
- "raytune": True, # Ray Tune integration
- "tensorboard": False, # TensorBoard logging
- "wandb": False, # Weights & Biases logging
- "vscode_msg": True, # VSCode message
- "openvino_msg": True, # OpenVINO export on Intel CPU message
- }
-
- self.help_msg = (
- f"\nView Ultralytics Settings with 'yolo settings' or at '{self.file}'"
- "\nUpdate Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. "
- "For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings."
- )
-
- with torch_distributed_zero_first(LOCAL_RANK):
- super().__init__(self.file)
-
- if not self.file.exists() or not self: # Check if file doesn't exist or is empty
- LOGGER.info(f"Creating new Ultralytics Settings v{version} file ✅ {self.help_msg}")
- self.reset()
-
- self._validate_settings()
-
- def _validate_settings(self):
- """Validate the current settings and reset if necessary."""
- correct_keys = frozenset(self.keys()) == frozenset(self.defaults.keys())
- correct_types = all(isinstance(self.get(k), type(v)) for k, v in self.defaults.items())
- correct_version = self.get("settings_version", "") == self.version
-
- if not (correct_keys and correct_types and correct_version):
- LOGGER.warning(
- "Ultralytics settings reset to default values. This may be due to a possible problem "
- f"with your settings or a recent ultralytics package update. {self.help_msg}"
- )
- self.reset()
-
- if self.get("datasets_dir") == self.get("runs_dir"):
- LOGGER.warning(
- f"Ultralytics setting 'datasets_dir: {self.get('datasets_dir')}' "
- f"must be different than 'runs_dir: {self.get('runs_dir')}'. "
- f"Please change one to avoid possible issues during training. {self.help_msg}"
- )
-
- def __setitem__(self, key, value):
- """Update one key: value pair."""
- self.update({key: value})
-
- def update(self, *args, **kwargs):
- """Update settings, validating keys and types."""
- for arg in args:
- if isinstance(arg, dict):
- kwargs.update(arg)
- for k, v in kwargs.items():
- if k not in self.defaults:
- raise KeyError(f"No Ultralytics setting '{k}'. {self.help_msg}")
- t = type(self.defaults[k])
- if not isinstance(v, t):
- raise TypeError(
- f"Ultralytics setting '{k}' must be '{t.__name__}' type, not '{type(v).__name__}'. {self.help_msg}"
- )
- super().update(*args, **kwargs)
-
- def reset(self):
- """Reset the settings to default and save them."""
- self.clear()
- self.update(self.defaults)
-
-
-def deprecation_warn(arg, new_arg=None):
- """Issue a deprecation warning when a deprecated argument is used, suggesting an updated argument."""
- msg = f"'{arg}' is deprecated and will be removed in the future."
- if new_arg is not None:
- msg += f" Use '{new_arg}' instead."
- LOGGER.warning(msg)
-
-
-def clean_url(url):
- """Strip auth from URL, i.e. https://url.com/file.txt?auth -> https://url.com/file.txt."""
- url = Path(url).as_posix().replace(":/", "://") # Pathlib turns :// -> :/, as_posix() for Windows
- return unquote(url).split("?", 1)[0] # '%2F' to '/', split https://url.com/file.txt?auth
-
-
-def url2file(url):
- """Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt."""
- return Path(clean_url(url)).name
-
-
-def vscode_msg(ext="ultralytics.ultralytics-snippets") -> str:
- """Display a message to install Ultralytics-Snippets for VS Code if not already installed."""
- path = (USER_CONFIG_DIR.parents[2] if WINDOWS else USER_CONFIG_DIR.parents[1]) / ".vscode/extensions"
- obs_file = path / ".obsolete" # file tracks uninstalled extensions, while source directory remains
- installed = any(path.glob(f"{ext}*")) and ext not in (obs_file.read_text("utf-8") if obs_file.exists() else "")
- url = "https://docs.ultralytics.com/integrations/vscode"
- return "" if installed else f"{colorstr('VS Code:')} view Ultralytics VS Code Extension ⚡ at {url}"
-
-
-# Run below code on utils init ------------------------------------------------------------------------------------
-
-# Check first-install steps
-PREFIX = colorstr("Ultralytics: ")
-SETTINGS = SettingsManager() # initialize settings
-PERSISTENT_CACHE = JSONDict(USER_CONFIG_DIR / "persistent_cache.json") # initialize persistent cache
-DATASETS_DIR = Path(SETTINGS["datasets_dir"]) # global datasets directory
-WEIGHTS_DIR = Path(SETTINGS["weights_dir"]) # global weights directory
-RUNS_DIR = Path(SETTINGS["runs_dir"]) # global runs directory
-ENVIRONMENT = (
- "Colab"
- if IS_COLAB
- else "Kaggle"
- if IS_KAGGLE
- else "Jupyter"
- if IS_JUPYTER
- else "Docker"
- if IS_DOCKER
- else platform.system()
-)
-TESTS_RUNNING = is_pytest_running() or is_github_action_running()
-set_sentry()
-
-# Apply monkey patches
-torch.save = torch_save
-if WINDOWS:
- # Apply cv2 patches for non-ASCII and non-UTF characters in image paths
- cv2.imread, cv2.imwrite, cv2.imshow = imread, imwrite, imshow
diff --git a/ultralytics/utils/__pycache__/__init__.cpython-310.pyc b/ultralytics/utils/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index ba71194..0000000
Binary files a/ultralytics/utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/autobatch.cpython-310.pyc b/ultralytics/utils/__pycache__/autobatch.cpython-310.pyc
deleted file mode 100644
index 27fd61c..0000000
Binary files a/ultralytics/utils/__pycache__/autobatch.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/checks.cpython-310.pyc b/ultralytics/utils/__pycache__/checks.cpython-310.pyc
deleted file mode 100644
index 808e940..0000000
Binary files a/ultralytics/utils/__pycache__/checks.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/cpu.cpython-310.pyc b/ultralytics/utils/__pycache__/cpu.cpython-310.pyc
deleted file mode 100644
index a474098..0000000
Binary files a/ultralytics/utils/__pycache__/cpu.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/dist.cpython-310.pyc b/ultralytics/utils/__pycache__/dist.cpython-310.pyc
deleted file mode 100644
index 3b0477d..0000000
Binary files a/ultralytics/utils/__pycache__/dist.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/downloads.cpython-310.pyc b/ultralytics/utils/__pycache__/downloads.cpython-310.pyc
deleted file mode 100644
index 54af114..0000000
Binary files a/ultralytics/utils/__pycache__/downloads.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/errors.cpython-310.pyc b/ultralytics/utils/__pycache__/errors.cpython-310.pyc
deleted file mode 100644
index a819cab..0000000
Binary files a/ultralytics/utils/__pycache__/errors.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/events.cpython-310.pyc b/ultralytics/utils/__pycache__/events.cpython-310.pyc
deleted file mode 100644
index f837209..0000000
Binary files a/ultralytics/utils/__pycache__/events.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/files.cpython-310.pyc b/ultralytics/utils/__pycache__/files.cpython-310.pyc
deleted file mode 100644
index 8ffb38d..0000000
Binary files a/ultralytics/utils/__pycache__/files.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/git.cpython-310.pyc b/ultralytics/utils/__pycache__/git.cpython-310.pyc
deleted file mode 100644
index 84c553d..0000000
Binary files a/ultralytics/utils/__pycache__/git.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/instance.cpython-310.pyc b/ultralytics/utils/__pycache__/instance.cpython-310.pyc
deleted file mode 100644
index 8de8105..0000000
Binary files a/ultralytics/utils/__pycache__/instance.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/loss.cpython-310.pyc b/ultralytics/utils/__pycache__/loss.cpython-310.pyc
deleted file mode 100644
index aa0dbce..0000000
Binary files a/ultralytics/utils/__pycache__/loss.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/metrics.cpython-310.pyc b/ultralytics/utils/__pycache__/metrics.cpython-310.pyc
deleted file mode 100644
index e8de956..0000000
Binary files a/ultralytics/utils/__pycache__/metrics.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/nms.cpython-310.pyc b/ultralytics/utils/__pycache__/nms.cpython-310.pyc
deleted file mode 100644
index 4085b3f..0000000
Binary files a/ultralytics/utils/__pycache__/nms.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/ops.cpython-310.pyc b/ultralytics/utils/__pycache__/ops.cpython-310.pyc
deleted file mode 100644
index 64cb8d9..0000000
Binary files a/ultralytics/utils/__pycache__/ops.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/patches.cpython-310.pyc b/ultralytics/utils/__pycache__/patches.cpython-310.pyc
deleted file mode 100644
index 81b64c3..0000000
Binary files a/ultralytics/utils/__pycache__/patches.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/plotting.cpython-310.pyc b/ultralytics/utils/__pycache__/plotting.cpython-310.pyc
deleted file mode 100644
index 94a759e..0000000
Binary files a/ultralytics/utils/__pycache__/plotting.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/tal.cpython-310.pyc b/ultralytics/utils/__pycache__/tal.cpython-310.pyc
deleted file mode 100644
index 5b2a39b..0000000
Binary files a/ultralytics/utils/__pycache__/tal.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/torch_utils.cpython-310.pyc b/ultralytics/utils/__pycache__/torch_utils.cpython-310.pyc
deleted file mode 100644
index 8708a1d..0000000
Binary files a/ultralytics/utils/__pycache__/torch_utils.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/__pycache__/tqdm.cpython-310.pyc b/ultralytics/utils/__pycache__/tqdm.cpython-310.pyc
deleted file mode 100644
index 1d1c1b8..0000000
Binary files a/ultralytics/utils/__pycache__/tqdm.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/autobatch.py b/ultralytics/utils/autobatch.py
deleted file mode 100644
index ef67cb4..0000000
--- a/ultralytics/utils/autobatch.py
+++ /dev/null
@@ -1,120 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""Functions for estimating the best YOLO batch size to use a fraction of the available CUDA memory in PyTorch."""
-
-from __future__ import annotations
-
-import os
-from copy import deepcopy
-
-import numpy as np
-import torch
-
-from ultralytics.utils import DEFAULT_CFG, LOGGER, colorstr
-from ultralytics.utils.torch_utils import autocast, profile_ops
-
-
-def check_train_batch_size(
- model: torch.nn.Module,
- imgsz: int = 640,
- amp: bool = True,
- batch: int | float = -1,
- max_num_obj: int = 1,
-) -> int:
- """
- Compute optimal YOLO training batch size using the autobatch() function.
-
- Args:
- model (torch.nn.Module): YOLO model to check batch size for.
- imgsz (int, optional): Image size used for training.
- amp (bool, optional): Use automatic mixed precision if True.
- batch (int | float, optional): Fraction of GPU memory to use. If -1, use default.
- max_num_obj (int, optional): The maximum number of objects from dataset.
-
- Returns:
- (int): Optimal batch size computed using the autobatch() function.
-
- Notes:
- If 0.0 < batch < 1.0, it's used as the fraction of GPU memory to use.
- Otherwise, a default fraction of 0.6 is used.
- """
- with autocast(enabled=amp):
- return autobatch(
- deepcopy(model).train(), imgsz, fraction=batch if 0.0 < batch < 1.0 else 0.6, max_num_obj=max_num_obj
- )
-
-
-def autobatch(
- model: torch.nn.Module,
- imgsz: int = 640,
- fraction: float = 0.60,
- batch_size: int = DEFAULT_CFG.batch,
- max_num_obj: int = 1,
-) -> int:
- """
- Automatically estimate the best YOLO batch size to use a fraction of the available CUDA memory.
-
- Args:
- model (torch.nn.Module): YOLO model to compute batch size for.
- imgsz (int, optional): The image size used as input for the YOLO model.
- fraction (float, optional): The fraction of available CUDA memory to use.
- batch_size (int, optional): The default batch size to use if an error is detected.
- max_num_obj (int, optional): The maximum number of objects from dataset.
-
- Returns:
- (int): The optimal batch size.
- """
- # Check device
- prefix = colorstr("AutoBatch: ")
- LOGGER.info(f"{prefix}Computing optimal batch size for imgsz={imgsz} at {fraction * 100}% CUDA memory utilization.")
- device = next(model.parameters()).device # get model device
- if device.type in {"cpu", "mps"}:
- LOGGER.warning(f"{prefix}intended for CUDA devices, using default batch-size {batch_size}")
- return batch_size
- if torch.backends.cudnn.benchmark:
- LOGGER.warning(f"{prefix}Requires torch.backends.cudnn.benchmark=False, using default batch-size {batch_size}")
- return batch_size
-
- # Inspect CUDA memory
- gb = 1 << 30 # bytes to GiB (1024 ** 3)
- d = f"CUDA:{os.getenv('CUDA_VISIBLE_DEVICES', '0').strip()[0]}" # 'CUDA:0'
- properties = torch.cuda.get_device_properties(device) # device properties
- t = properties.total_memory / gb # GiB total
- r = torch.cuda.memory_reserved(device) / gb # GiB reserved
- a = torch.cuda.memory_allocated(device) / gb # GiB allocated
- f = t - (r + a) # GiB free
- LOGGER.info(f"{prefix}{d} ({properties.name}) {t:.2f}G total, {r:.2f}G reserved, {a:.2f}G allocated, {f:.2f}G free")
-
- # Profile batch sizes
- batch_sizes = [1, 2, 4, 8, 16] if t < 16 else [1, 2, 4, 8, 16, 32, 64]
- try:
- img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
- results = profile_ops(img, model, n=1, device=device, max_num_obj=max_num_obj)
-
- # Fit a solution
- xy = [
- [x, y[2]]
- for i, (x, y) in enumerate(zip(batch_sizes, results))
- if y # valid result
- and isinstance(y[2], (int, float)) # is numeric
- and 0 < y[2] < t # between 0 and GPU limit
- and (i == 0 or not results[i - 1] or y[2] > results[i - 1][2]) # first item or increasing memory
- ]
- fit_x, fit_y = zip(*xy) if xy else ([], [])
- p = np.polyfit(fit_x, fit_y, deg=1) # first-degree polynomial fit in log space
- b = int((round(f * fraction) - p[1]) / p[0]) # y intercept (optimal batch size)
- if None in results: # some sizes failed
- i = results.index(None) # first fail index
- if b >= batch_sizes[i]: # y intercept above failure point
- b = batch_sizes[max(i - 1, 0)] # select prior safe point
- if b < 1 or b > 1024: # b outside of safe range
- LOGGER.warning(f"{prefix}batch={b} outside safe range, using default batch-size {batch_size}.")
- b = batch_size
-
- fraction = (np.polyval(p, b) + r + a) / t # predicted fraction
- LOGGER.info(f"{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅")
- return b
- except Exception as e:
- LOGGER.warning(f"{prefix}error detected: {e}, using default batch-size {batch_size}.")
- return batch_size
- finally:
- torch.cuda.empty_cache()
diff --git a/ultralytics/utils/autodevice.py b/ultralytics/utils/autodevice.py
deleted file mode 100644
index a0971bc..0000000
--- a/ultralytics/utils/autodevice.py
+++ /dev/null
@@ -1,206 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from typing import Any
-
-from ultralytics.utils import LOGGER
-from ultralytics.utils.checks import check_requirements
-
-
-class GPUInfo:
- """
- Manages NVIDIA GPU information via pynvml with robust error handling.
-
- Provides methods to query detailed GPU statistics (utilization, memory, temp, power) and select the most idle
- GPUs based on configurable criteria. It safely handles the absence or initialization failure of the pynvml
- library by logging warnings and disabling related features, preventing application crashes.
-
- Includes fallback logic using `torch.cuda` for basic device counting if NVML is unavailable during GPU
- selection. Manages NVML initialization and shutdown internally.
-
- Attributes:
- pynvml (module | None): The `pynvml` module if successfully imported and initialized, otherwise `None`.
- nvml_available (bool): Indicates if `pynvml` is ready for use. True if import and `nvmlInit()` succeeded,
- False otherwise.
- gpu_stats (list[dict[str, Any]]): A list of dictionaries, each holding stats for one GPU. Populated on
- initialization and by `refresh_stats()`. Keys include: 'index', 'name', 'utilization' (%),
- 'memory_used' (MiB), 'memory_total' (MiB), 'memory_free' (MiB), 'temperature' (C), 'power_draw' (W),
- 'power_limit' (W or 'N/A'). Empty if NVML is unavailable or queries fail.
-
- Methods:
- refresh_stats: Refresh the internal gpu_stats list by querying NVML.
- print_status: Print GPU status in a compact table format using current stats.
- select_idle_gpu: Select the most idle GPUs based on utilization and free memory.
- shutdown: Shut down NVML if it was initialized.
-
- Examples:
- Initialize GPUInfo and print status
- >>> gpu_info = GPUInfo()
- >>> gpu_info.print_status()
-
- Select idle GPUs with minimum memory requirements
- >>> selected = gpu_info.select_idle_gpu(count=2, min_memory_fraction=0.2)
- >>> print(f"Selected GPU indices: {selected}")
- """
-
- def __init__(self):
- """Initialize GPUInfo, attempting to import and initialize pynvml."""
- self.pynvml: Any | None = None
- self.nvml_available: bool = False
- self.gpu_stats: list[dict[str, Any]] = []
-
- try:
- check_requirements("nvidia-ml-py>=12.0.0")
- self.pynvml = __import__("pynvml")
- self.pynvml.nvmlInit()
- self.nvml_available = True
- self.refresh_stats()
- except Exception as e:
- LOGGER.warning(f"Failed to initialize pynvml, GPU stats disabled: {e}")
-
- def __del__(self):
- """Ensure NVML is shut down when the object is garbage collected."""
- self.shutdown()
-
- def shutdown(self):
- """Shut down NVML if it was initialized."""
- if self.nvml_available and self.pynvml:
- try:
- self.pynvml.nvmlShutdown()
- except Exception:
- pass
- self.nvml_available = False
-
- def refresh_stats(self):
- """Refresh the internal gpu_stats list by querying NVML."""
- self.gpu_stats = []
- if not self.nvml_available or not self.pynvml:
- return
-
- try:
- device_count = self.pynvml.nvmlDeviceGetCount()
- self.gpu_stats.extend(self._get_device_stats(i) for i in range(device_count))
- except Exception as e:
- LOGGER.warning(f"Error during device query: {e}")
- self.gpu_stats = []
-
- def _get_device_stats(self, index: int) -> dict[str, Any]:
- """Get stats for a single GPU device."""
- handle = self.pynvml.nvmlDeviceGetHandleByIndex(index)
- memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
- util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
-
- def safe_get(func, *args, default=-1, divisor=1):
- try:
- val = func(*args)
- return val // divisor if divisor != 1 and isinstance(val, (int, float)) else val
- except Exception:
- return default
-
- temp_type = getattr(self.pynvml, "NVML_TEMPERATURE_GPU", -1)
-
- return {
- "index": index,
- "name": self.pynvml.nvmlDeviceGetName(handle),
- "utilization": util.gpu if util else -1,
- "memory_used": memory.used >> 20 if memory else -1, # Convert bytes to MiB
- "memory_total": memory.total >> 20 if memory else -1,
- "memory_free": memory.free >> 20 if memory else -1,
- "temperature": safe_get(self.pynvml.nvmlDeviceGetTemperature, handle, temp_type),
- "power_draw": safe_get(self.pynvml.nvmlDeviceGetPowerUsage, handle, divisor=1000), # Convert mW to W
- "power_limit": safe_get(self.pynvml.nvmlDeviceGetEnforcedPowerLimit, handle, divisor=1000),
- }
-
- def print_status(self):
- """Print GPU status in a compact table format using current stats."""
- self.refresh_stats()
- if not self.gpu_stats:
- LOGGER.warning("No GPU stats available.")
- return
-
- stats = self.gpu_stats
- name_len = max(len(gpu.get("name", "N/A")) for gpu in stats)
- hdr = f"{'Idx':<3} {'Name':<{name_len}} {'Util':>6} {'Mem (MiB)':>15} {'Temp':>5} {'Pwr (W)':>10}"
- LOGGER.info(f"\n--- GPU Status ---\n{hdr}\n{'-' * len(hdr)}")
-
- for gpu in stats:
- u = f"{gpu['utilization']:>5}%" if gpu["utilization"] >= 0 else " N/A "
- m = f"{gpu['memory_used']:>6}/{gpu['memory_total']:<6}" if gpu["memory_used"] >= 0 else " N/A / N/A "
- t = f"{gpu['temperature']}C" if gpu["temperature"] >= 0 else " N/A "
- p = f"{gpu['power_draw']:>3}/{gpu['power_limit']:<3}" if gpu["power_draw"] >= 0 else " N/A "
-
- LOGGER.info(f"{gpu.get('index'):<3d} {gpu.get('name', 'N/A'):<{name_len}} {u:>6} {m:>15} {t:>5} {p:>10}")
-
- LOGGER.info(f"{'-' * len(hdr)}\n")
-
- def select_idle_gpu(
- self, count: int = 1, min_memory_fraction: float = 0, min_util_fraction: float = 0
- ) -> list[int]:
- """
- Select the most idle GPUs based on utilization and free memory.
-
- Args:
- count (int): The number of idle GPUs to select.
- min_memory_fraction (float): Minimum free memory required as a fraction of total memory.
- min_util_fraction (float): Minimum free utilization rate required from 0.0 - 1.0.
-
- Returns:
- (list[int]): Indices of the selected GPUs, sorted by idleness (lowest utilization first).
-
- Notes:
- Returns fewer than 'count' if not enough qualify or exist.
- Returns basic CUDA indices if NVML fails. Empty list if no GPUs found.
- """
- assert min_memory_fraction <= 1.0, f"min_memory_fraction must be <= 1.0, got {min_memory_fraction}"
- assert min_util_fraction <= 1.0, f"min_util_fraction must be <= 1.0, got {min_util_fraction}"
- LOGGER.info(
- f"Searching for {count} idle GPUs with free memory >= {min_memory_fraction * 100:.1f}% and free utilization >= {min_util_fraction * 100:.1f}%..."
- )
-
- if count <= 0:
- return []
-
- self.refresh_stats()
- if not self.gpu_stats:
- LOGGER.warning("NVML stats unavailable.")
- return []
-
- # Filter and sort eligible GPUs
- eligible_gpus = [
- gpu
- for gpu in self.gpu_stats
- if gpu.get("memory_free", 0) / gpu.get("memory_total", 1) >= min_memory_fraction
- and (100 - gpu.get("utilization", 100)) >= min_util_fraction * 100
- ]
- eligible_gpus.sort(key=lambda x: (x.get("utilization", 101), -x.get("memory_free", 0)))
-
- # Select top 'count' indices
- selected = [gpu["index"] for gpu in eligible_gpus[:count]]
-
- if selected:
- LOGGER.info(f"Selected idle CUDA devices {selected}")
- else:
- LOGGER.warning(
- f"No GPUs met criteria (Free Mem >= {min_memory_fraction * 100:.1f}% and Free Util >= {min_util_fraction * 100:.1f}%)."
- )
-
- return selected
-
-
-if __name__ == "__main__":
- required_free_mem_fraction = 0.2 # Require 20% free VRAM
- required_free_util_fraction = 0.2 # Require 20% free utilization
- num_gpus_to_select = 1
-
- gpu_info = GPUInfo()
- gpu_info.print_status()
-
- if selected := gpu_info.select_idle_gpu(
- count=num_gpus_to_select,
- min_memory_fraction=required_free_mem_fraction,
- min_util_fraction=required_free_util_fraction,
- ):
- print(f"\n==> Using selected GPU indices: {selected}")
- devices = [f"cuda:{idx}" for idx in selected]
- print(f" Target devices: {devices}")
diff --git a/ultralytics/utils/benchmarks.py b/ultralytics/utils/benchmarks.py
deleted file mode 100644
index da8f263..0000000
--- a/ultralytics/utils/benchmarks.py
+++ /dev/null
@@ -1,728 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""
-Benchmark a YOLO model formats for speed and accuracy.
-
-Usage:
- from ultralytics.utils.benchmarks import ProfileModels, benchmark
- ProfileModels(['yolo11n.yaml', 'yolov8s.yaml']).run()
- benchmark(model='yolo11n.pt', imgsz=160)
-
-Format | `format=argument` | Model
---- | --- | ---
-PyTorch | - | yolo11n.pt
-TorchScript | `torchscript` | yolo11n.torchscript
-ONNX | `onnx` | yolo11n.onnx
-OpenVINO | `openvino` | yolo11n_openvino_model/
-TensorRT | `engine` | yolo11n.engine
-CoreML | `coreml` | yolo11n.mlpackage
-TensorFlow SavedModel | `saved_model` | yolo11n_saved_model/
-TensorFlow GraphDef | `pb` | yolo11n.pb
-TensorFlow Lite | `tflite` | yolo11n.tflite
-TensorFlow Edge TPU | `edgetpu` | yolo11n_edgetpu.tflite
-TensorFlow.js | `tfjs` | yolo11n_web_model/
-PaddlePaddle | `paddle` | yolo11n_paddle_model/
-MNN | `mnn` | yolo11n.mnn
-NCNN | `ncnn` | yolo11n_ncnn_model/
-IMX | `imx` | yolo11n_imx_model/
-RKNN | `rknn` | yolo11n_rknn_model/
-"""
-
-from __future__ import annotations
-
-import glob
-import os
-import platform
-import re
-import shutil
-import time
-from pathlib import Path
-
-import numpy as np
-import torch.cuda
-
-from ultralytics import YOLO, YOLOWorld
-from ultralytics.cfg import TASK2DATA, TASK2METRIC
-from ultralytics.engine.exporter import export_formats
-from ultralytics.utils import ARM64, ASSETS, IS_JETSON, LINUX, LOGGER, MACOS, TQDM, WEIGHTS_DIR, YAML
-from ultralytics.utils.checks import IS_PYTHON_3_13, check_imgsz, check_requirements, check_yolo, is_rockchip
-from ultralytics.utils.downloads import safe_download
-from ultralytics.utils.files import file_size
-from ultralytics.utils.torch_utils import get_cpu_info, select_device
-
-
-def benchmark(
- model=WEIGHTS_DIR / "yolo11n.pt",
- data=None,
- imgsz=160,
- half=False,
- int8=False,
- device="cpu",
- verbose=False,
- eps=1e-3,
- format="",
- **kwargs,
-):
- """
- Benchmark a YOLO model across different formats for speed and accuracy.
-
- Args:
- model (str | Path): Path to the model file or directory.
- data (str | None): Dataset to evaluate on, inherited from TASK2DATA if not passed.
- imgsz (int): Image size for the benchmark.
- half (bool): Use half-precision for the model if True.
- int8 (bool): Use int8-precision for the model if True.
- device (str): Device to run the benchmark on, either 'cpu' or 'cuda'.
- verbose (bool | float): If True or a float, assert benchmarks pass with given metric.
- eps (float): Epsilon value for divide by zero prevention.
- format (str): Export format for benchmarking. If not supplied all formats are benchmarked.
- **kwargs (Any): Additional keyword arguments for exporter.
-
- Returns:
- (polars.DataFrame): A polars DataFrame with benchmark results for each format, including file size, metric,
- and inference time.
-
- Examples:
- Benchmark a YOLO model with default settings:
- >>> from ultralytics.utils.benchmarks import benchmark
- >>> benchmark(model="yolo11n.pt", imgsz=640)
- """
- imgsz = check_imgsz(imgsz)
- assert imgsz[0] == imgsz[1] if isinstance(imgsz, list) else True, "benchmark() only supports square imgsz."
-
- import polars as pl # scope for faster 'import ultralytics'
-
- pl.Config.set_tbl_cols(-1) # Show all columns
- pl.Config.set_tbl_rows(-1) # Show all rows
- pl.Config.set_tbl_width_chars(-1) # No width limit
- pl.Config.set_tbl_hide_column_data_types(True) # Hide data types
- pl.Config.set_tbl_hide_dataframe_shape(True) # Hide shape info
- pl.Config.set_tbl_formatting("ASCII_BORDERS_ONLY_CONDENSED")
-
- device = select_device(device, verbose=False)
- if isinstance(model, (str, Path)):
- model = YOLO(model)
- is_end2end = getattr(model.model.model[-1], "end2end", False)
- data = data or TASK2DATA[model.task] # task to dataset, i.e. coco8.yaml for task=detect
- key = TASK2METRIC[model.task] # task to metric, i.e. metrics/mAP50-95(B) for task=detect
-
- y = []
- t0 = time.time()
-
- format_arg = format.lower()
- if format_arg:
- formats = frozenset(export_formats()["Argument"])
- assert format in formats, f"Expected format to be one of {formats}, but got '{format_arg}'."
- for name, format, suffix, cpu, gpu, _ in zip(*export_formats().values()):
- emoji, filename = "❌", None # export defaults
- try:
- if format_arg and format_arg != format:
- continue
-
- # Checks
- if format == "pb":
- assert model.task != "obb", "TensorFlow GraphDef not supported for OBB task"
- elif format == "edgetpu":
- assert LINUX and not ARM64, "Edge TPU export only supported on non-aarch64 Linux"
- elif format in {"coreml", "tfjs"}:
- assert MACOS or (LINUX and not ARM64), (
- "CoreML and TF.js export only supported on macOS and non-aarch64 Linux"
- )
- if format == "coreml":
- assert not IS_PYTHON_3_13, "CoreML not supported on Python 3.13"
- if format in {"saved_model", "pb", "tflite", "edgetpu", "tfjs"}:
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 TensorFlow exports not supported by onnx2tf yet"
- # assert not IS_PYTHON_MINIMUM_3_12, "TFLite exports not supported on Python>=3.12 yet"
- if format == "paddle":
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 Paddle exports not supported yet"
- assert model.task != "obb", "Paddle OBB bug https://github.com/PaddlePaddle/Paddle/issues/72024"
- assert not is_end2end, "End-to-end models not supported by PaddlePaddle yet"
- assert (LINUX and not IS_JETSON) or MACOS, "Windows and Jetson Paddle exports not supported yet"
- if format == "mnn":
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 MNN exports not supported yet"
- if format == "ncnn":
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 NCNN exports not supported yet"
- if format == "imx":
- assert not is_end2end
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 IMX exports not supported"
- assert model.task == "detect", "IMX only supported for detection task"
- assert "C2f" in model.__str__(), "IMX only supported for YOLOv8n and YOLO11n"
- if format == "rknn":
- assert not isinstance(model, YOLOWorld), "YOLOWorldv2 RKNN exports not supported yet"
- assert not is_end2end, "End-to-end models not supported by RKNN yet"
- assert LINUX, "RKNN only supported on Linux"
- assert not is_rockchip(), "RKNN Inference only supported on Rockchip devices"
- if "cpu" in device.type:
- assert cpu, "inference not supported on CPU"
- if "cuda" in device.type:
- assert gpu, "inference not supported on GPU"
-
- # Export
- if format == "-":
- filename = model.pt_path or model.ckpt_path or model.model_name
- exported_model = model # PyTorch format
- else:
- filename = model.export(
- imgsz=imgsz, format=format, half=half, int8=int8, data=data, device=device, verbose=False, **kwargs
- )
- exported_model = YOLO(filename, task=model.task)
- assert suffix in str(filename), "export failed"
- emoji = "❎" # indicates export succeeded
-
- # Predict
- assert model.task != "pose" or format != "pb", "GraphDef Pose inference is not supported"
- assert format not in {"edgetpu", "tfjs"}, "inference not supported"
- assert format != "coreml" or platform.system() == "Darwin", "inference only supported on macOS>=10.13"
- if format == "ncnn":
- assert not is_end2end, "End-to-end torch.topk operation is not supported for NCNN prediction yet"
- exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half, verbose=False)
-
- # Validate
- results = exported_model.val(
- data=data,
- batch=1,
- imgsz=imgsz,
- plots=False,
- device=device,
- half=half,
- int8=int8,
- verbose=False,
- conf=0.001, # all the pre-set benchmark mAP values are based on conf=0.001
- )
- metric, speed = results.results_dict[key], results.speed["inference"]
- fps = round(1000 / (speed + eps), 2) # frames per second
- y.append([name, "✅", round(file_size(filename), 1), round(metric, 4), round(speed, 2), fps])
- except Exception as e:
- if verbose:
- assert type(e) is AssertionError, f"Benchmark failure for {name}: {e}"
- LOGGER.error(f"Benchmark failure for {name}: {e}")
- y.append([name, emoji, round(file_size(filename), 1), None, None, None]) # mAP, t_inference
-
- # Print results
- check_yolo(device=device) # print system info
- df = pl.DataFrame(y, schema=["Format", "Status❔", "Size (MB)", key, "Inference time (ms/im)", "FPS"], orient="row")
- df = df.with_row_index(" ", offset=1) # add index info
- df_display = df.with_columns(pl.all().cast(pl.String).fill_null("-"))
-
- name = model.model_name
- dt = time.time() - t0
- legend = "Benchmarks legend: - ✅ Success - ❎ Export passed but validation failed - ❌️ Export failed"
- s = f"\nBenchmarks complete for {name} on {data} at imgsz={imgsz} ({dt:.2f}s)\n{legend}\n{df_display}\n"
- LOGGER.info(s)
- with open("benchmarks.log", "a", errors="ignore", encoding="utf-8") as f:
- f.write(s)
-
- if verbose and isinstance(verbose, float):
- metrics = df[key].to_numpy() # values to compare to floor
- floor = verbose # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
- assert all(x > floor for x in metrics if not np.isnan(x)), f"Benchmark failure: metric(s) < floor {floor}"
-
- return df_display
-
-
-class RF100Benchmark:
- """
- Benchmark YOLO model performance across various formats for speed and accuracy.
-
- This class provides functionality to benchmark YOLO models on the RF100 dataset collection.
-
- Attributes:
- ds_names (list[str]): Names of datasets used for benchmarking.
- ds_cfg_list (list[Path]): List of paths to dataset configuration files.
- rf (Roboflow): Roboflow instance for accessing datasets.
- val_metrics (list[str]): Metrics used for validation.
-
- Methods:
- set_key: Set Roboflow API key for accessing datasets.
- parse_dataset: Parse dataset links and download datasets.
- fix_yaml: Fix train and validation paths in YAML files.
- evaluate: Evaluate model performance on validation results.
- """
-
- def __init__(self):
- """Initialize the RF100Benchmark class for benchmarking YOLO model performance across various formats."""
- self.ds_names = []
- self.ds_cfg_list = []
- self.rf = None
- self.val_metrics = ["class", "images", "targets", "precision", "recall", "map50", "map95"]
-
- def set_key(self, api_key: str):
- """
- Set Roboflow API key for processing.
-
- Args:
- api_key (str): The API key.
-
- Examples:
- Set the Roboflow API key for accessing datasets:
- >>> benchmark = RF100Benchmark()
- >>> benchmark.set_key("your_roboflow_api_key")
- """
- check_requirements("roboflow")
- from roboflow import Roboflow
-
- self.rf = Roboflow(api_key=api_key)
-
- def parse_dataset(self, ds_link_txt: str = "datasets_links.txt"):
- """
- Parse dataset links and download datasets.
-
- Args:
- ds_link_txt (str): Path to the file containing dataset links.
-
- Returns:
- ds_names (list[str]): List of dataset names.
- ds_cfg_list (list[Path]): List of paths to dataset configuration files.
-
- Examples:
- >>> benchmark = RF100Benchmark()
- >>> benchmark.set_key("api_key")
- >>> benchmark.parse_dataset("datasets_links.txt")
- """
- (shutil.rmtree("rf-100"), os.mkdir("rf-100")) if os.path.exists("rf-100") else os.mkdir("rf-100")
- os.chdir("rf-100")
- os.mkdir("ultralytics-benchmarks")
- safe_download("https://github.com/ultralytics/assets/releases/download/v0.0.0/datasets_links.txt")
-
- with open(ds_link_txt, encoding="utf-8") as file:
- for line in file:
- try:
- _, url, workspace, project, version = re.split("/+", line.strip())
- self.ds_names.append(project)
- proj_version = f"{project}-{version}"
- if not Path(proj_version).exists():
- self.rf.workspace(workspace).project(project).version(version).download("yolov8")
- else:
- LOGGER.info("Dataset already downloaded.")
- self.ds_cfg_list.append(Path.cwd() / proj_version / "data.yaml")
- except Exception:
- continue
-
- return self.ds_names, self.ds_cfg_list
-
- @staticmethod
- def fix_yaml(path: Path):
- """Fix the train and validation paths in a given YAML file."""
- yaml_data = YAML.load(path)
- yaml_data["train"] = "train/images"
- yaml_data["val"] = "valid/images"
- YAML.dump(yaml_data, path)
-
- def evaluate(self, yaml_path: str, val_log_file: str, eval_log_file: str, list_ind: int):
- """
- Evaluate model performance on validation results.
-
- Args:
- yaml_path (str): Path to the YAML configuration file.
- val_log_file (str): Path to the validation log file.
- eval_log_file (str): Path to the evaluation log file.
- list_ind (int): Index of the current dataset in the list.
-
- Returns:
- (float): The mean average precision (mAP) value for the evaluated model.
-
- Examples:
- Evaluate a model on a specific dataset
- >>> benchmark = RF100Benchmark()
- >>> benchmark.evaluate("path/to/data.yaml", "path/to/val_log.txt", "path/to/eval_log.txt", 0)
- """
- skip_symbols = ["🚀", "⚠️", "💡", "❌"]
- class_names = YAML.load(yaml_path)["names"]
- with open(val_log_file, encoding="utf-8") as f:
- lines = f.readlines()
- eval_lines = []
- for line in lines:
- if any(symbol in line for symbol in skip_symbols):
- continue
- entries = line.split(" ")
- entries = list(filter(lambda val: val != "", entries))
- entries = [e.strip("\n") for e in entries]
- eval_lines.extend(
- {
- "class": entries[0],
- "images": entries[1],
- "targets": entries[2],
- "precision": entries[3],
- "recall": entries[4],
- "map50": entries[5],
- "map95": entries[6],
- }
- for e in entries
- if e in class_names or (e == "all" and "(AP)" not in entries and "(AR)" not in entries)
- )
- map_val = 0.0
- if len(eval_lines) > 1:
- LOGGER.info("Multiple dicts found")
- for lst in eval_lines:
- if lst["class"] == "all":
- map_val = lst["map50"]
- else:
- LOGGER.info("Single dict found")
- map_val = [res["map50"] for res in eval_lines][0]
-
- with open(eval_log_file, "a", encoding="utf-8") as f:
- f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
-
- return float(map_val)
-
-
-class ProfileModels:
- """
- ProfileModels class for profiling different models on ONNX and TensorRT.
-
- This class profiles the performance of different models, returning results such as model speed and FLOPs.
-
- Attributes:
- paths (list[str]): Paths of the models to profile.
- num_timed_runs (int): Number of timed runs for the profiling.
- num_warmup_runs (int): Number of warmup runs before profiling.
- min_time (float): Minimum number of seconds to profile for.
- imgsz (int): Image size used in the models.
- half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
- trt (bool): Flag to indicate whether to profile using TensorRT.
- device (torch.device): Device used for profiling.
-
- Methods:
- run: Profile YOLO models for speed and accuracy across various formats.
- get_files: Get all relevant model files.
- get_onnx_model_info: Extract metadata from an ONNX model.
- iterative_sigma_clipping: Apply sigma clipping to remove outliers.
- profile_tensorrt_model: Profile a TensorRT model.
- profile_onnx_model: Profile an ONNX model.
- generate_table_row: Generate a table row with model metrics.
- generate_results_dict: Generate a dictionary of profiling results.
- print_table: Print a formatted table of results.
-
- Examples:
- Profile models and print results
- >>> from ultralytics.utils.benchmarks import ProfileModels
- >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
- >>> profiler.run()
- """
-
- def __init__(
- self,
- paths: list[str],
- num_timed_runs: int = 100,
- num_warmup_runs: int = 10,
- min_time: float = 60,
- imgsz: int = 640,
- half: bool = True,
- trt: bool = True,
- device: torch.device | str | None = None,
- ):
- """
- Initialize the ProfileModels class for profiling models.
-
- Args:
- paths (list[str]): List of paths of the models to be profiled.
- num_timed_runs (int): Number of timed runs for the profiling.
- num_warmup_runs (int): Number of warmup runs before the actual profiling starts.
- min_time (float): Minimum time in seconds for profiling a model.
- imgsz (int): Size of the image used during profiling.
- half (bool): Flag to indicate whether to use FP16 half-precision for TensorRT profiling.
- trt (bool): Flag to indicate whether to profile using TensorRT.
- device (torch.device | str | None): Device used for profiling. If None, it is determined automatically.
-
- Notes:
- FP16 'half' argument option removed for ONNX as slower on CPU than FP32.
-
- Examples:
- Initialize and profile models
- >>> from ultralytics.utils.benchmarks import ProfileModels
- >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"], imgsz=640)
- >>> profiler.run()
- """
- self.paths = paths
- self.num_timed_runs = num_timed_runs
- self.num_warmup_runs = num_warmup_runs
- self.min_time = min_time
- self.imgsz = imgsz
- self.half = half
- self.trt = trt # run TensorRT profiling
- self.device = device if isinstance(device, torch.device) else select_device(device)
-
- def run(self):
- """
- Profile YOLO models for speed and accuracy across various formats including ONNX and TensorRT.
-
- Returns:
- (list[dict]): List of dictionaries containing profiling results for each model.
-
- Examples:
- Profile models and print results
- >>> from ultralytics.utils.benchmarks import ProfileModels
- >>> profiler = ProfileModels(["yolo11n.yaml", "yolov8s.yaml"])
- >>> results = profiler.run()
- """
- files = self.get_files()
-
- if not files:
- LOGGER.warning("No matching *.pt or *.onnx files found.")
- return []
-
- table_rows = []
- output = []
- for file in files:
- engine_file = file.with_suffix(".engine")
- if file.suffix in {".pt", ".yaml", ".yml"}:
- model = YOLO(str(file))
- model.fuse() # to report correct params and GFLOPs in model.info()
- model_info = model.info()
- if self.trt and self.device.type != "cpu" and not engine_file.is_file():
- engine_file = model.export(
- format="engine",
- half=self.half,
- imgsz=self.imgsz,
- device=self.device,
- verbose=False,
- )
- onnx_file = model.export(
- format="onnx",
- imgsz=self.imgsz,
- device=self.device,
- verbose=False,
- )
- elif file.suffix == ".onnx":
- model_info = self.get_onnx_model_info(file)
- onnx_file = file
- else:
- continue
-
- t_engine = self.profile_tensorrt_model(str(engine_file))
- t_onnx = self.profile_onnx_model(str(onnx_file))
- table_rows.append(self.generate_table_row(file.stem, t_onnx, t_engine, model_info))
- output.append(self.generate_results_dict(file.stem, t_onnx, t_engine, model_info))
-
- self.print_table(table_rows)
- return output
-
- def get_files(self):
- """
- Return a list of paths for all relevant model files given by the user.
-
- Returns:
- (list[Path]): List of Path objects for the model files.
- """
- files = []
- for path in self.paths:
- path = Path(path)
- if path.is_dir():
- extensions = ["*.pt", "*.onnx", "*.yaml"]
- files.extend([file for ext in extensions for file in glob.glob(str(path / ext))])
- elif path.suffix in {".pt", ".yaml", ".yml"}: # add non-existing
- files.append(str(path))
- else:
- files.extend(glob.glob(str(path)))
-
- LOGGER.info(f"Profiling: {sorted(files)}")
- return [Path(file) for file in sorted(files)]
-
- @staticmethod
- def get_onnx_model_info(onnx_file: str):
- """Extract metadata from an ONNX model file including parameters, GFLOPs, and input shape."""
- return 0.0, 0.0, 0.0, 0.0 # return (num_layers, num_params, num_gradients, num_flops)
-
- @staticmethod
- def iterative_sigma_clipping(data: np.ndarray, sigma: float = 2, max_iters: int = 3):
- """
- Apply iterative sigma clipping to data to remove outliers.
-
- Args:
- data (np.ndarray): Input data array.
- sigma (float): Number of standard deviations to use for clipping.
- max_iters (int): Maximum number of iterations for the clipping process.
-
- Returns:
- (np.ndarray): Clipped data array with outliers removed.
- """
- data = np.array(data)
- for _ in range(max_iters):
- mean, std = np.mean(data), np.std(data)
- clipped_data = data[(data > mean - sigma * std) & (data < mean + sigma * std)]
- if len(clipped_data) == len(data):
- break
- data = clipped_data
- return data
-
- def profile_tensorrt_model(self, engine_file: str, eps: float = 1e-3):
- """
- Profile YOLO model performance with TensorRT, measuring average run time and standard deviation.
-
- Args:
- engine_file (str): Path to the TensorRT engine file.
- eps (float): Small epsilon value to prevent division by zero.
-
- Returns:
- mean_time (float): Mean inference time in milliseconds.
- std_time (float): Standard deviation of inference time in milliseconds.
- """
- if not self.trt or not Path(engine_file).is_file():
- return 0.0, 0.0
-
- # Model and input
- model = YOLO(engine_file)
- input_data = np.zeros((self.imgsz, self.imgsz, 3), dtype=np.uint8) # use uint8 for Classify
-
- # Warmup runs
- elapsed = 0.0
- for _ in range(3):
- start_time = time.time()
- for _ in range(self.num_warmup_runs):
- model(input_data, imgsz=self.imgsz, verbose=False)
- elapsed = time.time() - start_time
-
- # Compute number of runs as higher of min_time or num_timed_runs
- num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs * 50)
-
- # Timed runs
- run_times = []
- for _ in TQDM(range(num_runs), desc=engine_file):
- results = model(input_data, imgsz=self.imgsz, verbose=False)
- run_times.append(results[0].speed["inference"]) # Convert to milliseconds
-
- run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=3) # sigma clipping
- return np.mean(run_times), np.std(run_times)
-
- def profile_onnx_model(self, onnx_file: str, eps: float = 1e-3):
- """
- Profile an ONNX model, measuring average inference time and standard deviation across multiple runs.
-
- Args:
- onnx_file (str): Path to the ONNX model file.
- eps (float): Small epsilon value to prevent division by zero.
-
- Returns:
- mean_time (float): Mean inference time in milliseconds.
- std_time (float): Standard deviation of inference time in milliseconds.
- """
- check_requirements("onnxruntime")
- import onnxruntime as ort
-
- # Session with either 'TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'
- sess_options = ort.SessionOptions()
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
- sess_options.intra_op_num_threads = 8 # Limit the number of threads
- sess = ort.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"])
-
- input_tensor = sess.get_inputs()[0]
- input_type = input_tensor.type
- dynamic = not all(isinstance(dim, int) and dim >= 0 for dim in input_tensor.shape) # dynamic input shape
- input_shape = (1, 3, self.imgsz, self.imgsz) if dynamic else input_tensor.shape
-
- # Mapping ONNX datatype to numpy datatype
- if "float16" in input_type:
- input_dtype = np.float16
- elif "float" in input_type:
- input_dtype = np.float32
- elif "double" in input_type:
- input_dtype = np.float64
- elif "int64" in input_type:
- input_dtype = np.int64
- elif "int32" in input_type:
- input_dtype = np.int32
- else:
- raise ValueError(f"Unsupported ONNX datatype {input_type}")
-
- input_data = np.random.rand(*input_shape).astype(input_dtype)
- input_name = input_tensor.name
- output_name = sess.get_outputs()[0].name
-
- # Warmup runs
- elapsed = 0.0
- for _ in range(3):
- start_time = time.time()
- for _ in range(self.num_warmup_runs):
- sess.run([output_name], {input_name: input_data})
- elapsed = time.time() - start_time
-
- # Compute number of runs as higher of min_time or num_timed_runs
- num_runs = max(round(self.min_time / (elapsed + eps) * self.num_warmup_runs), self.num_timed_runs)
-
- # Timed runs
- run_times = []
- for _ in TQDM(range(num_runs), desc=onnx_file):
- start_time = time.time()
- sess.run([output_name], {input_name: input_data})
- run_times.append((time.time() - start_time) * 1000) # Convert to milliseconds
-
- run_times = self.iterative_sigma_clipping(np.array(run_times), sigma=2, max_iters=5) # sigma clipping
- return np.mean(run_times), np.std(run_times)
-
- def generate_table_row(
- self,
- model_name: str,
- t_onnx: tuple[float, float],
- t_engine: tuple[float, float],
- model_info: tuple[float, float, float, float],
- ):
- """
- Generate a table row string with model performance metrics.
-
- Args:
- model_name (str): Name of the model.
- t_onnx (tuple): ONNX model inference time statistics (mean, std).
- t_engine (tuple): TensorRT engine inference time statistics (mean, std).
- model_info (tuple): Model information (layers, params, gradients, flops).
-
- Returns:
- (str): Formatted table row string with model metrics.
- """
- layers, params, gradients, flops = model_info
- return (
- f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
- f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
- )
-
- @staticmethod
- def generate_results_dict(
- model_name: str,
- t_onnx: tuple[float, float],
- t_engine: tuple[float, float],
- model_info: tuple[float, float, float, float],
- ):
- """
- Generate a dictionary of profiling results.
-
- Args:
- model_name (str): Name of the model.
- t_onnx (tuple): ONNX model inference time statistics (mean, std).
- t_engine (tuple): TensorRT engine inference time statistics (mean, std).
- model_info (tuple): Model information (layers, params, gradients, flops).
-
- Returns:
- (dict): Dictionary containing profiling results.
- """
- layers, params, gradients, flops = model_info
- return {
- "model/name": model_name,
- "model/parameters": params,
- "model/GFLOPs": round(flops, 3),
- "model/speed_ONNX(ms)": round(t_onnx[0], 3),
- "model/speed_TensorRT(ms)": round(t_engine[0], 3),
- }
-
- @staticmethod
- def print_table(table_rows: list[str]):
- """
- Print a formatted table of model profiling results.
-
- Args:
- table_rows (list[str]): List of formatted table row strings.
- """
- gpu = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "GPU"
- headers = [
- "Model",
- "size
(pixels)",
- "mAPval
50-95",
- f"Speed
CPU ({get_cpu_info()}) ONNX
(ms)",
- f"Speed
{gpu} TensorRT
(ms)",
- "params
(M)",
- "FLOPs
(B)",
- ]
- header = "|" + "|".join(f" {h} " for h in headers) + "|"
- separator = "|" + "|".join("-" * (len(h) + 2) for h in headers) + "|"
-
- LOGGER.info(f"\n\n{header}")
- LOGGER.info(separator)
- for row in table_rows:
- LOGGER.info(row)
diff --git a/ultralytics/utils/callbacks/__init__.py b/ultralytics/utils/callbacks/__init__.py
deleted file mode 100644
index 920cc4f..0000000
--- a/ultralytics/utils/callbacks/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from .base import add_integration_callbacks, default_callbacks, get_default_callbacks
-
-__all__ = "add_integration_callbacks", "default_callbacks", "get_default_callbacks"
diff --git a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-310.pyc b/ultralytics/utils/callbacks/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index b8e3ff5..0000000
Binary files a/ultralytics/utils/callbacks/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/callbacks/__pycache__/base.cpython-310.pyc b/ultralytics/utils/callbacks/__pycache__/base.cpython-310.pyc
deleted file mode 100644
index db97163..0000000
Binary files a/ultralytics/utils/callbacks/__pycache__/base.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/callbacks/__pycache__/hub.cpython-310.pyc b/ultralytics/utils/callbacks/__pycache__/hub.cpython-310.pyc
deleted file mode 100644
index 0a17e4c..0000000
Binary files a/ultralytics/utils/callbacks/__pycache__/hub.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/callbacks/__pycache__/platform.cpython-310.pyc b/ultralytics/utils/callbacks/__pycache__/platform.cpython-310.pyc
deleted file mode 100644
index e9e5217..0000000
Binary files a/ultralytics/utils/callbacks/__pycache__/platform.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/callbacks/base.py b/ultralytics/utils/callbacks/base.py
deleted file mode 100644
index 46e529b..0000000
--- a/ultralytics/utils/callbacks/base.py
+++ /dev/null
@@ -1,235 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""Base callbacks for Ultralytics training, validation, prediction, and export processes."""
-
-from collections import defaultdict
-from copy import deepcopy
-
-# Trainer callbacks ----------------------------------------------------------------------------------------------------
-
-
-def on_pretrain_routine_start(trainer):
- """Called before the pretraining routine starts."""
- pass
-
-
-def on_pretrain_routine_end(trainer):
- """Called after the pretraining routine ends."""
- pass
-
-
-def on_train_start(trainer):
- """Called when the training starts."""
- pass
-
-
-def on_train_epoch_start(trainer):
- """Called at the start of each training epoch."""
- pass
-
-
-def on_train_batch_start(trainer):
- """Called at the start of each training batch."""
- pass
-
-
-def optimizer_step(trainer):
- """Called when the optimizer takes a step."""
- pass
-
-
-def on_before_zero_grad(trainer):
- """Called before the gradients are set to zero."""
- pass
-
-
-def on_train_batch_end(trainer):
- """Called at the end of each training batch."""
- pass
-
-
-def on_train_epoch_end(trainer):
- """Called at the end of each training epoch."""
- pass
-
-
-def on_fit_epoch_end(trainer):
- """Called at the end of each fit epoch (train + val)."""
- pass
-
-
-def on_model_save(trainer):
- """Called when the model is saved."""
- pass
-
-
-def on_train_end(trainer):
- """Called when the training ends."""
- pass
-
-
-def on_params_update(trainer):
- """Called when the model parameters are updated."""
- pass
-
-
-def teardown(trainer):
- """Called during the teardown of the training process."""
- pass
-
-
-# Validator callbacks --------------------------------------------------------------------------------------------------
-
-
-def on_val_start(validator):
- """Called when the validation starts."""
- pass
-
-
-def on_val_batch_start(validator):
- """Called at the start of each validation batch."""
- pass
-
-
-def on_val_batch_end(validator):
- """Called at the end of each validation batch."""
- pass
-
-
-def on_val_end(validator):
- """Called when the validation ends."""
- pass
-
-
-# Predictor callbacks --------------------------------------------------------------------------------------------------
-
-
-def on_predict_start(predictor):
- """Called when the prediction starts."""
- pass
-
-
-def on_predict_batch_start(predictor):
- """Called at the start of each prediction batch."""
- pass
-
-
-def on_predict_batch_end(predictor):
- """Called at the end of each prediction batch."""
- pass
-
-
-def on_predict_postprocess_end(predictor):
- """Called after the post-processing of the prediction ends."""
- pass
-
-
-def on_predict_end(predictor):
- """Called when the prediction ends."""
- pass
-
-
-# Exporter callbacks ---------------------------------------------------------------------------------------------------
-
-
-def on_export_start(exporter):
- """Called when the model export starts."""
- pass
-
-
-def on_export_end(exporter):
- """Called when the model export ends."""
- pass
-
-
-default_callbacks = {
- # Run in trainer
- "on_pretrain_routine_start": [on_pretrain_routine_start],
- "on_pretrain_routine_end": [on_pretrain_routine_end],
- "on_train_start": [on_train_start],
- "on_train_epoch_start": [on_train_epoch_start],
- "on_train_batch_start": [on_train_batch_start],
- "optimizer_step": [optimizer_step],
- "on_before_zero_grad": [on_before_zero_grad],
- "on_train_batch_end": [on_train_batch_end],
- "on_train_epoch_end": [on_train_epoch_end],
- "on_fit_epoch_end": [on_fit_epoch_end], # fit = train + val
- "on_model_save": [on_model_save],
- "on_train_end": [on_train_end],
- "on_params_update": [on_params_update],
- "teardown": [teardown],
- # Run in validator
- "on_val_start": [on_val_start],
- "on_val_batch_start": [on_val_batch_start],
- "on_val_batch_end": [on_val_batch_end],
- "on_val_end": [on_val_end],
- # Run in predictor
- "on_predict_start": [on_predict_start],
- "on_predict_batch_start": [on_predict_batch_start],
- "on_predict_postprocess_end": [on_predict_postprocess_end],
- "on_predict_batch_end": [on_predict_batch_end],
- "on_predict_end": [on_predict_end],
- # Run in exporter
- "on_export_start": [on_export_start],
- "on_export_end": [on_export_end],
-}
-
-
-def get_default_callbacks():
- """
- Get the default callbacks for Ultralytics training, validation, prediction, and export processes.
-
- Returns:
- (dict): Dictionary of default callbacks for various training events. Each key represents an event during the
- training process, and the corresponding value is a list of callback functions executed when that event
- occurs.
-
- Examples:
- >>> callbacks = get_default_callbacks()
- >>> print(list(callbacks.keys())) # show all available callback events
- ['on_pretrain_routine_start', 'on_pretrain_routine_end', ...]
- """
- return defaultdict(list, deepcopy(default_callbacks))
-
-
-def add_integration_callbacks(instance):
- """
- Add integration callbacks to the instance's callbacks dictionary.
-
- This function loads and adds various integration callbacks to the provided instance. The specific callbacks added
- depend on the type of instance provided. All instances receive HUB callbacks, while Trainer instances also receive
- additional callbacks for various integrations like ClearML, Comet, DVC, MLflow, Neptune, Ray Tune, TensorBoard,
- and Weights & Biases.
-
- Args:
- instance (Trainer | Predictor | Validator | Exporter): The object instance to which callbacks will be added.
- The type of instance determines which callbacks are loaded.
-
- Examples:
- >>> from ultralytics.engine.trainer import BaseTrainer
- >>> trainer = BaseTrainer()
- >>> add_integration_callbacks(trainer)
- """
- from .hub import callbacks as hub_cb
- from .platform import callbacks as platform_cb
-
- # Load Ultralytics callbacks
- callbacks_list = [hub_cb, platform_cb]
-
- # Load training callbacks
- if "Trainer" in instance.__class__.__name__:
- from .clearml import callbacks as clear_cb
- from .comet import callbacks as comet_cb
- from .dvc import callbacks as dvc_cb
- from .mlflow import callbacks as mlflow_cb
- from .neptune import callbacks as neptune_cb
- from .raytune import callbacks as tune_cb
- from .tensorboard import callbacks as tb_cb
- from .wb import callbacks as wb_cb
-
- callbacks_list.extend([clear_cb, comet_cb, dvc_cb, mlflow_cb, neptune_cb, tune_cb, tb_cb, wb_cb])
-
- # Add the callbacks to the callbacks dictionary
- for callbacks in callbacks_list:
- for k, v in callbacks.items():
- if v not in instance.callbacks[k]:
- instance.callbacks[k].append(v)
diff --git a/ultralytics/utils/callbacks/clearml.py b/ultralytics/utils/callbacks/clearml.py
deleted file mode 100644
index 446ee01..0000000
--- a/ultralytics/utils/callbacks/clearml.py
+++ /dev/null
@@ -1,154 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["clearml"] is True # verify integration is enabled
- import clearml
- from clearml import Task
-
- assert hasattr(clearml, "__version__") # verify package is not directory
-
-except (ImportError, AssertionError):
- clearml = None
-
-
-def _log_debug_samples(files, title: str = "Debug Samples") -> None:
- """
- Log files (images) as debug samples in the ClearML task.
-
- Args:
- files (list[Path]): A list of file paths in PosixPath format.
- title (str): A title that groups together images with the same values.
- """
- import re
-
- if task := Task.current_task():
- for f in files:
- if f.exists():
- it = re.search(r"_batch(\d+)", f.name)
- iteration = int(it.groups()[0]) if it else 0
- task.get_logger().report_image(
- title=title, series=f.name.replace(it.group(), ""), local_path=str(f), iteration=iteration
- )
-
-
-def _log_plot(title: str, plot_path: str) -> None:
- """
- Log an image as a plot in the plot section of ClearML.
-
- Args:
- title (str): The title of the plot.
- plot_path (str): The path to the saved image file.
- """
- import matplotlib.image as mpimg
- import matplotlib.pyplot as plt
-
- img = mpimg.imread(plot_path)
- fig = plt.figure()
- ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
- ax.imshow(img)
-
- Task.current_task().get_logger().report_matplotlib_figure(
- title=title, series="", figure=fig, report_interactive=False
- )
-
-
-def on_pretrain_routine_start(trainer) -> None:
- """Initialize and connect ClearML task at the start of pretraining routine."""
- try:
- if task := Task.current_task():
- # WARNING: make sure the automatic pytorch and matplotlib bindings are disabled!
- # We are logging these plots and model files manually in the integration
- from clearml.binding.frameworks.pytorch_bind import PatchPyTorchModelIO
- from clearml.binding.matplotlib_bind import PatchedMatplotlib
-
- PatchPyTorchModelIO.update_current_task(None)
- PatchedMatplotlib.update_current_task(None)
- else:
- task = Task.init(
- project_name=trainer.args.project or "Ultralytics",
- task_name=trainer.args.name,
- tags=["Ultralytics"],
- output_uri=True,
- reuse_last_task_id=False,
- auto_connect_frameworks={"pytorch": False, "matplotlib": False},
- )
- LOGGER.warning(
- "ClearML Initialized a new task. If you want to run remotely, "
- "please add clearml-init and connect your arguments before initializing YOLO."
- )
- task.connect(vars(trainer.args), name="General")
- except Exception as e:
- LOGGER.warning(f"ClearML installed but not initialized correctly, not logging this run. {e}")
-
-
-def on_train_epoch_end(trainer) -> None:
- """Log debug samples for the first epoch and report current training progress."""
- if task := Task.current_task():
- # Log debug samples for first epoch only
- if trainer.epoch == 1:
- _log_debug_samples(sorted(trainer.save_dir.glob("train_batch*.jpg")), "Mosaic")
- # Report the current training progress
- for k, v in trainer.label_loss_items(trainer.tloss, prefix="train").items():
- task.get_logger().report_scalar("train", k, v, iteration=trainer.epoch)
- for k, v in trainer.lr.items():
- task.get_logger().report_scalar("lr", k, v, iteration=trainer.epoch)
-
-
-def on_fit_epoch_end(trainer) -> None:
- """Report model information and metrics to logger at the end of an epoch."""
- if task := Task.current_task():
- # Report epoch time and validation metrics
- task.get_logger().report_scalar(
- title="Epoch Time", series="Epoch Time", value=trainer.epoch_time, iteration=trainer.epoch
- )
- for k, v in trainer.metrics.items():
- title = k.split("/")[0]
- task.get_logger().report_scalar(title, k, v, iteration=trainer.epoch)
- if trainer.epoch == 0:
- from ultralytics.utils.torch_utils import model_info_for_loggers
-
- for k, v in model_info_for_loggers(trainer).items():
- task.get_logger().report_single_value(k, v)
-
-
-def on_val_end(validator) -> None:
- """Log validation results including labels and predictions."""
- if Task.current_task():
- # Log validation labels and predictions
- _log_debug_samples(sorted(validator.save_dir.glob("val*.jpg")), "Validation")
-
-
-def on_train_end(trainer) -> None:
- """Log final model and training results on training completion."""
- if task := Task.current_task():
- # Log final results, confusion matrix and PR plots
- files = [
- "results.png",
- "confusion_matrix.png",
- "confusion_matrix_normalized.png",
- *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
- ]
- files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter existing files
- for f in files:
- _log_plot(title=f.stem, plot_path=f)
- # Report final metrics
- for k, v in trainer.validator.metrics.results_dict.items():
- task.get_logger().report_single_value(k, v)
- # Log the final model
- task.update_output_model(model_path=str(trainer.best), model_name=trainer.args.name, auto_delete_file=False)
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_train_epoch_end": on_train_epoch_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_val_end": on_val_end,
- "on_train_end": on_train_end,
- }
- if clearml
- else {}
-)
diff --git a/ultralytics/utils/callbacks/comet.py b/ultralytics/utils/callbacks/comet.py
deleted file mode 100644
index f094113..0000000
--- a/ultralytics/utils/callbacks/comet.py
+++ /dev/null
@@ -1,639 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from collections.abc import Callable
-from types import SimpleNamespace
-from typing import Any
-
-import cv2
-import numpy as np
-
-from ultralytics.utils import LOGGER, RANK, SETTINGS, TESTS_RUNNING, ops
-from ultralytics.utils.metrics import ClassifyMetrics, DetMetrics, OBBMetrics, PoseMetrics, SegmentMetrics
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["comet"] is True # verify integration is enabled
- import comet_ml
-
- assert hasattr(comet_ml, "__version__") # verify package is not directory
-
- import os
- from pathlib import Path
-
- # Ensures certain logging functions only run for supported tasks
- COMET_SUPPORTED_TASKS = ["detect", "segment"]
-
- # Names of plots created by Ultralytics that are logged to Comet
- CONFUSION_MATRIX_PLOT_NAMES = "confusion_matrix", "confusion_matrix_normalized"
- EVALUATION_PLOT_NAMES = "F1_curve", "P_curve", "R_curve", "PR_curve"
- LABEL_PLOT_NAMES = ["labels"]
- SEGMENT_METRICS_PLOT_PREFIX = "Box", "Mask"
- POSE_METRICS_PLOT_PREFIX = "Box", "Pose"
- DETECTION_METRICS_PLOT_PREFIX = ["Box"]
- RESULTS_TABLE_NAME = "results.csv"
- ARGS_YAML_NAME = "args.yaml"
-
- _comet_image_prediction_count = 0
-
-except (ImportError, AssertionError):
- comet_ml = None
-
-
-def _get_comet_mode() -> str:
- """Return the Comet mode from environment variables, defaulting to 'online'."""
- comet_mode = os.getenv("COMET_MODE")
- if comet_mode is not None:
- LOGGER.warning(
- "The COMET_MODE environment variable is deprecated. "
- "Please use COMET_START_ONLINE to set the Comet experiment mode. "
- "To start an offline Comet experiment, use 'export COMET_START_ONLINE=0'. "
- "If COMET_START_ONLINE is not set or is set to '1', an online Comet experiment will be created."
- )
- return comet_mode
-
- return "online"
-
-
-def _get_comet_model_name() -> str:
- """Return the Comet model name from environment variable or default to 'Ultralytics'."""
- return os.getenv("COMET_MODEL_NAME", "Ultralytics")
-
-
-def _get_eval_batch_logging_interval() -> int:
- """Get the evaluation batch logging interval from environment variable or use default value 1."""
- return int(os.getenv("COMET_EVAL_BATCH_LOGGING_INTERVAL", 1))
-
-
-def _get_max_image_predictions_to_log() -> int:
- """Get the maximum number of image predictions to log from environment variables."""
- return int(os.getenv("COMET_MAX_IMAGE_PREDICTIONS", 100))
-
-
-def _scale_confidence_score(score: float) -> float:
- """Scale the confidence score by a factor specified in environment variable."""
- scale = float(os.getenv("COMET_MAX_CONFIDENCE_SCORE", 100.0))
- return score * scale
-
-
-def _should_log_confusion_matrix() -> bool:
- """Determine if the confusion matrix should be logged based on environment variable settings."""
- return os.getenv("COMET_EVAL_LOG_CONFUSION_MATRIX", "false").lower() == "true"
-
-
-def _should_log_image_predictions() -> bool:
- """Determine whether to log image predictions based on environment variable."""
- return os.getenv("COMET_EVAL_LOG_IMAGE_PREDICTIONS", "true").lower() == "true"
-
-
-def _resume_or_create_experiment(args: SimpleNamespace) -> None:
- """
- Resume CometML experiment or create a new experiment based on args.
-
- Ensures that the experiment object is only created in a single process during distributed training.
-
- Args:
- args (SimpleNamespace): Training arguments containing project configuration and other parameters.
- """
- if RANK not in {-1, 0}:
- return
-
- # Set environment variable (if not set by the user) to configure the Comet experiment's online mode under the hood.
- # IF COMET_START_ONLINE is set by the user it will override COMET_MODE value.
- if os.getenv("COMET_START_ONLINE") is None:
- comet_mode = _get_comet_mode()
- os.environ["COMET_START_ONLINE"] = "1" if comet_mode != "offline" else "0"
-
- try:
- _project_name = os.getenv("COMET_PROJECT_NAME", args.project)
- experiment = comet_ml.start(project_name=_project_name)
- experiment.log_parameters(vars(args))
- experiment.log_others(
- {
- "eval_batch_logging_interval": _get_eval_batch_logging_interval(),
- "log_confusion_matrix_on_eval": _should_log_confusion_matrix(),
- "log_image_predictions": _should_log_image_predictions(),
- "max_image_predictions": _get_max_image_predictions_to_log(),
- }
- )
- experiment.log_other("Created from", "ultralytics")
-
- except Exception as e:
- LOGGER.warning(f"Comet installed but not initialized correctly, not logging this run. {e}")
-
-
-def _fetch_trainer_metadata(trainer) -> dict:
- """
- Return metadata for YOLO training including epoch and asset saving status.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The YOLO trainer object containing training state and config.
-
- Returns:
- (dict): Dictionary containing current epoch, step, save assets flag, and final epoch flag.
- """
- curr_epoch = trainer.epoch + 1
-
- train_num_steps_per_epoch = len(trainer.train_loader.dataset) // trainer.batch_size
- curr_step = curr_epoch * train_num_steps_per_epoch
- final_epoch = curr_epoch == trainer.epochs
-
- save = trainer.args.save
- save_period = trainer.args.save_period
- save_interval = curr_epoch % save_period == 0
- save_assets = save and save_period > 0 and save_interval and not final_epoch
-
- return dict(curr_epoch=curr_epoch, curr_step=curr_step, save_assets=save_assets, final_epoch=final_epoch)
-
-
-def _scale_bounding_box_to_original_image_shape(
- box, resized_image_shape, original_image_shape, ratio_pad
-) -> list[float]:
- """
- Scale bounding box from resized image coordinates to original image coordinates.
-
- YOLO resizes images during training and the label values are normalized based on this resized shape.
- This function rescales the bounding box labels to the original image shape.
-
- Args:
- box (torch.Tensor): Bounding box in normalized xywh format.
- resized_image_shape (tuple): Shape of the resized image (height, width).
- original_image_shape (tuple): Shape of the original image (height, width).
- ratio_pad (tuple): Ratio and padding information for scaling.
-
- Returns:
- (list[float]): Scaled bounding box coordinates in xywh format with top-left corner adjustment.
- """
- resized_image_height, resized_image_width = resized_image_shape
-
- # Convert normalized xywh format predictions to xyxy in resized scale format
- box = ops.xywhn2xyxy(box, h=resized_image_height, w=resized_image_width)
- # Scale box predictions from resized image scale back to original image scale
- box = ops.scale_boxes(resized_image_shape, box, original_image_shape, ratio_pad)
- # Convert bounding box format from xyxy to xywh for Comet logging
- box = ops.xyxy2xywh(box)
- # Adjust xy center to correspond top-left corner
- box[:2] -= box[2:] / 2
- box = box.tolist()
-
- return box
-
-
-def _format_ground_truth_annotations_for_detection(img_idx, image_path, batch, class_name_map=None) -> dict | None:
- """
- Format ground truth annotations for object detection.
-
- This function processes ground truth annotations from a batch of images for object detection tasks. It extracts
- bounding boxes, class labels, and other metadata for a specific image in the batch, and formats them for
- visualization or evaluation.
-
- Args:
- img_idx (int): Index of the image in the batch to process.
- image_path (str | Path): Path to the image file.
- batch (dict): Batch dictionary containing detection data with keys:
- - 'batch_idx': Tensor of batch indices
- - 'bboxes': Tensor of bounding boxes in normalized xywh format
- - 'cls': Tensor of class labels
- - 'ori_shape': Original image shapes
- - 'resized_shape': Resized image shapes
- - 'ratio_pad': Ratio and padding information
- class_name_map (dict, optional): Mapping from class indices to class names.
-
- Returns:
- (dict | None): Formatted ground truth annotations with the following structure:
- - 'boxes': List of box coordinates [x, y, width, height]
- - 'label': Label string with format "gt_{class_name}"
- - 'score': Confidence score (always 1.0, scaled by _scale_confidence_score)
- Returns None if no bounding boxes are found for the image.
- """
- indices = batch["batch_idx"] == img_idx
- bboxes = batch["bboxes"][indices]
- if len(bboxes) == 0:
- LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes labels")
- return None
-
- cls_labels = batch["cls"][indices].squeeze(1).tolist()
- if class_name_map:
- cls_labels = [str(class_name_map[label]) for label in cls_labels]
-
- original_image_shape = batch["ori_shape"][img_idx]
- resized_image_shape = batch["resized_shape"][img_idx]
- ratio_pad = batch["ratio_pad"][img_idx]
-
- data = []
- for box, label in zip(bboxes, cls_labels):
- box = _scale_bounding_box_to_original_image_shape(box, resized_image_shape, original_image_shape, ratio_pad)
- data.append(
- {
- "boxes": [box],
- "label": f"gt_{label}",
- "score": _scale_confidence_score(1.0),
- }
- )
-
- return {"name": "ground_truth", "data": data}
-
-
-def _format_prediction_annotations(image_path, metadata, class_label_map=None, class_map=None) -> dict | None:
- """
- Format YOLO predictions for object detection visualization.
-
- Args:
- image_path (Path): Path to the image file.
- metadata (dict): Prediction metadata containing bounding boxes and class information.
- class_label_map (dict, optional): Mapping from class indices to class names.
- class_map (dict, optional): Additional class mapping for label conversion.
-
- Returns:
- (dict | None): Formatted prediction annotations or None if no predictions exist.
- """
- stem = image_path.stem
- image_id = int(stem) if stem.isnumeric() else stem
-
- predictions = metadata.get(image_id)
- if not predictions:
- LOGGER.debug(f"Comet Image: {image_path} has no bounding boxes predictions")
- return None
-
- # apply the mapping that was used to map the predicted classes when the JSON was created
- if class_label_map and class_map:
- class_label_map = {class_map[k]: v for k, v in class_label_map.items()}
- try:
- # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation
- from faster_coco_eval.core.mask import decode # noqa
- except ImportError:
- decode = None
-
- data = []
- for prediction in predictions:
- boxes = prediction["bbox"]
- score = _scale_confidence_score(prediction["score"])
- cls_label = prediction["category_id"]
- if class_label_map:
- cls_label = str(class_label_map[cls_label])
-
- annotation_data = {"boxes": [boxes], "label": cls_label, "score": score}
-
- if decode is not None:
- # do segmentation processing only if we are able to decode it
- segments = prediction.get("segmentation", None)
- if segments is not None:
- segments = _extract_segmentation_annotation(segments, decode)
- if segments is not None:
- annotation_data["points"] = segments
-
- data.append(annotation_data)
-
- return {"name": "prediction", "data": data}
-
-
-def _extract_segmentation_annotation(segmentation_raw: str, decode: Callable) -> list[list[Any]] | None:
- """
- Extract segmentation annotation from compressed segmentations as list of polygons.
-
- Args:
- segmentation_raw (str): Raw segmentation data in compressed format.
- decode (Callable): Function to decode the compressed segmentation data.
-
- Returns:
- (list[list[Any]] | None): List of polygon points or None if extraction fails.
- """
- try:
- mask = decode(segmentation_raw)
- contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
- annotations = [np.array(polygon).squeeze() for polygon in contours if len(polygon) >= 3]
- return [annotation.ravel().tolist() for annotation in annotations]
- except Exception as e:
- LOGGER.warning(f"Comet Failed to extract segmentation annotation: {e}")
- return None
-
-
-def _fetch_annotations(img_idx, image_path, batch, prediction_metadata_map, class_label_map, class_map) -> list | None:
- """
- Join the ground truth and prediction annotations if they exist.
-
- Args:
- img_idx (int): Index of the image in the batch.
- image_path (Path): Path to the image file.
- batch (dict): Batch data containing ground truth annotations.
- prediction_metadata_map (dict): Map of prediction metadata by image ID.
- class_label_map (dict): Mapping from class indices to class names.
- class_map (dict): Additional class mapping for label conversion.
-
- Returns:
- (list | None): List of annotation dictionaries or None if no annotations exist.
- """
- ground_truth_annotations = _format_ground_truth_annotations_for_detection(
- img_idx, image_path, batch, class_label_map
- )
- prediction_annotations = _format_prediction_annotations(
- image_path, prediction_metadata_map, class_label_map, class_map
- )
-
- annotations = [
- annotation for annotation in [ground_truth_annotations, prediction_annotations] if annotation is not None
- ]
- return [annotations] if annotations else None
-
-
-def _create_prediction_metadata_map(model_predictions) -> dict:
- """Create metadata map for model predictions by grouping them based on image ID."""
- pred_metadata_map = {}
- for prediction in model_predictions:
- pred_metadata_map.setdefault(prediction["image_id"], [])
- pred_metadata_map[prediction["image_id"]].append(prediction)
-
- return pred_metadata_map
-
-
-def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
- """Log the confusion matrix to Comet experiment."""
- conf_mat = trainer.validator.confusion_matrix.matrix
- names = list(trainer.data["names"].values()) + ["background"]
- experiment.log_confusion_matrix(
- matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
- )
-
-
-def _log_images(experiment, image_paths, curr_step: int | None, annotations=None) -> None:
- """
- Log images to the experiment with optional annotations.
-
- This function logs images to a Comet ML experiment, optionally including annotation data for visualization
- such as bounding boxes or segmentation masks.
-
- Args:
- experiment (comet_ml.CometExperiment): The Comet ML experiment to log images to.
- image_paths (list[Path]): List of paths to images that will be logged.
- curr_step (int): Current training step/iteration for tracking in the experiment timeline.
- annotations (list[list[dict]], optional): Nested list of annotation dictionaries for each image. Each
- annotation contains visualization data like bounding boxes, labels, and confidence scores.
- """
- if annotations:
- for image_path, annotation in zip(image_paths, annotations):
- experiment.log_image(image_path, name=image_path.stem, step=curr_step, annotations=annotation)
-
- else:
- for image_path in image_paths:
- experiment.log_image(image_path, name=image_path.stem, step=curr_step)
-
-
-def _log_image_predictions(experiment, validator, curr_step) -> None:
- """
- Log predicted boxes for a single image during training.
-
- This function logs image predictions to a Comet ML experiment during model validation. It processes
- validation data and formats both ground truth and prediction annotations for visualization in the Comet
- dashboard. The function respects configured limits on the number of images to log.
-
- Args:
- experiment (comet_ml.CometExperiment): The Comet ML experiment to log to.
- validator (BaseValidator): The validator instance containing validation data and predictions.
- curr_step (int): The current training step for logging timeline.
-
- Notes:
- This function uses global state to track the number of logged predictions across calls.
- It only logs predictions for supported tasks defined in COMET_SUPPORTED_TASKS.
- The number of logged images is limited by the COMET_MAX_IMAGE_PREDICTIONS environment variable.
- """
- global _comet_image_prediction_count
-
- task = validator.args.task
- if task not in COMET_SUPPORTED_TASKS:
- return
-
- jdict = validator.jdict
- if not jdict:
- return
-
- predictions_metadata_map = _create_prediction_metadata_map(jdict)
- dataloader = validator.dataloader
- class_label_map = validator.names
- class_map = getattr(validator, "class_map", None)
-
- batch_logging_interval = _get_eval_batch_logging_interval()
- max_image_predictions = _get_max_image_predictions_to_log()
-
- for batch_idx, batch in enumerate(dataloader):
- if (batch_idx + 1) % batch_logging_interval != 0:
- continue
-
- image_paths = batch["im_file"]
- for img_idx, image_path in enumerate(image_paths):
- if _comet_image_prediction_count >= max_image_predictions:
- return
-
- image_path = Path(image_path)
- annotations = _fetch_annotations(
- img_idx,
- image_path,
- batch,
- predictions_metadata_map,
- class_label_map,
- class_map=class_map,
- )
- _log_images(
- experiment,
- [image_path],
- curr_step,
- annotations=annotations,
- )
- _comet_image_prediction_count += 1
-
-
-def _log_plots(experiment, trainer) -> None:
- """
- Log evaluation plots and label plots for the experiment.
-
- This function logs various evaluation plots and confusion matrices to the experiment tracking system. It handles
- different types of metrics (SegmentMetrics, PoseMetrics, DetMetrics, OBBMetrics) and logs the appropriate plots
- for each type.
-
- Args:
- experiment (comet_ml.CometExperiment): The Comet ML experiment to log plots to.
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing validation metrics and save
- directory information.
-
- Examples:
- >>> from ultralytics.utils.callbacks.comet import _log_plots
- >>> _log_plots(experiment, trainer)
- """
- plot_filenames = None
- if isinstance(trainer.validator.metrics, SegmentMetrics):
- plot_filenames = [
- trainer.save_dir / f"{prefix}{plots}.png"
- for plots in EVALUATION_PLOT_NAMES
- for prefix in SEGMENT_METRICS_PLOT_PREFIX
- ]
- elif isinstance(trainer.validator.metrics, PoseMetrics):
- plot_filenames = [
- trainer.save_dir / f"{prefix}{plots}.png"
- for plots in EVALUATION_PLOT_NAMES
- for prefix in POSE_METRICS_PLOT_PREFIX
- ]
- elif isinstance(trainer.validator.metrics, (DetMetrics, OBBMetrics)):
- plot_filenames = [
- trainer.save_dir / f"{prefix}{plots}.png"
- for plots in EVALUATION_PLOT_NAMES
- for prefix in DETECTION_METRICS_PLOT_PREFIX
- ]
-
- if plot_filenames is not None:
- _log_images(experiment, plot_filenames, None)
-
- confusion_matrix_filenames = [trainer.save_dir / f"{plots}.png" for plots in CONFUSION_MATRIX_PLOT_NAMES]
- _log_images(experiment, confusion_matrix_filenames, None)
-
- if not isinstance(trainer.validator.metrics, ClassifyMetrics):
- label_plot_filenames = [trainer.save_dir / f"{labels}.jpg" for labels in LABEL_PLOT_NAMES]
- _log_images(experiment, label_plot_filenames, None)
-
-
-def _log_model(experiment, trainer) -> None:
- """Log the best-trained model to Comet.ml."""
- model_name = _get_comet_model_name()
- experiment.log_model(model_name, file_or_folder=str(trainer.best), file_name="best.pt", overwrite=True)
-
-
-def _log_image_batches(experiment, trainer, curr_step: int) -> None:
- """Log samples of image batches for train, validation, and test."""
- _log_images(experiment, trainer.save_dir.glob("train_batch*.jpg"), curr_step)
- _log_images(experiment, trainer.save_dir.glob("val_batch*.jpg"), curr_step)
-
-
-def _log_asset(experiment, asset_path) -> None:
- """
- Logs a specific asset file to the given experiment.
-
- This function facilitates logging an asset, such as a file, to the provided
- experiment. It enables integration with experiment tracking platforms.
-
- Args:
- experiment (comet_ml.CometExperiment): The experiment instance to which the asset will be logged.
- asset_path (Path): The file path of the asset to log.
- """
- experiment.log_asset(asset_path)
-
-
-def _log_table(experiment, table_path) -> None:
- """
- Logs a table to the provided experiment.
-
- This function is used to log a table file to the given experiment. The table
- is identified by its file path.
-
- Args:
- experiment (comet_ml.CometExperiment): The experiment object where the table file will be logged.
- table_path (Path): The file path of the table to be logged.
- """
- experiment.log_table(str(table_path))
-
-
-def on_pretrain_routine_start(trainer) -> None:
- """Create or resume a CometML experiment at the start of a YOLO pre-training routine."""
- _resume_or_create_experiment(trainer.args)
-
-
-def on_train_epoch_end(trainer) -> None:
- """Log metrics and save batch images at the end of training epochs."""
- experiment = comet_ml.get_running_experiment()
- if not experiment:
- return
-
- metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata["curr_epoch"]
- curr_step = metadata["curr_step"]
-
- experiment.log_metrics(trainer.label_loss_items(trainer.tloss, prefix="train"), step=curr_step, epoch=curr_epoch)
-
-
-def on_fit_epoch_end(trainer) -> None:
- """
- Log model assets at the end of each epoch during training.
-
- This function is called at the end of each training epoch to log metrics, learning rates, and model information
- to a Comet ML experiment. It also logs model assets, confusion matrices, and image predictions based on
- configuration settings.
-
- The function retrieves the current Comet ML experiment and logs various training metrics. If it's the first epoch,
- it also logs model information. On specified save intervals, it logs the model, confusion matrix (if enabled),
- and image predictions (if enabled).
-
- Args:
- trainer (BaseTrainer): The YOLO trainer object containing training state, metrics, and configuration.
-
- Examples:
- >>> # Inside a training loop
- >>> on_fit_epoch_end(trainer) # Log metrics and assets to Comet ML
- """
- experiment = comet_ml.get_running_experiment()
- if not experiment:
- return
-
- metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata["curr_epoch"]
- curr_step = metadata["curr_step"]
- save_assets = metadata["save_assets"]
-
- experiment.log_metrics(trainer.metrics, step=curr_step, epoch=curr_epoch)
- experiment.log_metrics(trainer.lr, step=curr_step, epoch=curr_epoch)
- if curr_epoch == 1:
- from ultralytics.utils.torch_utils import model_info_for_loggers
-
- experiment.log_metrics(model_info_for_loggers(trainer), step=curr_step, epoch=curr_epoch)
-
- if not save_assets:
- return
-
- _log_model(experiment, trainer)
- if _should_log_confusion_matrix():
- _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
- if _should_log_image_predictions():
- _log_image_predictions(experiment, trainer.validator, curr_step)
-
-
-def on_train_end(trainer) -> None:
- """Perform operations at the end of training."""
- experiment = comet_ml.get_running_experiment()
- if not experiment:
- return
-
- metadata = _fetch_trainer_metadata(trainer)
- curr_epoch = metadata["curr_epoch"]
- curr_step = metadata["curr_step"]
- plots = trainer.args.plots
-
- _log_model(experiment, trainer)
- if plots:
- _log_plots(experiment, trainer)
-
- _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch)
- _log_image_predictions(experiment, trainer.validator, curr_step)
- _log_image_batches(experiment, trainer, curr_step)
- # log results table
- table_path = trainer.save_dir / RESULTS_TABLE_NAME
- if table_path.exists():
- _log_table(experiment, table_path)
-
- # log arguments YAML
- args_path = trainer.save_dir / ARGS_YAML_NAME
- if args_path.exists():
- _log_asset(experiment, args_path)
-
- experiment.end()
-
- global _comet_image_prediction_count
- _comet_image_prediction_count = 0
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_train_epoch_end": on_train_epoch_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_end": on_train_end,
- }
- if comet_ml
- else {}
-)
diff --git a/ultralytics/utils/callbacks/dvc.py b/ultralytics/utils/callbacks/dvc.py
deleted file mode 100644
index 35a16d7..0000000
--- a/ultralytics/utils/callbacks/dvc.py
+++ /dev/null
@@ -1,202 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from pathlib import Path
-
-from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, checks
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["dvc"] is True # verify integration is enabled
- import dvclive
-
- assert checks.check_version("dvclive", "2.11.0", verbose=True)
-
- import os
- import re
-
- # DVCLive logger instance
- live = None
- _processed_plots = {}
-
- # `on_fit_epoch_end` is called on final validation (probably need to be fixed) for now this is the way we
- # distinguish final evaluation of the best model vs last epoch validation
- _training_epoch = False
-
-except (ImportError, AssertionError, TypeError):
- dvclive = None
-
-
-def _log_images(path: Path, prefix: str = "") -> None:
- """
- Log images at specified path with an optional prefix using DVCLive.
-
- This function logs images found at the given path to DVCLive, organizing them by batch to enable slider
- functionality in the UI. It processes image filenames to extract batch information and restructures the path
- accordingly.
-
- Args:
- path (Path): Path to the image file to be logged.
- prefix (str, optional): Optional prefix to add to the image name when logging.
-
- Examples:
- >>> from pathlib import Path
- >>> _log_images(Path("runs/train/exp/val_batch0_pred.jpg"), prefix="validation")
- """
- if live:
- name = path.name
-
- # Group images by batch to enable sliders in UI
- if m := re.search(r"_batch(\d+)", name):
- ni = m[1]
- new_stem = re.sub(r"_batch(\d+)", "_batch", path.stem)
- name = (Path(new_stem) / ni).with_suffix(path.suffix)
-
- live.log_image(os.path.join(prefix, name), path)
-
-
-def _log_plots(plots: dict, prefix: str = "") -> None:
- """
- Log plot images for training progress if they have not been previously processed.
-
- Args:
- plots (dict): Dictionary containing plot information with timestamps.
- prefix (str, optional): Optional prefix to add to the logged image paths.
- """
- for name, params in plots.items():
- timestamp = params["timestamp"]
- if _processed_plots.get(name) != timestamp:
- _log_images(name, prefix)
- _processed_plots[name] = timestamp
-
-
-def _log_confusion_matrix(validator) -> None:
- """
- Log confusion matrix for a validator using DVCLive.
-
- This function processes the confusion matrix from a validator object and logs it to DVCLive by converting
- the matrix into lists of target and prediction labels.
-
- Args:
- validator (BaseValidator): The validator object containing the confusion matrix and class names. Must have
- attributes: confusion_matrix.matrix, confusion_matrix.task, and names.
- """
- targets = []
- preds = []
- matrix = validator.confusion_matrix.matrix
- names = list(validator.names.values())
- if validator.confusion_matrix.task == "detect":
- names += ["background"]
-
- for ti, pred in enumerate(matrix.T.astype(int)):
- for pi, num in enumerate(pred):
- targets.extend([names[ti]] * num)
- preds.extend([names[pi]] * num)
-
- live.log_sklearn_plot("confusion_matrix", targets, preds, name="cf.json", normalized=True)
-
-
-def on_pretrain_routine_start(trainer) -> None:
- """Initialize DVCLive logger for training metadata during pre-training routine."""
- try:
- global live
- live = dvclive.Live(save_dvc_exp=True, cache_images=True)
- LOGGER.info("DVCLive is detected and auto logging is enabled (run 'yolo settings dvc=False' to disable).")
- except Exception as e:
- LOGGER.warning(f"DVCLive installed but not initialized correctly, not logging this run. {e}")
-
-
-def on_pretrain_routine_end(trainer) -> None:
- """Log plots related to the training process at the end of the pretraining routine."""
- _log_plots(trainer.plots, "train")
-
-
-def on_train_start(trainer) -> None:
- """Log the training parameters if DVCLive logging is active."""
- if live:
- live.log_params(trainer.args)
-
-
-def on_train_epoch_start(trainer) -> None:
- """Set the global variable _training_epoch value to True at the start of training each epoch."""
- global _training_epoch
- _training_epoch = True
-
-
-def on_fit_epoch_end(trainer) -> None:
- """
- Log training metrics, model info, and advance to next step at the end of each fit epoch.
-
- This function is called at the end of each fit epoch during training. It logs various metrics including
- training loss items, validation metrics, and learning rates. On the first epoch, it also logs model
- information. Additionally, it logs training and validation plots and advances the DVCLive step counter.
-
- Args:
- trainer (BaseTrainer): The trainer object containing training state, metrics, and plots.
-
- Notes:
- This function only performs logging operations when DVCLive logging is active and during a training epoch.
- The global variable _training_epoch is used to track whether the current epoch is a training epoch.
- """
- global _training_epoch
- if live and _training_epoch:
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
- for metric, value in all_metrics.items():
- live.log_metric(metric, value)
-
- if trainer.epoch == 0:
- from ultralytics.utils.torch_utils import model_info_for_loggers
-
- for metric, value in model_info_for_loggers(trainer).items():
- live.log_metric(metric, value, plot=False)
-
- _log_plots(trainer.plots, "train")
- _log_plots(trainer.validator.plots, "val")
-
- live.next_step()
- _training_epoch = False
-
-
-def on_train_end(trainer) -> None:
- """
- Log best metrics, plots, and confusion matrix at the end of training.
-
- This function is called at the conclusion of the training process to log final metrics, visualizations, and
- model artifacts if DVCLive logging is active. It captures the best model performance metrics, training plots,
- validation plots, and confusion matrix for later analysis.
-
- Args:
- trainer (BaseTrainer): The trainer object containing training state, metrics, and validation results.
-
- Examples:
- >>> # Inside a custom training loop
- >>> from ultralytics.utils.callbacks.dvc import on_train_end
- >>> on_train_end(trainer) # Log final metrics and artifacts
- """
- if live:
- # At the end log the best metrics. It runs validator on the best model internally.
- all_metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics, **trainer.lr}
- for metric, value in all_metrics.items():
- live.log_metric(metric, value, plot=False)
-
- _log_plots(trainer.plots, "val")
- _log_plots(trainer.validator.plots, "val")
- _log_confusion_matrix(trainer.validator)
-
- if trainer.best.exists():
- live.log_artifact(trainer.best, copy=True, type="model")
-
- live.end()
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_pretrain_routine_end": on_pretrain_routine_end,
- "on_train_start": on_train_start,
- "on_train_epoch_start": on_train_epoch_start,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_end": on_train_end,
- }
- if dvclive
- else {}
-)
diff --git a/ultralytics/utils/callbacks/hub.py b/ultralytics/utils/callbacks/hub.py
deleted file mode 100644
index 2b57cd1..0000000
--- a/ultralytics/utils/callbacks/hub.py
+++ /dev/null
@@ -1,110 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import json
-from time import time
-
-from ultralytics.hub import HUB_WEB_ROOT, PREFIX, HUBTrainingSession
-from ultralytics.utils import LOGGER, RANK, SETTINGS
-from ultralytics.utils.events import events
-
-
-def on_pretrain_routine_start(trainer):
- """Create a remote Ultralytics HUB session to log local model training."""
- if RANK in {-1, 0} and SETTINGS["hub"] is True and SETTINGS["api_key"] and trainer.hub_session is None:
- trainer.hub_session = HUBTrainingSession.create_session(trainer.args.model, trainer.args)
-
-
-def on_pretrain_routine_end(trainer):
- """Initialize timers for upload rate limiting before training begins."""
- if session := getattr(trainer, "hub_session", None):
- # Start timer for upload rate limit
- session.timers = {"metrics": time(), "ckpt": time()} # start timer for session rate limiting
-
-
-def on_fit_epoch_end(trainer):
- """Upload training progress metrics to Ultralytics HUB at the end of each epoch."""
- if session := getattr(trainer, "hub_session", None):
- # Upload metrics after validation ends
- all_plots = {
- **trainer.label_loss_items(trainer.tloss, prefix="train"),
- **trainer.metrics,
- }
- if trainer.epoch == 0:
- from ultralytics.utils.torch_utils import model_info_for_loggers
-
- all_plots = {**all_plots, **model_info_for_loggers(trainer)}
-
- session.metrics_queue[trainer.epoch] = json.dumps(all_plots)
-
- # If any metrics failed to upload previously, add them to the queue to attempt uploading again
- if session.metrics_upload_failed_queue:
- session.metrics_queue.update(session.metrics_upload_failed_queue)
-
- if time() - session.timers["metrics"] > session.rate_limits["metrics"]:
- session.upload_metrics()
- session.timers["metrics"] = time() # reset timer
- session.metrics_queue = {} # reset queue
-
-
-def on_model_save(trainer):
- """Upload model checkpoints to Ultralytics HUB with rate limiting."""
- if session := getattr(trainer, "hub_session", None):
- # Upload checkpoints with rate limiting
- is_best = trainer.best_fitness == trainer.fitness
- if time() - session.timers["ckpt"] > session.rate_limits["ckpt"]:
- LOGGER.info(f"{PREFIX}Uploading checkpoint {HUB_WEB_ROOT}/models/{session.model.id}")
- session.upload_model(trainer.epoch, trainer.last, is_best)
- session.timers["ckpt"] = time() # reset timer
-
-
-def on_train_end(trainer):
- """Upload final model and metrics to Ultralytics HUB at the end of training."""
- if session := getattr(trainer, "hub_session", None):
- # Upload final model and metrics with exponential standoff
- LOGGER.info(f"{PREFIX}Syncing final model...")
- session.upload_model(
- trainer.epoch,
- trainer.best,
- map=trainer.metrics.get("metrics/mAP50-95(B)", 0),
- final=True,
- )
- session.alive = False # stop heartbeats
- LOGGER.info(f"{PREFIX}Done ✅\n{PREFIX}View model at {session.model_url} 🚀")
-
-
-def on_train_start(trainer):
- """Run events on train start."""
- events(trainer.args, trainer.device)
-
-
-def on_val_start(validator):
- """Run events on validation start."""
- if not validator.training:
- events(validator.args, validator.device)
-
-
-def on_predict_start(predictor):
- """Run events on predict start."""
- events(predictor.args, predictor.device)
-
-
-def on_export_start(exporter):
- """Run events on export start."""
- events(exporter.args, exporter.device)
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_pretrain_routine_end": on_pretrain_routine_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_model_save": on_model_save,
- "on_train_end": on_train_end,
- "on_train_start": on_train_start,
- "on_val_start": on_val_start,
- "on_predict_start": on_predict_start,
- "on_export_start": on_export_start,
- }
- if SETTINGS["hub"] is True
- else {}
-)
diff --git a/ultralytics/utils/callbacks/mlflow.py b/ultralytics/utils/callbacks/mlflow.py
deleted file mode 100644
index f570240..0000000
--- a/ultralytics/utils/callbacks/mlflow.py
+++ /dev/null
@@ -1,135 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""
-MLflow Logging for Ultralytics YOLO.
-
-This module enables MLflow logging for Ultralytics YOLO. It logs metrics, parameters, and model artifacts.
-For setting up, a tracking URI should be specified. The logging can be customized using environment variables.
-
-Commands:
- 1. To set a project name:
- `export MLFLOW_EXPERIMENT_NAME=` or use the project= argument
-
- 2. To set a run name:
- `export MLFLOW_RUN=` or use the name= argument
-
- 3. To start a local MLflow server:
- mlflow server --backend-store-uri runs/mlflow
- It will by default start a local server at http://127.0.0.1:5000.
- To specify a different URI, set the MLFLOW_TRACKING_URI environment variable.
-
- 4. To kill all running MLflow server instances:
- ps aux | grep 'mlflow' | grep -v 'grep' | awk '{print $2}' | xargs kill -9
-"""
-
-from ultralytics.utils import LOGGER, RUNS_DIR, SETTINGS, TESTS_RUNNING, colorstr
-
-try:
- import os
-
- assert not TESTS_RUNNING or "test_mlflow" in os.environ.get("PYTEST_CURRENT_TEST", "") # do not log pytest
- assert SETTINGS["mlflow"] is True # verify integration is enabled
- import mlflow
-
- assert hasattr(mlflow, "__version__") # verify package is not directory
- from pathlib import Path
-
- PREFIX = colorstr("MLflow: ")
-
-except (ImportError, AssertionError):
- mlflow = None
-
-
-def sanitize_dict(x: dict) -> dict:
- """Sanitize dictionary keys by removing parentheses and converting values to floats."""
- return {k.replace("(", "").replace(")", ""): float(v) for k, v in x.items()}
-
-
-def on_pretrain_routine_end(trainer):
- """
- Log training parameters to MLflow at the end of the pretraining routine.
-
- This function sets up MLflow logging based on environment variables and trainer arguments. It sets the tracking URI,
- experiment name, and run name, then starts the MLflow run if not already active. It finally logs the parameters
- from the trainer.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The training object with arguments and parameters to log.
-
- Environment Variables:
- MLFLOW_TRACKING_URI: The URI for MLflow tracking. If not set, defaults to 'runs/mlflow'.
- MLFLOW_EXPERIMENT_NAME: The name of the MLflow experiment. If not set, defaults to trainer.args.project.
- MLFLOW_RUN: The name of the MLflow run. If not set, defaults to trainer.args.name.
- MLFLOW_KEEP_RUN_ACTIVE: Boolean indicating whether to keep the MLflow run active after training ends.
- """
- global mlflow
-
- uri = os.environ.get("MLFLOW_TRACKING_URI") or str(RUNS_DIR / "mlflow")
- LOGGER.debug(f"{PREFIX} tracking uri: {uri}")
- mlflow.set_tracking_uri(uri)
-
- # Set experiment and run names
- experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME") or trainer.args.project or "/Shared/Ultralytics"
- run_name = os.environ.get("MLFLOW_RUN") or trainer.args.name
- mlflow.set_experiment(experiment_name)
-
- mlflow.autolog()
- try:
- active_run = mlflow.active_run() or mlflow.start_run(run_name=run_name)
- LOGGER.info(f"{PREFIX}logging run_id({active_run.info.run_id}) to {uri}")
- if Path(uri).is_dir():
- LOGGER.info(f"{PREFIX}view at http://127.0.0.1:5000 with 'mlflow server --backend-store-uri {uri}'")
- LOGGER.info(f"{PREFIX}disable with 'yolo settings mlflow=False'")
- mlflow.log_params(dict(trainer.args))
- except Exception as e:
- LOGGER.warning(f"{PREFIX}Failed to initialize: {e}")
- LOGGER.warning(f"{PREFIX}Not tracking this run")
-
-
-def on_train_epoch_end(trainer):
- """Log training metrics at the end of each train epoch to MLflow."""
- if mlflow:
- mlflow.log_metrics(
- metrics={
- **sanitize_dict(trainer.lr),
- **sanitize_dict(trainer.label_loss_items(trainer.tloss, prefix="train")),
- },
- step=trainer.epoch,
- )
-
-
-def on_fit_epoch_end(trainer):
- """Log training metrics at the end of each fit epoch to MLflow."""
- if mlflow:
- mlflow.log_metrics(metrics=sanitize_dict(trainer.metrics), step=trainer.epoch)
-
-
-def on_train_end(trainer):
- """Log model artifacts at the end of training."""
- if not mlflow:
- return
- mlflow.log_artifact(str(trainer.best.parent)) # log save_dir/weights directory with best.pt and last.pt
- for f in trainer.save_dir.glob("*"): # log all other files in save_dir
- if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
- mlflow.log_artifact(str(f))
- keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
- if keep_run_active:
- LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
- else:
- mlflow.end_run()
- LOGGER.debug(f"{PREFIX}mlflow run ended")
-
- LOGGER.info(
- f"{PREFIX}results logged to {mlflow.get_tracking_uri()}\n{PREFIX}disable with 'yolo settings mlflow=False'"
- )
-
-
-callbacks = (
- {
- "on_pretrain_routine_end": on_pretrain_routine_end,
- "on_train_epoch_end": on_train_epoch_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_end": on_train_end,
- }
- if mlflow
- else {}
-)
diff --git a/ultralytics/utils/callbacks/neptune.py b/ultralytics/utils/callbacks/neptune.py
deleted file mode 100644
index b27964b..0000000
--- a/ultralytics/utils/callbacks/neptune.py
+++ /dev/null
@@ -1,134 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["neptune"] is True # verify integration is enabled
-
- import neptune
- from neptune.types import File
-
- assert hasattr(neptune, "__version__")
-
- run = None # NeptuneAI experiment logger instance
-
-except (ImportError, AssertionError):
- neptune = None
-
-
-def _log_scalars(scalars: dict, step: int = 0) -> None:
- """
- Log scalars to the NeptuneAI experiment logger.
-
- Args:
- scalars (dict): Dictionary of scalar values to log to NeptuneAI.
- step (int, optional): The current step or iteration number for logging.
-
- Examples:
- >>> metrics = {"mAP": 0.85, "loss": 0.32}
- >>> _log_scalars(metrics, step=100)
- """
- if run:
- for k, v in scalars.items():
- run[k].append(value=v, step=step)
-
-
-def _log_images(imgs_dict: dict, group: str = "") -> None:
- """
- Log images to the NeptuneAI experiment logger.
-
- This function logs image data to Neptune.ai when a valid Neptune run is active. Images are organized
- under the specified group name.
-
- Args:
- imgs_dict (dict): Dictionary of images to log, with keys as image names and values as image data.
- group (str, optional): Group name to organize images under in the Neptune UI.
-
- Examples:
- >>> # Log validation images
- >>> _log_images({"val_batch": img_tensor}, group="validation")
- """
- if run:
- for k, v in imgs_dict.items():
- run[f"{group}/{k}"].upload(File(v))
-
-
-def _log_plot(title: str, plot_path: str) -> None:
- """Log plots to the NeptuneAI experiment logger."""
- import matplotlib.image as mpimg
- import matplotlib.pyplot as plt
-
- img = mpimg.imread(plot_path)
- fig = plt.figure()
- ax = fig.add_axes([0, 0, 1, 1], frameon=False, aspect="auto", xticks=[], yticks=[]) # no ticks
- ax.imshow(img)
- run[f"Plots/{title}"].upload(fig)
-
-
-def on_pretrain_routine_start(trainer) -> None:
- """Initialize NeptuneAI run and log hyperparameters before training starts."""
- try:
- global run
- run = neptune.init_run(
- project=trainer.args.project or "Ultralytics",
- name=trainer.args.name,
- tags=["Ultralytics"],
- )
- run["Configuration/Hyperparameters"] = {k: "" if v is None else v for k, v in vars(trainer.args).items()}
- except Exception as e:
- LOGGER.warning(f"NeptuneAI installed but not initialized correctly, not logging this run. {e}")
-
-
-def on_train_epoch_end(trainer) -> None:
- """Log training metrics and learning rate at the end of each training epoch."""
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
- _log_scalars(trainer.lr, trainer.epoch + 1)
- if trainer.epoch == 1:
- _log_images({f.stem: str(f) for f in trainer.save_dir.glob("train_batch*.jpg")}, "Mosaic")
-
-
-def on_fit_epoch_end(trainer) -> None:
- """Log model info and validation metrics at the end of each fit epoch."""
- if run and trainer.epoch == 0:
- from ultralytics.utils.torch_utils import model_info_for_loggers
-
- run["Configuration/Model"] = model_info_for_loggers(trainer)
- _log_scalars(trainer.metrics, trainer.epoch + 1)
-
-
-def on_val_end(validator) -> None:
- """Log validation images at the end of validation."""
- if run:
- # Log val_labels and val_pred
- _log_images({f.stem: str(f) for f in validator.save_dir.glob("val*.jpg")}, "Validation")
-
-
-def on_train_end(trainer) -> None:
- """Log final results, plots, and model weights at the end of training."""
- if run:
- # Log final results, CM matrix + PR plots
- files = [
- "results.png",
- "confusion_matrix.png",
- "confusion_matrix_normalized.png",
- *(f"{x}_curve.png" for x in ("F1", "PR", "P", "R")),
- ]
- files = [(trainer.save_dir / f) for f in files if (trainer.save_dir / f).exists()] # filter
- for f in files:
- _log_plot(title=f.stem, plot_path=f)
- # Log the final model
- run[f"weights/{trainer.args.name or trainer.args.task}/{trainer.best.name}"].upload(File(str(trainer.best)))
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_train_epoch_end": on_train_epoch_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_val_end": on_val_end,
- "on_train_end": on_train_end,
- }
- if neptune
- else {}
-)
diff --git a/ultralytics/utils/callbacks/platform.py b/ultralytics/utils/callbacks/platform.py
deleted file mode 100644
index 8e983f3..0000000
--- a/ultralytics/utils/callbacks/platform.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import RANK, SETTINGS
-
-
-def on_pretrain_routine_start(trainer):
- """Initialize and start console logging immediately at the very beginning."""
- if RANK in {-1, 0}:
- from ultralytics.utils.logger import DEFAULT_LOG_PATH, ConsoleLogger, SystemLogger
-
- trainer.system_logger = SystemLogger()
- trainer.console_logger = ConsoleLogger(DEFAULT_LOG_PATH)
- trainer.console_logger.start_capture()
-
-
-def on_pretrain_routine_end(trainer):
- """Handle pre-training routine completion event."""
- pass
-
-
-def on_fit_epoch_end(trainer):
- """Handle end of training epoch event and collect system metrics."""
- if RANK in {-1, 0} and hasattr(trainer, "system_logger"):
- system_metrics = trainer.system_logger.get_metrics()
- print(system_metrics) # for debug
-
-
-def on_model_save(trainer):
- """Handle model checkpoint save event."""
- pass
-
-
-def on_train_end(trainer):
- """Stop console capture and finalize logs."""
- if logger := getattr(trainer, "console_logger", None):
- logger.stop_capture()
-
-
-def on_train_start(trainer):
- """Handle training start event."""
- pass
-
-
-def on_val_start(validator):
- """Handle validation start event."""
- pass
-
-
-def on_predict_start(predictor):
- """Handle prediction start event."""
- pass
-
-
-def on_export_start(exporter):
- """Handle model export start event."""
- pass
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_pretrain_routine_end": on_pretrain_routine_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_model_save": on_model_save,
- "on_train_end": on_train_end,
- "on_train_start": on_train_start,
- "on_val_start": on_val_start,
- "on_predict_start": on_predict_start,
- "on_export_start": on_export_start,
- }
- if SETTINGS.get("platform", False) is True # disabled for debugging
- else {}
-)
diff --git a/ultralytics/utils/callbacks/raytune.py b/ultralytics/utils/callbacks/raytune.py
deleted file mode 100644
index 4a75a70..0000000
--- a/ultralytics/utils/callbacks/raytune.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import SETTINGS
-
-try:
- assert SETTINGS["raytune"] is True # verify integration is enabled
- import ray
- from ray import tune
- from ray.air import session
-
-except (ImportError, AssertionError):
- tune = None
-
-
-def on_fit_epoch_end(trainer):
- """
- Report training metrics to Ray Tune at epoch end when a Ray session is active.
-
- Captures metrics from the trainer object and sends them to Ray Tune with the current epoch number,
- enabling hyperparameter tuning optimization. Only executes when within an active Ray Tune session.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The Ultralytics trainer object containing metrics and epochs.
-
- Examples:
- >>> # Called automatically by the Ultralytics training loop
- >>> on_fit_epoch_end(trainer)
-
- References:
- Ray Tune docs: https://docs.ray.io/en/latest/tune/index.html
- """
- if ray.train._internal.session.get_session(): # check if Ray Tune session is active
- metrics = trainer.metrics
- session.report({**metrics, **{"epoch": trainer.epoch + 1}})
-
-
-callbacks = (
- {
- "on_fit_epoch_end": on_fit_epoch_end,
- }
- if tune
- else {}
-)
diff --git a/ultralytics/utils/callbacks/tensorboard.py b/ultralytics/utils/callbacks/tensorboard.py
deleted file mode 100644
index 5dbe3e1..0000000
--- a/ultralytics/utils/callbacks/tensorboard.py
+++ /dev/null
@@ -1,131 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import LOGGER, SETTINGS, TESTS_RUNNING, colorstr, torch_utils
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["tensorboard"] is True # verify integration is enabled
- WRITER = None # TensorBoard SummaryWriter instance
- PREFIX = colorstr("TensorBoard: ")
-
- # Imports below only required if TensorBoard enabled
- import warnings
- from copy import deepcopy
-
- import torch
- from torch.utils.tensorboard import SummaryWriter
-
-except (ImportError, AssertionError, TypeError, AttributeError):
- # TypeError for handling 'Descriptors cannot not be created directly.' protobuf errors in Windows
- # AttributeError: module 'tensorflow' has no attribute 'io' if 'tensorflow' not installed
- SummaryWriter = None
-
-
-def _log_scalars(scalars: dict, step: int = 0) -> None:
- """
- Log scalar values to TensorBoard.
-
- Args:
- scalars (dict): Dictionary of scalar values to log to TensorBoard. Keys are scalar names and values are the
- corresponding scalar values.
- step (int): Global step value to record with the scalar values. Used for x-axis in TensorBoard graphs.
-
- Examples:
- Log training metrics
- >>> metrics = {"loss": 0.5, "accuracy": 0.95}
- >>> _log_scalars(metrics, step=100)
- """
- if WRITER:
- for k, v in scalars.items():
- WRITER.add_scalar(k, v, step)
-
-
-def _log_tensorboard_graph(trainer) -> None:
- """
- Log model graph to TensorBoard.
-
- This function attempts to visualize the model architecture in TensorBoard by tracing the model with a dummy input
- tensor. It first tries a simple method suitable for YOLO models, and if that fails, falls back to a more complex
- approach for models like RTDETR that may require special handling.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing the model to visualize.
- Must have attributes model and args with imgsz.
-
- Notes:
- This function requires TensorBoard integration to be enabled and the global WRITER to be initialized.
- It handles potential warnings from the PyTorch JIT tracer and attempts to gracefully handle different
- model architectures.
- """
- # Input image
- imgsz = trainer.args.imgsz
- imgsz = (imgsz, imgsz) if isinstance(imgsz, int) else imgsz
- p = next(trainer.model.parameters()) # for device, type
- im = torch.zeros((1, 3, *imgsz), device=p.device, dtype=p.dtype) # input image (must be zeros, not empty)
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", category=UserWarning) # suppress jit trace warning
- warnings.simplefilter("ignore", category=torch.jit.TracerWarning) # suppress jit trace warning
-
- # Try simple method first (YOLO)
- try:
- trainer.model.eval() # place in .eval() mode to avoid BatchNorm statistics changes
- WRITER.add_graph(torch.jit.trace(torch_utils.unwrap_model(trainer.model), im, strict=False), [])
- LOGGER.info(f"{PREFIX}model graph visualization added ✅")
- return
-
- except Exception:
- # Fallback to TorchScript export steps (RTDETR)
- try:
- model = deepcopy(torch_utils.unwrap_model(trainer.model))
- model.eval()
- model = model.fuse(verbose=False)
- for m in model.modules():
- if hasattr(m, "export"): # Detect, RTDETRDecoder (Segment and Pose use Detect base class)
- m.export = True
- m.format = "torchscript"
- model(im) # dry run
- WRITER.add_graph(torch.jit.trace(model, im, strict=False), [])
- LOGGER.info(f"{PREFIX}model graph visualization added ✅")
- except Exception as e:
- LOGGER.warning(f"{PREFIX}TensorBoard graph visualization failure {e}")
-
-
-def on_pretrain_routine_start(trainer) -> None:
- """Initialize TensorBoard logging with SummaryWriter."""
- if SummaryWriter:
- try:
- global WRITER
- WRITER = SummaryWriter(str(trainer.save_dir))
- LOGGER.info(f"{PREFIX}Start with 'tensorboard --logdir {trainer.save_dir}', view at http://localhost:6006/")
- except Exception as e:
- LOGGER.warning(f"{PREFIX}TensorBoard not initialized correctly, not logging this run. {e}")
-
-
-def on_train_start(trainer) -> None:
- """Log TensorBoard graph."""
- if WRITER:
- _log_tensorboard_graph(trainer)
-
-
-def on_train_epoch_end(trainer) -> None:
- """Log scalar statistics at the end of a training epoch."""
- _log_scalars(trainer.label_loss_items(trainer.tloss, prefix="train"), trainer.epoch + 1)
- _log_scalars(trainer.lr, trainer.epoch + 1)
-
-
-def on_fit_epoch_end(trainer) -> None:
- """Log epoch metrics at end of training epoch."""
- _log_scalars(trainer.metrics, trainer.epoch + 1)
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_train_start": on_train_start,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_epoch_end": on_train_epoch_end,
- }
- if SummaryWriter
- else {}
-)
diff --git a/ultralytics/utils/callbacks/wb.py b/ultralytics/utils/callbacks/wb.py
deleted file mode 100644
index d97de5d..0000000
--- a/ultralytics/utils/callbacks/wb.py
+++ /dev/null
@@ -1,191 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import SETTINGS, TESTS_RUNNING
-from ultralytics.utils.torch_utils import model_info_for_loggers
-
-try:
- assert not TESTS_RUNNING # do not log pytest
- assert SETTINGS["wandb"] is True # verify integration is enabled
- import wandb as wb
-
- assert hasattr(wb, "__version__") # verify package is not directory
- _processed_plots = {}
-
-except (ImportError, AssertionError):
- wb = None
-
-
-def _custom_table(x, y, classes, title="Precision Recall Curve", x_title="Recall", y_title="Precision"):
- """
- Create and log a custom metric visualization to wandb.plot.pr_curve.
-
- This function crafts a custom metric visualization that mimics the behavior of the default wandb precision-recall
- curve while allowing for enhanced customization. The visual metric is useful for monitoring model performance across
- different classes.
-
- Args:
- x (list): Values for the x-axis; expected to have length N.
- y (list): Corresponding values for the y-axis; also expected to have length N.
- classes (list): Labels identifying the class of each point; length N.
- title (str, optional): Title for the plot.
- x_title (str, optional): Label for the x-axis.
- y_title (str, optional): Label for the y-axis.
-
- Returns:
- (wandb.Object): A wandb object suitable for logging, showcasing the crafted metric visualization.
- """
- import polars as pl # scope for faster 'import ultralytics'
- import polars.selectors as cs
-
- df = pl.DataFrame({"class": classes, "y": y, "x": x}).with_columns(cs.numeric().round(3))
- data = df.select(["class", "y", "x"]).rows()
-
- fields = {"x": "x", "y": "y", "class": "class"}
- string_fields = {"title": title, "x-axis-title": x_title, "y-axis-title": y_title}
- return wb.plot_table(
- "wandb/area-under-curve/v0",
- wb.Table(data=data, columns=["class", "y", "x"]),
- fields=fields,
- string_fields=string_fields,
- )
-
-
-def _plot_curve(
- x,
- y,
- names=None,
- id="precision-recall",
- title="Precision Recall Curve",
- x_title="Recall",
- y_title="Precision",
- num_x=100,
- only_mean=False,
-):
- """
- Log a metric curve visualization.
-
- This function generates a metric curve based on input data and logs the visualization to wandb.
- The curve can represent aggregated data (mean) or individual class data, depending on the 'only_mean' flag.
-
- Args:
- x (np.ndarray): Data points for the x-axis with length N.
- y (np.ndarray): Corresponding data points for the y-axis with shape (C, N), where C is the number of classes.
- names (list, optional): Names of the classes corresponding to the y-axis data; length C.
- id (str, optional): Unique identifier for the logged data in wandb.
- title (str, optional): Title for the visualization plot.
- x_title (str, optional): Label for the x-axis.
- y_title (str, optional): Label for the y-axis.
- num_x (int, optional): Number of interpolated data points for visualization.
- only_mean (bool, optional): Flag to indicate if only the mean curve should be plotted.
-
- Notes:
- The function leverages the '_custom_table' function to generate the actual visualization.
- """
- import numpy as np
-
- # Create new x
- if names is None:
- names = []
- x_new = np.linspace(x[0], x[-1], num_x).round(5)
-
- # Create arrays for logging
- x_log = x_new.tolist()
- y_log = np.interp(x_new, x, np.mean(y, axis=0)).round(3).tolist()
-
- if only_mean:
- table = wb.Table(data=list(zip(x_log, y_log)), columns=[x_title, y_title])
- wb.run.log({title: wb.plot.line(table, x_title, y_title, title=title)})
- else:
- classes = ["mean"] * len(x_log)
- for i, yi in enumerate(y):
- x_log.extend(x_new) # add new x
- y_log.extend(np.interp(x_new, x, yi)) # interpolate y to new x
- classes.extend([names[i]] * len(x_new)) # add class names
- wb.log({id: _custom_table(x_log, y_log, classes, title, x_title, y_title)}, commit=False)
-
-
-def _log_plots(plots, step):
- """
- Log plots to WandB at a specific step if they haven't been logged already.
-
- This function checks each plot in the input dictionary against previously processed plots and logs
- new or updated plots to WandB at the specified step.
-
- Args:
- plots (dict): Dictionary of plots to log, where keys are plot names and values are dictionaries
- containing plot metadata including timestamps.
- step (int): The step/epoch at which to log the plots in the WandB run.
-
- Notes:
- The function uses a shallow copy of the plots dictionary to prevent modification during iteration.
- Plots are identified by their stem name (filename without extension).
- Each plot is logged as a WandB Image object.
- """
- for name, params in plots.copy().items(): # shallow copy to prevent plots dict changing during iteration
- timestamp = params["timestamp"]
- if _processed_plots.get(name) != timestamp:
- wb.run.log({name.stem: wb.Image(str(name))}, step=step)
- _processed_plots[name] = timestamp
-
-
-def on_pretrain_routine_start(trainer):
- """Initialize and start wandb project if module is present."""
- if not wb.run:
- wb.init(
- project=str(trainer.args.project).replace("/", "-") if trainer.args.project else "Ultralytics",
- name=str(trainer.args.name).replace("/", "-"),
- config=vars(trainer.args),
- )
-
-
-def on_fit_epoch_end(trainer):
- """Log training metrics and model information at the end of an epoch."""
- wb.run.log(trainer.metrics, step=trainer.epoch + 1)
- _log_plots(trainer.plots, step=trainer.epoch + 1)
- _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
- if trainer.epoch == 0:
- wb.run.log(model_info_for_loggers(trainer), step=trainer.epoch + 1)
-
-
-def on_train_epoch_end(trainer):
- """Log metrics and save images at the end of each training epoch."""
- wb.run.log(trainer.label_loss_items(trainer.tloss, prefix="train"), step=trainer.epoch + 1)
- wb.run.log(trainer.lr, step=trainer.epoch + 1)
- if trainer.epoch == 1:
- _log_plots(trainer.plots, step=trainer.epoch + 1)
-
-
-def on_train_end(trainer):
- """Save the best model as an artifact and log final plots at the end of training."""
- _log_plots(trainer.validator.plots, step=trainer.epoch + 1)
- _log_plots(trainer.plots, step=trainer.epoch + 1)
- art = wb.Artifact(type="model", name=f"run_{wb.run.id}_model")
- if trainer.best.exists():
- art.add_file(trainer.best)
- wb.run.log_artifact(art, aliases=["best"])
- # Check if we actually have plots to save
- if trainer.args.plots and hasattr(trainer.validator.metrics, "curves_results"):
- for curve_name, curve_values in zip(trainer.validator.metrics.curves, trainer.validator.metrics.curves_results):
- x, y, x_title, y_title = curve_values
- _plot_curve(
- x,
- y,
- names=list(trainer.validator.metrics.names.values()),
- id=f"curves/{curve_name}",
- title=curve_name,
- x_title=x_title,
- y_title=y_title,
- )
- wb.run.finish() # required or run continues on dashboard
-
-
-callbacks = (
- {
- "on_pretrain_routine_start": on_pretrain_routine_start,
- "on_train_epoch_end": on_train_epoch_end,
- "on_fit_epoch_end": on_fit_epoch_end,
- "on_train_end": on_train_end,
- }
- if wb
- else {}
-)
diff --git a/ultralytics/utils/checks.py b/ultralytics/utils/checks.py
deleted file mode 100644
index d801d14..0000000
--- a/ultralytics/utils/checks.py
+++ /dev/null
@@ -1,964 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import functools
-import glob
-import inspect
-import math
-import os
-import platform
-import re
-import shutil
-import subprocess
-import time
-from importlib import metadata
-from pathlib import Path
-from types import SimpleNamespace
-
-import cv2
-import numpy as np
-import torch
-
-from ultralytics.utils import (
- ARM64,
- ASSETS,
- AUTOINSTALL,
- GIT,
- IS_COLAB,
- IS_JETSON,
- IS_KAGGLE,
- IS_PIP_PACKAGE,
- LINUX,
- LOGGER,
- MACOS,
- ONLINE,
- PYTHON_VERSION,
- RKNN_CHIPS,
- ROOT,
- TORCH_VERSION,
- TORCHVISION_VERSION,
- USER_CONFIG_DIR,
- WINDOWS,
- Retry,
- ThreadingLocked,
- TryExcept,
- clean_url,
- colorstr,
- downloads,
- is_github_action_running,
- url2file,
-)
-
-
-def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
- """
- Parse a requirements.txt file, ignoring lines that start with '#' and any text after '#'.
-
- Args:
- file_path (Path): Path to the requirements.txt file.
- package (str, optional): Python package to use instead of requirements.txt file.
-
- Returns:
- requirements (list[SimpleNamespace]): List of parsed requirements as SimpleNamespace objects with `name` and
- `specifier` attributes.
-
- Examples:
- >>> from ultralytics.utils.checks import parse_requirements
- >>> parse_requirements(package="ultralytics")
- """
- if package:
- requires = [x for x in metadata.distribution(package).requires if "extra == " not in x]
- else:
- requires = Path(file_path).read_text().splitlines()
-
- requirements = []
- for line in requires:
- line = line.strip()
- if line and not line.startswith("#"):
- line = line.partition("#")[0].strip() # ignore inline comments
- if match := re.match(r"([a-zA-Z0-9-_]+)\s*([<>!=~]+.*)?", line):
- requirements.append(SimpleNamespace(name=match[1], specifier=match[2].strip() if match[2] else ""))
-
- return requirements
-
-
-@functools.lru_cache
-def parse_version(version="0.0.0") -> tuple:
- """
- Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
-
- Args:
- version (str): Version string, i.e. '2.0.1+cpu'
-
- Returns:
- (tuple): Tuple of integers representing the numeric part of the version, i.e. (2, 0, 1)
- """
- try:
- return tuple(map(int, re.findall(r"\d+", version)[:3])) # '2.0.1+cpu' -> (2, 0, 1)
- except Exception as e:
- LOGGER.warning(f"failure for parse_version({version}), returning (0, 0, 0): {e}")
- return 0, 0, 0
-
-
-def is_ascii(s) -> bool:
- """
- Check if a string is composed of only ASCII characters.
-
- Args:
- s (str | list | tuple | dict): Input to be checked (all are converted to string for checking).
-
- Returns:
- (bool): True if the string is composed only of ASCII characters, False otherwise.
- """
- return all(ord(c) < 128 for c in str(s))
-
-
-def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
- """
- Verify image size is a multiple of the given stride in each dimension. If the image size is not a multiple of the
- stride, update it to the nearest multiple of the stride that is greater than or equal to the given floor value.
-
- Args:
- imgsz (int | list[int]): Image size.
- stride (int): Stride value.
- min_dim (int): Minimum number of dimensions.
- max_dim (int): Maximum number of dimensions.
- floor (int): Minimum allowed value for image size.
-
- Returns:
- (list[int] | int): Updated image size.
- """
- # Convert stride to integer if it is a tensor
- stride = int(stride.max() if isinstance(stride, torch.Tensor) else stride)
-
- # Convert image size to list if it is an integer
- if isinstance(imgsz, int):
- imgsz = [imgsz]
- elif isinstance(imgsz, (list, tuple)):
- imgsz = list(imgsz)
- elif isinstance(imgsz, str): # i.e. '640' or '[640,640]'
- imgsz = [int(imgsz)] if imgsz.isnumeric() else eval(imgsz)
- else:
- raise TypeError(
- f"'imgsz={imgsz}' is of invalid type {type(imgsz).__name__}. "
- f"Valid imgsz types are int i.e. 'imgsz=640' or list i.e. 'imgsz=[640,640]'"
- )
-
- # Apply max_dim
- if len(imgsz) > max_dim:
- msg = (
- "'train' and 'val' imgsz must be an integer, while 'predict' and 'export' imgsz may be a [h, w] list "
- "or an integer, i.e. 'yolo export imgsz=640,480' or 'yolo export imgsz=640'"
- )
- if max_dim != 1:
- raise ValueError(f"imgsz={imgsz} is not a valid image size. {msg}")
- LOGGER.warning(f"updating to 'imgsz={max(imgsz)}'. {msg}")
- imgsz = [max(imgsz)]
- # Make image size a multiple of the stride
- sz = [max(math.ceil(x / stride) * stride, floor) for x in imgsz]
-
- # Print warning message if image size was updated
- if sz != imgsz:
- LOGGER.warning(f"imgsz={imgsz} must be multiple of max stride {stride}, updating to {sz}")
-
- # Add missing dimensions if necessary
- sz = [sz[0], sz[0]] if min_dim == 2 and len(sz) == 1 else sz[0] if min_dim == 1 and len(sz) == 1 else sz
-
- return sz
-
-
-@functools.lru_cache
-def check_uv():
- """Check if uv package manager is installed and can run successfully."""
- try:
- return subprocess.run(["uv", "-V"], capture_output=True).returncode == 0
- except FileNotFoundError:
- return False
-
-
-@functools.lru_cache
-def check_version(
- current: str = "0.0.0",
- required: str = "0.0.0",
- name: str = "version",
- hard: bool = False,
- verbose: bool = False,
- msg: str = "",
-) -> bool:
- """
- Check current version against the required version or range.
-
- Args:
- current (str): Current version or package name to get version from.
- required (str): Required version or range (in pip-style format).
- name (str): Name to be used in warning message.
- hard (bool): If True, raise an AssertionError if the requirement is not met.
- verbose (bool): If True, print warning message if requirement is not met.
- msg (str): Extra message to display if verbose.
-
- Returns:
- (bool): True if requirement is met, False otherwise.
-
- Examples:
- Check if current version is exactly 22.04
- >>> check_version(current="22.04", required="==22.04")
-
- Check if current version is greater than or equal to 22.04
- >>> check_version(current="22.10", required="22.04") # assumes '>=' inequality if none passed
-
- Check if current version is less than or equal to 22.04
- >>> check_version(current="22.04", required="<=22.04")
-
- Check if current version is between 20.04 (inclusive) and 22.04 (exclusive)
- >>> check_version(current="21.10", required=">20.04,<22.04")
- """
- if not current: # if current is '' or None
- LOGGER.warning(f"invalid check_version({current}, {required}) requested, please check values.")
- return True
- elif not current[0].isdigit(): # current is package name rather than version string, i.e. current='ultralytics'
- try:
- name = current # assigned package name to 'name' arg
- current = metadata.version(current) # get version string from package name
- except metadata.PackageNotFoundError as e:
- if hard:
- raise ModuleNotFoundError(f"{current} package is required but not installed") from e
- else:
- return False
-
- if not required: # if required is '' or None
- return True
-
- if "sys_platform" in required and ( # i.e. required='<2.4.0,>=1.8.0; sys_platform == "win32"'
- (WINDOWS and "win32" not in required)
- or (LINUX and "linux" not in required)
- or (MACOS and "macos" not in required and "darwin" not in required)
- ):
- return True
-
- op = ""
- version = ""
- result = True
- c = parse_version(current) # '1.2.3' -> (1, 2, 3)
- for r in required.strip(",").split(","):
- op, version = re.match(r"([^0-9]*)([\d.]+)", r).groups() # split '>=22.04' -> ('>=', '22.04')
- if not op:
- op = ">=" # assume >= if no op passed
- v = parse_version(version) # '1.2.3' -> (1, 2, 3)
- if op == "==" and c != v:
- result = False
- elif op == "!=" and c == v:
- result = False
- elif op == ">=" and not (c >= v):
- result = False
- elif op == "<=" and not (c <= v):
- result = False
- elif op == ">" and not (c > v):
- result = False
- elif op == "<" and not (c < v):
- result = False
- if not result:
- warning = f"{name}{required} is required, but {name}=={current} is currently installed {msg}"
- if hard:
- raise ModuleNotFoundError(warning) # assert version requirements met
- if verbose:
- LOGGER.warning(warning)
- return result
-
-
-def check_latest_pypi_version(package_name="ultralytics"):
- """
- Return the latest version of a PyPI package without downloading or installing it.
-
- Args:
- package_name (str): The name of the package to find the latest version for.
-
- Returns:
- (str): The latest version of the package.
- """
- import requests # scoped as slow import
-
- try:
- requests.packages.urllib3.disable_warnings() # Disable the InsecureRequestWarning
- response = requests.get(f"https://pypi.org/pypi/{package_name}/json", timeout=3)
- if response.status_code == 200:
- return response.json()["info"]["version"]
- except Exception:
- return None
-
-
-def check_pip_update_available():
- """
- Check if a new version of the ultralytics package is available on PyPI.
-
- Returns:
- (bool): True if an update is available, False otherwise.
- """
- if ONLINE and IS_PIP_PACKAGE:
- try:
- from ultralytics import __version__
-
- latest = check_latest_pypi_version()
- if check_version(__version__, f"<{latest}"): # check if current version is < latest version
- LOGGER.info(
- f"New https://pypi.org/project/ultralytics/{latest} available 😃 "
- f"Update with 'pip install -U ultralytics'"
- )
- return True
- except Exception:
- pass
- return False
-
-
-@ThreadingLocked()
-@functools.lru_cache
-def check_font(font="Arial.ttf"):
- """
- Find font locally or download to user's configuration directory if it does not already exist.
-
- Args:
- font (str): Path or name of font.
-
- Returns:
- (Path): Resolved font file path.
- """
- from matplotlib import font_manager # scope for faster 'import ultralytics'
-
- # Check USER_CONFIG_DIR
- name = Path(font).name
- file = USER_CONFIG_DIR / name
- if file.exists():
- return file
-
- # Check system fonts
- matches = [s for s in font_manager.findSystemFonts() if font in s]
- if any(matches):
- return matches[0]
-
- # Download to USER_CONFIG_DIR if missing
- url = f"https://github.com/ultralytics/assets/releases/download/v0.0.0/{name}"
- if downloads.is_url(url, check=True):
- downloads.safe_download(url=url, file=file)
- return file
-
-
-def check_python(minimum: str = "3.8.0", hard: bool = True, verbose: bool = False) -> bool:
- """
- Check current python version against the required minimum version.
-
- Args:
- minimum (str): Required minimum version of python.
- hard (bool): If True, raise an AssertionError if the requirement is not met.
- verbose (bool): If True, print warning message if requirement is not met.
-
- Returns:
- (bool): Whether the installed Python version meets the minimum constraints.
- """
- return check_version(PYTHON_VERSION, minimum, name="Python", hard=hard, verbose=verbose)
-
-
-@TryExcept()
-def check_requirements(requirements=ROOT.parent / "requirements.txt", exclude=(), install=True, cmds=""):
- """
- Check if installed dependencies meet Ultralytics YOLO models requirements and attempt to auto-update if needed.
-
- Args:
- requirements (Path | str | list[str] | tuple[str]): Path to a requirements.txt file, a single package
- requirement as a string, or a list of package requirements as strings.
- exclude (tuple): Tuple of package names to exclude from checking.
- install (bool): If True, attempt to auto-update packages that don't meet requirements.
- cmds (str): Additional commands to pass to the pip install command when auto-updating.
-
- Examples:
- >>> from ultralytics.utils.checks import check_requirements
-
- Check a requirements.txt file
- >>> check_requirements("path/to/requirements.txt")
-
- Check a single package
- >>> check_requirements("ultralytics>=8.0.0")
-
- Check multiple packages
- >>> check_requirements(["numpy", "ultralytics>=8.0.0"])
- """
- prefix = colorstr("red", "bold", "requirements:")
- if isinstance(requirements, Path): # requirements.txt file
- file = requirements.resolve()
- assert file.exists(), f"{prefix} {file} not found, check failed."
- requirements = [f"{x.name}{x.specifier}" for x in parse_requirements(file) if x.name not in exclude]
- elif isinstance(requirements, str):
- requirements = [requirements]
-
- pkgs = []
- for r in requirements:
- r_stripped = r.rpartition("/")[-1].replace(".git", "") # replace git+https://org/repo.git -> 'repo'
- match = re.match(r"([a-zA-Z0-9-_]+)([<>!=~]+.*)?", r_stripped)
- name, required = match[1], match[2].strip() if match[2] else ""
- try:
- assert check_version(metadata.version(name), required) # exception if requirements not met
- except (AssertionError, metadata.PackageNotFoundError):
- pkgs.append(r)
-
- @Retry(times=2, delay=1)
- def attempt_install(packages, commands, use_uv):
- """Attempt package installation with uv if available, falling back to pip."""
- if use_uv:
- base = (
- f"uv pip install --no-cache-dir {packages} {commands} "
- f"--index-strategy=unsafe-best-match --break-system-packages --prerelease=allow"
- )
- try:
- return subprocess.check_output(base, shell=True, stderr=subprocess.PIPE, text=True)
- except subprocess.CalledProcessError as e:
- if e.stderr and "No virtual environment found" in e.stderr:
- return subprocess.check_output(
- base.replace("uv pip install", "uv pip install --system"),
- shell=True,
- stderr=subprocess.PIPE,
- text=True,
- )
- raise
- return subprocess.check_output(f"pip install --no-cache-dir {packages} {commands}", shell=True, text=True)
-
- s = " ".join(f'"{x}"' for x in pkgs) # console string
- if s:
- if install and AUTOINSTALL: # check environment variable
- # Note uv fails on arm64 macOS and Raspberry Pi runners
- n = len(pkgs) # number of packages updates
- LOGGER.info(f"{prefix} Ultralytics requirement{'s' * (n > 1)} {pkgs} not found, attempting AutoUpdate...")
- try:
- t = time.time()
- assert ONLINE, "AutoUpdate skipped (offline)"
- LOGGER.info(attempt_install(s, cmds, use_uv=not ARM64 and check_uv()))
- dt = time.time() - t
- LOGGER.info(f"{prefix} AutoUpdate success ✅ {dt:.1f}s")
- LOGGER.warning(
- f"{prefix} {colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"
- )
- except Exception as e:
- LOGGER.warning(f"{prefix} ❌ {e}")
- return False
- else:
- return False
-
- return True
-
-
-def check_torchvision():
- """
- Check the installed versions of PyTorch and Torchvision to ensure they're compatible.
-
- This function checks the installed versions of PyTorch and Torchvision, and warns if they're incompatible according
- to the compatibility table based on: https://github.com/pytorch/vision#installation.
- """
- compatibility_table = {
- "2.9": ["0.24"],
- "2.8": ["0.23"],
- "2.7": ["0.22"],
- "2.6": ["0.21"],
- "2.5": ["0.20"],
- "2.4": ["0.19"],
- "2.3": ["0.18"],
- "2.2": ["0.17"],
- "2.1": ["0.16"],
- "2.0": ["0.15"],
- "1.13": ["0.14"],
- "1.12": ["0.13"],
- }
-
- # Check major and minor versions
- v_torch = ".".join(TORCH_VERSION.split("+", 1)[0].split(".")[:2])
- if v_torch in compatibility_table:
- compatible_versions = compatibility_table[v_torch]
- v_torchvision = ".".join(TORCHVISION_VERSION.split("+", 1)[0].split(".")[:2])
- if all(v_torchvision != v for v in compatible_versions):
- LOGGER.warning(
- f"torchvision=={v_torchvision} is incompatible with torch=={v_torch}.\n"
- f"Run 'pip install torchvision=={compatible_versions[0]}' to fix torchvision or "
- "'pip install -U torch torchvision' to update both.\n"
- "For a full compatibility table see https://github.com/pytorch/vision#installation"
- )
-
-
-def check_suffix(file="yolo11n.pt", suffix=".pt", msg=""):
- """
- Check file(s) for acceptable suffix.
-
- Args:
- file (str | list[str]): File or list of files to check.
- suffix (str | tuple): Acceptable suffix or tuple of suffixes.
- msg (str): Additional message to display in case of error.
- """
- if file and suffix:
- if isinstance(suffix, str):
- suffix = {suffix}
- for f in file if isinstance(file, (list, tuple)) else [file]:
- if s := str(f).rpartition(".")[-1].lower().strip(): # file suffix
- assert f".{s}" in suffix, f"{msg}{f} acceptable suffix is {suffix}, not .{s}"
-
-
-def check_yolov5u_filename(file: str, verbose: bool = True):
- """
- Replace legacy YOLOv5 filenames with updated YOLOv5u filenames.
-
- Args:
- file (str): Filename to check and potentially update.
- verbose (bool): Whether to print information about the replacement.
-
- Returns:
- (str): Updated filename.
- """
- if "yolov3" in file or "yolov5" in file:
- if "u.yaml" in file:
- file = file.replace("u.yaml", ".yaml") # i.e. yolov5nu.yaml -> yolov5n.yaml
- elif ".pt" in file and "u" not in file:
- original_file = file
- file = re.sub(r"(.*yolov5([nsmlx]))\.pt", "\\1u.pt", file) # i.e. yolov5n.pt -> yolov5nu.pt
- file = re.sub(r"(.*yolov5([nsmlx])6)\.pt", "\\1u.pt", file) # i.e. yolov5n6.pt -> yolov5n6u.pt
- file = re.sub(r"(.*yolov3(|-tiny|-spp))\.pt", "\\1u.pt", file) # i.e. yolov3-spp.pt -> yolov3-sppu.pt
- if file != original_file and verbose:
- LOGGER.info(
- f"PRO TIP 💡 Replace 'model={original_file}' with new 'model={file}'.\nYOLOv5 'u' models are "
- f"trained with https://github.com/ultralytics/ultralytics and feature improved performance vs "
- f"standard YOLOv5 models trained with https://github.com/ultralytics/yolov5.\n"
- )
- return file
-
-
-def check_model_file_from_stem(model="yolo11n"):
- """
- Return a model filename from a valid model stem.
-
- Args:
- model (str): Model stem to check.
-
- Returns:
- (str | Path): Model filename with appropriate suffix.
- """
- path = Path(model)
- if not path.suffix and path.stem in downloads.GITHUB_ASSETS_STEMS:
- return path.with_suffix(".pt") # add suffix, i.e. yolo11n -> yolo11n.pt
- return model
-
-
-def check_file(file, suffix="", download=True, download_dir=".", hard=True):
- """
- Search/download file (if necessary), check suffix (if provided), and return path.
-
- Args:
- file (str): File name or path.
- suffix (str | tuple): Acceptable suffix or tuple of suffixes to validate against the file.
- download (bool): Whether to download the file if it doesn't exist locally.
- download_dir (str): Directory to download the file to.
- hard (bool): Whether to raise an error if the file is not found.
-
- Returns:
- (str): Path to the file.
- """
- check_suffix(file, suffix) # optional
- file = str(file).strip() # convert to string and strip spaces
- file = check_yolov5u_filename(file) # yolov5n -> yolov5nu
- if (
- not file
- or ("://" not in file and Path(file).exists()) # '://' check required in Windows Python<3.10
- or file.lower().startswith("grpc://")
- ): # file exists or gRPC Triton images
- return file
- elif download and file.lower().startswith(("https://", "http://", "rtsp://", "rtmp://", "tcp://")): # download
- url = file # warning: Pathlib turns :// -> :/
- file = Path(download_dir) / url2file(file) # '%2F' to '/', split https://url.com/file.txt?auth
- if file.exists():
- LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
- else:
- downloads.safe_download(url=url, file=file, unzip=False)
- return str(file)
- else: # search
- files = glob.glob(str(ROOT / "**" / file), recursive=True) or glob.glob(str(ROOT.parent / file)) # find file
- if not files and hard:
- raise FileNotFoundError(f"'{file}' does not exist")
- elif len(files) > 1 and hard:
- raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
- return files[0] if len(files) else [] # return file
-
-
-def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
- """
- Search/download YAML file (if necessary) and return path, checking suffix.
-
- Args:
- file (str | Path): File name or path.
- suffix (tuple): Tuple of acceptable YAML file suffixes.
- hard (bool): Whether to raise an error if the file is not found or multiple files are found.
-
- Returns:
- (str): Path to the YAML file.
- """
- return check_file(file, suffix, hard=hard)
-
-
-def check_is_path_safe(basedir, path):
- """
- Check if the resolved path is under the intended directory to prevent path traversal.
-
- Args:
- basedir (Path | str): The intended directory.
- path (Path | str): The path to check.
-
- Returns:
- (bool): True if the path is safe, False otherwise.
- """
- base_dir_resolved = Path(basedir).resolve()
- path_resolved = Path(path).resolve()
-
- return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
-
-
-@functools.lru_cache
-def check_imshow(warn=False):
- """
- Check if environment supports image displays.
-
- Args:
- warn (bool): Whether to warn if environment doesn't support image displays.
-
- Returns:
- (bool): True if environment supports image displays, False otherwise.
- """
- try:
- if LINUX:
- assert not IS_COLAB and not IS_KAGGLE
- assert "DISPLAY" in os.environ, "The DISPLAY environment variable isn't set."
- cv2.imshow("test", np.zeros((8, 8, 3), dtype=np.uint8)) # show a small 8-pixel image
- cv2.waitKey(1)
- cv2.destroyAllWindows()
- cv2.waitKey(1)
- return True
- except Exception as e:
- if warn:
- LOGGER.warning(f"Environment does not support cv2.imshow() or PIL Image.show()\n{e}")
- return False
-
-
-def check_yolo(verbose=True, device=""):
- """
- Return a human-readable YOLO software and hardware summary.
-
- Args:
- verbose (bool): Whether to print verbose information.
- device (str | torch.device): Device to use for YOLO.
- """
- import psutil # scoped as slow import
-
- from ultralytics.utils.torch_utils import select_device
-
- if IS_COLAB:
- shutil.rmtree("sample_data", ignore_errors=True) # remove colab /sample_data directory
-
- if verbose:
- # System info
- gib = 1 << 30 # bytes per GiB
- ram = psutil.virtual_memory().total
- total, used, free = shutil.disk_usage("/")
- s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
- try:
- from IPython import display
-
- display.clear_output() # clear display if notebook
- except ImportError:
- pass
- else:
- s = ""
-
- if GIT.is_repo:
- check_multiple_install() # check conflicting installation if using local clone
-
- select_device(device=device, newline=False)
- LOGGER.info(f"Setup complete ✅ {s}")
-
-
-def collect_system_info():
- """
- Collect and print relevant system information including OS, Python, RAM, CPU, and CUDA.
-
- Returns:
- (dict): Dictionary containing system information.
- """
- import psutil # scoped as slow import
-
- from ultralytics.utils import ENVIRONMENT # scope to avoid circular import
- from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info
-
- gib = 1 << 30 # bytes per GiB
- cuda = torch.cuda.is_available()
- check_yolo()
- total, used, free = shutil.disk_usage("/")
-
- info_dict = {
- "OS": platform.platform(),
- "Environment": ENVIRONMENT,
- "Python": PYTHON_VERSION,
- "Install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
- "Path": str(ROOT),
- "RAM": f"{psutil.virtual_memory().total / gib:.2f} GB",
- "Disk": f"{(total - free) / gib:.1f}/{total / gib:.1f} GB",
- "CPU": get_cpu_info(),
- "CPU count": os.cpu_count(),
- "GPU": get_gpu_info(index=0) if cuda else None,
- "GPU count": torch.cuda.device_count() if cuda else None,
- "CUDA": torch.version.cuda if cuda else None,
- }
- LOGGER.info("\n" + "\n".join(f"{k:<23}{v}" for k, v in info_dict.items()) + "\n")
-
- package_info = {}
- for r in parse_requirements(package="ultralytics"):
- try:
- current = metadata.version(r.name)
- is_met = "✅ " if check_version(current, str(r.specifier), name=r.name, hard=True) else "❌ "
- except metadata.PackageNotFoundError:
- current = "(not installed)"
- is_met = "❌ "
- package_info[r.name] = f"{is_met}{current}{r.specifier}"
- LOGGER.info(f"{r.name:<23}{package_info[r.name]}")
-
- info_dict["Package Info"] = package_info
-
- if is_github_action_running():
- github_info = {
- "RUNNER_OS": os.getenv("RUNNER_OS"),
- "GITHUB_EVENT_NAME": os.getenv("GITHUB_EVENT_NAME"),
- "GITHUB_WORKFLOW": os.getenv("GITHUB_WORKFLOW"),
- "GITHUB_ACTOR": os.getenv("GITHUB_ACTOR"),
- "GITHUB_REPOSITORY": os.getenv("GITHUB_REPOSITORY"),
- "GITHUB_REPOSITORY_OWNER": os.getenv("GITHUB_REPOSITORY_OWNER"),
- }
- LOGGER.info("\n" + "\n".join(f"{k}: {v}" for k, v in github_info.items()))
- info_dict["GitHub Info"] = github_info
-
- return info_dict
-
-
-def check_amp(model):
- """
- Check the PyTorch Automatic Mixed Precision (AMP) functionality of a YOLO model.
-
- If the checks fail, it means there are anomalies with AMP on the system that may cause NaN losses or zero-mAP
- results, so AMP will be disabled during training.
-
- Args:
- model (torch.nn.Module): A YOLO model instance.
-
- Returns:
- (bool): Returns True if the AMP functionality works correctly with YOLO11 model, else False.
-
- Examples:
- >>> from ultralytics import YOLO
- >>> from ultralytics.utils.checks import check_amp
- >>> model = YOLO("yolo11n.pt").model.cuda()
- >>> check_amp(model)
- """
- from ultralytics.utils.torch_utils import autocast
-
- device = next(model.parameters()).device # get model device
- prefix = colorstr("AMP: ")
- if device.type in {"cpu", "mps"}:
- return False # AMP only used on CUDA devices
- else:
- # GPUs that have issues with AMP
- pattern = re.compile(
- r"(nvidia|geforce|quadro|tesla).*?(1660|1650|1630|t400|t550|t600|t1000|t1200|t2000|k40m)", re.IGNORECASE
- )
-
- gpu = torch.cuda.get_device_name(device)
- if bool(pattern.search(gpu)):
- LOGGER.warning(
- f"{prefix}checks failed ❌. AMP training on {gpu} GPU may cause "
- f"NaN losses or zero-mAP results, so AMP will be disabled during training."
- )
- return False
-
- def amp_allclose(m, im):
- """All close FP32 vs AMP results."""
- batch = [im] * 8
- imgsz = max(256, int(model.stride.max() * 4)) # max stride P5-32 and P6-64
- a = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # FP32 inference
- with autocast(enabled=True):
- b = m(batch, imgsz=imgsz, device=device, verbose=False)[0].boxes.data # AMP inference
- del m
- return a.shape == b.shape and torch.allclose(a, b.float(), atol=0.5) # close to 0.5 absolute tolerance
-
- im = ASSETS / "bus.jpg" # image to check
- LOGGER.info(f"{prefix}running Automatic Mixed Precision (AMP) checks...")
- warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
- try:
- from ultralytics import YOLO
-
- assert amp_allclose(YOLO("yolo11n.pt"), im)
- LOGGER.info(f"{prefix}checks passed ✅")
- except ConnectionError:
- LOGGER.warning(f"{prefix}checks skipped. Offline and unable to download YOLO11n for AMP checks. {warning_msg}")
- except (AttributeError, ModuleNotFoundError):
- LOGGER.warning(
- f"{prefix}checks skipped. "
- f"Unable to load YOLO11n for AMP checks due to possible Ultralytics package modifications. {warning_msg}"
- )
- except AssertionError:
- LOGGER.error(
- f"{prefix}checks failed. Anomalies were detected with AMP on your system that may lead to "
- f"NaN losses or zero-mAP results, so AMP will be disabled during training."
- )
- return False
- return True
-
-
-def check_multiple_install():
- """Check if there are multiple Ultralytics installations."""
- import sys
-
- try:
- result = subprocess.run([sys.executable, "-m", "pip", "show", "ultralytics"], capture_output=True, text=True)
- install_msg = (
- f"Install your local copy in editable mode with 'pip install -e {ROOT.parent}' to avoid "
- "issues. See https://docs.ultralytics.com/quickstart/"
- )
- if result.returncode != 0:
- if "not found" in result.stderr.lower(): # Package not pip-installed but locally imported
- LOGGER.warning(f"Ultralytics not found via pip but importing from: {ROOT}. {install_msg}")
- return
- yolo_path = (Path(re.findall(r"location:\s+(.+)", result.stdout, flags=re.I)[-1]) / "ultralytics").resolve()
- if not yolo_path.samefile(ROOT.resolve()):
- LOGGER.warning(
- f"Multiple Ultralytics installations detected. The `yolo` command uses: {yolo_path}, "
- f"but current session imports from: {ROOT}. This may cause version conflicts. {install_msg}"
- )
- except Exception:
- return
-
-
-def print_args(args: dict | None = None, show_file=True, show_func=False):
- """
- Print function arguments (optional args dict).
-
- Args:
- args (dict, optional): Arguments to print.
- show_file (bool): Whether to show the file name.
- show_func (bool): Whether to show the function name.
- """
-
- def strip_auth(v):
- """Clean longer Ultralytics HUB URLs by stripping potential authentication information."""
- return clean_url(v) if (isinstance(v, str) and v.startswith("http") and len(v) > 100) else v
-
- x = inspect.currentframe().f_back # previous frame
- file, _, func, _, _ = inspect.getframeinfo(x)
- if args is None: # get args automatically
- args, _, _, frm = inspect.getargvalues(x)
- args = {k: v for k, v in frm.items() if k in args}
- try:
- file = Path(file).resolve().relative_to(ROOT).with_suffix("")
- except ValueError:
- file = Path(file).stem
- s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
- LOGGER.info(colorstr(s) + ", ".join(f"{k}={strip_auth(v)}" for k, v in sorted(args.items())))
-
-
-def cuda_device_count() -> int:
- """
- Get the number of NVIDIA GPUs available in the environment.
-
- Returns:
- (int): The number of NVIDIA GPUs available.
- """
- if IS_JETSON:
- # NVIDIA Jetson does not fully support nvidia-smi and therefore use PyTorch instead
- return torch.cuda.device_count()
- else:
- try:
- # Run the nvidia-smi command and capture its output
- output = subprocess.check_output(
- ["nvidia-smi", "--query-gpu=count", "--format=csv,noheader,nounits"], encoding="utf-8"
- )
-
- # Take the first line and strip any leading/trailing white space
- first_line = output.strip().split("\n", 1)[0]
-
- return int(first_line)
- except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
- # If the command fails, nvidia-smi is not found, or output is not an integer, assume no GPUs are available
- return 0
-
-
-def cuda_is_available() -> bool:
- """
- Check if CUDA is available in the environment.
-
- Returns:
- (bool): True if one or more NVIDIA GPUs are available, False otherwise.
- """
- return cuda_device_count() > 0
-
-
-def is_rockchip():
- """
- Check if the current environment is running on a Rockchip SoC.
-
- Returns:
- (bool): True if running on a Rockchip SoC, False otherwise.
- """
- if LINUX and ARM64:
- try:
- with open("/proc/device-tree/compatible") as f:
- dev_str = f.read()
- *_, soc = dev_str.split(",")
- if soc.replace("\x00", "") in RKNN_CHIPS:
- return True
- except OSError:
- return False
- else:
- return False
-
-
-def is_intel():
- """
- Check if the system has Intel hardware (CPU or GPU).
-
- Returns:
- (bool): True if Intel hardware is detected, False otherwise.
- """
- from ultralytics.utils.torch_utils import get_cpu_info
-
- # Check CPU
- if "intel" in get_cpu_info().lower():
- return True
-
- # Check GPU via xpu-smi
- try:
- result = subprocess.run(["xpu-smi", "discovery"], capture_output=True, text=True, timeout=5)
- return "intel" in result.stdout.lower()
- except Exception: # broad clause to capture all Intel GPU exception types
- return False
-
-
-def is_sudo_available() -> bool:
- """
- Check if the sudo command is available in the environment.
-
- Returns:
- (bool): True if the sudo command is available, False otherwise.
- """
- if WINDOWS:
- return False
- cmd = "sudo --version"
- return subprocess.run(cmd, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0
-
-
-# Run checks and define constants
-check_python("3.8", hard=False, verbose=True) # check python version
-check_torchvision() # check torch-torchvision compatibility
-
-# Define constants
-IS_PYTHON_3_8 = PYTHON_VERSION.startswith("3.8")
-IS_PYTHON_3_12 = PYTHON_VERSION.startswith("3.12")
-IS_PYTHON_3_13 = PYTHON_VERSION.startswith("3.13")
-
-IS_PYTHON_MINIMUM_3_10 = check_python("3.10", hard=False)
-IS_PYTHON_MINIMUM_3_12 = check_python("3.12", hard=False)
diff --git a/ultralytics/utils/cpu.py b/ultralytics/utils/cpu.py
deleted file mode 100644
index 0915df8..0000000
--- a/ultralytics/utils/cpu.py
+++ /dev/null
@@ -1,90 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import platform
-import re
-import subprocess
-import sys
-from pathlib import Path
-
-
-class CPUInfo:
- """
- Provide cross-platform CPU brand and model information.
-
- Query platform-specific sources to retrieve a human-readable CPU descriptor and normalize it for consistent
- presentation across macOS, Linux, and Windows. If platform-specific probing fails, generic platform identifiers are
- used to ensure a stable string is always returned.
-
- Methods:
- name: Return the normalized CPU name using platform-specific sources with robust fallbacks.
- _clean: Normalize and prettify common vendor brand strings and frequency patterns.
- __str__: Return the normalized CPU name for string contexts.
-
- Examples:
- >>> CPUInfo.name()
- 'Apple M4 Pro'
- >>> str(CPUInfo())
- 'Intel Core i7-9750H 2.60GHz'
- """
-
- @staticmethod
- def name() -> str:
- """Return a normalized CPU model string from platform-specific sources."""
- try:
- if sys.platform == "darwin":
- # Query macOS sysctl for the CPU brand string
- s = subprocess.run(
- ["sysctl", "-n", "machdep.cpu.brand_string"], capture_output=True, text=True
- ).stdout.strip()
- if s:
- return CPUInfo._clean(s)
- elif sys.platform.startswith("linux"):
- # Parse /proc/cpuinfo for the first "model name" entry
- p = Path("/proc/cpuinfo")
- if p.exists():
- for line in p.read_text(errors="ignore").splitlines():
- if "model name" in line:
- return CPUInfo._clean(line.split(":", 1)[1])
- elif sys.platform.startswith("win"):
- try:
- import winreg as wr
-
- with wr.OpenKey(wr.HKEY_LOCAL_MACHINE, r"HARDWARE\DESCRIPTION\System\CentralProcessor\0") as k:
- val, _ = wr.QueryValueEx(k, "ProcessorNameString")
- if val:
- return CPUInfo._clean(val)
- except Exception:
- # Fall through to generic platform fallbacks on Windows registry access failure
- pass
- # Generic platform fallbacks
- s = platform.processor() or getattr(platform.uname(), "processor", "") or platform.machine()
- return CPUInfo._clean(s or "Unknown CPU")
- except Exception:
- # Ensure a string is always returned even on unexpected failures
- s = platform.processor() or platform.machine() or ""
- return CPUInfo._clean(s or "Unknown CPU")
-
- @staticmethod
- def _clean(s: str) -> str:
- """Normalize and prettify a raw CPU descriptor string."""
- s = re.sub(r"\s+", " ", s.strip())
- s = s.replace("(TM)", "").replace("(tm)", "").replace("(R)", "").replace("(r)", "").strip()
- # Normalize common Intel pattern to 'Model Freq'
- m = re.search(r"(Intel.*?i\d[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
- if m:
- return f"{m.group(1)} {m.group(2)}"
- # Normalize common AMD Ryzen pattern to 'Model Freq'
- m = re.search(r"(AMD.*?Ryzen.*?[\w-]*) CPU @ ([\d.]+GHz)", s, re.I)
- if m:
- return f"{m.group(1)} {m.group(2)}"
- return s
-
- def __str__(self) -> str:
- """Return the normalized CPU name."""
- return self.name()
-
-
-if __name__ == "__main__":
- print(CPUInfo.name())
diff --git a/ultralytics/utils/dist.py b/ultralytics/utils/dist.py
deleted file mode 100644
index 30d7c04..0000000
--- a/ultralytics/utils/dist.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import os
-import shutil
-import sys
-import tempfile
-
-from . import USER_CONFIG_DIR
-from .torch_utils import TORCH_1_9
-
-
-def find_free_network_port() -> int:
- """
- Find a free port on localhost.
-
- It is useful in single-node training when we don't want to connect to a real main node but have to set the
- `MASTER_PORT` environment variable.
-
- Returns:
- (int): The available network port number.
- """
- import socket
-
- with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
- s.bind(("127.0.0.1", 0))
- return s.getsockname()[1] # port
-
-
-def generate_ddp_file(trainer):
- """
- Generate a DDP (Distributed Data Parallel) file for multi-GPU training.
-
- This function creates a temporary Python file that enables distributed training across multiple GPUs.
- The file contains the necessary configuration to initialize the trainer in a distributed environment.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing training configuration and arguments.
- Must have args attribute and be a class instance.
-
- Returns:
- (str): Path to the generated temporary DDP file.
-
- Notes:
- The generated file is saved in the USER_CONFIG_DIR/DDP directory and includes:
- - Trainer class import
- - Configuration overrides from the trainer arguments
- - Model path configuration
- - Training initialization code
- """
- module, name = f"{trainer.__class__.__module__}.{trainer.__class__.__name__}".rsplit(".", 1)
-
- content = f"""
-# Ultralytics Multi-GPU training temp file (should be automatically deleted after use)
-overrides = {vars(trainer.args)}
-
-if __name__ == "__main__":
- from {module} import {name}
- from ultralytics.utils import DEFAULT_CFG_DICT
-
- cfg = DEFAULT_CFG_DICT.copy()
- cfg.update(save_dir='') # handle the extra key 'save_dir'
- trainer = {name}(cfg=cfg, overrides=overrides)
- trainer.args.model = "{getattr(trainer.hub_session, "model_url", trainer.args.model)}"
- results = trainer.train()
-"""
- (USER_CONFIG_DIR / "DDP").mkdir(exist_ok=True)
- with tempfile.NamedTemporaryFile(
- prefix="_temp_",
- suffix=f"{id(trainer)}.py",
- mode="w+",
- encoding="utf-8",
- dir=USER_CONFIG_DIR / "DDP",
- delete=False,
- ) as file:
- file.write(content)
- return file.name
-
-
-def generate_ddp_command(trainer):
- """
- Generate command for distributed training.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer containing configuration for distributed training.
-
- Returns:
- cmd (list[str]): The command to execute for distributed training.
- file (str): Path to the temporary file created for DDP training.
- """
- import __main__ # noqa local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218
-
- if not trainer.resume:
- shutil.rmtree(trainer.save_dir) # remove the save_dir
- file = generate_ddp_file(trainer)
- dist_cmd = "torch.distributed.run" if TORCH_1_9 else "torch.distributed.launch"
- port = find_free_network_port()
- cmd = [
- sys.executable,
- "-m",
- dist_cmd,
- "--nproc_per_node",
- f"{trainer.world_size}",
- "--master_port",
- f"{port}",
- file,
- ]
- return cmd, file
-
-
-def ddp_cleanup(trainer, file):
- """
- Delete temporary file if created during distributed data parallel (DDP) training.
-
- This function checks if the provided file contains the trainer's ID in its name, indicating it was created
- as a temporary file for DDP training, and deletes it if so.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer used for distributed training.
- file (str): Path to the file that might need to be deleted.
-
- Examples:
- >>> trainer = YOLOTrainer()
- >>> file = "/tmp/ddp_temp_123456789.py"
- >>> ddp_cleanup(trainer, file)
- """
- if f"{id(trainer)}.py" in file: # if temp_file suffix in file
- os.remove(file)
diff --git a/ultralytics/utils/downloads.py b/ultralytics/utils/downloads.py
deleted file mode 100644
index 6257d21..0000000
--- a/ultralytics/utils/downloads.py
+++ /dev/null
@@ -1,541 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import re
-import shutil
-import subprocess
-from itertools import repeat
-from multiprocessing.pool import ThreadPool
-from pathlib import Path
-from urllib import parse, request
-
-from ultralytics.utils import LOGGER, TQDM, checks, clean_url, emojis, is_online, url2file
-
-# Define Ultralytics GitHub assets maintained at https://github.com/ultralytics/assets
-GITHUB_ASSETS_REPO = "ultralytics/assets"
-GITHUB_ASSETS_NAMES = frozenset(
- [f"yolov8{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb", "-oiv7")]
- + [f"yolo11{k}{suffix}.pt" for k in "nsmlx" for suffix in ("", "-cls", "-seg", "-pose", "-obb")]
- + [f"yolo12{k}{suffix}.pt" for k in "nsmlx" for suffix in ("",)] # detect models only currently
- + [f"yolov5{k}{resolution}u.pt" for k in "nsmlx" for resolution in ("", "6")]
- + [f"yolov3{k}u.pt" for k in ("", "-spp", "-tiny")]
- + [f"yolov8{k}-world.pt" for k in "smlx"]
- + [f"yolov8{k}-worldv2.pt" for k in "smlx"]
- + [f"yoloe-v8{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
- + [f"yoloe-11{k}{suffix}.pt" for k in "sml" for suffix in ("-seg", "-seg-pf")]
- + [f"yolov9{k}.pt" for k in "tsmce"]
- + [f"yolov10{k}.pt" for k in "nsmblx"]
- + [f"yolo_nas_{k}.pt" for k in "sml"]
- + [f"sam_{k}.pt" for k in "bl"]
- + [f"sam2_{k}.pt" for k in "blst"]
- + [f"sam2.1_{k}.pt" for k in "blst"]
- + [f"FastSAM-{k}.pt" for k in "sx"]
- + [f"rtdetr-{k}.pt" for k in "lx"]
- + [
- "mobile_sam.pt",
- "mobileclip_blt.ts",
- "yolo11n-grayscale.pt",
- "calibration_image_sample_data_20x128x128x3_float32.npy.zip",
- ]
-)
-GITHUB_ASSETS_STEMS = frozenset(k.rpartition(".")[0] for k in GITHUB_ASSETS_NAMES)
-
-
-def is_url(url: str | Path, check: bool = False) -> bool:
- """
- Validate if the given string is a URL and optionally check if the URL exists online.
-
- Args:
- url (str): The string to be validated as a URL.
- check (bool, optional): If True, performs an additional check to see if the URL exists online.
-
- Returns:
- (bool): True for a valid URL. If 'check' is True, also returns True if the URL exists online.
-
- Examples:
- >>> valid = is_url("https://www.example.com")
- >>> valid_and_exists = is_url("https://www.example.com", check=True)
- """
- try:
- url = str(url)
- result = parse.urlparse(url)
- assert all([result.scheme, result.netloc]) # check if is url
- if check:
- with request.urlopen(url) as response:
- return response.getcode() == 200 # check if exists online
- return True
- except Exception:
- return False
-
-
-def delete_dsstore(path: str | Path, files_to_delete: tuple[str, ...] = (".DS_Store", "__MACOSX")) -> None:
- """
- Delete all specified system files in a directory.
-
- Args:
- path (str | Path): The directory path where the files should be deleted.
- files_to_delete (tuple): The files to be deleted.
-
- Examples:
- >>> from ultralytics.utils.downloads import delete_dsstore
- >>> delete_dsstore("path/to/dir")
-
- Notes:
- ".DS_store" files are created by the Apple operating system and contain metadata about folders and files. They
- are hidden system files and can cause issues when transferring files between different operating systems.
- """
- for file in files_to_delete:
- matches = list(Path(path).rglob(file))
- LOGGER.info(f"Deleting {file} files: {matches}")
- for f in matches:
- f.unlink()
-
-
-def zip_directory(
- directory: str | Path,
- compress: bool = True,
- exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
- progress: bool = True,
-) -> Path:
- """
- Zip the contents of a directory, excluding specified files.
-
- The resulting zip file is named after the directory and placed alongside it.
-
- Args:
- directory (str | Path): The path to the directory to be zipped.
- compress (bool): Whether to compress the files while zipping.
- exclude (tuple, optional): A tuple of filename strings to be excluded.
- progress (bool, optional): Whether to display a progress bar.
-
- Returns:
- (Path): The path to the resulting zip file.
-
- Examples:
- >>> from ultralytics.utils.downloads import zip_directory
- >>> file = zip_directory("path/to/dir")
- """
- from zipfile import ZIP_DEFLATED, ZIP_STORED, ZipFile
-
- delete_dsstore(directory)
- directory = Path(directory)
- if not directory.is_dir():
- raise FileNotFoundError(f"Directory '{directory}' does not exist.")
-
- # Zip with progress bar
- files = [f for f in directory.rglob("*") if f.is_file() and all(x not in f.name for x in exclude)] # files to zip
- zip_file = directory.with_suffix(".zip")
- compression = ZIP_DEFLATED if compress else ZIP_STORED
- with ZipFile(zip_file, "w", compression) as f:
- for file in TQDM(files, desc=f"Zipping {directory} to {zip_file}...", unit="files", disable=not progress):
- f.write(file, file.relative_to(directory))
-
- return zip_file # return path to zip file
-
-
-def unzip_file(
- file: str | Path,
- path: str | Path | None = None,
- exclude: tuple[str, ...] = (".DS_Store", "__MACOSX"),
- exist_ok: bool = False,
- progress: bool = True,
-) -> Path:
- """
- Unzip a *.zip file to the specified path, excluding specified files.
-
- If the zipfile does not contain a single top-level directory, the function will create a new
- directory with the same name as the zipfile (without the extension) to extract its contents.
- If a path is not provided, the function will use the parent directory of the zipfile as the default path.
-
- Args:
- file (str | Path): The path to the zipfile to be extracted.
- path (str | Path, optional): The path to extract the zipfile to.
- exclude (tuple, optional): A tuple of filename strings to be excluded.
- exist_ok (bool, optional): Whether to overwrite existing contents if they exist.
- progress (bool, optional): Whether to display a progress bar.
-
- Returns:
- (Path): The path to the directory where the zipfile was extracted.
-
- Raises:
- BadZipFile: If the provided file does not exist or is not a valid zipfile.
-
- Examples:
- >>> from ultralytics.utils.downloads import unzip_file
- >>> directory = unzip_file("path/to/file.zip")
- """
- from zipfile import BadZipFile, ZipFile, is_zipfile
-
- if not (Path(file).exists() and is_zipfile(file)):
- raise BadZipFile(f"File '{file}' does not exist or is a bad zip file.")
- if path is None:
- path = Path(file).parent # default path
-
- # Unzip the file contents
- with ZipFile(file) as zipObj:
- files = [f for f in zipObj.namelist() if all(x not in f for x in exclude)]
- top_level_dirs = {Path(f).parts[0] for f in files}
-
- # Decide to unzip directly or unzip into a directory
- unzip_as_dir = len(top_level_dirs) == 1 # (len(files) > 1 and not files[0].endswith("/"))
- if unzip_as_dir:
- # Zip has 1 top-level directory
- extract_path = path # i.e. ../datasets
- path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/
- else:
- # Zip has multiple files at top level
- path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/
-
- # Check if destination directory already exists and contains files
- if path.exists() and any(path.iterdir()) and not exist_ok:
- # If it exists and is not empty, return the path without unzipping
- LOGGER.warning(f"Skipping {file} unzip as destination directory {path} is not empty.")
- return path
-
- for f in TQDM(files, desc=f"Unzipping {file} to {Path(path).resolve()}...", unit="files", disable=not progress):
- # Ensure the file is within the extract_path to avoid path traversal security vulnerability
- if ".." in Path(f).parts:
- LOGGER.warning(f"Potentially insecure file path: {f}, skipping extraction.")
- continue
- zipObj.extract(f, extract_path)
-
- return path # return unzip dir
-
-
-def check_disk_space(
- file_bytes: int,
- path: str | Path = Path.cwd(),
- sf: float = 1.5,
- hard: bool = True,
-) -> bool:
- """
- Check if there is sufficient disk space to download and store a file.
-
- Args:
- file_bytes (int): The file size in bytes.
- path (str | Path, optional): The path or drive to check the available free space on.
- sf (float, optional): Safety factor, the multiplier for the required free space.
- hard (bool, optional): Whether to throw an error or not on insufficient disk space.
-
- Returns:
- (bool): True if there is sufficient disk space, False otherwise.
- """
- total, used, free = shutil.disk_usage(path) # bytes
- if file_bytes * sf < free:
- return True # sufficient space
-
- # Insufficient space
- text = (
- f"Insufficient free disk space {free >> 30:.3f} GB < {int(file_bytes * sf) >> 30:.3f} GB required, "
- f"Please free {int(file_bytes * sf - free) >> 30:.3f} GB additional disk space and try again."
- )
- if hard:
- raise MemoryError(text)
- LOGGER.warning(text)
- return False
-
-
-def get_google_drive_file_info(link: str) -> tuple[str, str | None]:
- """
- Retrieve the direct download link and filename for a shareable Google Drive file link.
-
- Args:
- link (str): The shareable link of the Google Drive file.
-
- Returns:
- url (str): Direct download URL for the Google Drive file.
- filename (str | None): Original filename of the Google Drive file. If filename extraction fails, returns None.
-
- Examples:
- >>> from ultralytics.utils.downloads import get_google_drive_file_info
- >>> link = "https://drive.google.com/file/d/1cqT-cJgANNrhIHCrEufUYhQ4RqiWG_lJ/view?usp=drive_link"
- >>> url, filename = get_google_drive_file_info(link)
- """
- import requests # scoped as slow import
-
- file_id = link.split("/d/")[1].split("/view", 1)[0]
- drive_url = f"https://drive.google.com/uc?export=download&id={file_id}"
- filename = None
-
- # Start session
- with requests.Session() as session:
- response = session.get(drive_url, stream=True)
- if "quota exceeded" in str(response.content.lower()):
- raise ConnectionError(
- emojis(
- f"❌ Google Drive file download quota exceeded. "
- f"Please try again later or download this file manually at {link}."
- )
- )
- for k, v in response.cookies.items():
- if k.startswith("download_warning"):
- drive_url += f"&confirm={v}" # v is token
- if cd := response.headers.get("content-disposition"):
- filename = re.findall('filename="(.+)"', cd)[0]
- return drive_url, filename
-
-
-def safe_download(
- url: str | Path,
- file: str | Path | None = None,
- dir: str | Path | None = None,
- unzip: bool = True,
- delete: bool = False,
- curl: bool = False,
- retry: int = 3,
- min_bytes: float = 1e0,
- exist_ok: bool = False,
- progress: bool = True,
-) -> Path | str:
- """
- Download files from a URL with options for retrying, unzipping, and deleting the downloaded file. Enhanced with
- robust partial download detection using Content-Length validation.
-
- Args:
- url (str): The URL of the file to be downloaded.
- file (str, optional): The filename of the downloaded file.
- If not provided, the file will be saved with the same name as the URL.
- dir (str | Path, optional): The directory to save the downloaded file.
- If not provided, the file will be saved in the current working directory.
- unzip (bool, optional): Whether to unzip the downloaded file.
- delete (bool, optional): Whether to delete the downloaded file after unzipping.
- curl (bool, optional): Whether to use curl command line tool for downloading.
- retry (int, optional): The number of times to retry the download in case of failure.
- min_bytes (float, optional): The minimum number of bytes that the downloaded file should have, to be considered
- a successful download.
- exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
- progress (bool, optional): Whether to display a progress bar during the download.
-
- Returns:
- (Path | str): The path to the downloaded file or extracted directory.
-
- Examples:
- >>> from ultralytics.utils.downloads import safe_download
- >>> link = "https://ultralytics.com/assets/bus.jpg"
- >>> path = safe_download(link)
- """
- gdrive = url.startswith("https://drive.google.com/") # check if the URL is a Google Drive link
- if gdrive:
- url, file = get_google_drive_file_info(url)
-
- f = Path(dir or ".") / (file or url2file(url)) # URL converted to filename
- if "://" not in str(url) and Path(url).is_file(): # URL exists ('://' check required in Windows Python<3.10)
- f = Path(url) # filename
- elif not f.is_file(): # URL and file do not exist
- uri = (url if gdrive else clean_url(url)).replace( # cleaned and aliased url
- "https://github.com/ultralytics/assets/releases/download/v0.0.0/",
- "https://ultralytics.com/assets/", # assets alias
- )
- desc = f"Downloading {uri} to '{f}'"
- f.parent.mkdir(parents=True, exist_ok=True) # make directory if missing
- curl_installed = shutil.which("curl")
- for i in range(retry + 1):
- try:
- if (curl or i > 0) and curl_installed: # curl download with retry, continue
- s = "sS" * (not progress) # silent
- r = subprocess.run(["curl", "-#", f"-{s}L", url, "-o", f, "--retry", "3", "-C", "-"]).returncode
- assert r == 0, f"Curl return value {r}"
- expected_size = None # Can't get size with curl
- else: # urllib download
- with request.urlopen(url) as response:
- expected_size = int(response.getheader("Content-Length", 0))
- if i == 0 and expected_size > 1048576:
- check_disk_space(expected_size, path=f.parent)
- buffer_size = max(8192, min(1048576, expected_size // 1000)) if expected_size else 8192
- with TQDM(
- total=expected_size,
- desc=desc,
- disable=not progress,
- unit="B",
- unit_scale=True,
- unit_divisor=1024,
- ) as pbar:
- with open(f, "wb") as f_opened:
- while True:
- data = response.read(buffer_size)
- if not data:
- break
- f_opened.write(data)
- pbar.update(len(data))
-
- if f.exists():
- file_size = f.stat().st_size
- if file_size > min_bytes:
- # Check if download is complete (only if we have expected_size)
- if expected_size and file_size != expected_size:
- LOGGER.warning(
- f"Partial download: {file_size}/{expected_size} bytes ({file_size / expected_size * 100:.1f}%)"
- )
- else:
- break # success
- f.unlink() # remove partial downloads
- except MemoryError:
- raise # Re-raise immediately - no point retrying if insufficient disk space
- except Exception as e:
- if i == 0 and not is_online():
- raise ConnectionError(emojis(f"❌ Download failure for {uri}. Environment is not online.")) from e
- elif i >= retry:
- raise ConnectionError(emojis(f"❌ Download failure for {uri}. Retry limit reached.")) from e
- LOGGER.warning(f"Download failure, retrying {i + 1}/{retry} {uri}...")
-
- if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
- from zipfile import is_zipfile
-
- unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
- if is_zipfile(f):
- unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
- elif f.suffix in {".tar", ".gz"}:
- LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
- subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
- if delete:
- f.unlink() # remove zip
- return unzip_dir
- return f
-
-
-def get_github_assets(
- repo: str = "ultralytics/assets",
- version: str = "latest",
- retry: bool = False,
-) -> tuple[str, list[str]]:
- """
- Retrieve the specified version's tag and assets from a GitHub repository.
-
- If the version is not specified, the function fetches the latest release assets.
-
- Args:
- repo (str, optional): The GitHub repository in the format 'owner/repo'.
- version (str, optional): The release version to fetch assets from.
- retry (bool, optional): Flag to retry the request in case of a failure.
-
- Returns:
- tag (str): The release tag.
- assets (list[str]): A list of asset names.
-
- Examples:
- >>> tag, assets = get_github_assets(repo="ultralytics/assets", version="latest")
- """
- import requests # scoped as slow import
-
- if version != "latest":
- version = f"tags/{version}" # i.e. tags/v6.2
- url = f"https://api.github.com/repos/{repo}/releases/{version}"
- r = requests.get(url) # github api
- if r.status_code != 200 and r.reason != "rate limit exceeded" and retry: # failed and not 403 rate limit exceeded
- r = requests.get(url) # try again
- if r.status_code != 200:
- LOGGER.warning(f"GitHub assets check failure for {url}: {r.status_code} {r.reason}")
- return "", []
- data = r.json()
- return data["tag_name"], [x["name"] for x in data["assets"]] # tag, assets i.e. ['yolo11n.pt', 'yolov8s.pt', ...]
-
-
-def attempt_download_asset(
- file: str | Path,
- repo: str = "ultralytics/assets",
- release: str = "v8.3.0",
- **kwargs,
-) -> str:
- """
- Attempt to download a file from GitHub release assets if it is not found locally.
-
- Args:
- file (str | Path): The filename or file path to be downloaded.
- repo (str, optional): The GitHub repository in the format 'owner/repo'.
- release (str, optional): The specific release version to be downloaded.
- **kwargs (Any): Additional keyword arguments for the download process.
-
- Returns:
- (str): The path to the downloaded file.
-
- Examples:
- >>> file_path = attempt_download_asset("yolo11n.pt", repo="ultralytics/assets", release="latest")
- """
- from ultralytics.utils import SETTINGS # scoped for circular import
-
- # YOLOv3/5u updates
- file = str(file)
- file = checks.check_yolov5u_filename(file)
- file = Path(file.strip().replace("'", ""))
- if file.exists():
- return str(file)
- elif (SETTINGS["weights_dir"] / file).exists():
- return str(SETTINGS["weights_dir"] / file)
- else:
- # URL specified
- name = Path(parse.unquote(str(file))).name # decode '%2F' to '/' etc.
- download_url = f"https://github.com/{repo}/releases/download"
- if str(file).startswith(("http:/", "https:/")): # download
- url = str(file).replace(":/", "://") # Pathlib turns :// -> :/
- file = url2file(name) # parse authentication https://url.com/file.txt?auth...
- if Path(file).is_file():
- LOGGER.info(f"Found {clean_url(url)} locally at {file}") # file already exists
- else:
- safe_download(url=url, file=file, min_bytes=1e5, **kwargs)
-
- elif repo == GITHUB_ASSETS_REPO and name in GITHUB_ASSETS_NAMES:
- safe_download(url=f"{download_url}/{release}/{name}", file=file, min_bytes=1e5, **kwargs)
-
- else:
- tag, assets = get_github_assets(repo, release)
- if not assets:
- tag, assets = get_github_assets(repo) # latest release
- if name in assets:
- safe_download(url=f"{download_url}/{tag}/{name}", file=file, min_bytes=1e5, **kwargs)
-
- return str(file)
-
-
-def download(
- url: str | list[str] | Path,
- dir: Path = Path.cwd(),
- unzip: bool = True,
- delete: bool = False,
- curl: bool = False,
- threads: int = 1,
- retry: int = 3,
- exist_ok: bool = False,
-) -> None:
- """
- Download files from specified URLs to a given directory.
-
- Supports concurrent downloads if multiple threads are specified.
-
- Args:
- url (str | list[str]): The URL or list of URLs of the files to be downloaded.
- dir (Path, optional): The directory where the files will be saved.
- unzip (bool, optional): Flag to unzip the files after downloading.
- delete (bool, optional): Flag to delete the zip files after extraction.
- curl (bool, optional): Flag to use curl for downloading.
- threads (int, optional): Number of threads to use for concurrent downloads.
- retry (int, optional): Number of retries in case of download failure.
- exist_ok (bool, optional): Whether to overwrite existing contents during unzipping.
-
- Examples:
- >>> download("https://ultralytics.com/assets/example.zip", dir="path/to/dir", unzip=True)
- """
- dir = Path(dir)
- dir.mkdir(parents=True, exist_ok=True) # make directory
- urls = [url] if isinstance(url, (str, Path)) else url
- if threads > 1:
- LOGGER.info(f"Downloading {len(urls)} file(s) with {threads} threads to {dir}...")
- with ThreadPool(threads) as pool:
- pool.map(
- lambda x: safe_download(
- url=x[0],
- dir=x[1],
- unzip=unzip,
- delete=delete,
- curl=curl,
- retry=retry,
- exist_ok=exist_ok,
- progress=True,
- ),
- zip(urls, repeat(dir)),
- )
- pool.close()
- pool.join()
- else:
- for u in urls:
- safe_download(url=u, dir=dir, unzip=unzip, delete=delete, curl=curl, retry=retry, exist_ok=exist_ok)
diff --git a/ultralytics/utils/errors.py b/ultralytics/utils/errors.py
deleted file mode 100644
index 036c23e..0000000
--- a/ultralytics/utils/errors.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.utils import emojis
-
-
-class HUBModelError(Exception):
- """
- Exception raised when a model cannot be found or retrieved from ultralytics HUB.
-
- This custom exception is used specifically for handling errors related to model fetching in Ultralytics YOLO.
- The error message is processed to include emojis for better user experience.
-
- Attributes:
- message (str): The error message displayed when the exception is raised.
-
- Methods:
- __init__: Initialize the HUBModelError with a custom message.
-
- Examples:
- >>> try:
- ... # Code that might fail to find a model
- ... raise HUBModelError("Custom model not found message")
- ... except HUBModelError as e:
- ... print(e) # Displays the emoji-enhanced error message
- """
-
- def __init__(self, message: str = "Model not found. Please check model URL and try again."):
- """
- Initialize a HUBModelError exception.
-
- This exception is raised when a requested model is not found or cannot be retrieved from ultralytics HUB.
- The message is processed to include emojis for better user experience.
-
- Args:
- message (str, optional): The error message to display when the exception is raised.
-
- Examples:
- >>> try:
- ... raise HUBModelError("Custom model error message")
- ... except HUBModelError as e:
- ... print(e)
- """
- super().__init__(emojis(message))
diff --git a/ultralytics/utils/events.py b/ultralytics/utils/events.py
deleted file mode 100644
index d267911..0000000
--- a/ultralytics/utils/events.py
+++ /dev/null
@@ -1,115 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import json
-import random
-import time
-from pathlib import Path
-from threading import Thread
-from urllib.request import Request, urlopen
-
-from ultralytics import SETTINGS, __version__
-from ultralytics.utils import ARGV, ENVIRONMENT, GIT, IS_PIP_PACKAGE, ONLINE, PYTHON_VERSION, RANK, TESTS_RUNNING
-from ultralytics.utils.downloads import GITHUB_ASSETS_NAMES
-from ultralytics.utils.torch_utils import get_cpu_info
-
-
-def _post(url: str, data: dict, timeout: float = 5.0) -> None:
- """Send a one-shot JSON POST request."""
- try:
- body = json.dumps(data, separators=(",", ":")).encode() # compact JSON
- req = Request(url, data=body, headers={"Content-Type": "application/json"})
- urlopen(req, timeout=timeout).close()
- except Exception:
- pass
-
-
-class Events:
- """
- Collect and send anonymous usage analytics with rate-limiting.
-
- Event collection and transmission are enabled when sync is enabled in settings, the current process is rank -1 or 0,
- tests are not running, the environment is online, and the installation source is either pip or the official
- Ultralytics GitHub repository.
-
- Attributes:
- url (str): Measurement Protocol endpoint for receiving anonymous events.
- events (list[dict]): In-memory queue of event payloads awaiting transmission.
- rate_limit (float): Minimum time in seconds between POST requests.
- t (float): Timestamp of the last transmission in seconds since the epoch.
- metadata (dict): Static metadata describing runtime, installation source, and environment.
- enabled (bool): Flag indicating whether analytics collection is active.
-
- Methods:
- __init__: Initialize the event queue, rate limiter, and runtime metadata.
- __call__: Queue an event and trigger a non-blocking send when the rate limit elapses.
- """
-
- url = "https://www.google-analytics.com/mp/collect?measurement_id=G-X8NCJYTQXM&api_secret=QLQrATrNSwGRFRLE-cbHJw"
-
- def __init__(self) -> None:
- """Initialize the Events instance with queue, rate limiter, and environment metadata."""
- self.events = [] # pending events
- self.rate_limit = 30.0 # rate limit (seconds)
- self.t = 0.0 # last send timestamp (seconds)
- self.metadata = {
- "cli": Path(ARGV[0]).name == "yolo",
- "install": "git" if GIT.is_repo else "pip" if IS_PIP_PACKAGE else "other",
- "python": PYTHON_VERSION.rsplit(".", 1)[0], # i.e. 3.13
- "CPU": get_cpu_info(),
- # "GPU": get_gpu_info(index=0) if cuda else None,
- "version": __version__,
- "env": ENVIRONMENT,
- "session_id": round(random.random() * 1e15),
- "engagement_time_msec": 1000,
- }
- self.enabled = (
- SETTINGS["sync"]
- and RANK in {-1, 0}
- and not TESTS_RUNNING
- and ONLINE
- and (IS_PIP_PACKAGE or GIT.origin == "https://github.com/ultralytics/ultralytics.git")
- )
-
- def __call__(self, cfg, device=None) -> None:
- """
- Queue an event and flush the queue asynchronously when the rate limit elapses.
-
- Args:
- cfg (IterableSimpleNamespace): The configuration object containing mode and task information.
- device (torch.device | str, optional): The device type (e.g., 'cpu', 'cuda').
- """
- if not self.enabled:
- # Events disabled, do nothing
- return
-
- # Attempt to enqueue a new event
- if len(self.events) < 25: # Queue limited to 25 events to bound memory and traffic
- params = {
- **self.metadata,
- "task": cfg.task,
- "model": cfg.model if cfg.model in GITHUB_ASSETS_NAMES else "custom",
- "device": str(device),
- }
- if cfg.mode == "export":
- params["format"] = cfg.format
- self.events.append({"name": cfg.mode, "params": params})
-
- # Check rate limit and return early if under limit
- t = time.time()
- if (t - self.t) < self.rate_limit:
- return
-
- # Overrate limit: send a snapshot of queued events in a background thread
- payload_events = list(self.events) # snapshot to avoid race with queue reset
- Thread(
- target=_post,
- args=(self.url, {"client_id": SETTINGS["uuid"], "events": payload_events}), # SHA-256 anonymized
- daemon=True,
- ).start()
-
- # Reset queue and rate limit timer
- self.events = []
- self.t = t
-
-
-events = Events()
diff --git a/ultralytics/utils/export/__init__.py b/ultralytics/utils/export/__init__.py
deleted file mode 100644
index 5e028e6..0000000
--- a/ultralytics/utils/export/__init__.py
+++ /dev/null
@@ -1,239 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import json
-from pathlib import Path
-
-import torch
-
-from ultralytics.utils import IS_JETSON, LOGGER
-
-from .imx import torch2imx # noqa
-
-
-def torch2onnx(
- torch_model: torch.nn.Module,
- im: torch.Tensor,
- onnx_file: str,
- opset: int = 14,
- input_names: list[str] = ["images"],
- output_names: list[str] = ["output0"],
- dynamic: bool | dict = False,
-) -> None:
- """
- Export a PyTorch model to ONNX format.
-
- Args:
- torch_model (torch.nn.Module): The PyTorch model to export.
- im (torch.Tensor): Example input tensor for the model.
- onnx_file (str): Path to save the exported ONNX file.
- opset (int): ONNX opset version to use for export.
- input_names (list[str]): List of input tensor names.
- output_names (list[str]): List of output tensor names.
- dynamic (bool | dict, optional): Whether to enable dynamic axes.
-
- Notes:
- Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
- """
- torch.onnx.export(
- torch_model,
- im,
- onnx_file,
- verbose=False,
- opset_version=opset,
- do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic or None,
- )
-
-
-def onnx2engine(
- onnx_file: str,
- engine_file: str | None = None,
- workspace: int | None = None,
- half: bool = False,
- int8: bool = False,
- dynamic: bool = False,
- shape: tuple[int, int, int, int] = (1, 3, 640, 640),
- dla: int | None = None,
- dataset=None,
- metadata: dict | None = None,
- verbose: bool = False,
- prefix: str = "",
-) -> None:
- """
- Export a YOLO model to TensorRT engine format.
-
- Args:
- onnx_file (str): Path to the ONNX file to be converted.
- engine_file (str, optional): Path to save the generated TensorRT engine file.
- workspace (int, optional): Workspace size in GB for TensorRT.
- half (bool, optional): Enable FP16 precision.
- int8 (bool, optional): Enable INT8 precision.
- dynamic (bool, optional): Enable dynamic input shapes.
- shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
- dla (int, optional): DLA core to use (Jetson devices only).
- dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
- metadata (dict, optional): Metadata to include in the engine file.
- verbose (bool, optional): Enable verbose logging.
- prefix (str, optional): Prefix for log messages.
-
- Raises:
- ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
- RuntimeError: If the ONNX file cannot be parsed.
-
- Notes:
- TensorRT version compatibility is handled for workspace size and engine building.
- INT8 calibration requires a dataset and generates a calibration cache.
- Metadata is serialized and written to the engine file if provided.
- """
- import tensorrt as trt # noqa
-
- engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
-
- logger = trt.Logger(trt.Logger.INFO)
- if verbose:
- logger.min_severity = trt.Logger.Severity.VERBOSE
-
- # Engine builder
- builder = trt.Builder(logger)
- config = builder.create_builder_config()
- workspace_bytes = int((workspace or 0) * (1 << 30))
- is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
- if is_trt10 and workspace_bytes > 0:
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
- elif workspace_bytes > 0: # TensorRT versions 7, 8
- config.max_workspace_size = workspace_bytes
- flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
- network = builder.create_network(flag)
- half = builder.platform_has_fast_fp16 and half
- int8 = builder.platform_has_fast_int8 and int8
-
- # Optionally switch to DLA if enabled
- if dla is not None:
- if not IS_JETSON:
- raise ValueError("DLA is only available on NVIDIA Jetson devices")
- LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
- if not half and not int8:
- raise ValueError(
- "DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
- )
- config.default_device_type = trt.DeviceType.DLA
- config.DLA_core = int(dla)
- config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
-
- # Read ONNX file
- parser = trt.OnnxParser(network, logger)
- if not parser.parse_from_file(onnx_file):
- raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
-
- # Network inputs
- inputs = [network.get_input(i) for i in range(network.num_inputs)]
- outputs = [network.get_output(i) for i in range(network.num_outputs)]
- for inp in inputs:
- LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
- for out in outputs:
- LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
-
- if dynamic:
- profile = builder.create_optimization_profile()
- min_shape = (1, shape[1], 32, 32) # minimum input shape
- max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
- for inp in inputs:
- profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
- config.add_optimization_profile(profile)
- if int8:
- config.set_calibration_profile(profile)
-
- LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
- if int8:
- config.set_flag(trt.BuilderFlag.INT8)
- config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
-
- class EngineCalibrator(trt.IInt8Calibrator):
- """
- Custom INT8 calibrator for TensorRT engine optimization.
-
- This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
- using a dataset. It handles batch generation, caching, and calibration algorithm selection.
-
- Attributes:
- dataset: Dataset for calibration.
- data_iter: Iterator over the calibration dataset.
- algo (trt.CalibrationAlgoType): Calibration algorithm type.
- batch (int): Batch size for calibration.
- cache (Path): Path to save the calibration cache.
-
- Methods:
- get_algorithm: Get the calibration algorithm to use.
- get_batch_size: Get the batch size to use for calibration.
- get_batch: Get the next batch to use for calibration.
- read_calibration_cache: Use existing cache instead of calibrating again.
- write_calibration_cache: Write calibration cache to disk.
- """
-
- def __init__(
- self,
- dataset, # ultralytics.data.build.InfiniteDataLoader
- cache: str = "",
- ) -> None:
- """Initialize the INT8 calibrator with dataset and cache path."""
- trt.IInt8Calibrator.__init__(self)
- self.dataset = dataset
- self.data_iter = iter(dataset)
- self.algo = (
- trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
- if dla is not None
- else trt.CalibrationAlgoType.MINMAX_CALIBRATION
- )
- self.batch = dataset.batch_size
- self.cache = Path(cache)
-
- def get_algorithm(self) -> trt.CalibrationAlgoType:
- """Get the calibration algorithm to use."""
- return self.algo
-
- def get_batch_size(self) -> int:
- """Get the batch size to use for calibration."""
- return self.batch or 1
-
- def get_batch(self, names) -> list[int] | None:
- """Get the next batch to use for calibration, as a list of device memory pointers."""
- try:
- im0s = next(self.data_iter)["img"] / 255.0
- im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
- return [int(im0s.data_ptr())]
- except StopIteration:
- # Return None to signal to TensorRT there is no calibration data remaining
- return None
-
- def read_calibration_cache(self) -> bytes | None:
- """Use existing cache instead of calibrating again, otherwise, implicitly return None."""
- if self.cache.exists() and self.cache.suffix == ".cache":
- return self.cache.read_bytes()
-
- def write_calibration_cache(self, cache: bytes) -> None:
- """Write calibration cache to disk."""
- _ = self.cache.write_bytes(cache)
-
- # Load dataset w/ builder (for batching) and calibrate
- config.int8_calibrator = EngineCalibrator(
- dataset=dataset,
- cache=str(Path(onnx_file).with_suffix(".cache")),
- )
-
- elif half:
- config.set_flag(trt.BuilderFlag.FP16)
-
- # Write file
- build = builder.build_serialized_network if is_trt10 else builder.build_engine
- with build(network, config) as engine, open(engine_file, "wb") as t:
- # Metadata
- if metadata is not None:
- meta = json.dumps(metadata)
- t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
- t.write(meta.encode())
- # Model
- t.write(engine if is_trt10 else engine.serialize())
diff --git a/ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc b/ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc
deleted file mode 100644
index bb56cb5..0000000
Binary files a/ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/export/__pycache__/imx.cpython-310.pyc b/ultralytics/utils/export/__pycache__/imx.cpython-310.pyc
deleted file mode 100644
index 7bb0716..0000000
Binary files a/ultralytics/utils/export/__pycache__/imx.cpython-310.pyc and /dev/null differ
diff --git a/ultralytics/utils/export/imx.py b/ultralytics/utils/export/imx.py
deleted file mode 100644
index a72ea31..0000000
--- a/ultralytics/utils/export/imx.py
+++ /dev/null
@@ -1,289 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import subprocess
-import types
-from pathlib import Path
-
-import torch
-
-from ultralytics.nn.modules import Detect, Pose
-from ultralytics.utils import LOGGER
-from ultralytics.utils.tal import make_anchors
-from ultralytics.utils.torch_utils import copy_attr
-
-
-class FXModel(torch.nn.Module):
- """
- A custom model class for torch.fx compatibility.
-
- This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
- manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
- copying.
-
- Attributes:
- model (nn.Module): The original model's layers.
- """
-
- def __init__(self, model, imgsz=(640, 640)):
- """
- Initialize the FXModel.
-
- Args:
- model (nn.Module): The original model to wrap for torch.fx compatibility.
- imgsz (tuple[int, int]): The input image size (height, width). Default is (640, 640).
- """
- super().__init__()
- copy_attr(self, model)
- # Explicitly set `model` since `copy_attr` somehow does not copy it.
- self.model = model.model
- self.imgsz = imgsz
-
- def forward(self, x):
- """
- Forward pass through the model.
-
- This method performs the forward pass through the model, handling the dependencies between layers and saving
- intermediate outputs.
-
- Args:
- x (torch.Tensor): The input tensor to the model.
-
- Returns:
- (torch.Tensor): The output tensor from the model.
- """
- y = [] # outputs
- for m in self.model:
- if m.f != -1: # if not from previous layer
- # from earlier layers
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
- if isinstance(m, Detect):
- m._inference = types.MethodType(_inference, m) # bind method to Detect
- m.anchors, m.strides = (
- x.transpose(0, 1)
- for x in make_anchors(
- torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
- )
- )
- if type(m) is Pose:
- m.forward = types.MethodType(pose_forward, m) # bind method to Detect
- x = m(x) # run
- y.append(x) # save output
- return x
-
-
-def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
- """Decode boxes and cls scores for imx object detection."""
- x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
- box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
- dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
- return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
-
-
-def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """Forward pass for imx pose estimation, including keypoint decoding."""
- bs = x[0].shape[0] # batch size
- kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
- x = Detect.forward(self, x)
- pred_kpt = self.kpts_decode(bs, kpt)
- return (*x, pred_kpt.permute(0, 2, 1))
-
-
-class NMSWrapper(torch.nn.Module):
- """Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
-
- def __init__(
- self,
- model: torch.nn.Module,
- score_threshold: float = 0.001,
- iou_threshold: float = 0.7,
- max_detections: int = 300,
- task: str = "detect",
- ):
- """
- Initialize NMSWrapper with PyTorch Module and NMS parameters.
-
- Args:
- model (torch.nn.Module): Model instance.
- score_threshold (float): Score threshold for non-maximum suppression.
- iou_threshold (float): Intersection over union threshold for non-maximum suppression.
- max_detections (int): The number of detections to return.
- task (str): Task type, either 'detect' or 'pose'.
- """
- super().__init__()
- self.model = model
- self.score_threshold = score_threshold
- self.iou_threshold = iou_threshold
- self.max_detections = max_detections
- self.task = task
-
- def forward(self, images):
- """Forward pass with model inference and NMS post-processing."""
- from sony_custom_layers.pytorch import multiclass_nms_with_indices
-
- # model inference
- outputs = self.model(images)
- boxes, scores = outputs[0], outputs[1]
- nms_outputs = multiclass_nms_with_indices(
- boxes=boxes,
- scores=scores,
- score_threshold=self.score_threshold,
- iou_threshold=self.iou_threshold,
- max_detections=self.max_detections,
- )
- if self.task == "pose":
- kpts = outputs[2] # (bs, max_detections, kpts 17*3)
- out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
- return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
- return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
-
-
-def torch2imx(
- model: torch.nn.Module,
- file: Path | str,
- conf: float,
- iou: float,
- max_det: int,
- metadata: dict | None = None,
- gptq: bool = False,
- dataset=None,
- prefix: str = "",
-):
- """
- Export YOLO model to IMX format for deployment on Sony IMX500 devices.
-
- This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
- to IMX format compatible with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n
- models for detection and pose estimation tasks.
-
- Args:
- model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
- file (Path | str): Output file path for the exported model.
- conf (float): Confidence threshold for NMS post-processing.
- iou (float): IoU threshold for NMS post-processing.
- max_det (int): Maximum number of detections to return.
- metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
- gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
- If False, uses standard Post Training Quantization. Defaults to False.
- dataset (optional): Representative dataset for quantization calibration. Defaults to None.
- prefix (str, optional): Logging prefix string. Defaults to "".
-
- Returns:
- f (Path): Path to the exported IMX model directory
-
- Raises:
- ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
-
- Example:
- >>> from ultralytics import YOLO
- >>> model = YOLO("yolo11n.pt")
- >>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
-
- Note:
- - Requires model_compression_toolkit, onnx, edgemdt_tpc, and sony_custom_layers packages
- - Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
- - Output includes quantized ONNX model, IMX binary, and labels.txt file
- """
- import model_compression_toolkit as mct
- import onnx
- from edgemdt_tpc import get_target_platform_capabilities
-
- LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
-
- def representative_dataset_gen(dataloader=dataset):
- for batch in dataloader:
- img = batch["img"]
- img = img / 255.0
- yield [img]
-
- tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
-
- bit_cfg = mct.core.BitWidthConfig()
- if "C2PSA" in model.__str__(): # YOLO11
- if model.task == "detect":
- layer_names = ["sub", "mul_2", "add_14", "cat_21"]
- weights_memory = 2585350.2439
- n_layers = 238 # 238 layers for fused YOLO11n
- elif model.task == "pose":
- layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
- weights_memory = 2437771.67
- n_layers = 257 # 257 layers for fused YOLO11n-pose
- else: # YOLOv8
- if model.task == "detect":
- layer_names = ["sub", "mul", "add_6", "cat_17"]
- weights_memory = 2550540.8
- n_layers = 168 # 168 layers for fused YOLOv8n
- elif model.task == "pose":
- layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
- weights_memory = 2482451.85
- n_layers = 187 # 187 layers for fused YOLO11n-pose
-
- # Check if the model has the expected number of layers
- if len(list(model.modules())) != n_layers:
- raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
-
- for layer_name in layer_names:
- bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
-
- config = mct.core.CoreConfig(
- mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
- quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
- bit_width_config=bit_cfg,
- )
-
- resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
-
- quant_model = (
- mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
- model=model,
- representative_data_gen=representative_dataset_gen,
- target_resource_utilization=resource_utilization,
- gptq_config=mct.gptq.get_pytorch_gptq_config(
- n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
- ),
- core_config=config,
- target_platform_capabilities=tpc,
- )[0]
- if gptq
- else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
- in_module=model,
- representative_data_gen=representative_dataset_gen,
- target_resource_utilization=resource_utilization,
- core_config=config,
- target_platform_capabilities=tpc,
- )[0]
- )
-
- quant_model = NMSWrapper(
- model=quant_model,
- score_threshold=conf or 0.001,
- iou_threshold=iou,
- max_detections=max_det,
- task=model.task,
- )
-
- f = Path(str(file).replace(file.suffix, "_imx_model"))
- f.mkdir(exist_ok=True)
- onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
- mct.exporter.pytorch_export_model(
- model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
- )
-
- model_onnx = onnx.load(onnx_model) # load onnx model
- for k, v in metadata.items():
- meta = model_onnx.metadata_props.add()
- meta.key, meta.value = k, str(v)
-
- onnx.save(model_onnx, onnx_model)
-
- subprocess.run(
- ["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
- check=True,
- )
-
- # Needed for imx models.
- with open(f / "labels.txt", "w", encoding="utf-8") as file:
- file.writelines([f"{name}\n" for _, name in model.names.items()])
-
- return f
diff --git a/ultralytics/utils/files.py b/ultralytics/utils/files.py
deleted file mode 100644
index e7bce39..0000000
--- a/ultralytics/utils/files.py
+++ /dev/null
@@ -1,223 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import contextlib
-import glob
-import os
-import shutil
-import tempfile
-from contextlib import contextmanager
-from datetime import datetime
-from pathlib import Path
-
-
-class WorkingDirectory(contextlib.ContextDecorator):
- """
- A context manager and decorator for temporarily changing the working directory.
-
- This class allows for the temporary change of the working directory using a context manager or decorator.
- It ensures that the original working directory is restored after the context or decorated function completes.
-
- Attributes:
- dir (Path | str): The new directory to switch to.
- cwd (Path): The original current working directory before the switch.
-
- Methods:
- __enter__: Changes the current directory to the specified directory.
- __exit__: Restores the original working directory on context exit.
-
- Examples:
- Using as a context manager:
- >>> with WorkingDirectory('/path/to/new/dir'):
- >>> # Perform operations in the new directory
- >>> pass
-
- Using as a decorator:
- >>> @WorkingDirectory('/path/to/new/dir')
- >>> def some_function():
- >>> # Perform operations in the new directory
- >>> pass
- """
-
- def __init__(self, new_dir: str | Path):
- """Initialize the WorkingDirectory context manager with the target directory."""
- self.dir = new_dir # new dir
- self.cwd = Path.cwd().resolve() # current dir
-
- def __enter__(self):
- """Change the current working directory to the specified directory upon entering the context."""
- os.chdir(self.dir)
-
- def __exit__(self, exc_type, exc_val, exc_tb): # noqa
- """Restore the original working directory when exiting the context."""
- os.chdir(self.cwd)
-
-
-@contextmanager
-def spaces_in_path(path: str | Path):
- """
- Context manager to handle paths with spaces in their names.
-
- If a path contains spaces, it replaces them with underscores, copies the file/directory to the new path, executes
- the context code block, then copies the file/directory back to its original location.
-
- Args:
- path (str | Path): The original path that may contain spaces.
-
- Yields:
- (Path | str): Temporary path with spaces replaced by underscores if spaces were present, otherwise the
- original path.
-
- Examples:
- >>> with spaces_in_path('/path/with spaces') as new_path:
- >>> # Your code here
- >>> pass
- """
- # If path has spaces, replace them with underscores
- if " " in str(path):
- string = isinstance(path, str) # input type
- path = Path(path)
-
- # Create a temporary directory and construct the new path
- with tempfile.TemporaryDirectory() as tmp_dir:
- tmp_path = Path(tmp_dir) / path.name.replace(" ", "_")
-
- # Copy file/directory
- if path.is_dir():
- shutil.copytree(path, tmp_path)
- elif path.is_file():
- tmp_path.parent.mkdir(parents=True, exist_ok=True)
- shutil.copy2(path, tmp_path)
-
- try:
- # Yield the temporary path
- yield str(tmp_path) if string else tmp_path
-
- finally:
- # Copy file/directory back
- if tmp_path.is_dir():
- shutil.copytree(tmp_path, path, dirs_exist_ok=True)
- elif tmp_path.is_file():
- shutil.copy2(tmp_path, path) # Copy back the file
-
- else:
- # If there are no spaces, just yield the original path
- yield path
-
-
-def increment_path(path: str | Path, exist_ok: bool = False, sep: str = "", mkdir: bool = False) -> Path:
- """
- Increment a file or directory path, i.e., runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
-
- If the path exists and `exist_ok` is not True, the path will be incremented by appending a number and `sep` to
- the end of the path. If the path is a file, the file extension will be preserved. If the path is a directory, the
- number will be appended directly to the end of the path.
-
- Args:
- path (str | Path): Path to increment.
- exist_ok (bool, optional): If True, the path will not be incremented and returned as-is.
- sep (str, optional): Separator to use between the path and the incrementation number.
- mkdir (bool, optional): Create a directory if it does not exist.
-
- Returns:
- (Path): Incremented path.
-
- Examples:
- Increment a directory path:
- >>> from pathlib import Path
- >>> path = Path("runs/exp")
- >>> new_path = increment_path(path)
- >>> print(new_path)
- runs/exp2
-
- Increment a file path:
- >>> path = Path("runs/exp/results.txt")
- >>> new_path = increment_path(path)
- >>> print(new_path)
- runs/exp/results2.txt
- """
- path = Path(path) # os-agnostic
- if path.exists() and not exist_ok:
- path, suffix = (path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
-
- # Method 1
- for n in range(2, 9999):
- p = f"{path}{sep}{n}{suffix}" # increment path
- if not os.path.exists(p):
- break
- path = Path(p)
-
- if mkdir:
- path.mkdir(parents=True, exist_ok=True) # make directory
-
- return path
-
-
-def file_age(path: str | Path = __file__) -> int:
- """Return days since the last modification of the specified file."""
- dt = datetime.now() - datetime.fromtimestamp(Path(path).stat().st_mtime) # delta
- return dt.days # + dt.seconds / 86400 # fractional days
-
-
-def file_date(path: str | Path = __file__) -> str:
- """Return the file modification date in 'YYYY-M-D' format."""
- t = datetime.fromtimestamp(Path(path).stat().st_mtime)
- return f"{t.year}-{t.month}-{t.day}"
-
-
-def file_size(path: str | Path) -> float:
- """Return the size of a file or directory in megabytes (MB)."""
- if isinstance(path, (str, Path)):
- mb = 1 << 20 # bytes to MiB (1024 ** 2)
- path = Path(path)
- if path.is_file():
- return path.stat().st_size / mb
- elif path.is_dir():
- return sum(f.stat().st_size for f in path.glob("**/*") if f.is_file()) / mb
- return 0.0
-
-
-def get_latest_run(search_dir: str = ".") -> str:
- """Return the path to the most recent 'last.pt' file in the specified directory for resuming training."""
- last_list = glob.glob(f"{search_dir}/**/last*.pt", recursive=True)
- return max(last_list, key=os.path.getctime) if last_list else ""
-
-
-def update_models(model_names: tuple = ("yolo11n.pt",), source_dir: Path = Path("."), update_names: bool = False):
- """
- Update and re-save specified YOLO models in an 'updated_models' subdirectory.
-
- Args:
- model_names (tuple, optional): Model filenames to update.
- source_dir (Path, optional): Directory containing models and target subdirectory.
- update_names (bool, optional): Update model names from a data YAML.
-
- Examples:
- Update specified YOLO models and save them in 'updated_models' subdirectory:
- >>> from ultralytics.utils.files import update_models
- >>> model_names = ("yolo11n.pt", "yolov8s.pt")
- >>> update_models(model_names, source_dir=Path("/models"), update_names=True)
- """
- from ultralytics import YOLO
- from ultralytics.nn.autobackend import default_class_names
-
- target_dir = source_dir / "updated_models"
- target_dir.mkdir(parents=True, exist_ok=True) # Ensure target directory exists
-
- for model_name in model_names:
- model_path = source_dir / model_name
- print(f"Loading model from {model_path}")
-
- # Load model
- model = YOLO(model_path)
- model.half()
- if update_names: # update model names from a dataset YAML
- model.model.names = default_class_names("coco8.yaml")
-
- # Define new save path
- save_path = target_dir / model_name
-
- # Save model using model.save()
- print(f"Re-saving {model_name} model to {save_path}")
- model.save(save_path)
diff --git a/ultralytics/utils/git.py b/ultralytics/utils/git.py
deleted file mode 100644
index 9cfc951..0000000
--- a/ultralytics/utils/git.py
+++ /dev/null
@@ -1,139 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from functools import cached_property
-from pathlib import Path
-
-
-class GitRepo:
- """
- Represent a local Git repository and expose branch, commit, and remote metadata.
-
- This class discovers the repository root by searching for a .git entry from the given path upward, resolves the
- actual .git directory (including worktrees), and reads Git metadata directly from on-disk files. It does not
- invoke the git binary and therefore works in restricted environments. All metadata properties are resolved
- lazily and cached; construct a new instance to refresh state.
-
- Attributes:
- root (Path | None): Repository root directory containing the .git entry; None if not in a repository.
- gitdir (Path | None): Resolved .git directory path; handles worktrees; None if unresolved.
- head (str | None): Raw contents of HEAD; a SHA for detached HEAD or "ref: " for branch heads.
- is_repo (bool): Whether the provided path resides inside a Git repository.
- branch (str | None): Current branch name when HEAD points to a branch; None for detached HEAD or non-repo.
- commit (str | None): Current commit SHA for HEAD; None if not determinable.
- origin (str | None): URL of the "origin" remote as read from gitdir/config; None if unset or unavailable.
-
- Examples:
- Initialize from the current working directory and read metadata
- >>> from pathlib import Path
- >>> repo = GitRepo(Path.cwd())
- >>> repo.is_repo
- True
- >>> repo.branch, repo.commit[:7], repo.origin
- ('main', '1a2b3c4', 'https://example.com/owner/repo.git')
-
- Notes:
- - Resolves metadata by reading files: HEAD, packed-refs, and config; no subprocess calls are used.
- - Caches properties on first access using cached_property; recreate the object to reflect repository changes.
- """
-
- def __init__(self, path: Path = Path(__file__).resolve()):
- """
- Initialize a Git repository context by discovering the repository root from a starting path.
-
- Args:
- path (Path, optional): File or directory path used as the starting point to locate the repository root.
- """
- self.root = self._find_root(path)
- self.gitdir = self._gitdir(self.root) if self.root else None
-
- @staticmethod
- def _find_root(p: Path) -> Path | None:
- """Return repo root or None."""
- return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
-
- @staticmethod
- def _gitdir(root: Path) -> Path | None:
- """Resolve actual .git directory (handles worktrees)."""
- g = root / ".git"
- if g.is_dir():
- return g
- if g.is_file():
- t = g.read_text(errors="ignore").strip()
- if t.startswith("gitdir:"):
- return (root / t.split(":", 1)[1].strip()).resolve()
- return None
-
- def _read(self, p: Path | None) -> str | None:
- """Read and strip file if exists."""
- return p.read_text(errors="ignore").strip() if p and p.exists() else None
-
- @cached_property
- def head(self) -> str | None:
- """HEAD file contents."""
- return self._read(self.gitdir / "HEAD" if self.gitdir else None)
-
- def _ref_commit(self, ref: str) -> str | None:
- """Commit for ref (handles packed-refs)."""
- rf = self.gitdir / ref
- s = self._read(rf)
- if s:
- return s
- pf = self.gitdir / "packed-refs"
- b = pf.read_bytes().splitlines() if pf.exists() else []
- tgt = ref.encode()
- for line in b:
- if line[:1] in (b"#", b"^") or b" " not in line:
- continue
- sha, name = line.split(b" ", 1)
- if name.strip() == tgt:
- return sha.decode()
- return None
-
- @property
- def is_repo(self) -> bool:
- """True if inside a git repo."""
- return self.gitdir is not None
-
- @cached_property
- def branch(self) -> str | None:
- """Current branch or None."""
- if not self.is_repo or not self.head or not self.head.startswith("ref: "):
- return None
- ref = self.head[5:].strip()
- return ref[len("refs/heads/") :] if ref.startswith("refs/heads/") else ref
-
- @cached_property
- def commit(self) -> str | None:
- """Current commit SHA or None."""
- if not self.is_repo or not self.head:
- return None
- return self._ref_commit(self.head[5:].strip()) if self.head.startswith("ref: ") else self.head
-
- @cached_property
- def origin(self) -> str | None:
- """Origin URL or None."""
- if not self.is_repo:
- return None
- cfg = self.gitdir / "config"
- remote, url = None, None
- for s in (self._read(cfg) or "").splitlines():
- t = s.strip()
- if t.startswith("[") and t.endswith("]"):
- remote = t.lower()
- elif t.lower().startswith("url =") and remote == '[remote "origin"]':
- url = t.split("=", 1)[1].strip()
- break
- return url
-
-
-if __name__ == "__main__":
- import time
-
- g = GitRepo()
- if g.is_repo:
- t0 = time.perf_counter()
- print(f"repo={g.root}\nbranch={g.branch}\ncommit={g.commit}\norigin={g.origin}")
- dt = (time.perf_counter() - t0) * 1000
- print(f"\n⏱️ Profiling: total {dt:.3f} ms")
diff --git a/ultralytics/utils/instance.py b/ultralytics/utils/instance.py
deleted file mode 100644
index bfc1d54..0000000
--- a/ultralytics/utils/instance.py
+++ /dev/null
@@ -1,505 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from collections import abc
-from itertools import repeat
-from numbers import Number
-
-import numpy as np
-
-from .ops import ltwh2xywh, ltwh2xyxy, resample_segments, xywh2ltwh, xywh2xyxy, xyxy2ltwh, xyxy2xywh
-
-
-def _ntuple(n):
- """Create a function that converts input to n-tuple by repeating singleton values."""
-
- def parse(x):
- """Parse input to return n-tuple by repeating singleton values n times."""
- return x if isinstance(x, abc.Iterable) else tuple(repeat(x, n))
-
- return parse
-
-
-to_2tuple = _ntuple(2)
-to_4tuple = _ntuple(4)
-
-# `xyxy` means left top and right bottom
-# `xywh` means center x, center y and width, height(YOLO format)
-# `ltwh` means left top and width, height(COCO format)
-_formats = ["xyxy", "xywh", "ltwh"]
-
-__all__ = ("Bboxes", "Instances") # tuple or list
-
-
-class Bboxes:
- """
- A class for handling bounding boxes in multiple formats.
-
- The class supports various bounding box formats like 'xyxy', 'xywh', and 'ltwh' and provides methods for format
- conversion, scaling, and area calculation. Bounding box data should be provided as numpy arrays.
-
- Attributes:
- bboxes (np.ndarray): The bounding boxes stored in a 2D numpy array with shape (N, 4).
- format (str): The format of the bounding boxes ('xyxy', 'xywh', or 'ltwh').
-
- Methods:
- convert: Convert bounding box format from one type to another.
- areas: Calculate the area of bounding boxes.
- mul: Multiply bounding box coordinates by scale factor(s).
- add: Add offset to bounding box coordinates.
- concatenate: Concatenate multiple Bboxes objects.
-
- Examples:
- Create bounding boxes in YOLO format
- >>> bboxes = Bboxes(np.array([[100, 50, 150, 100]]), format="xywh")
- >>> bboxes.convert("xyxy")
- >>> print(bboxes.areas())
-
- Notes:
- This class does not handle normalization or denormalization of bounding boxes.
- """
-
- def __init__(self, bboxes: np.ndarray, format: str = "xyxy") -> None:
- """
- Initialize the Bboxes class with bounding box data in a specified format.
-
- Args:
- bboxes (np.ndarray): Array of bounding boxes with shape (N, 4) or (4,).
- format (str): Format of the bounding boxes, one of 'xyxy', 'xywh', or 'ltwh'.
- """
- assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
- bboxes = bboxes[None, :] if bboxes.ndim == 1 else bboxes
- assert bboxes.ndim == 2
- assert bboxes.shape[1] == 4
- self.bboxes = bboxes
- self.format = format
-
- def convert(self, format: str) -> None:
- """
- Convert bounding box format from one type to another.
-
- Args:
- format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
- """
- assert format in _formats, f"Invalid bounding box format: {format}, format must be one of {_formats}"
- if self.format == format:
- return
- elif self.format == "xyxy":
- func = xyxy2xywh if format == "xywh" else xyxy2ltwh
- elif self.format == "xywh":
- func = xywh2xyxy if format == "xyxy" else xywh2ltwh
- else:
- func = ltwh2xyxy if format == "xyxy" else ltwh2xywh
- self.bboxes = func(self.bboxes)
- self.format = format
-
- def areas(self) -> np.ndarray:
- """Calculate the area of bounding boxes."""
- return (
- (self.bboxes[:, 2] - self.bboxes[:, 0]) * (self.bboxes[:, 3] - self.bboxes[:, 1]) # format xyxy
- if self.format == "xyxy"
- else self.bboxes[:, 3] * self.bboxes[:, 2] # format xywh or ltwh
- )
-
- def mul(self, scale: int | tuple | list) -> None:
- """
- Multiply bounding box coordinates by scale factor(s).
-
- Args:
- scale (int | tuple | list): Scale factor(s) for four coordinates. If int, the same scale is applied to
- all coordinates.
- """
- if isinstance(scale, Number):
- scale = to_4tuple(scale)
- assert isinstance(scale, (tuple, list))
- assert len(scale) == 4
- self.bboxes[:, 0] *= scale[0]
- self.bboxes[:, 1] *= scale[1]
- self.bboxes[:, 2] *= scale[2]
- self.bboxes[:, 3] *= scale[3]
-
- def add(self, offset: int | tuple | list) -> None:
- """
- Add offset to bounding box coordinates.
-
- Args:
- offset (int | tuple | list): Offset(s) for four coordinates. If int, the same offset is applied to
- all coordinates.
- """
- if isinstance(offset, Number):
- offset = to_4tuple(offset)
- assert isinstance(offset, (tuple, list))
- assert len(offset) == 4
- self.bboxes[:, 0] += offset[0]
- self.bboxes[:, 1] += offset[1]
- self.bboxes[:, 2] += offset[2]
- self.bboxes[:, 3] += offset[3]
-
- def __len__(self) -> int:
- """Return the number of bounding boxes."""
- return len(self.bboxes)
-
- @classmethod
- def concatenate(cls, boxes_list: list[Bboxes], axis: int = 0) -> Bboxes:
- """
- Concatenate a list of Bboxes objects into a single Bboxes object.
-
- Args:
- boxes_list (list[Bboxes]): A list of Bboxes objects to concatenate.
- axis (int, optional): The axis along which to concatenate the bounding boxes.
-
- Returns:
- (Bboxes): A new Bboxes object containing the concatenated bounding boxes.
-
- Notes:
- The input should be a list or tuple of Bboxes objects.
- """
- assert isinstance(boxes_list, (list, tuple))
- if not boxes_list:
- return cls(np.empty(0))
- assert all(isinstance(box, Bboxes) for box in boxes_list)
-
- if len(boxes_list) == 1:
- return boxes_list[0]
- return cls(np.concatenate([b.bboxes for b in boxes_list], axis=axis))
-
- def __getitem__(self, index: int | np.ndarray | slice) -> Bboxes:
- """
- Retrieve a specific bounding box or a set of bounding boxes using indexing.
-
- Args:
- index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired bounding boxes.
-
- Returns:
- (Bboxes): A new Bboxes object containing the selected bounding boxes.
-
- Notes:
- When using boolean indexing, make sure to provide a boolean array with the same length as the number of
- bounding boxes.
- """
- if isinstance(index, int):
- return Bboxes(self.bboxes[index].reshape(1, -1))
- b = self.bboxes[index]
- assert b.ndim == 2, f"Indexing on Bboxes with {index} failed to return a matrix!"
- return Bboxes(b)
-
-
-class Instances:
- """
- Container for bounding boxes, segments, and keypoints of detected objects in an image.
-
- This class provides a unified interface for handling different types of object annotations including bounding
- boxes, segmentation masks, and keypoints. It supports various operations like scaling, normalization, clipping,
- and format conversion.
-
- Attributes:
- _bboxes (Bboxes): Internal object for handling bounding box operations.
- keypoints (np.ndarray): Keypoints with shape (N, 17, 3) in format (x, y, visible).
- normalized (bool): Flag indicating whether the bounding box coordinates are normalized.
- segments (np.ndarray): Segments array with shape (N, M, 2) after resampling.
-
- Methods:
- convert_bbox: Convert bounding box format.
- scale: Scale coordinates by given factors.
- denormalize: Convert normalized coordinates to absolute coordinates.
- normalize: Convert absolute coordinates to normalized coordinates.
- add_padding: Add padding to coordinates.
- flipud: Flip coordinates vertically.
- fliplr: Flip coordinates horizontally.
- clip: Clip coordinates to stay within image boundaries.
- remove_zero_area_boxes: Remove boxes with zero area.
- update: Update instance variables.
- concatenate: Concatenate multiple Instances objects.
-
- Examples:
- Create instances with bounding boxes and segments
- >>> instances = Instances(
- ... bboxes=np.array([[10, 10, 30, 30], [20, 20, 40, 40]]),
- ... segments=[np.array([[5, 5], [10, 10]]), np.array([[15, 15], [20, 20]])],
- ... keypoints=np.array([[[5, 5, 1], [10, 10, 1]], [[15, 15, 1], [20, 20, 1]]]),
- ... )
- """
-
- def __init__(
- self,
- bboxes: np.ndarray,
- segments: np.ndarray = None,
- keypoints: np.ndarray = None,
- bbox_format: str = "xywh",
- normalized: bool = True,
- ) -> None:
- """
- Initialize the Instances object with bounding boxes, segments, and keypoints.
-
- Args:
- bboxes (np.ndarray): Bounding boxes with shape (N, 4).
- segments (np.ndarray, optional): Segmentation masks.
- keypoints (np.ndarray, optional): Keypoints with shape (N, 17, 3) in format (x, y, visible).
- bbox_format (str): Format of bboxes.
- normalized (bool): Whether the coordinates are normalized.
- """
- self._bboxes = Bboxes(bboxes=bboxes, format=bbox_format)
- self.keypoints = keypoints
- self.normalized = normalized
- self.segments = segments
-
- def convert_bbox(self, format: str) -> None:
- """
- Convert bounding box format.
-
- Args:
- format (str): Target format for conversion, one of 'xyxy', 'xywh', or 'ltwh'.
- """
- self._bboxes.convert(format=format)
-
- @property
- def bbox_areas(self) -> np.ndarray:
- """Calculate the area of bounding boxes."""
- return self._bboxes.areas()
-
- def scale(self, scale_w: float, scale_h: float, bbox_only: bool = False):
- """
- Scale coordinates by given factors.
-
- Args:
- scale_w (float): Scale factor for width.
- scale_h (float): Scale factor for height.
- bbox_only (bool, optional): Whether to scale only bounding boxes.
- """
- self._bboxes.mul(scale=(scale_w, scale_h, scale_w, scale_h))
- if bbox_only:
- return
- self.segments[..., 0] *= scale_w
- self.segments[..., 1] *= scale_h
- if self.keypoints is not None:
- self.keypoints[..., 0] *= scale_w
- self.keypoints[..., 1] *= scale_h
-
- def denormalize(self, w: int, h: int) -> None:
- """
- Convert normalized coordinates to absolute coordinates.
-
- Args:
- w (int): Image width.
- h (int): Image height.
- """
- if not self.normalized:
- return
- self._bboxes.mul(scale=(w, h, w, h))
- self.segments[..., 0] *= w
- self.segments[..., 1] *= h
- if self.keypoints is not None:
- self.keypoints[..., 0] *= w
- self.keypoints[..., 1] *= h
- self.normalized = False
-
- def normalize(self, w: int, h: int) -> None:
- """
- Convert absolute coordinates to normalized coordinates.
-
- Args:
- w (int): Image width.
- h (int): Image height.
- """
- if self.normalized:
- return
- self._bboxes.mul(scale=(1 / w, 1 / h, 1 / w, 1 / h))
- self.segments[..., 0] /= w
- self.segments[..., 1] /= h
- if self.keypoints is not None:
- self.keypoints[..., 0] /= w
- self.keypoints[..., 1] /= h
- self.normalized = True
-
- def add_padding(self, padw: int, padh: int) -> None:
- """
- Add padding to coordinates.
-
- Args:
- padw (int): Padding width.
- padh (int): Padding height.
- """
- assert not self.normalized, "you should add padding with absolute coordinates."
- self._bboxes.add(offset=(padw, padh, padw, padh))
- self.segments[..., 0] += padw
- self.segments[..., 1] += padh
- if self.keypoints is not None:
- self.keypoints[..., 0] += padw
- self.keypoints[..., 1] += padh
-
- def __getitem__(self, index: int | np.ndarray | slice) -> Instances:
- """
- Retrieve a specific instance or a set of instances using indexing.
-
- Args:
- index (int | slice | np.ndarray): The index, slice, or boolean array to select the desired instances.
-
- Returns:
- (Instances): A new Instances object containing the selected boxes, segments, and keypoints if present.
-
- Notes:
- When using boolean indexing, make sure to provide a boolean array with the same length as the number of
- instances.
- """
- segments = self.segments[index] if len(self.segments) else self.segments
- keypoints = self.keypoints[index] if self.keypoints is not None else None
- bboxes = self.bboxes[index]
- bbox_format = self._bboxes.format
- return Instances(
- bboxes=bboxes,
- segments=segments,
- keypoints=keypoints,
- bbox_format=bbox_format,
- normalized=self.normalized,
- )
-
- def flipud(self, h: int) -> None:
- """
- Flip coordinates vertically.
-
- Args:
- h (int): Image height.
- """
- if self._bboxes.format == "xyxy":
- y1 = self.bboxes[:, 1].copy()
- y2 = self.bboxes[:, 3].copy()
- self.bboxes[:, 1] = h - y2
- self.bboxes[:, 3] = h - y1
- else:
- self.bboxes[:, 1] = h - self.bboxes[:, 1]
- self.segments[..., 1] = h - self.segments[..., 1]
- if self.keypoints is not None:
- self.keypoints[..., 1] = h - self.keypoints[..., 1]
-
- def fliplr(self, w: int) -> None:
- """
- Flip coordinates horizontally.
-
- Args:
- w (int): Image width.
- """
- if self._bboxes.format == "xyxy":
- x1 = self.bboxes[:, 0].copy()
- x2 = self.bboxes[:, 2].copy()
- self.bboxes[:, 0] = w - x2
- self.bboxes[:, 2] = w - x1
- else:
- self.bboxes[:, 0] = w - self.bboxes[:, 0]
- self.segments[..., 0] = w - self.segments[..., 0]
- if self.keypoints is not None:
- self.keypoints[..., 0] = w - self.keypoints[..., 0]
-
- def clip(self, w: int, h: int) -> None:
- """
- Clip coordinates to stay within image boundaries.
-
- Args:
- w (int): Image width.
- h (int): Image height.
- """
- ori_format = self._bboxes.format
- self.convert_bbox(format="xyxy")
- self.bboxes[:, [0, 2]] = self.bboxes[:, [0, 2]].clip(0, w)
- self.bboxes[:, [1, 3]] = self.bboxes[:, [1, 3]].clip(0, h)
- if ori_format != "xyxy":
- self.convert_bbox(format=ori_format)
- self.segments[..., 0] = self.segments[..., 0].clip(0, w)
- self.segments[..., 1] = self.segments[..., 1].clip(0, h)
- if self.keypoints is not None:
- # Set out of bounds visibility to zero
- self.keypoints[..., 2][
- (self.keypoints[..., 0] < 0)
- | (self.keypoints[..., 0] > w)
- | (self.keypoints[..., 1] < 0)
- | (self.keypoints[..., 1] > h)
- ] = 0.0
- self.keypoints[..., 0] = self.keypoints[..., 0].clip(0, w)
- self.keypoints[..., 1] = self.keypoints[..., 1].clip(0, h)
-
- def remove_zero_area_boxes(self) -> np.ndarray:
- """
- Remove zero-area boxes, i.e. after clipping some boxes may have zero width or height.
-
- Returns:
- (np.ndarray): Boolean array indicating which boxes were kept.
- """
- good = self.bbox_areas > 0
- if not all(good):
- self._bboxes = self._bboxes[good]
- if len(self.segments):
- self.segments = self.segments[good]
- if self.keypoints is not None:
- self.keypoints = self.keypoints[good]
- return good
-
- def update(self, bboxes: np.ndarray, segments: np.ndarray = None, keypoints: np.ndarray = None):
- """
- Update instance variables.
-
- Args:
- bboxes (np.ndarray): New bounding boxes.
- segments (np.ndarray, optional): New segments.
- keypoints (np.ndarray, optional): New keypoints.
- """
- self._bboxes = Bboxes(bboxes, format=self._bboxes.format)
- if segments is not None:
- self.segments = segments
- if keypoints is not None:
- self.keypoints = keypoints
-
- def __len__(self) -> int:
- """Return the number of instances."""
- return len(self.bboxes)
-
- @classmethod
- def concatenate(cls, instances_list: list[Instances], axis=0) -> Instances:
- """
- Concatenate a list of Instances objects into a single Instances object.
-
- Args:
- instances_list (list[Instances]): A list of Instances objects to concatenate.
- axis (int, optional): The axis along which the arrays will be concatenated.
-
- Returns:
- (Instances): A new Instances object containing the concatenated bounding boxes, segments, and keypoints
- if present.
-
- Notes:
- The `Instances` objects in the list should have the same properties, such as the format of the bounding
- boxes, whether keypoints are present, and if the coordinates are normalized.
- """
- assert isinstance(instances_list, (list, tuple))
- if not instances_list:
- return cls(np.empty(0))
- assert all(isinstance(instance, Instances) for instance in instances_list)
-
- if len(instances_list) == 1:
- return instances_list[0]
-
- use_keypoint = instances_list[0].keypoints is not None
- bbox_format = instances_list[0]._bboxes.format
- normalized = instances_list[0].normalized
-
- cat_boxes = np.concatenate([ins.bboxes for ins in instances_list], axis=axis)
- seg_len = [b.segments.shape[1] for b in instances_list]
- if len(frozenset(seg_len)) > 1: # resample segments if there's different length
- max_len = max(seg_len)
- cat_segments = np.concatenate(
- [
- resample_segments(list(b.segments), max_len)
- if len(b.segments)
- else np.zeros((0, max_len, 2), dtype=np.float32) # re-generating empty segments
- for b in instances_list
- ],
- axis=axis,
- )
- else:
- cat_segments = np.concatenate([b.segments for b in instances_list], axis=axis)
- cat_keypoints = np.concatenate([b.keypoints for b in instances_list], axis=axis) if use_keypoint else None
- return cls(cat_boxes, cat_segments, cat_keypoints, bbox_format, normalized)
-
- @property
- def bboxes(self) -> np.ndarray:
- """Return bounding boxes."""
- return self._bboxes.bboxes
diff --git a/ultralytics/utils/logger.py b/ultralytics/utils/logger.py
deleted file mode 100644
index 6494ec5..0000000
--- a/ultralytics/utils/logger.py
+++ /dev/null
@@ -1,408 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import logging
-import queue
-import shutil
-import sys
-import threading
-import time
-from datetime import datetime
-from pathlib import Path
-
-from ultralytics.utils import MACOS, RANK
-from ultralytics.utils.checks import check_requirements
-
-# Initialize default log file
-DEFAULT_LOG_PATH = Path("train.log")
-if RANK in {-1, 0} and DEFAULT_LOG_PATH.exists():
- DEFAULT_LOG_PATH.unlink(missing_ok=True)
-
-
-class ConsoleLogger:
- """
- Console output capture with API/file streaming and deduplication.
-
- Captures stdout/stderr output and streams it to either an API endpoint or local file, with intelligent
- deduplication to reduce noise from repetitive console output.
-
- Attributes:
- destination (str | Path): Target destination for streaming (URL or Path object).
- is_api (bool): Whether destination is an API endpoint (True) or local file (False).
- original_stdout: Reference to original sys.stdout for restoration.
- original_stderr: Reference to original sys.stderr for restoration.
- log_queue (queue.Queue): Thread-safe queue for buffering log messages.
- active (bool): Whether console capture is currently active.
- worker_thread (threading.Thread): Background thread for processing log queue.
- last_line (str): Last processed line for deduplication.
- last_time (float): Timestamp of last processed line.
- last_progress_line (str): Last progress bar line for progress deduplication.
- last_was_progress (bool): Whether the last line was a progress bar.
-
- Examples:
- Basic file logging:
- >>> logger = ConsoleLogger("training.log")
- >>> logger.start_capture()
- >>> print("This will be logged")
- >>> logger.stop_capture()
-
- API streaming:
- >>> logger = ConsoleLogger("https://api.example.com/logs")
- >>> logger.start_capture()
- >>> # All output streams to API
- >>> logger.stop_capture()
- """
-
- def __init__(self, destination):
- """
- Initialize with API endpoint or local file path.
-
- Args:
- destination (str | Path): API endpoint URL (http/https) or local file path for streaming output.
- """
- self.destination = destination
- self.is_api = isinstance(destination, str) and destination.startswith(("http://", "https://"))
- if not self.is_api:
- self.destination = Path(destination)
-
- # Console capture
- self.original_stdout = sys.stdout
- self.original_stderr = sys.stderr
- self.log_queue = queue.Queue(maxsize=1000)
- self.active = False
- self.worker_thread = None
-
- # State tracking
- self.last_line = ""
- self.last_time = 0.0
- self.last_progress_line = "" # Track last progress line for deduplication
- self.last_was_progress = False # Track if last line was a progress bar
-
- def start_capture(self):
- """Start capturing console output and redirect stdout/stderr to custom capture objects."""
- if self.active:
- return
-
- self.active = True
- sys.stdout = self._ConsoleCapture(self.original_stdout, self._queue_log)
- sys.stderr = self._ConsoleCapture(self.original_stderr, self._queue_log)
-
- # Hook Ultralytics logger
- try:
- handler = self._LogHandler(self._queue_log)
- logging.getLogger("ultralytics").addHandler(handler)
- except Exception:
- pass
-
- self.worker_thread = threading.Thread(target=self._stream_worker, daemon=True)
- self.worker_thread.start()
-
- def stop_capture(self):
- """Stop capturing console output and restore original stdout/stderr."""
- if not self.active:
- return
-
- self.active = False
- sys.stdout = self.original_stdout
- sys.stderr = self.original_stderr
- self.log_queue.put(None)
-
- def _queue_log(self, text):
- """Queue console text with deduplication and timestamp processing."""
- if not self.active:
- return
-
- current_time = time.time()
-
- # Handle carriage returns and process lines
- if "\r" in text:
- text = text.split("\r")[-1]
-
- lines = text.split("\n")
- if lines and lines[-1] == "":
- lines.pop()
-
- for line in lines:
- line = line.rstrip()
-
- # Skip lines with only thin progress bars (partial progress)
- if "─" in line: # Has thin lines but no thick lines
- continue
-
- # Deduplicate completed progress bars only if they match the previous progress line
- if " ━━" in line:
- progress_core = line.split(" ━━")[0].strip()
- if progress_core == self.last_progress_line and self.last_was_progress:
- continue
- self.last_progress_line = progress_core
- self.last_was_progress = True
- else:
- # Skip empty line after progress bar
- if not line and self.last_was_progress:
- self.last_was_progress = False
- continue
- self.last_was_progress = False
-
- # General deduplication
- if line == self.last_line and current_time - self.last_time < 0.1:
- continue
-
- self.last_line = line
- self.last_time = current_time
-
- # Add timestamp if needed
- if not line.startswith("[20"):
- timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- line = f"[{timestamp}] {line}"
-
- # Queue with overflow protection
- if not self._safe_put(f"{line}\n"):
- continue # Skip if queue handling fails
-
- def _safe_put(self, item):
- """Safely put item in queue with overflow handling."""
- try:
- self.log_queue.put_nowait(item)
- return True
- except queue.Full:
- try:
- self.log_queue.get_nowait() # Drop oldest
- self.log_queue.put_nowait(item)
- return True
- except queue.Empty:
- return False
-
- def _stream_worker(self):
- """Background worker for streaming logs to destination."""
- while self.active:
- try:
- log_text = self.log_queue.get(timeout=1)
- if log_text is None:
- break
- self._write_log(log_text)
- except queue.Empty:
- continue
-
- def _write_log(self, text):
- """Write log to API endpoint or local file destination."""
- try:
- if self.is_api:
- import requests # scoped as slow import
-
- payload = {"timestamp": datetime.now().isoformat(), "message": text.strip()}
- requests.post(str(self.destination), json=payload, timeout=5)
- else:
- self.destination.parent.mkdir(parents=True, exist_ok=True)
- with self.destination.open("a", encoding="utf-8") as f:
- f.write(text)
- except Exception as e:
- print(f"Platform logging error: {e}", file=self.original_stderr)
-
- class _ConsoleCapture:
- """Lightweight stdout/stderr capture."""
-
- __slots__ = ("original", "callback")
-
- def __init__(self, original, callback):
- self.original = original
- self.callback = callback
-
- def write(self, text):
- self.original.write(text)
- self.callback(text)
-
- def flush(self):
- self.original.flush()
-
- class _LogHandler(logging.Handler):
- """Lightweight logging handler."""
-
- __slots__ = ("callback",)
-
- def __init__(self, callback):
- super().__init__()
- self.callback = callback
-
- def emit(self, record):
- self.callback(self.format(record) + "\n")
-
-
-class SystemLogger:
- """
- Log dynamic system metrics for training monitoring.
-
- Captures real-time system metrics including CPU, RAM, disk I/O, network I/O, and NVIDIA GPU statistics for
- training performance monitoring and analysis.
-
- Attributes:
- pynvml: NVIDIA pynvml module instance if successfully imported, None otherwise.
- nvidia_initialized (bool): Whether NVIDIA GPU monitoring is available and initialized.
- net_start: Initial network I/O counters for calculating cumulative usage.
- disk_start: Initial disk I/O counters for calculating cumulative usage.
-
- Examples:
- Basic usage:
- >>> logger = SystemLogger()
- >>> metrics = logger.get_metrics()
- >>> print(f"CPU: {metrics['cpu']}%, RAM: {metrics['ram']}%")
- >>> if metrics["gpus"]:
- ... gpu0 = metrics["gpus"]["0"]
- ... print(f"GPU0: {gpu0['usage']}% usage, {gpu0['temp']}°C")
-
- Training loop integration:
- >>> system_logger = SystemLogger()
- >>> for epoch in range(epochs):
- ... # Training code here
- ... metrics = system_logger.get_metrics()
- ... # Log to database/file
- """
-
- def __init__(self):
- """Initialize the system logger."""
- import psutil # scoped as slow import
-
- self.pynvml = None
- self.nvidia_initialized = self._init_nvidia()
- self.net_start = psutil.net_io_counters()
- self.disk_start = psutil.disk_io_counters()
-
- def _init_nvidia(self):
- """Initialize NVIDIA GPU monitoring with pynvml."""
- try:
- assert not MACOS
- check_requirements("nvidia-ml-py>=12.0.0")
- self.pynvml = __import__("pynvml")
- self.pynvml.nvmlInit()
- return True
- except Exception:
- return False
-
- def get_metrics(self):
- """
- Get current system metrics.
-
- Collects comprehensive system metrics including CPU usage, RAM usage, disk I/O statistics,
- network I/O statistics, and GPU metrics (if available). Example output:
-
- ```python
- metrics = {
- "cpu": 45.2,
- "ram": 78.9,
- "disk": {"read_mb": 156.7, "write_mb": 89.3, "used_gb": 256.8},
- "network": {"recv_mb": 157.2, "sent_mb": 89.1},
- "gpus": {
- 0: {"usage": 95.6, "memory": 85.4, "temp": 72, "power": 285},
- 1: {"usage": 94.1, "memory": 82.7, "temp": 70, "power": 278},
- },
- }
- ```
-
- - cpu (float): CPU usage percentage (0-100%)
- - ram (float): RAM usage percentage (0-100%)
- - disk (dict):
- - read_mb (float): Cumulative disk read in MB since initialization
- - write_mb (float): Cumulative disk write in MB since initialization
- - used_gb (float): Total disk space used in GB
- - network (dict):
- - recv_mb (float): Cumulative network received in MB since initialization
- - sent_mb (float): Cumulative network sent in MB since initialization
- - gpus (dict): GPU metrics by device index (e.g., 0, 1) containing:
- - usage (int): GPU utilization percentage (0-100%)
- - memory (float): CUDA memory usage percentage (0-100%)
- - temp (int): GPU temperature in degrees Celsius
- - power (int): GPU power consumption in watts
-
- Returns:
- metrics (dict): System metrics containing 'cpu', 'ram', 'disk', 'network', 'gpus' with respective usage data.
- """
- import psutil # scoped as slow import
-
- net = psutil.net_io_counters()
- disk = psutil.disk_io_counters()
- memory = psutil.virtual_memory()
- disk_usage = shutil.disk_usage("/")
-
- metrics = {
- "cpu": round(psutil.cpu_percent(), 3),
- "ram": round(memory.percent, 3),
- "disk": {
- "read_mb": round((disk.read_bytes - self.disk_start.read_bytes) / (1 << 20), 3),
- "write_mb": round((disk.write_bytes - self.disk_start.write_bytes) / (1 << 20), 3),
- "used_gb": round(disk_usage.used / (1 << 30), 3),
- },
- "network": {
- "recv_mb": round((net.bytes_recv - self.net_start.bytes_recv) / (1 << 20), 3),
- "sent_mb": round((net.bytes_sent - self.net_start.bytes_sent) / (1 << 20), 3),
- },
- "gpus": {},
- }
-
- # Add GPU metrics (NVIDIA only)
- if self.nvidia_initialized:
- metrics["gpus"].update(self._get_nvidia_metrics())
-
- return metrics
-
- def _get_nvidia_metrics(self):
- """Get NVIDIA GPU metrics including utilization, memory, temperature, and power."""
- gpus = {}
- if not self.nvidia_initialized or not self.pynvml:
- return gpus
- try:
- device_count = self.pynvml.nvmlDeviceGetCount()
- for i in range(device_count):
- handle = self.pynvml.nvmlDeviceGetHandleByIndex(i)
- util = self.pynvml.nvmlDeviceGetUtilizationRates(handle)
- memory = self.pynvml.nvmlDeviceGetMemoryInfo(handle)
- temp = self.pynvml.nvmlDeviceGetTemperature(handle, self.pynvml.NVML_TEMPERATURE_GPU)
- power = self.pynvml.nvmlDeviceGetPowerUsage(handle) // 1000
-
- gpus[str(i)] = {
- "usage": round(util.gpu, 3),
- "memory": round((memory.used / memory.total) * 100, 3),
- "temp": temp,
- "power": power,
- }
- except Exception:
- pass
- return gpus
-
-
-if __name__ == "__main__":
- print("SystemLogger Real-time Metrics Monitor")
- print("Press Ctrl+C to stop\n")
-
- logger = SystemLogger()
-
- try:
- while True:
- metrics = logger.get_metrics()
-
- # Clear screen (works on most terminals)
- print("\033[H\033[J", end="")
-
- # Display system metrics
- print(f"CPU: {metrics['cpu']:5.1f}%")
- print(f"RAM: {metrics['ram']:5.1f}%")
- print(f"Disk Read: {metrics['disk']['read_mb']:8.1f} MB")
- print(f"Disk Write: {metrics['disk']['write_mb']:7.1f} MB")
- print(f"Disk Used: {metrics['disk']['used_gb']:8.1f} GB")
- print(f"Net Recv: {metrics['network']['recv_mb']:9.1f} MB")
- print(f"Net Sent: {metrics['network']['sent_mb']:9.1f} MB")
-
- # Display GPU metrics if available
- if metrics["gpus"]:
- print("\nGPU Metrics:")
- for gpu_id, gpu_data in metrics["gpus"].items():
- print(
- f" GPU {gpu_id}: {gpu_data['usage']:3}% | "
- f"Mem: {gpu_data['memory']:5.1f}% | "
- f"Temp: {gpu_data['temp']:2}°C | "
- f"Power: {gpu_data['power']:3}W"
- )
- else:
- print("\nGPU: No NVIDIA GPUs detected")
-
- time.sleep(1)
-
- except KeyboardInterrupt:
- print("\n\nStopped monitoring.")
diff --git a/ultralytics/utils/loss.py b/ultralytics/utils/loss.py
deleted file mode 100644
index 95628da..0000000
--- a/ultralytics/utils/loss.py
+++ /dev/null
@@ -1,857 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from typing import Any
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ultralytics.utils.metrics import OKS_SIGMA
-from ultralytics.utils.ops import crop_mask, xywh2xyxy, xyxy2xywh
-from ultralytics.utils.tal import RotatedTaskAlignedAssigner, TaskAlignedAssigner, dist2bbox, dist2rbox, make_anchors
-from ultralytics.utils.torch_utils import autocast
-
-from .metrics import bbox_iou, probiou
-from .tal import bbox2dist
-
-
-class VarifocalLoss(nn.Module):
- """
- Varifocal loss by Zhang et al.
-
- Implements the Varifocal Loss function for addressing class imbalance in object detection by focusing on
- hard-to-classify examples and balancing positive/negative samples.
-
- Attributes:
- gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
- alpha (float): The balancing factor used to address class imbalance.
-
- References:
- https://arxiv.org/abs/2008.13367
- """
-
- def __init__(self, gamma: float = 2.0, alpha: float = 0.75):
- """Initialize the VarifocalLoss class with focusing and balancing parameters."""
- super().__init__()
- self.gamma = gamma
- self.alpha = alpha
-
- def forward(self, pred_score: torch.Tensor, gt_score: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
- """Compute varifocal loss between predictions and ground truth."""
- weight = self.alpha * pred_score.sigmoid().pow(self.gamma) * (1 - label) + gt_score * label
- with autocast(enabled=False):
- loss = (
- (F.binary_cross_entropy_with_logits(pred_score.float(), gt_score.float(), reduction="none") * weight)
- .mean(1)
- .sum()
- )
- return loss
-
-
-class FocalLoss(nn.Module):
- """
- Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5).
-
- Implements the Focal Loss function for addressing class imbalance by down-weighting easy examples and focusing
- on hard negatives during training.
-
- Attributes:
- gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
- alpha (torch.Tensor): The balancing factor used to address class imbalance.
- """
-
- def __init__(self, gamma: float = 1.5, alpha: float = 0.25):
- """Initialize FocalLoss class with focusing and balancing parameters."""
- super().__init__()
- self.gamma = gamma
- self.alpha = torch.tensor(alpha)
-
- def forward(self, pred: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
- """Calculate focal loss with modulating factors for class imbalance."""
- loss = F.binary_cross_entropy_with_logits(pred, label, reduction="none")
- # p_t = torch.exp(-loss)
- # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
-
- # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
- pred_prob = pred.sigmoid() # prob from logits
- p_t = label * pred_prob + (1 - label) * (1 - pred_prob)
- modulating_factor = (1.0 - p_t) ** self.gamma
- loss *= modulating_factor
- if (self.alpha > 0).any():
- self.alpha = self.alpha.to(device=pred.device, dtype=pred.dtype)
- alpha_factor = label * self.alpha + (1 - label) * (1 - self.alpha)
- loss *= alpha_factor
- return loss.mean(1).sum()
-
-
-class DFLoss(nn.Module):
- """Criterion class for computing Distribution Focal Loss (DFL)."""
-
- def __init__(self, reg_max: int = 16) -> None:
- """Initialize the DFL module with regularization maximum."""
- super().__init__()
- self.reg_max = reg_max
-
- def __call__(self, pred_dist: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
- """Return sum of left and right DFL losses from https://ieeexplore.ieee.org/document/9792391."""
- target = target.clamp_(0, self.reg_max - 1 - 0.01)
- tl = target.long() # target left
- tr = tl + 1 # target right
- wl = tr - target # weight left
- wr = 1 - wl # weight right
- return (
- F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
- + F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
- ).mean(-1, keepdim=True)
-
-
-class BboxLoss(nn.Module):
- """Criterion class for computing training losses for bounding boxes."""
-
- def __init__(self, reg_max: int = 16):
- """Initialize the BboxLoss module with regularization maximum and DFL settings."""
- super().__init__()
- self.dfl_loss = DFLoss(reg_max) if reg_max > 1 else None
-
- def forward(
- self,
- pred_dist: torch.Tensor,
- pred_bboxes: torch.Tensor,
- anchor_points: torch.Tensor,
- target_bboxes: torch.Tensor,
- target_scores: torch.Tensor,
- target_scores_sum: torch.Tensor,
- fg_mask: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Compute IoU and DFL losses for bounding boxes."""
- weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
- iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)
- loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
-
- # DFL loss
- if self.dfl_loss:
- target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
- loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
- loss_dfl = loss_dfl.sum() / target_scores_sum
- else:
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
-
- return loss_iou, loss_dfl
-
-
-class RotatedBboxLoss(BboxLoss):
- """Criterion class for computing training losses for rotated bounding boxes."""
-
- def __init__(self, reg_max: int):
- """Initialize the RotatedBboxLoss module with regularization maximum and DFL settings."""
- super().__init__(reg_max)
-
- def forward(
- self,
- pred_dist: torch.Tensor,
- pred_bboxes: torch.Tensor,
- anchor_points: torch.Tensor,
- target_bboxes: torch.Tensor,
- target_scores: torch.Tensor,
- target_scores_sum: torch.Tensor,
- fg_mask: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """Compute IoU and DFL losses for rotated bounding boxes."""
- weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
- iou = probiou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
- loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
-
- # DFL loss
- if self.dfl_loss:
- target_ltrb = bbox2dist(anchor_points, xywh2xyxy(target_bboxes[..., :4]), self.dfl_loss.reg_max - 1)
- loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
- loss_dfl = loss_dfl.sum() / target_scores_sum
- else:
- loss_dfl = torch.tensor(0.0).to(pred_dist.device)
-
- return loss_iou, loss_dfl
-
-
-class KeypointLoss(nn.Module):
- """Criterion class for computing keypoint losses."""
-
- def __init__(self, sigmas: torch.Tensor) -> None:
- """Initialize the KeypointLoss class with keypoint sigmas."""
- super().__init__()
- self.sigmas = sigmas
-
- def forward(
- self, pred_kpts: torch.Tensor, gt_kpts: torch.Tensor, kpt_mask: torch.Tensor, area: torch.Tensor
- ) -> torch.Tensor:
- """Calculate keypoint loss factor and Euclidean distance loss for keypoints."""
- d = (pred_kpts[..., 0] - gt_kpts[..., 0]).pow(2) + (pred_kpts[..., 1] - gt_kpts[..., 1]).pow(2)
- kpt_loss_factor = kpt_mask.shape[1] / (torch.sum(kpt_mask != 0, dim=1) + 1e-9)
- # e = d / (2 * (area * self.sigmas) ** 2 + 1e-9) # from formula
- e = d / ((2 * self.sigmas).pow(2) * (area + 1e-9) * 2) # from cocoeval
- return (kpt_loss_factor.view(-1, 1) * ((1 - torch.exp(-e)) * kpt_mask)).mean()
-
-
-class v8DetectionLoss:
- """Criterion class for computing training losses for YOLOv8 object detection."""
-
- def __init__(self, model, tal_topk: int = 10): # model must be de-paralleled
- """Initialize v8DetectionLoss with model parameters and task-aligned assignment settings."""
- device = next(model.parameters()).device # get model device
- h = model.args # hyperparameters
-
- m = model.model[-1] # Detect() module
- self.bce = nn.BCEWithLogitsLoss(reduction="none")
- self.hyp = h
- self.stride = m.stride # model strides
- self.nc = m.nc # number of classes
- self.no = m.nc + m.reg_max * 4
- self.reg_max = m.reg_max
- self.device = device
-
- self.use_dfl = m.reg_max > 1
-
- self.assigner = TaskAlignedAssigner(topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0)
- self.bbox_loss = BboxLoss(m.reg_max).to(device)
- self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
-
- def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
- """Preprocess targets by converting to tensor format and scaling coordinates."""
- nl, ne = targets.shape
- if nl == 0:
- out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
- else:
- i = targets[:, 0] # image index
- _, counts = i.unique(return_counts=True)
- counts = counts.to(dtype=torch.int32)
- out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)
- for j in range(batch_size):
- matches = i == j
- if n := matches.sum():
- out[j, :n] = targets[matches, 1:]
- out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor))
- return out
-
- def bbox_decode(self, anchor_points: torch.Tensor, pred_dist: torch.Tensor) -> torch.Tensor:
- """Decode predicted object bounding box coordinates from anchor points and distribution."""
- if self.use_dfl:
- b, a, c = pred_dist.shape # batch, anchors, channels
- pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
- # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
- # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
- return dist2bbox(pred_dist, anchor_points, xywh=False)
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
- feats = preds[1] if isinstance(preds, tuple) else preds
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1
- )
-
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
-
- dtype = pred_scores.dtype
- batch_size = pred_scores.shape[0]
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
- # Targets
- targets = torch.cat((batch["batch_idx"].view(-1, 1), batch["cls"].view(-1, 1), batch["bboxes"]), 1)
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
-
- # Pboxes
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
- # dfl_conf = pred_distri.view(batch_size, -1, 4, self.reg_max).detach().softmax(-1)
- # dfl_conf = (dfl_conf.amax(-1).mean(-1) + dfl_conf.amax(-1).amin(-1)) / 2
-
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
- # pred_scores.detach().sigmoid() * 0.8 + dfl_conf.unsqueeze(-1) * 0.2,
- pred_scores.detach().sigmoid(),
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor,
- gt_labels,
- gt_bboxes,
- mask_gt,
- )
-
- target_scores_sum = max(target_scores.sum(), 1)
-
- # Cls loss
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
- loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
-
- # Bbox loss
- if fg_mask.sum():
- loss[0], loss[2] = self.bbox_loss(
- pred_distri,
- pred_bboxes,
- anchor_points,
- target_bboxes / stride_tensor,
- target_scores,
- target_scores_sum,
- fg_mask,
- )
-
- loss[0] *= self.hyp.box # box gain
- loss[1] *= self.hyp.cls # cls gain
- loss[2] *= self.hyp.dfl # dfl gain
-
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
-
-
-class v8SegmentationLoss(v8DetectionLoss):
- """Criterion class for computing training losses for YOLOv8 segmentation."""
-
- def __init__(self, model): # model must be de-paralleled
- """Initialize the v8SegmentationLoss class with model parameters and mask overlap setting."""
- super().__init__(model)
- self.overlap = model.args.overlap_mask
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate and return the combined loss for detection and segmentation."""
- loss = torch.zeros(4, device=self.device) # box, seg, cls, dfl
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
- batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1
- )
-
- # B, grids, ..
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
- pred_masks = pred_masks.permute(0, 2, 1).contiguous()
-
- dtype = pred_scores.dtype
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
- # Targets
- try:
- batch_idx = batch["batch_idx"].view(-1, 1)
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
- except RuntimeError as e:
- raise TypeError(
- "ERROR ❌ segment dataset incorrectly formatted or not a segment dataset.\n"
- "This error can occur when incorrectly training a 'segment' model on a 'detect' dataset, "
- "i.e. 'yolo train model=yolo11n-seg.pt data=coco8.yaml'.\nVerify your dataset is a "
- "correctly formatted 'segment' dataset using 'data=coco8-seg.yaml' "
- "as an example.\nSee https://docs.ultralytics.com/datasets/segment/ for help."
- ) from e
-
- # Pboxes
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
-
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
- pred_scores.detach().sigmoid(),
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor,
- gt_labels,
- gt_bboxes,
- mask_gt,
- )
-
- target_scores_sum = max(target_scores.sum(), 1)
-
- # Cls loss
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
- loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
-
- if fg_mask.sum():
- # Bbox loss
- loss[0], loss[3] = self.bbox_loss(
- pred_distri,
- pred_bboxes,
- anchor_points,
- target_bboxes / stride_tensor,
- target_scores,
- target_scores_sum,
- fg_mask,
- )
- # Masks loss
- masks = batch["masks"].to(self.device).float()
- if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
- masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
-
- loss[1] = self.calculate_segmentation_loss(
- fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
- )
-
- # WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
- else:
- loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
-
- loss[0] *= self.hyp.box # box gain
- loss[1] *= self.hyp.box # seg gain
- loss[2] *= self.hyp.cls # cls gain
- loss[3] *= self.hyp.dfl # dfl gain
-
- return loss * batch_size, loss.detach() # loss(box, seg, cls, dfl)
-
- @staticmethod
- def single_mask_loss(
- gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
- ) -> torch.Tensor:
- """
- Compute the instance segmentation loss for a single image.
-
- Args:
- gt_mask (torch.Tensor): Ground truth mask of shape (N, H, W), where N is the number of objects.
- pred (torch.Tensor): Predicted mask coefficients of shape (N, 32).
- proto (torch.Tensor): Prototype masks of shape (32, H, W).
- xyxy (torch.Tensor): Ground truth bounding boxes in xyxy format, normalized to [0, 1], of shape (N, 4).
- area (torch.Tensor): Area of each ground truth bounding box of shape (N,).
-
- Returns:
- (torch.Tensor): The calculated mask loss for a single image.
-
- Notes:
- The function uses the equation pred_mask = torch.einsum('in,nhw->ihw', pred, proto) to produce the
- predicted masks from the prototype masks and predicted mask coefficients.
- """
- pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 80, 80) -> (n, 80, 80)
- loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
- return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
-
- def calculate_segmentation_loss(
- self,
- fg_mask: torch.Tensor,
- masks: torch.Tensor,
- target_gt_idx: torch.Tensor,
- target_bboxes: torch.Tensor,
- batch_idx: torch.Tensor,
- proto: torch.Tensor,
- pred_masks: torch.Tensor,
- imgsz: torch.Tensor,
- overlap: bool,
- ) -> torch.Tensor:
- """
- Calculate the loss for instance segmentation.
-
- Args:
- fg_mask (torch.Tensor): A binary tensor of shape (BS, N_anchors) indicating which anchors are positive.
- masks (torch.Tensor): Ground truth masks of shape (BS, H, W) if `overlap` is False, otherwise (BS, ?, H, W).
- target_gt_idx (torch.Tensor): Indexes of ground truth objects for each anchor of shape (BS, N_anchors).
- target_bboxes (torch.Tensor): Ground truth bounding boxes for each anchor of shape (BS, N_anchors, 4).
- batch_idx (torch.Tensor): Batch indices of shape (N_labels_in_batch, 1).
- proto (torch.Tensor): Prototype masks of shape (BS, 32, H, W).
- pred_masks (torch.Tensor): Predicted masks for each anchor of shape (BS, N_anchors, 32).
- imgsz (torch.Tensor): Size of the input image as a tensor of shape (2), i.e., (H, W).
- overlap (bool): Whether the masks in `masks` tensor overlap.
-
- Returns:
- (torch.Tensor): The calculated loss for instance segmentation.
-
- Notes:
- The batch loss can be computed for improved speed at higher memory usage.
- For example, pred_mask can be computed as follows:
- pred_mask = torch.einsum('in,nhw->ihw', pred, proto) # (i, 32) @ (32, 160, 160) -> (i, 160, 160)
- """
- _, _, mask_h, mask_w = proto.shape
- loss = 0
-
- # Normalize to 0-1
- target_bboxes_normalized = target_bboxes / imgsz[[1, 0, 1, 0]]
-
- # Areas of target bboxes
- marea = xyxy2xywh(target_bboxes_normalized)[..., 2:].prod(2)
-
- # Normalize to mask size
- mxyxy = target_bboxes_normalized * torch.tensor([mask_w, mask_h, mask_w, mask_h], device=proto.device)
-
- for i, single_i in enumerate(zip(fg_mask, target_gt_idx, pred_masks, proto, mxyxy, marea, masks)):
- fg_mask_i, target_gt_idx_i, pred_masks_i, proto_i, mxyxy_i, marea_i, masks_i = single_i
- if fg_mask_i.any():
- mask_idx = target_gt_idx_i[fg_mask_i]
- if overlap:
- gt_mask = masks_i == (mask_idx + 1).view(-1, 1, 1)
- gt_mask = gt_mask.float()
- else:
- gt_mask = masks[batch_idx.view(-1) == i][mask_idx]
-
- loss += self.single_mask_loss(
- gt_mask, pred_masks_i[fg_mask_i], proto_i, mxyxy_i[fg_mask_i], marea_i[fg_mask_i]
- )
-
- # WARNING: lines below prevents Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
- else:
- loss += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
-
- return loss / fg_mask.sum()
-
-
-class v8PoseLoss(v8DetectionLoss):
- """Criterion class for computing training losses for YOLOv8 pose estimation."""
-
- def __init__(self, model): # model must be de-paralleled
- """Initialize v8PoseLoss with model parameters and keypoint-specific loss functions."""
- super().__init__(model)
- self.kpt_shape = model.model[-1].kpt_shape
- self.bce_pose = nn.BCEWithLogitsLoss()
- is_pose = self.kpt_shape == [17, 3]
- nkpt = self.kpt_shape[0] # number of keypoints
- sigmas = torch.from_numpy(OKS_SIGMA).to(self.device) if is_pose else torch.ones(nkpt, device=self.device) / nkpt
- self.keypoint_loss = KeypointLoss(sigmas=sigmas)
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate the total loss and detach it for pose estimation."""
- loss = torch.zeros(5, device=self.device) # box, cls, dfl, kpt_location, kpt_visibility
- feats, pred_kpts = preds if isinstance(preds[0], list) else preds[1]
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1
- )
-
- # B, grids, ..
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
- pred_kpts = pred_kpts.permute(0, 2, 1).contiguous()
-
- dtype = pred_scores.dtype
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
- # Targets
- batch_size = pred_scores.shape[0]
- batch_idx = batch["batch_idx"].view(-1, 1)
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
- gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
-
- # Pboxes
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
- pred_kpts = self.kpts_decode(anchor_points, pred_kpts.view(batch_size, -1, *self.kpt_shape)) # (b, h*w, 17, 3)
-
- _, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
- pred_scores.detach().sigmoid(),
- (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
- anchor_points * stride_tensor,
- gt_labels,
- gt_bboxes,
- mask_gt,
- )
-
- target_scores_sum = max(target_scores.sum(), 1)
-
- # Cls loss
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
- loss[3] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
-
- # Bbox loss
- if fg_mask.sum():
- target_bboxes /= stride_tensor
- loss[0], loss[4] = self.bbox_loss(
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
- )
- keypoints = batch["keypoints"].to(self.device).float().clone()
- keypoints[..., 0] *= imgsz[1]
- keypoints[..., 1] *= imgsz[0]
-
- loss[1], loss[2] = self.calculate_keypoints_loss(
- fg_mask, target_gt_idx, keypoints, batch_idx, stride_tensor, target_bboxes, pred_kpts
- )
-
- loss[0] *= self.hyp.box # box gain
- loss[1] *= self.hyp.pose # pose gain
- loss[2] *= self.hyp.kobj # kobj gain
- loss[3] *= self.hyp.cls # cls gain
- loss[4] *= self.hyp.dfl # dfl gain
-
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
-
- @staticmethod
- def kpts_decode(anchor_points: torch.Tensor, pred_kpts: torch.Tensor) -> torch.Tensor:
- """Decode predicted keypoints to image coordinates."""
- y = pred_kpts.clone()
- y[..., :2] *= 2.0
- y[..., 0] += anchor_points[:, [0]] - 0.5
- y[..., 1] += anchor_points[:, [1]] - 0.5
- return y
-
- def calculate_keypoints_loss(
- self,
- masks: torch.Tensor,
- target_gt_idx: torch.Tensor,
- keypoints: torch.Tensor,
- batch_idx: torch.Tensor,
- stride_tensor: torch.Tensor,
- target_bboxes: torch.Tensor,
- pred_kpts: torch.Tensor,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Calculate the keypoints loss for the model.
-
- This function calculates the keypoints loss and keypoints object loss for a given batch. The keypoints loss is
- based on the difference between the predicted keypoints and ground truth keypoints. The keypoints object loss is
- a binary classification loss that classifies whether a keypoint is present or not.
-
- Args:
- masks (torch.Tensor): Binary mask tensor indicating object presence, shape (BS, N_anchors).
- target_gt_idx (torch.Tensor): Index tensor mapping anchors to ground truth objects, shape (BS, N_anchors).
- keypoints (torch.Tensor): Ground truth keypoints, shape (N_kpts_in_batch, N_kpts_per_object, kpts_dim).
- batch_idx (torch.Tensor): Batch index tensor for keypoints, shape (N_kpts_in_batch, 1).
- stride_tensor (torch.Tensor): Stride tensor for anchors, shape (N_anchors, 1).
- target_bboxes (torch.Tensor): Ground truth boxes in (x1, y1, x2, y2) format, shape (BS, N_anchors, 4).
- pred_kpts (torch.Tensor): Predicted keypoints, shape (BS, N_anchors, N_kpts_per_object, kpts_dim).
-
- Returns:
- kpts_loss (torch.Tensor): The keypoints loss.
- kpts_obj_loss (torch.Tensor): The keypoints object loss.
- """
- batch_idx = batch_idx.flatten()
- batch_size = len(masks)
-
- # Find the maximum number of keypoints in a single image
- max_kpts = torch.unique(batch_idx, return_counts=True)[1].max()
-
- # Create a tensor to hold batched keypoints
- batched_keypoints = torch.zeros(
- (batch_size, max_kpts, keypoints.shape[1], keypoints.shape[2]), device=keypoints.device
- )
-
- # TODO: any idea how to vectorize this?
- # Fill batched_keypoints with keypoints based on batch_idx
- for i in range(batch_size):
- keypoints_i = keypoints[batch_idx == i]
- batched_keypoints[i, : keypoints_i.shape[0]] = keypoints_i
-
- # Expand dimensions of target_gt_idx to match the shape of batched_keypoints
- target_gt_idx_expanded = target_gt_idx.unsqueeze(-1).unsqueeze(-1)
-
- # Use target_gt_idx_expanded to select keypoints from batched_keypoints
- selected_keypoints = batched_keypoints.gather(
- 1, target_gt_idx_expanded.expand(-1, -1, keypoints.shape[1], keypoints.shape[2])
- )
-
- # Divide coordinates by stride
- selected_keypoints[..., :2] /= stride_tensor.view(1, -1, 1, 1)
-
- kpts_loss = 0
- kpts_obj_loss = 0
-
- if masks.any():
- gt_kpt = selected_keypoints[masks]
- area = xyxy2xywh(target_bboxes[masks])[:, 2:].prod(1, keepdim=True)
- pred_kpt = pred_kpts[masks]
- kpt_mask = gt_kpt[..., 2] != 0 if gt_kpt.shape[-1] == 3 else torch.full_like(gt_kpt[..., 0], True)
- kpts_loss = self.keypoint_loss(pred_kpt, gt_kpt, kpt_mask, area) # pose loss
-
- if pred_kpt.shape[-1] == 3:
- kpts_obj_loss = self.bce_pose(pred_kpt[..., 2], kpt_mask.float()) # keypoint obj loss
-
- return kpts_loss, kpts_obj_loss
-
-
-class v8ClassificationLoss:
- """Criterion class for computing training losses for classification."""
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Compute the classification loss between predictions and true labels."""
- preds = preds[1] if isinstance(preds, (list, tuple)) else preds
- loss = F.cross_entropy(preds, batch["cls"], reduction="mean")
- return loss, loss.detach()
-
-
-class v8OBBLoss(v8DetectionLoss):
- """Calculates losses for object detection, classification, and box distribution in rotated YOLO models."""
-
- def __init__(self, model):
- """Initialize v8OBBLoss with model, assigner, and rotated bbox loss; model must be de-paralleled."""
- super().__init__(model)
- self.assigner = RotatedTaskAlignedAssigner(topk=10, num_classes=self.nc, alpha=0.5, beta=6.0)
- self.bbox_loss = RotatedBboxLoss(self.reg_max).to(self.device)
-
- def preprocess(self, targets: torch.Tensor, batch_size: int, scale_tensor: torch.Tensor) -> torch.Tensor:
- """Preprocess targets for oriented bounding box detection."""
- if targets.shape[0] == 0:
- out = torch.zeros(batch_size, 0, 6, device=self.device)
- else:
- i = targets[:, 0] # image index
- _, counts = i.unique(return_counts=True)
- counts = counts.to(dtype=torch.int32)
- out = torch.zeros(batch_size, counts.max(), 6, device=self.device)
- for j in range(batch_size):
- matches = i == j
- if n := matches.sum():
- bboxes = targets[matches, 2:]
- bboxes[..., :4].mul_(scale_tensor)
- out[j, :n] = torch.cat([targets[matches, 1:2], bboxes], dim=-1)
- return out
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate and return the loss for oriented bounding box detection."""
- loss = torch.zeros(3, device=self.device) # box, cls, dfl
- feats, pred_angle = preds if isinstance(preds[0], list) else preds[1]
- batch_size = pred_angle.shape[0] # batch size, number of masks, mask height, mask width
- pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
- (self.reg_max * 4, self.nc), 1
- )
-
- # b, grids, ..
- pred_scores = pred_scores.permute(0, 2, 1).contiguous()
- pred_distri = pred_distri.permute(0, 2, 1).contiguous()
- pred_angle = pred_angle.permute(0, 2, 1).contiguous()
-
- dtype = pred_scores.dtype
- imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
- anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
-
- # targets
- try:
- batch_idx = batch["batch_idx"].view(-1, 1)
- targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"].view(-1, 5)), 1)
- rw, rh = targets[:, 4] * imgsz[0].item(), targets[:, 5] * imgsz[1].item()
- targets = targets[(rw >= 2) & (rh >= 2)] # filter rboxes of tiny size to stabilize training
- targets = self.preprocess(targets, batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])
- gt_labels, gt_bboxes = targets.split((1, 5), 2) # cls, xywhr
- mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
- except RuntimeError as e:
- raise TypeError(
- "ERROR ❌ OBB dataset incorrectly formatted or not a OBB dataset.\n"
- "This error can occur when incorrectly training a 'OBB' model on a 'detect' dataset, "
- "i.e. 'yolo train model=yolo11n-obb.pt data=coco8.yaml'.\nVerify your dataset is a "
- "correctly formatted 'OBB' dataset using 'data=dota8.yaml' "
- "as an example.\nSee https://docs.ultralytics.com/datasets/obb/ for help."
- ) from e
-
- # Pboxes
- pred_bboxes = self.bbox_decode(anchor_points, pred_distri, pred_angle) # xyxy, (b, h*w, 4)
-
- bboxes_for_assigner = pred_bboxes.clone().detach()
- # Only the first four elements need to be scaled
- bboxes_for_assigner[..., :4] *= stride_tensor
- _, target_bboxes, target_scores, fg_mask, _ = self.assigner(
- pred_scores.detach().sigmoid(),
- bboxes_for_assigner.type(gt_bboxes.dtype),
- anchor_points * stride_tensor,
- gt_labels,
- gt_bboxes,
- mask_gt,
- )
-
- target_scores_sum = max(target_scores.sum(), 1)
-
- # Cls loss
- # loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
- loss[1] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE
-
- # Bbox loss
- if fg_mask.sum():
- target_bboxes[..., :4] /= stride_tensor
- loss[0], loss[2] = self.bbox_loss(
- pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
- )
- else:
- loss[0] += (pred_angle * 0).sum()
-
- loss[0] *= self.hyp.box # box gain
- loss[1] *= self.hyp.cls # cls gain
- loss[2] *= self.hyp.dfl # dfl gain
-
- return loss * batch_size, loss.detach() # loss(box, cls, dfl)
-
- def bbox_decode(
- self, anchor_points: torch.Tensor, pred_dist: torch.Tensor, pred_angle: torch.Tensor
- ) -> torch.Tensor:
- """
- Decode predicted object bounding box coordinates from anchor points and distribution.
-
- Args:
- anchor_points (torch.Tensor): Anchor points, (h*w, 2).
- pred_dist (torch.Tensor): Predicted rotated distance, (bs, h*w, 4).
- pred_angle (torch.Tensor): Predicted angle, (bs, h*w, 1).
-
- Returns:
- (torch.Tensor): Predicted rotated bounding boxes with angles, (bs, h*w, 5).
- """
- if self.use_dfl:
- b, a, c = pred_dist.shape # batch, anchors, channels
- pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
- return torch.cat((dist2rbox(pred_dist, pred_angle, anchor_points), pred_angle), dim=-1)
-
-
-class E2EDetectLoss:
- """Criterion class for computing training losses for end-to-end detection."""
-
- def __init__(self, model):
- """Initialize E2EDetectLoss with one-to-many and one-to-one detection losses using the provided model."""
- self.one2many = v8DetectionLoss(model, tal_topk=10)
- self.one2one = v8DetectionLoss(model, tal_topk=1)
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate the sum of the loss for box, cls and dfl multiplied by batch size."""
- preds = preds[1] if isinstance(preds, tuple) else preds
- one2many = preds["one2many"]
- loss_one2many = self.one2many(one2many, batch)
- one2one = preds["one2one"]
- loss_one2one = self.one2one(one2one, batch)
- return loss_one2many[0] + loss_one2one[0], loss_one2many[1] + loss_one2one[1]
-
-
-class TVPDetectLoss:
- """Criterion class for computing training losses for text-visual prompt detection."""
-
- def __init__(self, model):
- """Initialize TVPDetectLoss with task-prompt and visual-prompt criteria using the provided model."""
- self.vp_criterion = v8DetectionLoss(model)
- # NOTE: store following info as it's changeable in __call__
- self.ori_nc = self.vp_criterion.nc
- self.ori_no = self.vp_criterion.no
- self.ori_reg_max = self.vp_criterion.reg_max
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate the loss for text-visual prompt detection."""
- feats = preds[1] if isinstance(preds, tuple) else preds
- assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
-
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
- loss = torch.zeros(3, device=self.vp_criterion.device, requires_grad=True)
- return loss, loss.detach()
-
- vp_feats = self._get_vp_features(feats)
- vp_loss = self.vp_criterion(vp_feats, batch)
- box_loss = vp_loss[0][1]
- return box_loss, vp_loss[1]
-
- def _get_vp_features(self, feats: list[torch.Tensor]) -> list[torch.Tensor]:
- """Extract visual-prompt features from the model output."""
- vnc = feats[0].shape[1] - self.ori_reg_max * 4 - self.ori_nc
-
- self.vp_criterion.nc = vnc
- self.vp_criterion.no = vnc + self.vp_criterion.reg_max * 4
- self.vp_criterion.assigner.num_classes = vnc
-
- return [
- torch.cat((box, cls_vp), dim=1)
- for box, _, cls_vp in [xi.split((self.ori_reg_max * 4, self.ori_nc, vnc), dim=1) for xi in feats]
- ]
-
-
-class TVPSegmentLoss(TVPDetectLoss):
- """Criterion class for computing training losses for text-visual prompt segmentation."""
-
- def __init__(self, model):
- """Initialize TVPSegmentLoss with task-prompt and visual-prompt criteria using the provided model."""
- super().__init__(model)
- self.vp_criterion = v8SegmentationLoss(model)
-
- def __call__(self, preds: Any, batch: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
- """Calculate the loss for text-visual prompt segmentation."""
- feats, pred_masks, proto = preds if len(preds) == 3 else preds[1]
- assert self.ori_reg_max == self.vp_criterion.reg_max # TODO: remove it
-
- if self.ori_reg_max * 4 + self.ori_nc == feats[0].shape[1]:
- loss = torch.zeros(4, device=self.vp_criterion.device, requires_grad=True)
- return loss, loss.detach()
-
- vp_feats = self._get_vp_features(feats)
- vp_loss = self.vp_criterion((vp_feats, pred_masks, proto), batch)
- cls_loss = vp_loss[0][2]
- return cls_loss, vp_loss[1]
diff --git a/ultralytics/utils/metrics.py b/ultralytics/utils/metrics.py
deleted file mode 100644
index dd1feb3..0000000
--- a/ultralytics/utils/metrics.py
+++ /dev/null
@@ -1,1592 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""Model validation metrics."""
-
-from __future__ import annotations
-
-import math
-import warnings
-from collections import defaultdict
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import torch
-
-from ultralytics.utils import LOGGER, DataExportMixin, SimpleClass, TryExcept, checks, plt_settings
-
-OKS_SIGMA = (
- np.array([0.26, 0.25, 0.25, 0.35, 0.35, 0.79, 0.79, 0.72, 0.72, 0.62, 0.62, 1.07, 1.07, 0.87, 0.87, 0.89, 0.89])
- / 10.0
-)
-
-
-def bbox_ioa(box1: np.ndarray, box2: np.ndarray, iou: bool = False, eps: float = 1e-7) -> np.ndarray:
- """
- Calculate the intersection over box2 area given box1 and box2.
-
- Args:
- box1 (np.ndarray): A numpy array of shape (N, 4) representing N bounding boxes in x1y1x2y2 format.
- box2 (np.ndarray): A numpy array of shape (M, 4) representing M bounding boxes in x1y1x2y2 format.
- iou (bool, optional): Calculate the standard IoU if True else return inter_area/box2_area.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (np.ndarray): A numpy array of shape (N, M) representing the intersection over box2 area.
- """
- # Get the coordinates of bounding boxes
- b1_x1, b1_y1, b1_x2, b1_y2 = box1.T
- b2_x1, b2_y1, b2_x2, b2_y2 = box2.T
-
- # Intersection area
- inter_area = (np.minimum(b1_x2[:, None], b2_x2) - np.maximum(b1_x1[:, None], b2_x1)).clip(0) * (
- np.minimum(b1_y2[:, None], b2_y2) - np.maximum(b1_y1[:, None], b2_y1)
- ).clip(0)
-
- # Box2 area
- area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
- if iou:
- box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
- area = area + box1_area[:, None] - inter_area
-
- # Intersection over box2 area
- return inter_area / (area + eps)
-
-
-def box_iou(box1: torch.Tensor, box2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
- """
- Calculate intersection-over-union (IoU) of boxes.
-
- Args:
- box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes in (x1, y1, x2, y2) format.
- box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes in (x1, y1, x2, y2) format.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
-
- References:
- https://github.com/pytorch/vision/blob/main/torchvision/ops/boxes.py
- """
- # NOTE: Need .float() to get accurate iou values
- # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
- (a1, a2), (b1, b2) = box1.float().unsqueeze(1).chunk(2, 2), box2.float().unsqueeze(0).chunk(2, 2)
- inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
-
- # IoU = inter / (area1 + area2 - inter)
- return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
-
-
-def bbox_iou(
- box1: torch.Tensor,
- box2: torch.Tensor,
- xywh: bool = True,
- GIoU: bool = False,
- DIoU: bool = False,
- CIoU: bool = False,
- eps: float = 1e-7,
-) -> torch.Tensor:
- """
- Calculate the Intersection over Union (IoU) between bounding boxes.
-
- This function supports various shapes for `box1` and `box2` as long as the last dimension is 4.
- For instance, you may pass tensors shaped like (4,), (N, 4), (B, N, 4), or (B, N, 1, 4).
- Internally, the code will split the last dimension into (x, y, w, h) if `xywh=True`,
- or (x1, y1, x2, y2) if `xywh=False`.
-
- Args:
- box1 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
- box2 (torch.Tensor): A tensor representing one or more bounding boxes, with the last dimension being 4.
- xywh (bool, optional): If True, input boxes are in (x, y, w, h) format. If False, input boxes are in
- (x1, y1, x2, y2) format.
- GIoU (bool, optional): If True, calculate Generalized IoU.
- DIoU (bool, optional): If True, calculate Distance IoU.
- CIoU (bool, optional): If True, calculate Complete IoU.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): IoU, GIoU, DIoU, or CIoU values depending on the specified flags.
- """
- # Get the coordinates of bounding boxes
- if xywh: # transform from xywh to xyxy
- (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
- w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
- b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
- b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
- else: # x1, y1, x2, y2 = box1
- b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
- b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
- w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
- w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
-
- # Intersection area
- inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp_(0) * (
- b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
- ).clamp_(0)
-
- # Union Area
- union = w1 * h1 + w2 * h2 - inter + eps
-
- # IoU
- iou = inter / union
- if CIoU or DIoU or GIoU:
- cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
- ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
- if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
- c2 = cw.pow(2) + ch.pow(2) + eps # convex diagonal squared
- rho2 = (
- (b2_x1 + b2_x2 - b1_x1 - b1_x2).pow(2) + (b2_y1 + b2_y2 - b1_y1 - b1_y2).pow(2)
- ) / 4 # center dist**2
- if CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
- v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
- with torch.no_grad():
- alpha = v / (v - iou + (1 + eps))
- return iou - (rho2 / c2 + v * alpha) # CIoU
- return iou - rho2 / c2 # DIoU
- c_area = cw * ch + eps # convex area
- return iou - (c_area - union) / c_area # GIoU https://arxiv.org/pdf/1902.09630.pdf
- return iou # IoU
-
-
-def mask_iou(mask1: torch.Tensor, mask2: torch.Tensor, eps: float = 1e-7) -> torch.Tensor:
- """
- Calculate masks IoU.
-
- Args:
- mask1 (torch.Tensor): A tensor of shape (N, n) where N is the number of ground truth objects and n is the
- product of image width and height.
- mask2 (torch.Tensor): A tensor of shape (M, n) where M is the number of predicted objects and n is the
- product of image width and height.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): A tensor of shape (N, M) representing masks IoU.
- """
- intersection = torch.matmul(mask1, mask2.T).clamp_(0)
- union = (mask1.sum(1)[:, None] + mask2.sum(1)[None]) - intersection # (area1 + area2) - intersection
- return intersection / (union + eps)
-
-
-def kpt_iou(
- kpt1: torch.Tensor, kpt2: torch.Tensor, area: torch.Tensor, sigma: list[float], eps: float = 1e-7
-) -> torch.Tensor:
- """
- Calculate Object Keypoint Similarity (OKS).
-
- Args:
- kpt1 (torch.Tensor): A tensor of shape (N, 17, 3) representing ground truth keypoints.
- kpt2 (torch.Tensor): A tensor of shape (M, 17, 3) representing predicted keypoints.
- area (torch.Tensor): A tensor of shape (N,) representing areas from ground truth.
- sigma (list): A list containing 17 values representing keypoint scales.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): A tensor of shape (N, M) representing keypoint similarities.
- """
- d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
- sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
- kpt_mask = kpt1[..., 2] != 0 # (N, 17)
- e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
- # e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
- return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)
-
-
-def _get_covariance_matrix(boxes: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Generate covariance matrix from oriented bounding boxes.
-
- Args:
- boxes (torch.Tensor): A tensor of shape (N, 5) representing rotated bounding boxes, with xywhr format.
-
- Returns:
- (torch.Tensor): Covariance matrices corresponding to original rotated bounding boxes.
- """
- # Gaussian bounding boxes, ignore the center points (the first two columns) because they are not needed here.
- gbbs = torch.cat((boxes[:, 2:4].pow(2) / 12, boxes[:, 4:]), dim=-1)
- a, b, c = gbbs.split(1, dim=-1)
- cos = c.cos()
- sin = c.sin()
- cos2 = cos.pow(2)
- sin2 = sin.pow(2)
- return a * cos2 + b * sin2, a * sin2 + b * cos2, (a - b) * cos * sin
-
-
-def probiou(obb1: torch.Tensor, obb2: torch.Tensor, CIoU: bool = False, eps: float = 1e-7) -> torch.Tensor:
- """
- Calculate probabilistic IoU between oriented bounding boxes.
-
- Args:
- obb1 (torch.Tensor): Ground truth OBBs, shape (N, 5), format xywhr.
- obb2 (torch.Tensor): Predicted OBBs, shape (N, 5), format xywhr.
- CIoU (bool, optional): If True, calculate CIoU.
- eps (float, optional): Small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): OBB similarities, shape (N,).
-
- Notes:
- OBB format: [center_x, center_y, width, height, rotation_angle].
-
- References:
- https://arxiv.org/pdf/2106.06072v1.pdf
- """
- x1, y1 = obb1[..., :2].split(1, dim=-1)
- x2, y2 = obb2[..., :2].split(1, dim=-1)
- a1, b1, c1 = _get_covariance_matrix(obb1)
- a2, b2, c2 = _get_covariance_matrix(obb2)
-
- t1 = (
- ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
- ) * 0.25
- t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
- t3 = (
- ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
- / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
- + eps
- ).log() * 0.5
- bd = (t1 + t2 + t3).clamp(eps, 100.0)
- hd = (1.0 - (-bd).exp() + eps).sqrt()
- iou = 1 - hd
- if CIoU: # only include the wh aspect ratio part
- w1, h1 = obb1[..., 2:4].split(1, dim=-1)
- w2, h2 = obb2[..., 2:4].split(1, dim=-1)
- v = (4 / math.pi**2) * ((w2 / h2).atan() - (w1 / h1).atan()).pow(2)
- with torch.no_grad():
- alpha = v / (v - iou + (1 + eps))
- return iou - v * alpha # CIoU
- return iou
-
-
-def batch_probiou(obb1: torch.Tensor | np.ndarray, obb2: torch.Tensor | np.ndarray, eps: float = 1e-7) -> torch.Tensor:
- """
- Calculate the probabilistic IoU between oriented bounding boxes.
-
- Args:
- obb1 (torch.Tensor | np.ndarray): A tensor of shape (N, 5) representing ground truth obbs, with xywhr format.
- obb2 (torch.Tensor | np.ndarray): A tensor of shape (M, 5) representing predicted obbs, with xywhr format.
- eps (float, optional): A small value to avoid division by zero.
-
- Returns:
- (torch.Tensor): A tensor of shape (N, M) representing obb similarities.
-
- References:
- https://arxiv.org/pdf/2106.06072v1.pdf
- """
- obb1 = torch.from_numpy(obb1) if isinstance(obb1, np.ndarray) else obb1
- obb2 = torch.from_numpy(obb2) if isinstance(obb2, np.ndarray) else obb2
-
- x1, y1 = obb1[..., :2].split(1, dim=-1)
- x2, y2 = (x.squeeze(-1)[None] for x in obb2[..., :2].split(1, dim=-1))
- a1, b1, c1 = _get_covariance_matrix(obb1)
- a2, b2, c2 = (x.squeeze(-1)[None] for x in _get_covariance_matrix(obb2))
-
- t1 = (
- ((a1 + a2) * (y1 - y2).pow(2) + (b1 + b2) * (x1 - x2).pow(2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)
- ) * 0.25
- t2 = (((c1 + c2) * (x2 - x1) * (y1 - y2)) / ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2) + eps)) * 0.5
- t3 = (
- ((a1 + a2) * (b1 + b2) - (c1 + c2).pow(2))
- / (4 * ((a1 * b1 - c1.pow(2)).clamp_(0) * (a2 * b2 - c2.pow(2)).clamp_(0)).sqrt() + eps)
- + eps
- ).log() * 0.5
- bd = (t1 + t2 + t3).clamp(eps, 100.0)
- hd = (1.0 - (-bd).exp() + eps).sqrt()
- return 1 - hd
-
-
-def smooth_bce(eps: float = 0.1) -> tuple[float, float]:
- """
- Compute smoothed positive and negative Binary Cross-Entropy targets.
-
- Args:
- eps (float, optional): The epsilon value for label smoothing.
-
- Returns:
- pos (float): Positive label smoothing BCE target.
- neg (float): Negative label smoothing BCE target.
-
- References:
- https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
- """
- return 1.0 - 0.5 * eps, 0.5 * eps
-
-
-class ConfusionMatrix(DataExportMixin):
- """
- A class for calculating and updating a confusion matrix for object detection and classification tasks.
-
- Attributes:
- task (str): The type of task, either 'detect' or 'classify'.
- matrix (np.ndarray): The confusion matrix, with dimensions depending on the task.
- nc (int): The number of category.
- names (list[str]): The names of the classes, used as labels on the plot.
- matches (dict): Contains the indices of ground truths and predictions categorized into TP, FP and FN.
- """
-
- def __init__(self, names: dict[int, str] = [], task: str = "detect", save_matches: bool = False):
- """
- Initialize a ConfusionMatrix instance.
-
- Args:
- names (dict[int, str], optional): Names of classes, used as labels on the plot.
- task (str, optional): Type of task, either 'detect' or 'classify'.
- save_matches (bool, optional): Save the indices of GTs, TPs, FPs, FNs for visualization.
- """
- self.task = task
- self.nc = len(names) # number of classes
- self.matrix = np.zeros((self.nc, self.nc)) if self.task == "classify" else np.zeros((self.nc + 1, self.nc + 1))
- self.names = names # name of classes
- self.matches = {} if save_matches else None
-
- def _append_matches(self, mtype: str, batch: dict[str, Any], idx: int) -> None:
- """
- Append the matches to TP, FP, FN or GT list for the last batch.
-
- This method updates the matches dictionary by appending specific batch data
- to the appropriate match type (True Positive, False Positive, or False Negative).
-
- Args:
- mtype (str): Match type identifier ('TP', 'FP', 'FN' or 'GT').
- batch (dict[str, Any]): Batch data containing detection results with keys
- like 'bboxes', 'cls', 'conf', 'keypoints', 'masks'.
- idx (int): Index of the specific detection to append from the batch.
-
- Note:
- For masks, handles both overlap and non-overlap cases. When masks.max() > 1.0,
- it indicates overlap_mask=True with shape (1, H, W), otherwise uses direct indexing.
- """
- if self.matches is None:
- return
- for k, v in batch.items():
- if k in {"bboxes", "cls", "conf", "keypoints"}:
- self.matches[mtype][k] += v[[idx]]
- elif k == "masks":
- # NOTE: masks.max() > 1.0 means overlap_mask=True with (1, H, W) shape
- self.matches[mtype][k] += [v[0] == idx + 1] if v.max() > 1.0 else [v[idx]]
-
- def process_cls_preds(self, preds: list[torch.Tensor], targets: list[torch.Tensor]) -> None:
- """
- Update confusion matrix for classification task.
-
- Args:
- preds (list[N, min(nc,5)]): Predicted class labels.
- targets (list[N, 1]): Ground truth class labels.
- """
- preds, targets = torch.cat(preds)[:, 0], torch.cat(targets)
- for p, t in zip(preds.cpu().numpy(), targets.cpu().numpy()):
- self.matrix[p][t] += 1
-
- def process_batch(
- self,
- detections: dict[str, torch.Tensor],
- batch: dict[str, Any],
- conf: float = 0.25,
- iou_thres: float = 0.45,
- ) -> None:
- """
- Update confusion matrix for object detection task.
-
- Args:
- detections (dict[str, torch.Tensor]): Dictionary containing detected bounding boxes and their associated information.
- Should contain 'cls', 'conf', and 'bboxes' keys, where 'bboxes' can be
- Array[N, 4] for regular boxes or Array[N, 5] for OBB with angle.
- batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' (Array[M, 4]| Array[M, 5]) and
- 'cls' (Array[M]) keys, where M is the number of ground truth objects.
- conf (float, optional): Confidence threshold for detections.
- iou_thres (float, optional): IoU threshold for matching detections to ground truth.
- """
- gt_cls, gt_bboxes = batch["cls"], batch["bboxes"]
- if self.matches is not None: # only if visualization is enabled
- self.matches = {k: defaultdict(list) for k in {"TP", "FP", "FN", "GT"}}
- for i in range(gt_cls.shape[0]):
- self._append_matches("GT", batch, i) # store GT
- is_obb = gt_bboxes.shape[1] == 5 # check if boxes contains angle for OBB
- conf = 0.25 if conf in {None, 0.01 if is_obb else 0.001} else conf # apply 0.25 if default val conf is passed
- no_pred = detections["cls"].shape[0] == 0
- if gt_cls.shape[0] == 0: # Check if labels is empty
- if not no_pred:
- detections = {k: detections[k][detections["conf"] > conf] for k in detections}
- detection_classes = detections["cls"].int().tolist()
- for i, dc in enumerate(detection_classes):
- self.matrix[dc, self.nc] += 1 # FP
- self._append_matches("FP", detections, i)
- return
- if no_pred:
- gt_classes = gt_cls.int().tolist()
- for i, gc in enumerate(gt_classes):
- self.matrix[self.nc, gc] += 1 # FN
- self._append_matches("FN", batch, i)
- return
-
- detections = {k: detections[k][detections["conf"] > conf] for k in detections}
- gt_classes = gt_cls.int().tolist()
- detection_classes = detections["cls"].int().tolist()
- bboxes = detections["bboxes"]
- iou = batch_probiou(gt_bboxes, bboxes) if is_obb else box_iou(gt_bboxes, bboxes)
-
- x = torch.where(iou > iou_thres)
- if x[0].shape[0]:
- matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()
- if x[0].shape[0] > 1:
- matches = matches[matches[:, 2].argsort()[::-1]]
- matches = matches[np.unique(matches[:, 1], return_index=True)[1]]
- matches = matches[matches[:, 2].argsort()[::-1]]
- matches = matches[np.unique(matches[:, 0], return_index=True)[1]]
- else:
- matches = np.zeros((0, 3))
-
- n = matches.shape[0] > 0
- m0, m1, _ = matches.transpose().astype(int)
- for i, gc in enumerate(gt_classes):
- j = m0 == i
- if n and sum(j) == 1:
- dc = detection_classes[m1[j].item()]
- self.matrix[dc, gc] += 1 # TP if class is correct else both an FP and an FN
- if dc == gc:
- self._append_matches("TP", detections, m1[j].item())
- else:
- self._append_matches("FP", detections, m1[j].item())
- self._append_matches("FN", batch, i)
- else:
- self.matrix[self.nc, gc] += 1 # FN
- self._append_matches("FN", batch, i)
-
- for i, dc in enumerate(detection_classes):
- if not any(m1 == i):
- self.matrix[dc, self.nc] += 1 # FP
- self._append_matches("FP", detections, i)
-
- def matrix(self):
- """Return the confusion matrix."""
- return self.matrix
-
- def tp_fp(self) -> tuple[np.ndarray, np.ndarray]:
- """
- Return true positives and false positives.
-
- Returns:
- tp (np.ndarray): True positives.
- fp (np.ndarray): False positives.
- """
- tp = self.matrix.diagonal() # true positives
- fp = self.matrix.sum(1) - tp # false positives
- # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
- return (tp, fp) if self.task == "classify" else (tp[:-1], fp[:-1]) # remove background class if task=detect
-
- def plot_matches(self, img: torch.Tensor, im_file: str, save_dir: Path) -> None:
- """
- Plot grid of GT, TP, FP, FN for each image.
-
- Args:
- img (torch.Tensor): Image to plot onto.
- im_file (str): Image filename to save visualizations.
- save_dir (Path): Location to save the visualizations to.
- """
- if not self.matches:
- return
- from .ops import xyxy2xywh
- from .plotting import plot_images
-
- # Create batch of 4 (GT, TP, FP, FN)
- labels = defaultdict(list)
- for i, mtype in enumerate(["GT", "FP", "TP", "FN"]):
- mbatch = self.matches[mtype]
- if "conf" not in mbatch:
- mbatch["conf"] = torch.tensor([1.0] * len(mbatch["bboxes"]), device=img.device)
- mbatch["batch_idx"] = torch.ones(len(mbatch["bboxes"]), device=img.device) * i
- for k in mbatch.keys():
- labels[k] += mbatch[k]
-
- labels = {k: torch.stack(v, 0) if len(v) else torch.empty(0) for k, v in labels.items()}
- if self.task != "obb" and labels["bboxes"].shape[0]:
- labels["bboxes"] = xyxy2xywh(labels["bboxes"])
- (save_dir / "visualizations").mkdir(parents=True, exist_ok=True)
- plot_images(
- labels,
- img.repeat(4, 1, 1, 1),
- paths=["Ground Truth", "False Positives", "True Positives", "False Negatives"],
- fname=save_dir / "visualizations" / Path(im_file).name,
- names=self.names,
- max_subplots=4,
- conf_thres=0.001,
- )
-
- @TryExcept(msg="ConfusionMatrix plot failure")
- @plt_settings()
- def plot(self, normalize: bool = True, save_dir: str = "", on_plot=None):
- """
- Plot the confusion matrix using matplotlib and save it to a file.
-
- Args:
- normalize (bool, optional): Whether to normalize the confusion matrix.
- save_dir (str, optional): Directory where the plot will be saved.
- on_plot (callable, optional): An optional callback to pass plots path and data when they are rendered.
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
- array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
-
- fig, ax = plt.subplots(1, 1, figsize=(12, 9))
- names, n = list(self.names.values()), self.nc
- if self.nc >= 100: # downsample for large class count
- k = max(2, self.nc // 60) # step size for downsampling, always > 1
- keep_idx = slice(None, None, k) # create slice instead of array
- names = names[keep_idx] # slice class names
- array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
- n = (self.nc + k - 1) // k # number of retained classes
- nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
- xy_ticks = np.arange(len(ticklabels))
- tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
- label_fontsize = max(6, 12 - 0.1 * nc)
- title_fontsize = max(6, 12 - 0.1 * nc)
- btm = max(0.1, 0.25 - 0.001 * nc) # Minimum value is 0.1
- with warnings.catch_warnings():
- warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
- im = ax.imshow(array, cmap="Blues", vmin=0.0, interpolation="none")
- ax.xaxis.set_label_position("bottom")
- if nc < 30: # Add score for each cell of confusion matrix
- color_threshold = 0.45 * (1 if normalize else np.nanmax(array)) # text color threshold
- for i, row in enumerate(array[:nc]):
- for j, val in enumerate(row[:nc]):
- val = array[i, j]
- if np.isnan(val):
- continue
- ax.text(
- j,
- i,
- f"{val:.2f}" if normalize else f"{int(val)}",
- ha="center",
- va="center",
- fontsize=10,
- color="white" if val > color_threshold else "black",
- )
- cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)
- title = "Confusion Matrix" + " Normalized" * normalize
- ax.set_xlabel("True", fontsize=label_fontsize, labelpad=10)
- ax.set_ylabel("Predicted", fontsize=label_fontsize, labelpad=10)
- ax.set_title(title, fontsize=title_fontsize, pad=20)
- ax.set_xticks(xy_ticks)
- ax.set_yticks(xy_ticks)
- ax.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False)
- ax.tick_params(axis="y", left=True, right=False, labelleft=True, labelright=False)
- if ticklabels != "auto":
- ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha="center")
- ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)
- for s in {"left", "right", "bottom", "top", "outline"}:
- if s != "outline":
- ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline
- cbar.ax.spines[s].set_visible(False)
- fig.subplots_adjust(left=0, right=0.84, top=0.94, bottom=btm) # Adjust layout to ensure equal margins
- plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
- fig.savefig(plot_fname, dpi=250)
- plt.close(fig)
- if on_plot:
- on_plot(plot_fname)
-
- def print(self):
- """Print the confusion matrix to the console."""
- for i in range(self.matrix.shape[0]):
- LOGGER.info(" ".join(map(str, self.matrix[i])))
-
- def summary(self, normalize: bool = False, decimals: int = 5) -> list[dict[str, float]]:
- """
- Generate a summarized representation of the confusion matrix as a list of dictionaries, with optional
- normalization. This is useful for exporting the matrix to various formats such as CSV, XML, HTML, JSON, or SQL.
-
- Args:
- normalize (bool): Whether to normalize the confusion matrix values.
- decimals (int): Number of decimal places to round the output values to.
-
- Returns:
- (list[dict[str, float]]): A list of dictionaries, each representing one predicted class with corresponding values for all actual classes.
-
- Examples:
- >>> results = model.val(data="coco8.yaml", plots=True)
- >>> cm_dict = results.confusion_matrix.summary(normalize=True, decimals=5)
- >>> print(cm_dict)
- """
- import re
-
- names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
- clean_names, seen = [], set()
- for name in names:
- clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
- original_clean = clean_name
- counter = 1
- while clean_name.lower() in seen:
- clean_name = f"{original_clean}_{counter}"
- counter += 1
- seen.add(clean_name.lower())
- clean_names.append(clean_name)
- array = (self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1)).round(decimals)
- return [
- dict({"Predicted": clean_names[i]}, **{clean_names[j]: array[i, j] for j in range(len(clean_names))})
- for i in range(len(clean_names))
- ]
-
-
-def smooth(y: np.ndarray, f: float = 0.05) -> np.ndarray:
- """Box filter of fraction f."""
- nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
- p = np.ones(nf // 2) # ones padding
- yp = np.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
- return np.convolve(yp, np.ones(nf) / nf, mode="valid") # y-smoothed
-
-
-@plt_settings()
-def plot_pr_curve(
- px: np.ndarray,
- py: np.ndarray,
- ap: np.ndarray,
- save_dir: Path = Path("pr_curve.png"),
- names: dict[int, str] = {},
- on_plot=None,
-):
- """
- Plot precision-recall curve.
-
- Args:
- px (np.ndarray): X values for the PR curve.
- py (np.ndarray): Y values for the PR curve.
- ap (np.ndarray): Average precision values.
- save_dir (Path, optional): Path to save the plot.
- names (dict[int, str], optional): Dictionary mapping class indices to class names.
- on_plot (callable, optional): Function to call after plot is saved.
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
- py = np.stack(py, axis=1)
-
- if 0 < len(names) < 21: # display per-class legend if < 21 classes
- for i, y in enumerate(py.T):
- ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
- else:
- ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
-
- ax.plot(px, py.mean(1), linewidth=3, color="blue", label=f"all classes {ap[:, 0].mean():.3f} mAP@0.5")
- ax.set_xlabel("Recall")
- ax.set_ylabel("Precision")
- ax.set_xlim(0, 1)
- ax.set_ylim(0, 1)
- ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
- ax.set_title("Precision-Recall Curve")
- fig.savefig(save_dir, dpi=250)
- plt.close(fig)
- if on_plot:
- on_plot(save_dir)
-
-
-@plt_settings()
-def plot_mc_curve(
- px: np.ndarray,
- py: np.ndarray,
- save_dir: Path = Path("mc_curve.png"),
- names: dict[int, str] = {},
- xlabel: str = "Confidence",
- ylabel: str = "Metric",
- on_plot=None,
-):
- """
- Plot metric-confidence curve.
-
- Args:
- px (np.ndarray): X values for the metric-confidence curve.
- py (np.ndarray): Y values for the metric-confidence curve.
- save_dir (Path, optional): Path to save the plot.
- names (dict[int, str], optional): Dictionary mapping class indices to class names.
- xlabel (str, optional): X-axis label.
- ylabel (str, optional): Y-axis label.
- on_plot (callable, optional): Function to call after plot is saved.
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
-
- if 0 < len(names) < 21: # display per-class legend if < 21 classes
- for i, y in enumerate(py):
- ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
- else:
- ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
-
- y = smooth(py.mean(0), 0.1)
- ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.2f} at {px[y.argmax()]:.3f}")
- ax.set_xlabel(xlabel)
- ax.set_ylabel(ylabel)
- ax.set_xlim(0, 1)
- ax.set_ylim(0, 1)
- ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
- ax.set_title(f"{ylabel}-Confidence Curve")
- fig.savefig(save_dir, dpi=250)
- plt.close(fig)
- if on_plot:
- on_plot(save_dir)
-
-
-def compute_ap(recall: list[float], precision: list[float]) -> tuple[float, np.ndarray, np.ndarray]:
- """
- Compute the average precision (AP) given the recall and precision curves.
-
- Args:
- recall (list): The recall curve.
- precision (list): The precision curve.
-
- Returns:
- ap (float): Average precision.
- mpre (np.ndarray): Precision envelope curve.
- mrec (np.ndarray): Modified recall curve with sentinel values added at the beginning and end.
- """
- # Append sentinel values to beginning and end
- mrec = np.concatenate(([0.0], recall, [1.0]))
- mpre = np.concatenate(([1.0], precision, [0.0]))
-
- # Compute the precision envelope
- mpre = np.flip(np.maximum.accumulate(np.flip(mpre)))
-
- # Integrate area under curve
- method = "interp" # methods: 'continuous', 'interp'
- if method == "interp":
- x = np.linspace(0, 1, 101) # 101-point interp (COCO)
- func = np.trapezoid if checks.check_version(np.__version__, ">=2.0") else np.trapz # np.trapz deprecated
- ap = func(np.interp(x, mrec, mpre), x) # integrate
- else: # 'continuous'
- i = np.where(mrec[1:] != mrec[:-1])[0] # points where x-axis (recall) changes
- ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve
-
- return ap, mpre, mrec
-
-
-def ap_per_class(
- tp: np.ndarray,
- conf: np.ndarray,
- pred_cls: np.ndarray,
- target_cls: np.ndarray,
- plot: bool = False,
- on_plot=None,
- save_dir: Path = Path(),
- names: dict[int, str] = {},
- eps: float = 1e-16,
- prefix: str = "",
-) -> tuple:
- """
- Compute the average precision per class for object detection evaluation.
-
- Args:
- tp (np.ndarray): Binary array indicating whether the detection is correct (True) or not (False).
- conf (np.ndarray): Array of confidence scores of the detections.
- pred_cls (np.ndarray): Array of predicted classes of the detections.
- target_cls (np.ndarray): Array of true classes of the detections.
- plot (bool, optional): Whether to plot PR curves or not.
- on_plot (callable, optional): A callback to pass plots path and data when they are rendered.
- save_dir (Path, optional): Directory to save the PR curves.
- names (dict[int, str], optional): Dictionary of class names to plot PR curves.
- eps (float, optional): A small value to avoid division by zero.
- prefix (str, optional): A prefix string for saving the plot files.
-
- Returns:
- tp (np.ndarray): True positive counts at threshold given by max F1 metric for each class.
- fp (np.ndarray): False positive counts at threshold given by max F1 metric for each class.
- p (np.ndarray): Precision values at threshold given by max F1 metric for each class.
- r (np.ndarray): Recall values at threshold given by max F1 metric for each class.
- f1 (np.ndarray): F1-score values at threshold given by max F1 metric for each class.
- ap (np.ndarray): Average precision for each class at different IoU thresholds.
- unique_classes (np.ndarray): An array of unique classes that have data.
- p_curve (np.ndarray): Precision curves for each class.
- r_curve (np.ndarray): Recall curves for each class.
- f1_curve (np.ndarray): F1-score curves for each class.
- x (np.ndarray): X-axis values for the curves.
- prec_values (np.ndarray): Precision values at mAP@0.5 for each class.
- """
- # Sort by objectness
- i = np.argsort(-conf)
- tp, conf, pred_cls = tp[i], conf[i], pred_cls[i]
-
- # Find unique classes
- unique_classes, nt = np.unique(target_cls, return_counts=True)
- nc = unique_classes.shape[0] # number of classes, number of detections
-
- # Create Precision-Recall curve and compute AP for each class
- x, prec_values = np.linspace(0, 1, 1000), []
-
- # Average precision, precision and recall curves
- ap, p_curve, r_curve = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000))
- for ci, c in enumerate(unique_classes):
- i = pred_cls == c
- n_l = nt[ci] # number of labels
- n_p = i.sum() # number of predictions
- if n_p == 0 or n_l == 0:
- continue
-
- # Accumulate FPs and TPs
- fpc = (1 - tp[i]).cumsum(0)
- tpc = tp[i].cumsum(0)
-
- # Recall
- recall = tpc / (n_l + eps) # recall curve
- r_curve[ci] = np.interp(-x, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases
-
- # Precision
- precision = tpc / (tpc + fpc) # precision curve
- p_curve[ci] = np.interp(-x, -conf[i], precision[:, 0], left=1) # p at pr_score
-
- # AP from recall-precision curve
- for j in range(tp.shape[1]):
- ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j])
- if j == 0:
- prec_values.append(np.interp(x, mrec, mpre)) # precision at mAP@0.5
-
- prec_values = np.array(prec_values) if prec_values else np.zeros((1, 1000)) # (nc, 1000)
-
- # Compute F1 (harmonic mean of precision and recall)
- f1_curve = 2 * p_curve * r_curve / (p_curve + r_curve + eps)
- names = {i: names[k] for i, k in enumerate(unique_classes) if k in names} # dict: only classes that have data
- if plot:
- plot_pr_curve(x, prec_values, ap, save_dir / f"{prefix}PR_curve.png", names, on_plot=on_plot)
- plot_mc_curve(x, f1_curve, save_dir / f"{prefix}F1_curve.png", names, ylabel="F1", on_plot=on_plot)
- plot_mc_curve(x, p_curve, save_dir / f"{prefix}P_curve.png", names, ylabel="Precision", on_plot=on_plot)
- plot_mc_curve(x, r_curve, save_dir / f"{prefix}R_curve.png", names, ylabel="Recall", on_plot=on_plot)
-
- i = smooth(f1_curve.mean(0), 0.1).argmax() # max F1 index
- p, r, f1 = p_curve[:, i], r_curve[:, i], f1_curve[:, i] # max-F1 precision, recall, F1 values
- tp = (r * nt).round() # true positives
- fp = (tp / (p + eps) - tp).round() # false positives
- return tp, fp, p, r, f1, ap, unique_classes.astype(int), p_curve, r_curve, f1_curve, x, prec_values
-
-
-class Metric(SimpleClass):
- """
- Class for computing evaluation metrics for Ultralytics YOLO models.
-
- Attributes:
- p (list): Precision for each class. Shape: (nc,).
- r (list): Recall for each class. Shape: (nc,).
- f1 (list): F1 score for each class. Shape: (nc,).
- all_ap (list): AP scores for all classes and all IoU thresholds. Shape: (nc, 10).
- ap_class_index (list): Index of class for each AP score. Shape: (nc,).
- nc (int): Number of classes.
-
- Methods:
- ap50: AP at IoU threshold of 0.5 for all classes.
- ap: AP at IoU thresholds from 0.5 to 0.95 for all classes.
- mp: Mean precision of all classes.
- mr: Mean recall of all classes.
- map50: Mean AP at IoU threshold of 0.5 for all classes.
- map75: Mean AP at IoU threshold of 0.75 for all classes.
- map: Mean AP at IoU thresholds from 0.5 to 0.95 for all classes.
- mean_results: Mean of results, returns mp, mr, map50, map.
- class_result: Class-aware result, returns p[i], r[i], ap50[i], ap[i].
- maps: mAP of each class.
- fitness: Model fitness as a weighted combination of metrics.
- update: Update metric attributes with new evaluation results.
- curves: Provides a list of curves for accessing specific metrics like precision, recall, F1, etc.
- curves_results: Provide a list of results for accessing specific metrics like precision, recall, F1, etc.
- """
-
- def __init__(self) -> None:
- """Initialize a Metric instance for computing evaluation metrics for the YOLOv8 model."""
- self.p = [] # (nc, )
- self.r = [] # (nc, )
- self.f1 = [] # (nc, )
- self.all_ap = [] # (nc, 10)
- self.ap_class_index = [] # (nc, )
- self.nc = 0
-
- @property
- def ap50(self) -> np.ndarray | list:
- """
- Return the Average Precision (AP) at an IoU threshold of 0.5 for all classes.
-
- Returns:
- (np.ndarray | list): Array of shape (nc,) with AP50 values per class, or an empty list if not available.
- """
- return self.all_ap[:, 0] if len(self.all_ap) else []
-
- @property
- def ap(self) -> np.ndarray | list:
- """
- Return the Average Precision (AP) at an IoU threshold of 0.5-0.95 for all classes.
-
- Returns:
- (np.ndarray | list): Array of shape (nc,) with AP50-95 values per class, or an empty list if not available.
- """
- return self.all_ap.mean(1) if len(self.all_ap) else []
-
- @property
- def mp(self) -> float:
- """
- Return the Mean Precision of all classes.
-
- Returns:
- (float): The mean precision of all classes.
- """
- return self.p.mean() if len(self.p) else 0.0
-
- @property
- def mr(self) -> float:
- """
- Return the Mean Recall of all classes.
-
- Returns:
- (float): The mean recall of all classes.
- """
- return self.r.mean() if len(self.r) else 0.0
-
- @property
- def map50(self) -> float:
- """
- Return the mean Average Precision (mAP) at an IoU threshold of 0.5.
-
- Returns:
- (float): The mAP at an IoU threshold of 0.5.
- """
- return self.all_ap[:, 0].mean() if len(self.all_ap) else 0.0
-
- @property
- def map75(self) -> float:
- """
- Return the mean Average Precision (mAP) at an IoU threshold of 0.75.
-
- Returns:
- (float): The mAP at an IoU threshold of 0.75.
- """
- return self.all_ap[:, 5].mean() if len(self.all_ap) else 0.0
-
- @property
- def map(self) -> float:
- """
- Return the mean Average Precision (mAP) over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
-
- Returns:
- (float): The mAP over IoU thresholds of 0.5 - 0.95 in steps of 0.05.
- """
- return self.all_ap.mean() if len(self.all_ap) else 0.0
-
- def mean_results(self) -> list[float]:
- """Return mean of results, mp, mr, map50, map."""
- return [self.mp, self.mr, self.map50, self.map]
-
- def class_result(self, i: int) -> tuple[float, float, float, float]:
- """Return class-aware result, p[i], r[i], ap50[i], ap[i]."""
- return self.p[i], self.r[i], self.ap50[i], self.ap[i]
-
- @property
- def maps(self) -> np.ndarray:
- """Return mAP of each class."""
- maps = np.zeros(self.nc) + self.map
- for i, c in enumerate(self.ap_class_index):
- maps[c] = self.ap[i]
- return maps
-
- def fitness(self) -> float:
- """Return model fitness as a weighted combination of metrics."""
- w = [0.0, 0.0, 0.0, 1.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
- return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
-
- def update(self, results: tuple):
- """
- Update the evaluation metrics with a new set of results.
-
- Args:
- results (tuple): A tuple containing evaluation metrics:
- - p (list): Precision for each class.
- - r (list): Recall for each class.
- - f1 (list): F1 score for each class.
- - all_ap (list): AP scores for all classes and all IoU thresholds.
- - ap_class_index (list): Index of class for each AP score.
- - p_curve (list): Precision curve for each class.
- - r_curve (list): Recall curve for each class.
- - f1_curve (list): F1 curve for each class.
- - px (list): X values for the curves.
- - prec_values (list): Precision values for each class.
- """
- (
- self.p,
- self.r,
- self.f1,
- self.all_ap,
- self.ap_class_index,
- self.p_curve,
- self.r_curve,
- self.f1_curve,
- self.px,
- self.prec_values,
- ) = results
-
- @property
- def curves(self) -> list:
- """Return a list of curves for accessing specific metrics curves."""
- return []
-
- @property
- def curves_results(self) -> list[list]:
- """Return a list of curves for accessing specific metrics curves."""
- return [
- [self.px, self.prec_values, "Recall", "Precision"],
- [self.px, self.f1_curve, "Confidence", "F1"],
- [self.px, self.p_curve, "Confidence", "Precision"],
- [self.px, self.r_curve, "Confidence", "Recall"],
- ]
-
-
-class DetMetrics(SimpleClass, DataExportMixin):
- """
- Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP).
-
- Attributes:
- names (dict[int, str]): A dictionary of class names.
- box (Metric): An instance of the Metric class for storing detection results.
- speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
- task (str): The task type, set to 'detect'.
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
- nt_per_class: Number of targets per class.
- nt_per_image: Number of targets per image.
-
- Methods:
- update_stats: Update statistics by appending new values to existing stat collections.
- process: Process predicted results for object detection and update metrics.
- clear_stats: Clear the stored statistics.
- keys: Return a list of keys for accessing specific metrics.
- mean_results: Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95.
- class_result: Return the result of evaluating the performance of an object detection model on a specific class.
- maps: Return mean Average Precision (mAP) scores per class.
- fitness: Return the fitness of box object.
- ap_class_index: Return the average precision index per class.
- results_dict: Return dictionary of computed performance metrics and statistics.
- curves: Return a list of curves for accessing specific metrics curves.
- curves_results: Return a list of computed performance metrics and statistics.
- summary: Generate a summarized representation of per-class detection metrics as a list of dictionaries.
- """
-
- def __init__(self, names: dict[int, str] = {}) -> None:
- """
- Initialize a DetMetrics instance with a save directory, plot flag, and class names.
-
- Args:
- names (dict[int, str], optional): Dictionary of class names.
- """
- self.names = names
- self.box = Metric()
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
- self.task = "detect"
- self.stats = dict(tp=[], conf=[], pred_cls=[], target_cls=[], target_img=[])
- self.nt_per_class = None
- self.nt_per_image = None
-
- def update_stats(self, stat: dict[str, Any]) -> None:
- """
- Update statistics by appending new values to existing stat collections.
-
- Args:
- stat (dict[str, any]): Dictionary containing new statistical values to append.
- Keys should match existing keys in self.stats.
- """
- for k in self.stats.keys():
- self.stats[k].append(stat[k])
-
- def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
- """
- Process predicted results for object detection and update metrics.
-
- Args:
- save_dir (Path): Directory to save plots. Defaults to Path(".").
- plot (bool): Whether to plot precision-recall curves. Defaults to False.
- on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
-
- Returns:
- (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
- """
- stats = {k: np.concatenate(v, 0) for k, v in self.stats.items()} # to numpy
- if not stats:
- return stats
- results = ap_per_class(
- stats["tp"],
- stats["conf"],
- stats["pred_cls"],
- stats["target_cls"],
- plot=plot,
- save_dir=save_dir,
- names=self.names,
- on_plot=on_plot,
- prefix="Box",
- )[2:]
- self.box.nc = len(self.names)
- self.box.update(results)
- self.nt_per_class = np.bincount(stats["target_cls"].astype(int), minlength=len(self.names))
- self.nt_per_image = np.bincount(stats["target_img"].astype(int), minlength=len(self.names))
- return stats
-
- def clear_stats(self):
- """Clear the stored statistics."""
- for v in self.stats.values():
- v.clear()
-
- @property
- def keys(self) -> list[str]:
- """Return a list of keys for accessing specific metrics."""
- return ["metrics/precision(B)", "metrics/recall(B)", "metrics/mAP50(B)", "metrics/mAP50-95(B)"]
-
- def mean_results(self) -> list[float]:
- """Calculate mean of detected objects & return precision, recall, mAP50, and mAP50-95."""
- return self.box.mean_results()
-
- def class_result(self, i: int) -> tuple[float, float, float, float]:
- """Return the result of evaluating the performance of an object detection model on a specific class."""
- return self.box.class_result(i)
-
- @property
- def maps(self) -> np.ndarray:
- """Return mean Average Precision (mAP) scores per class."""
- return self.box.maps
-
- @property
- def fitness(self) -> float:
- """Return the fitness of box object."""
- return self.box.fitness()
-
- @property
- def ap_class_index(self) -> list:
- """Return the average precision index per class."""
- return self.box.ap_class_index
-
- @property
- def results_dict(self) -> dict[str, float]:
- """Return dictionary of computed performance metrics and statistics."""
- keys = self.keys + ["fitness"]
- values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results() + [self.fitness]))
- return dict(zip(keys, values))
-
- @property
- def curves(self) -> list[str]:
- """Return a list of curves for accessing specific metrics curves."""
- return ["Precision-Recall(B)", "F1-Confidence(B)", "Precision-Confidence(B)", "Recall-Confidence(B)"]
-
- @property
- def curves_results(self) -> list[list]:
- """Return a list of computed performance metrics and statistics."""
- return self.box.curves_results
-
- def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
- """
- Generate a summarized representation of per-class detection metrics as a list of dictionaries. Includes shared
- scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
-
- Args:
- normalize (bool): For Detect metrics, everything is normalized by default [0-1].
- decimals (int): Number of decimal places to round the metrics values to.
-
- Returns:
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
-
- Examples:
- >>> results = model.val(data="coco8.yaml")
- >>> detection_summary = results.summary()
- >>> print(detection_summary)
- """
- per_class = {
- "Box-P": self.box.p,
- "Box-R": self.box.r,
- "Box-F1": self.box.f1,
- }
- return [
- {
- "Class": self.names[self.ap_class_index[i]],
- "Images": self.nt_per_image[self.ap_class_index[i]],
- "Instances": self.nt_per_class[self.ap_class_index[i]],
- **{k: round(v[i], decimals) for k, v in per_class.items()},
- "mAP50": round(self.class_result(i)[2], decimals),
- "mAP50-95": round(self.class_result(i)[3], decimals),
- }
- for i in range(len(per_class["Box-P"]))
- ]
-
-
-class SegmentMetrics(DetMetrics):
- """
- Calculate and aggregate detection and segmentation metrics over a given set of classes.
-
- Attributes:
- names (dict[int, str]): Dictionary of class names.
- box (Metric): An instance of the Metric class for storing detection results.
- seg (Metric): An instance of the Metric class to calculate mask segmentation metrics.
- speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
- task (str): The task type, set to 'segment'.
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
- nt_per_class: Number of targets per class.
- nt_per_image: Number of targets per image.
-
- Methods:
- process: Process the detection and segmentation metrics over the given set of predictions.
- keys: Return a list of keys for accessing metrics.
- mean_results: Return the mean metrics for bounding box and segmentation results.
- class_result: Return classification results for a specified class index.
- maps: Return mAP scores for object detection and semantic segmentation models.
- fitness: Return the fitness score for both segmentation and bounding box models.
- curves: Return a list of curves for accessing specific metrics curves.
- curves_results: Provide a list of computed performance metrics and statistics.
- summary: Generate a summarized representation of per-class segmentation metrics as a list of dictionaries.
- """
-
- def __init__(self, names: dict[int, str] = {}) -> None:
- """
- Initialize a SegmentMetrics instance with a save directory, plot flag, and class names.
-
- Args:
- names (dict[int, str], optional): Dictionary of class names.
- """
- DetMetrics.__init__(self, names)
- self.seg = Metric()
- self.task = "segment"
- self.stats["tp_m"] = [] # add additional stats for masks
-
- def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
- """
- Process the detection and segmentation metrics over the given set of predictions.
-
- Args:
- save_dir (Path): Directory to save plots. Defaults to Path(".").
- plot (bool): Whether to plot precision-recall curves. Defaults to False.
- on_plot (callable, optional): Function to call after plots are generated. Defaults to None.
-
- Returns:
- (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
- """
- stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
- results_mask = ap_per_class(
- stats["tp_m"],
- stats["conf"],
- stats["pred_cls"],
- stats["target_cls"],
- plot=plot,
- on_plot=on_plot,
- save_dir=save_dir,
- names=self.names,
- prefix="Mask",
- )[2:]
- self.seg.nc = len(self.names)
- self.seg.update(results_mask)
- return stats
-
- @property
- def keys(self) -> list[str]:
- """Return a list of keys for accessing metrics."""
- return DetMetrics.keys.fget(self) + [
- "metrics/precision(M)",
- "metrics/recall(M)",
- "metrics/mAP50(M)",
- "metrics/mAP50-95(M)",
- ]
-
- def mean_results(self) -> list[float]:
- """Return the mean metrics for bounding box and segmentation results."""
- return DetMetrics.mean_results(self) + self.seg.mean_results()
-
- def class_result(self, i: int) -> list[float]:
- """Return classification results for a specified class index."""
- return DetMetrics.class_result(self, i) + self.seg.class_result(i)
-
- @property
- def maps(self) -> np.ndarray:
- """Return mAP scores for object detection and semantic segmentation models."""
- return DetMetrics.maps.fget(self) + self.seg.maps
-
- @property
- def fitness(self) -> float:
- """Return the fitness score for both segmentation and bounding box models."""
- return self.seg.fitness() + DetMetrics.fitness.fget(self)
-
- @property
- def curves(self) -> list[str]:
- """Return a list of curves for accessing specific metrics curves."""
- return DetMetrics.curves.fget(self) + [
- "Precision-Recall(M)",
- "F1-Confidence(M)",
- "Precision-Confidence(M)",
- "Recall-Confidence(M)",
- ]
-
- @property
- def curves_results(self) -> list[list]:
- """Return a list of computed performance metrics and statistics."""
- return DetMetrics.curves_results.fget(self) + self.seg.curves_results
-
- def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
- """
- Generate a summarized representation of per-class segmentation metrics as a list of dictionaries. Includes both
- box and mask scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
-
- Args:
- normalize (bool): For Segment metrics, everything is normalized by default [0-1].
- decimals (int): Number of decimal places to round the metrics values to.
-
- Returns:
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
-
- Examples:
- >>> results = model.val(data="coco8-seg.yaml")
- >>> seg_summary = results.summary(decimals=4)
- >>> print(seg_summary)
- """
- per_class = {
- "Mask-P": self.seg.p,
- "Mask-R": self.seg.r,
- "Mask-F1": self.seg.f1,
- }
- summary = DetMetrics.summary(self, normalize, decimals) # get box summary
- for i, s in enumerate(summary):
- s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
- return summary
-
-
-class PoseMetrics(DetMetrics):
- """
- Calculate and aggregate detection and pose metrics over a given set of classes.
-
- Attributes:
- names (dict[int, str]): Dictionary of class names.
- pose (Metric): An instance of the Metric class to calculate pose metrics.
- box (Metric): An instance of the Metric class for storing detection results.
- speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
- task (str): The task type, set to 'pose'.
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
- nt_per_class: Number of targets per class.
- nt_per_image: Number of targets per image.
-
- Methods:
- process: Process the detection and pose metrics over the given set of predictions. R
- keys: Return a list of keys for accessing metrics.
- mean_results: Return the mean results of box and pose.
- class_result: Return the class-wise detection results for a specific class i.
- maps: Return the mean average precision (mAP) per class for both box and pose detections.
- fitness: Return combined fitness score for pose and box detection.
- curves: Return a list of curves for accessing specific metrics curves.
- curves_results: Provide a list of computed performance metrics and statistics.
- summary: Generate a summarized representation of per-class pose metrics as a list of dictionaries.
- """
-
- def __init__(self, names: dict[int, str] = {}) -> None:
- """
- Initialize the PoseMetrics class with directory path, class names, and plotting options.
-
- Args:
- names (dict[int, str], optional): Dictionary of class names.
- """
- super().__init__(names)
- self.pose = Metric()
- self.task = "pose"
- self.stats["tp_p"] = [] # add additional stats for pose
-
- def process(self, save_dir: Path = Path("."), plot: bool = False, on_plot=None) -> dict[str, np.ndarray]:
- """
- Process the detection and pose metrics over the given set of predictions.
-
- Args:
- save_dir (Path): Directory to save plots. Defaults to Path(".").
- plot (bool): Whether to plot precision-recall curves. Defaults to False.
- on_plot (callable, optional): Function to call after plots are generated.
-
- Returns:
- (dict[str, np.ndarray]): Dictionary containing concatenated statistics arrays.
- """
- stats = DetMetrics.process(self, save_dir, plot, on_plot=on_plot) # process box stats
- results_pose = ap_per_class(
- stats["tp_p"],
- stats["conf"],
- stats["pred_cls"],
- stats["target_cls"],
- plot=plot,
- on_plot=on_plot,
- save_dir=save_dir,
- names=self.names,
- prefix="Pose",
- )[2:]
- self.pose.nc = len(self.names)
- self.pose.update(results_pose)
- return stats
-
- @property
- def keys(self) -> list[str]:
- """Return a list of evaluation metric keys."""
- return DetMetrics.keys.fget(self) + [
- "metrics/precision(P)",
- "metrics/recall(P)",
- "metrics/mAP50(P)",
- "metrics/mAP50-95(P)",
- ]
-
- def mean_results(self) -> list[float]:
- """Return the mean results of box and pose."""
- return DetMetrics.mean_results(self) + self.pose.mean_results()
-
- def class_result(self, i: int) -> list[float]:
- """Return the class-wise detection results for a specific class i."""
- return DetMetrics.class_result(self, i) + self.pose.class_result(i)
-
- @property
- def maps(self) -> np.ndarray:
- """Return the mean average precision (mAP) per class for both box and pose detections."""
- return DetMetrics.maps.fget(self) + self.pose.maps
-
- @property
- def fitness(self) -> float:
- """Return combined fitness score for pose and box detection."""
- return self.pose.fitness() + DetMetrics.fitness.fget(self)
-
- @property
- def curves(self) -> list[str]:
- """Return a list of curves for accessing specific metrics curves."""
- return DetMetrics.curves.fget(self) + [
- "Precision-Recall(B)",
- "F1-Confidence(B)",
- "Precision-Confidence(B)",
- "Recall-Confidence(B)",
- "Precision-Recall(P)",
- "F1-Confidence(P)",
- "Precision-Confidence(P)",
- "Recall-Confidence(P)",
- ]
-
- @property
- def curves_results(self) -> list[list]:
- """Return a list of computed performance metrics and statistics."""
- return DetMetrics.curves_results.fget(self) + self.pose.curves_results
-
- def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, Any]]:
- """
- Generate a summarized representation of per-class pose metrics as a list of dictionaries. Includes both box and
- pose scalar metrics (mAP, mAP50, mAP75) alongside precision, recall, and F1-score for each class.
-
- Args:
- normalize (bool): For Pose metrics, everything is normalized by default [0-1].
- decimals (int): Number of decimal places to round the metrics values to.
-
- Returns:
- (list[dict[str, Any]]): A list of dictionaries, each representing one class with corresponding metric values.
-
- Examples:
- >>> results = model.val(data="coco8-pose.yaml")
- >>> pose_summary = results.summary(decimals=4)
- >>> print(pose_summary)
- """
- per_class = {
- "Pose-P": self.pose.p,
- "Pose-R": self.pose.r,
- "Pose-F1": self.pose.f1,
- }
- summary = DetMetrics.summary(self, normalize, decimals) # get box summary
- for i, s in enumerate(summary):
- s.update({**{k: round(v[i], decimals) for k, v in per_class.items()}})
- return summary
-
-
-class ClassifyMetrics(SimpleClass, DataExportMixin):
- """
- Class for computing classification metrics including top-1 and top-5 accuracy.
-
- Attributes:
- top1 (float): The top-1 accuracy.
- top5 (float): The top-5 accuracy.
- speed (dict): A dictionary containing the time taken for each step in the pipeline.
- task (str): The task type, set to 'classify'.
-
- Methods:
- process: Process target classes and predicted classes to compute metrics.
- fitness: Return mean of top-1 and top-5 accuracies as fitness score.
- results_dict: Return a dictionary with model's performance metrics and fitness score.
- keys: Return a list of keys for the results_dict property.
- curves: Return a list of curves for accessing specific metrics curves.
- curves_results: Provide a list of computed performance metrics and statistics.
- summary: Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
- """
-
- def __init__(self) -> None:
- """Initialize a ClassifyMetrics instance."""
- self.top1 = 0
- self.top5 = 0
- self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
- self.task = "classify"
-
- def process(self, targets: torch.Tensor, pred: torch.Tensor):
- """
- Process target classes and predicted classes to compute metrics.
-
- Args:
- targets (torch.Tensor): Target classes.
- pred (torch.Tensor): Predicted classes.
- """
- pred, targets = torch.cat(pred), torch.cat(targets)
- correct = (targets[:, None] == pred).float()
- acc = torch.stack((correct[:, 0], correct.max(1).values), dim=1) # (top1, top5) accuracy
- self.top1, self.top5 = acc.mean(0).tolist()
-
- @property
- def fitness(self) -> float:
- """Return mean of top-1 and top-5 accuracies as fitness score."""
- return (self.top1 + self.top5) / 2
-
- @property
- def results_dict(self) -> dict[str, float]:
- """Return a dictionary with model's performance metrics and fitness score."""
- return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
-
- @property
- def keys(self) -> list[str]:
- """Return a list of keys for the results_dict property."""
- return ["metrics/accuracy_top1", "metrics/accuracy_top5"]
-
- @property
- def curves(self) -> list:
- """Return a list of curves for accessing specific metrics curves."""
- return []
-
- @property
- def curves_results(self) -> list:
- """Return a list of curves for accessing specific metrics curves."""
- return []
-
- def summary(self, normalize: bool = True, decimals: int = 5) -> list[dict[str, float]]:
- """
- Generate a single-row summary of classification metrics (Top-1 and Top-5 accuracy).
-
- Args:
- normalize (bool): For Classify metrics, everything is normalized by default [0-1].
- decimals (int): Number of decimal places to round the metrics values to.
-
- Returns:
- (list[dict[str, float]]): A list with one dictionary containing Top-1 and Top-5 classification accuracy.
-
- Examples:
- >>> results = model.val(data="imagenet10")
- >>> classify_summary = results.summary(decimals=4)
- >>> print(classify_summary)
- """
- return [{"top1_acc": round(self.top1, decimals), "top5_acc": round(self.top5, decimals)}]
-
-
-class OBBMetrics(DetMetrics):
- """
- Metrics for evaluating oriented bounding box (OBB) detection.
-
- Attributes:
- names (dict[int, str]): Dictionary of class names.
- box (Metric): An instance of the Metric class for storing detection results.
- speed (dict[str, float]): A dictionary for storing execution times of different parts of the detection process.
- task (str): The task type, set to 'obb'.
- stats (dict[str, list]): A dictionary containing lists for true positives, confidence scores, predicted classes, target classes, and target images.
- nt_per_class: Number of targets per class.
- nt_per_image: Number of targets per image.
-
- References:
- https://arxiv.org/pdf/2106.06072.pdf
- """
-
- def __init__(self, names: dict[int, str] = {}) -> None:
- """
- Initialize an OBBMetrics instance with directory, plotting, and class names.
-
- Args:
- names (dict[int, str], optional): Dictionary of class names.
- """
- DetMetrics.__init__(self, names)
- # TODO: probably remove task as well
- self.task = "obb"
diff --git a/ultralytics/utils/nms.py b/ultralytics/utils/nms.py
deleted file mode 100644
index b638640..0000000
--- a/ultralytics/utils/nms.py
+++ /dev/null
@@ -1,340 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import sys
-import time
-
-import torch
-
-from ultralytics.utils import LOGGER
-from ultralytics.utils.metrics import batch_probiou, box_iou
-from ultralytics.utils.ops import xywh2xyxy
-
-
-def non_max_suppression(
- prediction,
- conf_thres: float = 0.25,
- iou_thres: float = 0.45,
- classes=None,
- agnostic: bool = False,
- multi_label: bool = False,
- labels=(),
- max_det: int = 300,
- nc: int = 0, # number of classes (optional)
- max_time_img: float = 0.05,
- max_nms: int = 30000,
- max_wh: int = 7680,
- rotated: bool = False,
- end2end: bool = False,
- return_idxs: bool = False,
-):
- """
- Perform non-maximum suppression (NMS) on prediction results.
-
- Applies NMS to filter overlapping bounding boxes based on confidence and IoU thresholds. Supports multiple
- detection formats including standard boxes, rotated boxes, and masks.
-
- Args:
- prediction (torch.Tensor): Predictions with shape (batch_size, num_classes + 4 + num_masks, num_boxes)
- containing boxes, classes, and optional masks.
- conf_thres (float): Confidence threshold for filtering detections. Valid values are between 0.0 and 1.0.
- iou_thres (float): IoU threshold for NMS filtering. Valid values are between 0.0 and 1.0.
- classes (list[int], optional): List of class indices to consider. If None, all classes are considered.
- agnostic (bool): Whether to perform class-agnostic NMS.
- multi_label (bool): Whether each box can have multiple labels.
- labels (list[list[Union[int, float, torch.Tensor]]]): A priori labels for each image.
- max_det (int): Maximum number of detections to keep per image.
- nc (int): Number of classes. Indices after this are considered masks.
- max_time_img (float): Maximum time in seconds for processing one image.
- max_nms (int): Maximum number of boxes for NMS.
- max_wh (int): Maximum box width and height in pixels.
- rotated (bool): Whether to handle Oriented Bounding Boxes (OBB).
- end2end (bool): Whether the model is end-to-end and doesn't require NMS.
- return_idxs (bool): Whether to return the indices of kept detections.
-
- Returns:
- output (list[torch.Tensor]): List of detections per image with shape (num_boxes, 6 + num_masks)
- containing (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
- keepi (list[torch.Tensor]): Indices of kept detections if return_idxs=True.
- """
- # Checks
- assert 0 <= conf_thres <= 1, f"Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0"
- assert 0 <= iou_thres <= 1, f"Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0"
- if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
- prediction = prediction[0] # select only inference output
- if classes is not None:
- classes = torch.tensor(classes, device=prediction.device)
-
- if prediction.shape[-1] == 6 or end2end: # end-to-end model (BNC, i.e. 1,300,6)
- output = [pred[pred[:, 4] > conf_thres][:max_det] for pred in prediction]
- if classes is not None:
- output = [pred[(pred[:, 5:6] == classes).any(1)] for pred in output]
- return output
-
- bs = prediction.shape[0] # batch size (BCN, i.e. 1,84,6300)
- nc = nc or (prediction.shape[1] - 4) # number of classes
- extra = prediction.shape[1] - nc - 4 # number of extra info
- mi = 4 + nc # mask start index
- xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
- xinds = torch.arange(prediction.shape[-1], device=prediction.device).expand(bs, -1)[..., None] # to track idxs
-
- # Settings
- # min_wh = 2 # (pixels) minimum box width and height
- time_limit = 2.0 + max_time_img * bs # seconds to quit after
- multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
-
- prediction = prediction.transpose(-1, -2) # shape(1,84,6300) to shape(1,6300,84)
- if not rotated:
- prediction[..., :4] = xywh2xyxy(prediction[..., :4]) # xywh to xyxy
-
- t = time.time()
- output = [torch.zeros((0, 6 + extra), device=prediction.device)] * bs
- keepi = [torch.zeros((0, 1), device=prediction.device)] * bs # to store the kept idxs
- for xi, (x, xk) in enumerate(zip(prediction, xinds)): # image index, (preds, preds indices)
- # Apply constraints
- # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
- filt = xc[xi] # confidence
- x = x[filt]
- if return_idxs:
- xk = xk[filt]
-
- # Cat apriori labels if autolabelling
- if labels and len(labels[xi]) and not rotated:
- lb = labels[xi]
- v = torch.zeros((len(lb), nc + extra + 4), device=x.device)
- v[:, :4] = xywh2xyxy(lb[:, 1:5]) # box
- v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
- x = torch.cat((x, v), 0)
-
- # If none remain process next image
- if not x.shape[0]:
- continue
-
- # Detections matrix nx6 (xyxy, conf, cls)
- box, cls, mask = x.split((4, nc, extra), 1)
-
- if multi_label:
- i, j = torch.where(cls > conf_thres)
- x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
- if return_idxs:
- xk = xk[i]
- else: # best class only
- conf, j = cls.max(1, keepdim=True)
- filt = conf.view(-1) > conf_thres
- x = torch.cat((box, conf, j.float(), mask), 1)[filt]
- if return_idxs:
- xk = xk[filt]
-
- # Filter by class
- if classes is not None:
- filt = (x[:, 5:6] == classes).any(1)
- x = x[filt]
- if return_idxs:
- xk = xk[filt]
-
- # Check shape
- n = x.shape[0] # number of boxes
- if not n: # no boxes
- continue
- if n > max_nms: # excess boxes
- filt = x[:, 4].argsort(descending=True)[:max_nms] # sort by confidence and remove excess boxes
- x = x[filt]
- if return_idxs:
- xk = xk[filt]
-
- c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
- scores = x[:, 4] # scores
- if rotated:
- boxes = torch.cat((x[:, :2] + c, x[:, 2:4], x[:, -1:]), dim=-1) # xywhr
- i = TorchNMS.fast_nms(boxes, scores, iou_thres, iou_func=batch_probiou)
- else:
- boxes = x[:, :4] + c # boxes (offset by class)
- # Speed strategy: torchvision for val or already loaded (faster), TorchNMS for predict (lower latency)
- if "torchvision" in sys.modules:
- import torchvision # scope as slow import
-
- i = torchvision.ops.nms(boxes, scores, iou_thres)
- else:
- i = TorchNMS.nms(boxes, scores, iou_thres)
- i = i[:max_det] # limit detections
-
- output[xi] = x[i]
- if return_idxs:
- keepi[xi] = xk[i].view(-1)
- if (time.time() - t) > time_limit:
- LOGGER.warning(f"NMS time limit {time_limit:.3f}s exceeded")
- break # time limit exceeded
-
- return (output, keepi) if return_idxs else output
-
-
-class TorchNMS:
- """
- Ultralytics custom NMS implementation optimized for YOLO.
-
- This class provides static methods for performing non-maximum suppression (NMS) operations on bounding boxes,
- including both standard NMS and batched NMS for multi-class scenarios.
-
- Methods:
- nms: Optimized NMS with early termination that matches torchvision behavior exactly.
- batched_nms: Batched NMS for class-aware suppression.
-
- Examples:
- Perform standard NMS on boxes and scores
- >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
- >>> scores = torch.tensor([0.9, 0.8])
- >>> keep = TorchNMS.nms(boxes, scores, 0.5)
- """
-
- @staticmethod
- def fast_nms(
- boxes: torch.Tensor,
- scores: torch.Tensor,
- iou_threshold: float,
- use_triu: bool = True,
- iou_func=box_iou,
- exit_early: bool = True,
- ) -> torch.Tensor:
- """
- Fast-NMS implementation from https://arxiv.org/pdf/1904.02689 using upper triangular matrix operations.
-
- Args:
- boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
- scores (torch.Tensor): Confidence scores with shape (N,).
- iou_threshold (float): IoU threshold for suppression.
- use_triu (bool): Whether to use torch.triu operator for upper triangular matrix operations.
- iou_func (callable): Function to compute IoU between boxes.
- exit_early (bool): Whether to exit early if there are no boxes.
-
- Returns:
- (torch.Tensor): Indices of boxes to keep after NMS.
-
- Examples:
- Apply NMS to a set of boxes
- >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
- >>> scores = torch.tensor([0.9, 0.8])
- >>> keep = TorchNMS.nms(boxes, scores, 0.5)
- """
- if boxes.numel() == 0 and exit_early:
- return torch.empty((0,), dtype=torch.int64, device=boxes.device)
-
- sorted_idx = torch.argsort(scores, descending=True)
- boxes = boxes[sorted_idx]
- ious = iou_func(boxes, boxes)
- if use_triu:
- ious = ious.triu_(diagonal=1)
- # NOTE: handle the case when len(boxes) hence exportable by eliminating if-else condition
- pick = torch.nonzero((ious >= iou_threshold).sum(0) <= 0).squeeze_(-1)
- else:
- n = boxes.shape[0]
- row_idx = torch.arange(n, device=boxes.device).view(-1, 1).expand(-1, n)
- col_idx = torch.arange(n, device=boxes.device).view(1, -1).expand(n, -1)
- upper_mask = row_idx < col_idx
- ious = ious * upper_mask
- # Zeroing these scores ensures the additional indices would not affect the final results
- scores[~((ious >= iou_threshold).sum(0) <= 0)] = 0
- # NOTE: return indices with fixed length to avoid TFLite reshape error
- pick = torch.topk(scores, scores.shape[0]).indices
- return sorted_idx[pick]
-
- @staticmethod
- def nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float) -> torch.Tensor:
- """
- Optimized NMS with early termination that matches torchvision behavior exactly.
-
- Args:
- boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
- scores (torch.Tensor): Confidence scores with shape (N,).
- iou_threshold (float): IoU threshold for suppression.
-
- Returns:
- (torch.Tensor): Indices of boxes to keep after NMS.
-
- Examples:
- Apply NMS to a set of boxes
- >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
- >>> scores = torch.tensor([0.9, 0.8])
- >>> keep = TorchNMS.nms(boxes, scores, 0.5)
- """
- if boxes.numel() == 0:
- return torch.empty((0,), dtype=torch.int64, device=boxes.device)
-
- # Pre-allocate and extract coordinates once
- x1, y1, x2, y2 = boxes.unbind(1)
- areas = (x2 - x1) * (y2 - y1)
-
- # Sort by scores descending
- order = scores.argsort(0, descending=True)
-
- # Pre-allocate keep list with maximum possible size
- keep = torch.zeros(order.numel(), dtype=torch.int64, device=boxes.device)
- keep_idx = 0
- while order.numel() > 0:
- i = order[0]
- keep[keep_idx] = i
- keep_idx += 1
-
- if order.numel() == 1:
- break
- # Vectorized IoU calculation for remaining boxes
- rest = order[1:]
- xx1 = torch.maximum(x1[i], x1[rest])
- yy1 = torch.maximum(y1[i], y1[rest])
- xx2 = torch.minimum(x2[i], x2[rest])
- yy2 = torch.minimum(y2[i], y2[rest])
-
- # Fast intersection and IoU
- w = (xx2 - xx1).clamp_(min=0)
- h = (yy2 - yy1).clamp_(min=0)
- inter = w * h
- # Early exit: skip IoU calculation if no intersection
- if inter.sum() == 0:
- # No overlaps with current box, keep all remaining boxes
- order = rest
- continue
- iou = inter / (areas[i] + areas[rest] - inter)
- # Keep boxes with IoU <= threshold
- order = rest[iou <= iou_threshold]
-
- return keep[:keep_idx]
-
- @staticmethod
- def batched_nms(
- boxes: torch.Tensor,
- scores: torch.Tensor,
- idxs: torch.Tensor,
- iou_threshold: float,
- use_fast_nms: bool = False,
- ) -> torch.Tensor:
- """
- Batched NMS for class-aware suppression.
-
- Args:
- boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
- scores (torch.Tensor): Confidence scores with shape (N,).
- idxs (torch.Tensor): Class indices with shape (N,).
- iou_threshold (float): IoU threshold for suppression.
- use_fast_nms (bool): Whether to use the Fast-NMS implementation.
-
- Returns:
- (torch.Tensor): Indices of boxes to keep after NMS.
-
- Examples:
- Apply batched NMS across multiple classes
- >>> boxes = torch.tensor([[0, 0, 10, 10], [5, 5, 15, 15]])
- >>> scores = torch.tensor([0.9, 0.8])
- >>> idxs = torch.tensor([0, 1])
- >>> keep = TorchNMS.batched_nms(boxes, scores, idxs, 0.5)
- """
- if boxes.numel() == 0:
- return torch.empty((0,), dtype=torch.int64, device=boxes.device)
-
- # Strategy: offset boxes by class index to prevent cross-class suppression
- max_coordinate = boxes.max()
- offsets = idxs.to(boxes) * (max_coordinate + 1)
- boxes_for_nms = boxes + offsets[:, None]
-
- return (
- TorchNMS.fast_nms(boxes_for_nms, scores, iou_threshold)
- if use_fast_nms
- else TorchNMS.nms(boxes_for_nms, scores, iou_threshold)
- )
diff --git a/ultralytics/utils/ops.py b/ultralytics/utils/ops.py
deleted file mode 100644
index 43a0574..0000000
--- a/ultralytics/utils/ops.py
+++ /dev/null
@@ -1,722 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import contextlib
-import math
-import re
-import time
-
-import cv2
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from ultralytics.utils import NOT_MACOS14
-
-
-class Profile(contextlib.ContextDecorator):
- """
- Ultralytics Profile class for timing code execution.
-
- Use as a decorator with @Profile() or as a context manager with 'with Profile():'. Provides accurate timing
- measurements with CUDA synchronization support for GPU operations.
-
- Attributes:
- t (float): Accumulated time in seconds.
- device (torch.device): Device used for model inference.
- cuda (bool): Whether CUDA is being used for timing synchronization.
-
- Examples:
- Use as a context manager to time code execution
- >>> with Profile(device=device) as dt:
- ... pass # slow operation here
- >>> print(dt) # prints "Elapsed time is 9.5367431640625e-07 s"
-
- Use as a decorator to time function execution
- >>> @Profile()
- ... def slow_function():
- ... time.sleep(0.1)
- """
-
- def __init__(self, t: float = 0.0, device: torch.device | None = None):
- """
- Initialize the Profile class.
-
- Args:
- t (float): Initial accumulated time in seconds.
- device (torch.device, optional): Device used for model inference to enable CUDA synchronization.
- """
- self.t = t
- self.device = device
- self.cuda = bool(device and str(device).startswith("cuda"))
-
- def __enter__(self):
- """Start timing."""
- self.start = self.time()
- return self
-
- def __exit__(self, type, value, traceback): # noqa
- """Stop timing."""
- self.dt = self.time() - self.start # delta-time
- self.t += self.dt # accumulate dt
-
- def __str__(self):
- """Return a human-readable string representing the accumulated elapsed time."""
- return f"Elapsed time is {self.t} s"
-
- def time(self):
- """Get current time with CUDA synchronization if applicable."""
- if self.cuda:
- torch.cuda.synchronize(self.device)
- return time.perf_counter()
-
-
-def segment2box(segment, width: int = 640, height: int = 640):
- """
- Convert segment coordinates to bounding box coordinates.
-
- Converts a single segment label to a box label by finding the minimum and maximum x and y coordinates.
- Applies inside-image constraint and clips coordinates when necessary.
-
- Args:
- segment (torch.Tensor): Segment coordinates in format (N, 2) where N is number of points.
- width (int): Width of the image in pixels.
- height (int): Height of the image in pixels.
-
- Returns:
- (np.ndarray): Bounding box coordinates in xyxy format [x1, y1, x2, y2].
- """
- x, y = segment.T # segment xy
- # Clip coordinates if 3 out of 4 sides are outside the image
- if np.array([x.min() < 0, y.min() < 0, x.max() > width, y.max() > height]).sum() >= 3:
- x = x.clip(0, width)
- y = y.clip(0, height)
- inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)
- x = x[inside]
- y = y[inside]
- return (
- np.array([x.min(), y.min(), x.max(), y.max()], dtype=segment.dtype)
- if any(x)
- else np.zeros(4, dtype=segment.dtype)
- ) # xyxy
-
-
-def scale_boxes(img1_shape, boxes, img0_shape, ratio_pad=None, padding: bool = True, xywh: bool = False):
- """
- Rescale bounding boxes from one image shape to another.
-
- Rescales bounding boxes from img1_shape to img0_shape, accounting for padding and aspect ratio changes.
- Supports both xyxy and xywh box formats.
-
- Args:
- img1_shape (tuple): Shape of the source image (height, width).
- boxes (torch.Tensor): Bounding boxes to rescale in format (N, 4).
- img0_shape (tuple): Shape of the target image (height, width).
- ratio_pad (tuple, optional): Tuple of (ratio, pad) for scaling. If None, calculated from image shapes.
- padding (bool): Whether boxes are based on YOLO-style augmented images with padding.
- xywh (bool): Whether box format is xywh (True) or xyxy (False).
-
- Returns:
- (torch.Tensor): Rescaled bounding boxes in the same format as input.
- """
- if ratio_pad is None: # calculate from img0_shape
- gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
- pad_x = round((img1_shape[1] - img0_shape[1] * gain) / 2 - 0.1)
- pad_y = round((img1_shape[0] - img0_shape[0] * gain) / 2 - 0.1)
- else:
- gain = ratio_pad[0][0]
- pad_x, pad_y = ratio_pad[1]
-
- if padding:
- boxes[..., 0] -= pad_x # x padding
- boxes[..., 1] -= pad_y # y padding
- if not xywh:
- boxes[..., 2] -= pad_x # x padding
- boxes[..., 3] -= pad_y # y padding
- boxes[..., :4] /= gain
- return boxes if xywh else clip_boxes(boxes, img0_shape)
-
-
-def make_divisible(x: int, divisor):
- """
- Return the nearest number that is divisible by the given divisor.
-
- Args:
- x (int): The number to make divisible.
- divisor (int | torch.Tensor): The divisor.
-
- Returns:
- (int): The nearest number divisible by the divisor.
- """
- if isinstance(divisor, torch.Tensor):
- divisor = int(divisor.max()) # to int
- return math.ceil(x / divisor) * divisor
-
-
-def clip_boxes(boxes, shape):
- """
- Clip bounding boxes to image boundaries.
-
- Args:
- boxes (torch.Tensor | np.ndarray): Bounding boxes to clip.
- shape (tuple): Image shape as HWC or HW (supports both).
-
- Returns:
- (torch.Tensor | np.ndarray): Clipped bounding boxes.
- """
- h, w = shape[:2] # supports both HWC or HW shapes
- if isinstance(boxes, torch.Tensor): # faster individually
- if NOT_MACOS14:
- boxes[..., 0].clamp_(0, w) # x1
- boxes[..., 1].clamp_(0, h) # y1
- boxes[..., 2].clamp_(0, w) # x2
- boxes[..., 3].clamp_(0, h) # y2
- else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
- boxes[..., 0] = boxes[..., 0].clamp(0, w)
- boxes[..., 1] = boxes[..., 1].clamp(0, h)
- boxes[..., 2] = boxes[..., 2].clamp(0, w)
- boxes[..., 3] = boxes[..., 3].clamp(0, h)
- else: # np.array (faster grouped)
- boxes[..., [0, 2]] = boxes[..., [0, 2]].clip(0, w) # x1, x2
- boxes[..., [1, 3]] = boxes[..., [1, 3]].clip(0, h) # y1, y2
- return boxes
-
-
-def clip_coords(coords, shape):
- """
- Clip line coordinates to image boundaries.
-
- Args:
- coords (torch.Tensor | np.ndarray): Line coordinates to clip.
- shape (tuple): Image shape as HWC or HW (supports both).
-
- Returns:
- (torch.Tensor | np.ndarray): Clipped coordinates.
- """
- h, w = shape[:2] # supports both HWC or HW shapes
- if isinstance(coords, torch.Tensor):
- if NOT_MACOS14:
- coords[..., 0].clamp_(0, w) # x
- coords[..., 1].clamp_(0, h) # y
- else: # Apple macOS14 MPS bug https://github.com/ultralytics/ultralytics/pull/21878
- coords[..., 0] = coords[..., 0].clamp(0, w)
- coords[..., 1] = coords[..., 1].clamp(0, h)
- else: # np.array
- coords[..., 0] = coords[..., 0].clip(0, w) # x
- coords[..., 1] = coords[..., 1].clip(0, h) # y
- return coords
-
-
-def scale_image(masks, im0_shape, ratio_pad=None):
- """
- Rescale masks to original image size.
-
- Takes resized and padded masks and rescales them back to the original image dimensions, removing any padding
- that was applied during preprocessing.
-
- Args:
- masks (np.ndarray): Resized and padded masks with shape [H, W, N] or [H, W, 3].
- im0_shape (tuple): Original image shape as HWC or HW (supports both).
- ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
-
- Returns:
- (np.ndarray): Rescaled masks with shape [H, W, N] matching original image dimensions.
- """
- # Rescale coordinates (xyxy) from im1_shape to im0_shape
- im0_h, im0_w = im0_shape[:2] # supports both HWC or HW shapes
- im1_h, im1_w, _ = masks.shape
- if im1_h == im0_h and im1_w == im0_w:
- return masks
-
- if ratio_pad is None: # calculate from im0_shape
- gain = min(im1_h / im0_h, im1_w / im0_w) # gain = old / new
- pad = (im1_w - im0_w * gain) / 2, (im1_h - im0_h * gain) / 2 # wh padding
- else:
- pad = ratio_pad[1]
-
- pad_w, pad_h = pad
- top = int(round(pad_h - 0.1))
- left = int(round(pad_w - 0.1))
- bottom = im1_h - int(round(pad_h + 0.1))
- right = im1_w - int(round(pad_w + 0.1))
-
- if len(masks.shape) < 2:
- raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
- masks = masks[top:bottom, left:right]
- # handle the cv2.resize 512 channels limitation: https://github.com/ultralytics/ultralytics/pull/21947
- masks = [cv2.resize(array, (im0_w, im0_h)) for array in np.array_split(masks, masks.shape[-1] // 512 + 1, axis=-1)]
- masks = np.concatenate(masks, axis=-1) if len(masks) > 1 else masks[0]
- if len(masks.shape) == 2:
- masks = masks[:, :, None]
-
- return masks
-
-
-def xyxy2xywh(x):
- """
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height) format where (x1, y1) is the
- top-left corner and (x2, y2) is the bottom-right corner.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in (x, y, width, height) format.
- """
- assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
- y = empty_like(x) # faster than clone/copy
- x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
- y[..., 0] = (x1 + x2) / 2 # x center
- y[..., 1] = (y1 + y2) / 2 # y center
- y[..., 2] = x2 - x1 # width
- y[..., 3] = y2 - y1 # height
- return y
-
-
-def xywh2xyxy(x):
- """
- Convert bounding box coordinates from (x, y, width, height) format to (x1, y1, x2, y2) format where (x1, y1) is the
- top-left corner and (x2, y2) is the bottom-right corner. Note: ops per 2 channels faster than per channel.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x, y, width, height) format.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in (x1, y1, x2, y2) format.
- """
- assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
- y = empty_like(x) # faster than clone/copy
- xy = x[..., :2] # centers
- wh = x[..., 2:] / 2 # half width-height
- y[..., :2] = xy - wh # top left xy
- y[..., 2:] = xy + wh # bottom right xy
- return y
-
-
-def xywhn2xyxy(x, w: int = 640, h: int = 640, padw: int = 0, padh: int = 0):
- """
- Convert normalized bounding box coordinates to pixel coordinates.
-
- Args:
- x (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, w, h) format.
- w (int): Image width in pixels.
- h (int): Image height in pixels.
- padw (int): Padding width in pixels.
- padh (int): Padding height in pixels.
-
- Returns:
- y (np.ndarray | torch.Tensor): The coordinates of the bounding box in the format [x1, y1, x2, y2] where
- x1,y1 is the top-left corner, x2,y2 is the bottom-right corner of the bounding box.
- """
- assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
- y = empty_like(x) # faster than clone/copy
- xc, yc, xw, xh = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
- half_w, half_h = xw / 2, xh / 2
- y[..., 0] = w * (xc - half_w) + padw # top left x
- y[..., 1] = h * (yc - half_h) + padh # top left y
- y[..., 2] = w * (xc + half_w) + padw # bottom right x
- y[..., 3] = h * (yc + half_h) + padh # bottom right y
- return y
-
-
-def xyxy2xywhn(x, w: int = 640, h: int = 640, clip: bool = False, eps: float = 0.0):
- """
- Convert bounding box coordinates from (x1, y1, x2, y2) format to (x, y, width, height, normalized) format. x, y,
- width and height are normalized to image dimensions.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates in (x1, y1, x2, y2) format.
- w (int): Image width in pixels.
- h (int): Image height in pixels.
- clip (bool): Whether to clip boxes to image boundaries.
- eps (float): Minimum value for box width and height.
-
- Returns:
- (np.ndarray | torch.Tensor): Normalized bounding box coordinates in (x, y, width, height) format.
- """
- if clip:
- x = clip_boxes(x, (h - eps, w - eps))
- assert x.shape[-1] == 4, f"input shape last dimension expected 4 but input shape is {x.shape}"
- y = empty_like(x) # faster than clone/copy
- x1, y1, x2, y2 = x[..., 0], x[..., 1], x[..., 2], x[..., 3]
- y[..., 0] = ((x1 + x2) / 2) / w # x center
- y[..., 1] = ((y1 + y2) / 2) / h # y center
- y[..., 2] = (x2 - x1) / w # width
- y[..., 3] = (y2 - y1) / h # height
- return y
-
-
-def xywh2ltwh(x):
- """
- Convert bounding box format from [x, y, w, h] to [x1, y1, w, h] where x1, y1 are top-left coordinates.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates in xywh format.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
- """
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x
- y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y
- return y
-
-
-def xyxy2ltwh(x):
- """
- Convert bounding boxes from [x1, y1, x2, y2] to [x1, y1, w, h] format.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates in xyxy format.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in xyltwh format.
- """
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[..., 2] = x[..., 2] - x[..., 0] # width
- y[..., 3] = x[..., 3] - x[..., 1] # height
- return y
-
-
-def ltwh2xywh(x):
- """
- Convert bounding boxes from [x1, y1, w, h] to [x, y, w, h] where xy1=top-left, xy=center.
-
- Args:
- x (torch.Tensor): Input bounding box coordinates.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in xywh format.
- """
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[..., 0] = x[..., 0] + x[..., 2] / 2 # center x
- y[..., 1] = x[..., 1] + x[..., 3] / 2 # center y
- return y
-
-
-def xyxyxyxy2xywhr(x):
- """
- Convert batched Oriented Bounding Boxes (OBB) from [xy1, xy2, xy3, xy4] to [xywh, rotation] format.
-
- Args:
- x (np.ndarray | torch.Tensor): Input box corners with shape (N, 8) in [xy1, xy2, xy3, xy4] format.
-
- Returns:
- (np.ndarray | torch.Tensor): Converted data in [cx, cy, w, h, rotation] format with shape (N, 5).
- Rotation values are in radians from 0 to pi/2.
- """
- is_torch = isinstance(x, torch.Tensor)
- points = x.cpu().numpy() if is_torch else x
- points = points.reshape(len(x), -1, 2)
- rboxes = []
- for pts in points:
- # NOTE: Use cv2.minAreaRect to get accurate xywhr,
- # especially some objects are cut off by augmentations in dataloader.
- (cx, cy), (w, h), angle = cv2.minAreaRect(pts)
- rboxes.append([cx, cy, w, h, angle / 180 * np.pi])
- return torch.tensor(rboxes, device=x.device, dtype=x.dtype) if is_torch else np.asarray(rboxes)
-
-
-def xywhr2xyxyxyxy(x):
- """
- Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4] format.
-
- Args:
- x (np.ndarray | torch.Tensor): Boxes in [cx, cy, w, h, rotation] format with shape (N, 5) or (B, N, 5).
- Rotation values should be in radians from 0 to pi/2.
-
- Returns:
- (np.ndarray | torch.Tensor): Converted corner points with shape (N, 4, 2) or (B, N, 4, 2).
- """
- cos, sin, cat, stack = (
- (torch.cos, torch.sin, torch.cat, torch.stack)
- if isinstance(x, torch.Tensor)
- else (np.cos, np.sin, np.concatenate, np.stack)
- )
-
- ctr = x[..., :2]
- w, h, angle = (x[..., i : i + 1] for i in range(2, 5))
- cos_value, sin_value = cos(angle), sin(angle)
- vec1 = [w / 2 * cos_value, w / 2 * sin_value]
- vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
- vec1 = cat(vec1, -1)
- vec2 = cat(vec2, -1)
- pt1 = ctr + vec1 + vec2
- pt2 = ctr + vec1 - vec2
- pt3 = ctr - vec1 - vec2
- pt4 = ctr - vec1 + vec2
- return stack([pt1, pt2, pt3, pt4], -2)
-
-
-def ltwh2xyxy(x):
- """
- Convert bounding box from [x1, y1, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right.
-
- Args:
- x (np.ndarray | torch.Tensor): Input bounding box coordinates.
-
- Returns:
- (np.ndarray | torch.Tensor): Bounding box coordinates in xyxy format.
- """
- y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
- y[..., 2] = x[..., 2] + x[..., 0] # width
- y[..., 3] = x[..., 3] + x[..., 1] # height
- return y
-
-
-def segments2boxes(segments):
- """
- Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh).
-
- Args:
- segments (list): List of segments where each segment is a list of points, each point is [x, y] coordinates.
-
- Returns:
- (np.ndarray): Bounding box coordinates in xywh format.
- """
- boxes = []
- for s in segments:
- x, y = s.T # segment xy
- boxes.append([x.min(), y.min(), x.max(), y.max()]) # cls, xyxy
- return xyxy2xywh(np.array(boxes)) # cls, xywh
-
-
-def resample_segments(segments, n: int = 1000):
- """
- Resample segments to n points each using linear interpolation.
-
- Args:
- segments (list): List of (N, 2) arrays where N is the number of points in each segment.
- n (int): Number of points to resample each segment to.
-
- Returns:
- (list): Resampled segments with n points each.
- """
- for i, s in enumerate(segments):
- if len(s) == n:
- continue
- s = np.concatenate((s, s[0:1, :]), axis=0)
- x = np.linspace(0, len(s) - 1, n - len(s) if len(s) < n else n)
- xp = np.arange(len(s))
- x = np.insert(x, np.searchsorted(x, xp), xp) if len(s) < n else x
- segments[i] = (
- np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)], dtype=np.float32).reshape(2, -1).T
- ) # segment xy
- return segments
-
-
-def crop_mask(masks, boxes):
- """
- Crop masks to bounding box regions.
-
- Args:
- masks (torch.Tensor): Masks with shape (N, H, W).
- boxes (torch.Tensor): Bounding box coordinates with shape (N, 4) in relative point form.
-
- Returns:
- (torch.Tensor): Cropped masks.
- """
- _, h, w = masks.shape
- x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
- r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
- c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
-
- return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
-
-
-def process_mask(protos, masks_in, bboxes, shape, upsample: bool = False):
- """
- Apply masks to bounding boxes using mask head output.
-
- Args:
- protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
- masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
- bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
- shape (tuple): Input image size as (height, width).
- upsample (bool): Whether to upsample masks to original image size.
-
- Returns:
- (torch.Tensor): A binary mask tensor of shape [n, h, w], where n is the number of masks after NMS, and h and w
- are the height and width of the input image. The mask is applied to the bounding boxes.
- """
- c, mh, mw = protos.shape # CHW
- ih, iw = shape
- masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw) # CHW
- width_ratio = mw / iw
- height_ratio = mh / ih
-
- downsampled_bboxes = bboxes.clone()
- downsampled_bboxes[:, 0] *= width_ratio
- downsampled_bboxes[:, 2] *= width_ratio
- downsampled_bboxes[:, 3] *= height_ratio
- downsampled_bboxes[:, 1] *= height_ratio
-
- masks = crop_mask(masks, downsampled_bboxes) # CHW
- if upsample:
- masks = F.interpolate(masks[None], shape, mode="bilinear", align_corners=False)[0] # CHW
- return masks.gt_(0.0)
-
-
-def process_mask_native(protos, masks_in, bboxes, shape):
- """
- Apply masks to bounding boxes using mask head output with native upsampling.
-
- Args:
- protos (torch.Tensor): Mask prototypes with shape (mask_dim, mask_h, mask_w).
- masks_in (torch.Tensor): Mask coefficients with shape (N, mask_dim) where N is number of masks after NMS.
- bboxes (torch.Tensor): Bounding boxes with shape (N, 4) where N is number of masks after NMS.
- shape (tuple): Input image size as (height, width).
-
- Returns:
- (torch.Tensor): Binary mask tensor with shape (H, W, N).
- """
- c, mh, mw = protos.shape # CHW
- masks = (masks_in @ protos.float().view(c, -1)).view(-1, mh, mw)
- masks = scale_masks(masks[None], shape)[0] # CHW
- masks = crop_mask(masks, bboxes) # CHW
- return masks.gt_(0.0)
-
-
-def scale_masks(masks, shape, padding: bool = True):
- """
- Rescale segment masks to target shape.
-
- Args:
- masks (torch.Tensor): Masks with shape (N, C, H, W).
- shape (tuple): Target height and width as (height, width).
- padding (bool): Whether masks are based on YOLO-style augmented images with padding.
-
- Returns:
- (torch.Tensor): Rescaled masks.
- """
- mh, mw = masks.shape[2:]
- gain = min(mh / shape[0], mw / shape[1]) # gain = old / new
- pad_w = mw - shape[1] * gain
- pad_h = mh - shape[0] * gain
- if padding:
- pad_w /= 2
- pad_h /= 2
- top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
- bottom = mh - int(round(pad_h + 0.1))
- right = mw - int(round(pad_w + 0.1))
- return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear", align_corners=False) # NCHW masks
-
-
-def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None, normalize: bool = False, padding: bool = True):
- """
- Rescale segment coordinates from img1_shape to img0_shape.
-
- Args:
- img1_shape (tuple): Source image shape as HWC or HW (supports both).
- coords (torch.Tensor): Coordinates to scale with shape (N, 2).
- img0_shape (tuple): Image 0 shape as HWC or HW (supports both).
- ratio_pad (tuple, optional): Ratio and padding values as ((ratio_h, ratio_w), (pad_h, pad_w)).
- normalize (bool): Whether to normalize coordinates to range [0, 1].
- padding (bool): Whether coordinates are based on YOLO-style augmented images with padding.
-
- Returns:
- (torch.Tensor): Scaled coordinates.
- """
- img0_h, img0_w = img0_shape[:2] # supports both HWC or HW shapes
- if ratio_pad is None: # calculate from img0_shape
- img1_h, img1_w = img1_shape[:2] # supports both HWC or HW shapes
- gain = min(img1_h / img0_h, img1_w / img0_w) # gain = old / new
- pad = (img1_w - img0_w * gain) / 2, (img1_h - img0_h * gain) / 2 # wh padding
- else:
- gain = ratio_pad[0][0]
- pad = ratio_pad[1]
-
- if padding:
- coords[..., 0] -= pad[0] # x padding
- coords[..., 1] -= pad[1] # y padding
- coords[..., 0] /= gain
- coords[..., 1] /= gain
- coords = clip_coords(coords, img0_shape)
- if normalize:
- coords[..., 0] /= img0_w # width
- coords[..., 1] /= img0_h # height
- return coords
-
-
-def regularize_rboxes(rboxes):
- """
- Regularize rotated bounding boxes to range [0, pi/2].
-
- Args:
- rboxes (torch.Tensor): Input rotated boxes with shape (N, 5) in xywhr format.
-
- Returns:
- (torch.Tensor): Regularized rotated boxes.
- """
- x, y, w, h, t = rboxes.unbind(dim=-1)
- # Swap edge if t >= pi/2 while not being symmetrically opposite
- swap = t % math.pi >= math.pi / 2
- w_ = torch.where(swap, h, w)
- h_ = torch.where(swap, w, h)
- t = t % (math.pi / 2)
- return torch.stack([x, y, w_, h_, t], dim=-1) # regularized boxes
-
-
-def masks2segments(masks, strategy: str = "all"):
- """
- Convert masks to segments using contour detection.
-
- Args:
- masks (torch.Tensor): Binary masks with shape (batch_size, 160, 160).
- strategy (str): Segmentation strategy, either 'all' or 'largest'.
-
- Returns:
- (list): List of segment masks as float32 arrays.
- """
- from ultralytics.data.converter import merge_multi_segment
-
- segments = []
- for x in masks.int().cpu().numpy().astype("uint8"):
- c = cv2.findContours(x, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0]
- if c:
- if strategy == "all": # merge and concatenate all segments
- c = (
- np.concatenate(merge_multi_segment([x.reshape(-1, 2) for x in c]))
- if len(c) > 1
- else c[0].reshape(-1, 2)
- )
- elif strategy == "largest": # select largest segment
- c = np.array(c[np.array([len(x) for x in c]).argmax()]).reshape(-1, 2)
- else:
- c = np.zeros((0, 2)) # no segments found
- segments.append(c.astype("float32"))
- return segments
-
-
-def convert_torch2numpy_batch(batch: torch.Tensor) -> np.ndarray:
- """
- Convert a batch of FP32 torch tensors to NumPy uint8 arrays, changing from BCHW to BHWC layout.
-
- Args:
- batch (torch.Tensor): Input tensor batch with shape (Batch, Channels, Height, Width) and dtype torch.float32.
-
- Returns:
- (np.ndarray): Output NumPy array batch with shape (Batch, Height, Width, Channels) and dtype uint8.
- """
- return (batch.permute(0, 2, 3, 1).contiguous() * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
-
-
-def clean_str(s):
- """
- Clean a string by replacing special characters with '_' character.
-
- Args:
- s (str): A string needing special characters replaced.
-
- Returns:
- (str): A string with special characters replaced by an underscore _.
- """
- return re.sub(pattern="[|@#!¡·$€%&()=?¿^*;:,¨´><+]", repl="_", string=s)
-
-
-def empty_like(x):
- """Create empty torch.Tensor or np.ndarray with same shape as input and float32 dtype."""
- return (
- torch.empty_like(x, dtype=torch.float32) if isinstance(x, torch.Tensor) else np.empty_like(x, dtype=np.float32)
- )
diff --git a/ultralytics/utils/patches.py b/ultralytics/utils/patches.py
deleted file mode 100644
index 4527dae..0000000
--- a/ultralytics/utils/patches.py
+++ /dev/null
@@ -1,189 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-"""Monkey patches to update/extend functionality of existing functions."""
-
-from __future__ import annotations
-
-import time
-from contextlib import contextmanager
-from copy import copy
-from pathlib import Path
-from typing import Any
-
-import cv2
-import numpy as np
-import torch
-
-# OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
-_imshow = cv2.imshow # copy to avoid recursion errors
-
-
-def imread(filename: str, flags: int = cv2.IMREAD_COLOR) -> np.ndarray | None:
- """
- Read an image from a file with multilanguage filename support.
-
- Args:
- filename (str): Path to the file to read.
- flags (int, optional): Flag that can take values of cv2.IMREAD_*. Controls how the image is read.
-
- Returns:
- (np.ndarray | None): The read image array, or None if reading fails.
-
- Examples:
- >>> img = imread("path/to/image.jpg")
- >>> img = imread("path/to/image.jpg", cv2.IMREAD_GRAYSCALE)
- """
- file_bytes = np.fromfile(filename, np.uint8)
- if filename.endswith((".tiff", ".tif")):
- success, frames = cv2.imdecodemulti(file_bytes, cv2.IMREAD_UNCHANGED)
- if success:
- # Handle RGB images in tif/tiff format
- return frames[0] if len(frames) == 1 and frames[0].ndim == 3 else np.stack(frames, axis=2)
- return None
- else:
- im = cv2.imdecode(file_bytes, flags)
- return im[..., None] if im is not None and im.ndim == 2 else im # Always ensure 3 dimensions
-
-
-def imwrite(filename: str, img: np.ndarray, params: list[int] | None = None) -> bool:
- """
- Write an image to a file with multilanguage filename support.
-
- Args:
- filename (str): Path to the file to write.
- img (np.ndarray): Image to write.
- params (list[int], optional): Additional parameters for image encoding.
-
- Returns:
- (bool): True if the file was written successfully, False otherwise.
-
- Examples:
- >>> import numpy as np
- >>> img = np.zeros((100, 100, 3), dtype=np.uint8) # Create a black image
- >>> success = imwrite("output.jpg", img) # Write image to file
- >>> print(success)
- True
- """
- try:
- cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
- return True
- except Exception:
- return False
-
-
-def imshow(winname: str, mat: np.ndarray) -> None:
- """
- Display an image in the specified window with multilanguage window name support.
-
- This function is a wrapper around OpenCV's imshow function that displays an image in a named window. It handles
- multilanguage window names by encoding them properly for OpenCV compatibility.
-
- Args:
- winname (str): Name of the window where the image will be displayed. If a window with this name already
- exists, the image will be displayed in that window.
- mat (np.ndarray): Image to be shown. Should be a valid numpy array representing an image.
-
- Examples:
- >>> import numpy as np
- >>> img = np.zeros((300, 300, 3), dtype=np.uint8) # Create a black image
- >>> img[:100, :100] = [255, 0, 0] # Add a blue square
- >>> imshow("Example Window", img) # Display the image
- """
- _imshow(winname.encode("unicode_escape").decode(), mat)
-
-
-# PyTorch functions ----------------------------------------------------------------------------------------------------
-_torch_save = torch.save
-
-
-def torch_load(*args, **kwargs):
- """
- Load a PyTorch model with updated arguments to avoid warnings.
-
- This function wraps torch.load and adds the 'weights_only' argument for PyTorch 1.13.0+ to prevent warnings.
-
- Args:
- *args (Any): Variable length argument list to pass to torch.load.
- **kwargs (Any): Arbitrary keyword arguments to pass to torch.load.
-
- Returns:
- (Any): The loaded PyTorch object.
-
- Notes:
- For PyTorch versions 2.0 and above, this function automatically sets 'weights_only=False'
- if the argument is not provided, to avoid deprecation warnings.
- """
- from ultralytics.utils.torch_utils import TORCH_1_13
-
- if TORCH_1_13 and "weights_only" not in kwargs:
- kwargs["weights_only"] = False
-
- return torch.load(*args, **kwargs)
-
-
-def torch_save(*args, **kwargs):
- """
- Save PyTorch objects with retry mechanism for robustness.
-
- This function wraps torch.save with 3 retries and exponential backoff in case of save failures, which can occur
- due to device flushing delays or antivirus scanning.
-
- Args:
- *args (Any): Positional arguments to pass to torch.save.
- **kwargs (Any): Keyword arguments to pass to torch.save.
-
- Examples:
- >>> model = torch.nn.Linear(10, 1)
- >>> torch_save(model.state_dict(), "model.pt")
- """
- for i in range(4): # 3 retries
- try:
- return _torch_save(*args, **kwargs)
- except RuntimeError as e: # Unable to save, possibly waiting for device to flush or antivirus scan
- if i == 3:
- raise e
- time.sleep((2**i) / 2) # Exponential backoff: 0.5s, 1.0s, 2.0s
-
-
-@contextmanager
-def arange_patch(args):
- """
- Workaround for ONNX torch.arange incompatibility with FP16.
-
- https://github.com/pytorch/pytorch/issues/148041.
- """
- if args.dynamic and args.half and args.format == "onnx":
- func = torch.arange
-
- def arange(*args, dtype=None, **kwargs):
- """Return a 1-D tensor of size with values from the interval and common difference."""
- return func(*args, **kwargs).to(dtype) # cast to dtype instead of passing dtype
-
- torch.arange = arange # patch
- yield
- torch.arange = func # unpatch
- else:
- yield
-
-
-@contextmanager
-def override_configs(args, overrides: dict[str, Any] | None = None):
- """
- Context manager to temporarily override configurations in args.
-
- Args:
- args (IterableSimpleNamespace): Original configuration arguments.
- overrides (dict[str, Any]): Dictionary of overrides to apply.
-
- Yields:
- (IterableSimpleNamespace): Configuration arguments with overrides applied.
- """
- if overrides:
- original_args = copy(args)
- for key, value in overrides.items():
- setattr(args, key, value)
- try:
- yield args
- finally:
- args.__dict__.update(original_args.__dict__)
- else:
- yield args
diff --git a/ultralytics/utils/plotting.py b/ultralytics/utils/plotting.py
deleted file mode 100644
index 627160a..0000000
--- a/ultralytics/utils/plotting.py
+++ /dev/null
@@ -1,1031 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import math
-import warnings
-from pathlib import Path
-from typing import Any, Callable
-
-import cv2
-import numpy as np
-import torch
-from PIL import Image, ImageDraw, ImageFont
-from PIL import __version__ as pil_version
-
-from ultralytics.utils import IS_COLAB, IS_KAGGLE, LOGGER, TryExcept, ops, plt_settings, threaded
-from ultralytics.utils.checks import check_font, check_version, is_ascii
-from ultralytics.utils.files import increment_path
-
-
-class Colors:
- """
- Ultralytics color palette for visualization and plotting.
-
- This class provides methods to work with the Ultralytics color palette, including converting hex color codes to
- RGB values and accessing predefined color schemes for object detection and pose estimation.
-
- Attributes:
- palette (list[tuple]): List of RGB color tuples for general use.
- n (int): The number of colors in the palette.
- pose_palette (np.ndarray): A specific color palette array for pose estimation with dtype np.uint8.
-
- Examples:
- >>> from ultralytics.utils.plotting import Colors
- >>> colors = Colors()
- >>> colors(5, True) # Returns BGR format: (221, 111, 255)
- >>> colors(5, False) # Returns RGB format: (255, 111, 221)
-
- ## Ultralytics Color Palette
-
- | Index | Color | HEX | RGB |
- |-------|-------------------------------------------------------------------|-----------|-------------------|
- | 0 | | `#042aff` | (4, 42, 255) |
- | 1 | | `#0bdbeb` | (11, 219, 235) |
- | 2 | | `#f3f3f3` | (243, 243, 243) |
- | 3 | | `#00dfb7` | (0, 223, 183) |
- | 4 | | `#111f68` | (17, 31, 104) |
- | 5 | | `#ff6fdd` | (255, 111, 221) |
- | 6 | | `#ff444f` | (255, 68, 79) |
- | 7 | | `#cced00` | (204, 237, 0) |
- | 8 | | `#00f344` | (0, 243, 68) |
- | 9 | | `#bd00ff` | (189, 0, 255) |
- | 10 | | `#00b4ff` | (0, 180, 255) |
- | 11 | | `#dd00ba` | (221, 0, 186) |
- | 12 | | `#00ffff` | (0, 255, 255) |
- | 13 | | `#26c000` | (38, 192, 0) |
- | 14 | | `#01ffb3` | (1, 255, 179) |
- | 15 | | `#7d24ff` | (125, 36, 255) |
- | 16 | | `#7b0068` | (123, 0, 104) |
- | 17 | | `#ff1b6c` | (255, 27, 108) |
- | 18 | | `#fc6d2f` | (252, 109, 47) |
- | 19 | | `#a2ff0b` | (162, 255, 11) |
-
- ## Pose Color Palette
-
- | Index | Color | HEX | RGB |
- |-------|-------------------------------------------------------------------|-----------|-------------------|
- | 0 | | `#ff8000` | (255, 128, 0) |
- | 1 | | `#ff9933` | (255, 153, 51) |
- | 2 | | `#ffb266` | (255, 178, 102) |
- | 3 | | `#e6e600` | (230, 230, 0) |
- | 4 | | `#ff99ff` | (255, 153, 255) |
- | 5 | | `#99ccff` | (153, 204, 255) |
- | 6 | | `#ff66ff` | (255, 102, 255) |
- | 7 | | `#ff33ff` | (255, 51, 255) |
- | 8 | | `#66b2ff` | (102, 178, 255) |
- | 9 | | `#3399ff` | (51, 153, 255) |
- | 10 | | `#ff9999` | (255, 153, 153) |
- | 11 | | `#ff6666` | (255, 102, 102) |
- | 12 | | `#ff3333` | (255, 51, 51) |
- | 13 | | `#99ff99` | (153, 255, 153) |
- | 14 | | `#66ff66` | (102, 255, 102) |
- | 15 | | `#33ff33` | (51, 255, 51) |
- | 16 | | `#00ff00` | (0, 255, 0) |
- | 17 | | `#0000ff` | (0, 0, 255) |
- | 18 | | `#ff0000` | (255, 0, 0) |
- | 19 | | `#ffffff` | (255, 255, 255) |
-
- !!! note "Ultralytics Brand Colors"
-
- For Ultralytics brand colors see [https://www.ultralytics.com/brand](https://www.ultralytics.com/brand).
- Please use the official Ultralytics colors for all marketing materials.
- """
-
- def __init__(self):
- """Initialize colors as hex = matplotlib.colors.TABLEAU_COLORS.values()."""
- hexs = (
- "042AFF",
- "0BDBEB",
- "F3F3F3",
- "00DFB7",
- "111F68",
- "FF6FDD",
- "FF444F",
- "CCED00",
- "00F344",
- "BD00FF",
- "00B4FF",
- "DD00BA",
- "00FFFF",
- "26C000",
- "01FFB3",
- "7D24FF",
- "7B0068",
- "FF1B6C",
- "FC6D2F",
- "A2FF0B",
- )
- self.palette = [self.hex2rgb(f"#{c}") for c in hexs]
- self.n = len(self.palette)
- self.pose_palette = np.array(
- [
- [255, 128, 0],
- [255, 153, 51],
- [255, 178, 102],
- [230, 230, 0],
- [255, 153, 255],
- [153, 204, 255],
- [255, 102, 255],
- [255, 51, 255],
- [102, 178, 255],
- [51, 153, 255],
- [255, 153, 153],
- [255, 102, 102],
- [255, 51, 51],
- [153, 255, 153],
- [102, 255, 102],
- [51, 255, 51],
- [0, 255, 0],
- [0, 0, 255],
- [255, 0, 0],
- [255, 255, 255],
- ],
- dtype=np.uint8,
- )
-
- def __call__(self, i: int | torch.Tensor, bgr: bool = False) -> tuple:
- """
- Convert hex color codes to RGB values.
-
- Args:
- i (int | torch.Tensor): Color index.
- bgr (bool, optional): Whether to return BGR format instead of RGB.
-
- Returns:
- (tuple): RGB or BGR color tuple.
- """
- c = self.palette[int(i) % self.n]
- return (c[2], c[1], c[0]) if bgr else c
-
- @staticmethod
- def hex2rgb(h: str) -> tuple:
- """Convert hex color codes to RGB values (i.e. default PIL order)."""
- return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
-
-
-colors = Colors() # create instance for 'from utils.plots import colors'
-
-
-class Annotator:
- """
- Ultralytics Annotator for train/val mosaics and JPGs and predictions annotations.
-
- Attributes:
- im (Image.Image | np.ndarray): The image to annotate.
- pil (bool): Whether to use PIL or cv2 for drawing annotations.
- font (ImageFont.truetype | ImageFont.load_default): Font used for text annotations.
- lw (float): Line width for drawing.
- skeleton (list[list[int]]): Skeleton structure for keypoints.
- limb_color (list[int]): Color palette for limbs.
- kpt_color (list[int]): Color palette for keypoints.
- dark_colors (set): Set of colors considered dark for text contrast.
- light_colors (set): Set of colors considered light for text contrast.
-
- Examples:
- >>> from ultralytics.utils.plotting import Annotator
- >>> im0 = cv2.imread("test.png")
- >>> annotator = Annotator(im0, line_width=10)
- >>> annotator.box_label([10, 10, 100, 100], "person", (255, 0, 0))
- """
-
- def __init__(
- self,
- im,
- line_width: int | None = None,
- font_size: int | None = None,
- font: str = "Arial.ttf",
- pil: bool = False,
- example: str = "abc",
- ):
- """Initialize the Annotator class with image and line width along with color palette for keypoints and limbs."""
- non_ascii = not is_ascii(example) # non-latin labels, i.e. asian, arabic, cyrillic
- input_is_pil = isinstance(im, Image.Image)
- self.pil = pil or non_ascii or input_is_pil
- self.lw = line_width or max(round(sum(im.size if input_is_pil else im.shape) / 2 * 0.003), 2)
- if not input_is_pil:
- if im.shape[2] == 1: # handle grayscale
- im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
- elif im.shape[2] > 3: # multispectral
- im = np.ascontiguousarray(im[..., :3])
- if self.pil: # use PIL
- self.im = im if input_is_pil else Image.fromarray(im)
- if self.im.mode not in {"RGB", "RGBA"}: # multispectral
- self.im = self.im.convert("RGB")
- self.draw = ImageDraw.Draw(self.im, "RGBA")
- try:
- font = check_font("Arial.Unicode.ttf" if non_ascii else font)
- size = font_size or max(round(sum(self.im.size) / 2 * 0.035), 12)
- self.font = ImageFont.truetype(str(font), size)
- except Exception:
- self.font = ImageFont.load_default()
- # Deprecation fix for w, h = getsize(string) -> _, _, w, h = getbox(string)
- if check_version(pil_version, "9.2.0"):
- self.font.getsize = lambda x: self.font.getbbox(x)[2:4] # text width, height
- else: # use cv2
- assert im.data.contiguous, "Image not contiguous. Apply np.ascontiguousarray(im) to Annotator input images."
- self.im = im if im.flags.writeable else im.copy()
- self.tf = max(self.lw - 1, 1) # font thickness
- self.sf = self.lw / 3 # font scale
- # Pose
- self.skeleton = [
- [16, 14],
- [14, 12],
- [17, 15],
- [15, 13],
- [12, 13],
- [6, 12],
- [7, 13],
- [6, 7],
- [6, 8],
- [7, 9],
- [8, 10],
- [9, 11],
- [2, 3],
- [1, 2],
- [1, 3],
- [2, 4],
- [3, 5],
- [4, 6],
- [5, 7],
- ]
-
- self.limb_color = colors.pose_palette[[9, 9, 9, 9, 7, 7, 7, 0, 0, 0, 0, 0, 16, 16, 16, 16, 16, 16, 16]]
- self.kpt_color = colors.pose_palette[[16, 16, 16, 16, 16, 0, 0, 0, 0, 0, 0, 9, 9, 9, 9, 9, 9]]
- self.dark_colors = {
- (235, 219, 11),
- (243, 243, 243),
- (183, 223, 0),
- (221, 111, 255),
- (0, 237, 204),
- (68, 243, 0),
- (255, 255, 0),
- (179, 255, 1),
- (11, 255, 162),
- }
- self.light_colors = {
- (255, 42, 4),
- (79, 68, 255),
- (255, 0, 189),
- (255, 180, 0),
- (186, 0, 221),
- (0, 192, 38),
- (255, 36, 125),
- (104, 0, 123),
- (108, 27, 255),
- (47, 109, 252),
- (104, 31, 17),
- }
-
- def get_txt_color(self, color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)) -> tuple:
- """
- Assign text color based on background color.
-
- Args:
- color (tuple, optional): The background color of the rectangle for text (B, G, R).
- txt_color (tuple, optional): The color of the text (R, G, B).
-
- Returns:
- (tuple): Text color for label.
-
- Examples:
- >>> from ultralytics.utils.plotting import Annotator
- >>> im0 = cv2.imread("test.png")
- >>> annotator = Annotator(im0, line_width=10)
- >>> annotator.get_txt_color(color=(104, 31, 17)) # return (255, 255, 255)
- """
- if color in self.dark_colors:
- return 104, 31, 17
- elif color in self.light_colors:
- return 255, 255, 255
- else:
- return txt_color
-
- def box_label(self, box, label: str = "", color: tuple = (128, 128, 128), txt_color: tuple = (255, 255, 255)):
- """
- Draw a bounding box on an image with a given label.
-
- Args:
- box (tuple): The bounding box coordinates (x1, y1, x2, y2).
- label (str, optional): The text label to be displayed.
- color (tuple, optional): The background color of the rectangle (B, G, R).
- txt_color (tuple, optional): The color of the text (R, G, B).
-
- Examples:
- >>> from ultralytics.utils.plotting import Annotator
- >>> im0 = cv2.imread("test.png")
- >>> annotator = Annotator(im0, line_width=10)
- >>> annotator.box_label(box=[10, 20, 30, 40], label="person")
- """
- txt_color = self.get_txt_color(color, txt_color)
- if isinstance(box, torch.Tensor):
- box = box.tolist()
-
- multi_points = isinstance(box[0], list) # multiple points with shape (n, 2)
- p1 = [int(b) for b in box[0]] if multi_points else (int(box[0]), int(box[1]))
- if self.pil:
- self.draw.polygon(
- [tuple(b) for b in box], width=self.lw, outline=color
- ) if multi_points else self.draw.rectangle(box, width=self.lw, outline=color)
- if label:
- w, h = self.font.getsize(label) # text width, height
- outside = p1[1] >= h # label fits outside box
- if p1[0] > self.im.size[0] - w: # size is (w, h), check if label extend beyond right side of image
- p1 = self.im.size[0] - w, p1[1]
- self.draw.rectangle(
- (p1[0], p1[1] - h if outside else p1[1], p1[0] + w + 1, p1[1] + 1 if outside else p1[1] + h + 1),
- fill=color,
- )
- # self.draw.text([box[0], box[1]], label, fill=txt_color, font=self.font, anchor='ls') # for PIL>8.0
- self.draw.text((p1[0], p1[1] - h if outside else p1[1]), label, fill=txt_color, font=self.font)
- else: # cv2
- cv2.polylines(
- self.im, [np.asarray(box, dtype=int)], True, color, self.lw
- ) if multi_points else cv2.rectangle(
- self.im, p1, (int(box[2]), int(box[3])), color, thickness=self.lw, lineType=cv2.LINE_AA
- )
- if label:
- w, h = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0] # text width, height
- h += 3 # add pixels to pad text
- outside = p1[1] >= h # label fits outside box
- if p1[0] > self.im.shape[1] - w: # shape is (h, w), check if label extend beyond right side of image
- p1 = self.im.shape[1] - w, p1[1]
- p2 = p1[0] + w, p1[1] - h if outside else p1[1] + h
- cv2.rectangle(self.im, p1, p2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(
- self.im,
- label,
- (p1[0], p1[1] - 2 if outside else p1[1] + h - 1),
- 0,
- self.sf,
- txt_color,
- thickness=self.tf,
- lineType=cv2.LINE_AA,
- )
-
- def masks(self, masks, colors, im_gpu: torch.Tensor = None, alpha: float = 0.5, retina_masks: bool = False):
- """
- Plot masks on image.
-
- Args:
- masks (torch.Tensor | np.ndarray): Predicted masks with shape: [n, h, w]
- colors (list[list[int]]): Colors for predicted masks, [[r, g, b] * n]
- im_gpu (torch.Tensor | None): Image is in cuda, shape: [3, h, w], range: [0, 1]
- alpha (float, optional): Mask transparency: 0.0 fully transparent, 1.0 opaque.
- retina_masks (bool, optional): Whether to use high resolution masks or not.
- """
- if self.pil:
- # Convert to numpy first
- self.im = np.asarray(self.im).copy()
- if im_gpu is None:
- assert isinstance(masks, np.ndarray), "`masks` must be a np.ndarray if `im_gpu` is not provided."
- overlay = self.im.copy()
- for i, mask in enumerate(masks):
- overlay[mask.astype(bool)] = colors[i]
- self.im = cv2.addWeighted(self.im, 1 - alpha, overlay, alpha, 0)
- else:
- assert isinstance(masks, torch.Tensor), "`masks` must be a torch.Tensor if `im_gpu` is provided."
- if len(masks) == 0:
- self.im[:] = im_gpu.permute(1, 2, 0).contiguous().cpu().numpy() * 255
- if im_gpu.device != masks.device:
- im_gpu = im_gpu.to(masks.device)
- colors = torch.tensor(colors, device=masks.device, dtype=torch.float32) / 255.0 # shape(n,3)
- colors = colors[:, None, None] # shape(n,1,1,3)
- masks = masks.unsqueeze(3) # shape(n,h,w,1)
- masks_color = masks * (colors * alpha) # shape(n,h,w,3)
-
- inv_alpha_masks = (1 - masks * alpha).cumprod(0) # shape(n,h,w,1)
- mcs = masks_color.max(dim=0).values # shape(n,h,w,3)
-
- im_gpu = im_gpu.flip(dims=[0]) # flip channel
- im_gpu = im_gpu.permute(1, 2, 0).contiguous() # shape(h,w,3)
- im_gpu = im_gpu * inv_alpha_masks[-1] + mcs
- im_mask = im_gpu * 255
- im_mask_np = im_mask.byte().cpu().numpy()
- self.im[:] = im_mask_np if retina_masks else ops.scale_image(im_mask_np, self.im.shape)
- if self.pil:
- # Convert im back to PIL and update draw
- self.fromarray(self.im)
-
- def kpts(
- self,
- kpts,
- shape: tuple = (640, 640),
- radius: int | None = None,
- kpt_line: bool = True,
- conf_thres: float = 0.25,
- kpt_color: tuple | None = None,
- ):
- """
- Plot keypoints on the image.
-
- Args:
- kpts (torch.Tensor): Keypoints, shape [17, 3] (x, y, confidence).
- shape (tuple, optional): Image shape (h, w).
- radius (int, optional): Keypoint radius.
- kpt_line (bool, optional): Draw lines between keypoints.
- conf_thres (float, optional): Confidence threshold.
- kpt_color (tuple, optional): Keypoint color (B, G, R).
-
- Note:
- - `kpt_line=True` currently only supports human pose plotting.
- - Modifies self.im in-place.
- - If self.pil is True, converts image to numpy array and back to PIL.
- """
- radius = radius if radius is not None else self.lw
- if self.pil:
- # Convert to numpy first
- self.im = np.asarray(self.im).copy()
- nkpt, ndim = kpts.shape
- is_pose = nkpt == 17 and ndim in {2, 3}
- kpt_line &= is_pose # `kpt_line=True` for now only supports human pose plotting
- for i, k in enumerate(kpts):
- color_k = kpt_color or (self.kpt_color[i].tolist() if is_pose else colors(i))
- x_coord, y_coord = k[0], k[1]
- if x_coord % shape[1] != 0 and y_coord % shape[0] != 0:
- if len(k) == 3:
- conf = k[2]
- if conf < conf_thres:
- continue
- cv2.circle(self.im, (int(x_coord), int(y_coord)), radius, color_k, -1, lineType=cv2.LINE_AA)
-
- if kpt_line:
- ndim = kpts.shape[-1]
- for i, sk in enumerate(self.skeleton):
- pos1 = (int(kpts[(sk[0] - 1), 0]), int(kpts[(sk[0] - 1), 1]))
- pos2 = (int(kpts[(sk[1] - 1), 0]), int(kpts[(sk[1] - 1), 1]))
- if ndim == 3:
- conf1 = kpts[(sk[0] - 1), 2]
- conf2 = kpts[(sk[1] - 1), 2]
- if conf1 < conf_thres or conf2 < conf_thres:
- continue
- if pos1[0] % shape[1] == 0 or pos1[1] % shape[0] == 0 or pos1[0] < 0 or pos1[1] < 0:
- continue
- if pos2[0] % shape[1] == 0 or pos2[1] % shape[0] == 0 or pos2[0] < 0 or pos2[1] < 0:
- continue
- cv2.line(
- self.im,
- pos1,
- pos2,
- kpt_color or self.limb_color[i].tolist(),
- thickness=int(np.ceil(self.lw / 2)),
- lineType=cv2.LINE_AA,
- )
- if self.pil:
- # Convert im back to PIL and update draw
- self.fromarray(self.im)
-
- def rectangle(self, xy, fill=None, outline=None, width: int = 1):
- """Add rectangle to image (PIL-only)."""
- self.draw.rectangle(xy, fill, outline, width)
-
- def text(self, xy, text: str, txt_color: tuple = (255, 255, 255), anchor: str = "top", box_color: tuple = ()):
- """
- Add text to an image using PIL or cv2.
-
- Args:
- xy (list[int]): Top-left coordinates for text placement.
- text (str): Text to be drawn.
- txt_color (tuple, optional): Text color (R, G, B).
- anchor (str, optional): Text anchor position ('top' or 'bottom').
- box_color (tuple, optional): Box color (R, G, B, A) with optional alpha.
- """
- if self.pil:
- w, h = self.font.getsize(text)
- if anchor == "bottom": # start y from font bottom
- xy[1] += 1 - h
- for line in text.split("\n"):
- if box_color:
- # Draw rectangle for each line
- w, h = self.font.getsize(line)
- self.draw.rectangle((xy[0], xy[1], xy[0] + w + 1, xy[1] + h + 1), fill=box_color)
- self.draw.text(xy, line, fill=txt_color, font=self.font)
- xy[1] += h
- else:
- if box_color:
- w, h = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
- h += 3 # add pixels to pad text
- outside = xy[1] >= h # label fits outside box
- p2 = xy[0] + w, xy[1] - h if outside else xy[1] + h
- cv2.rectangle(self.im, xy, p2, box_color, -1, cv2.LINE_AA) # filled
- cv2.putText(self.im, text, xy, 0, self.sf, txt_color, thickness=self.tf, lineType=cv2.LINE_AA)
-
- def fromarray(self, im):
- """Update self.im from a numpy array."""
- self.im = im if isinstance(im, Image.Image) else Image.fromarray(im)
- self.draw = ImageDraw.Draw(self.im)
-
- def result(self):
- """Return annotated image as array."""
- return np.asarray(self.im)
-
- def show(self, title: str | None = None):
- """Show the annotated image."""
- im = Image.fromarray(np.asarray(self.im)[..., ::-1]) # Convert numpy array to PIL Image with RGB to BGR
- if IS_COLAB or IS_KAGGLE: # can not use IS_JUPYTER as will run for all ipython environments
- try:
- display(im) # noqa - display() function only available in ipython environments
- except ImportError as e:
- LOGGER.warning(f"Unable to display image in Jupyter notebooks: {e}")
- else:
- im.show(title=title)
-
- def save(self, filename: str = "image.jpg"):
- """Save the annotated image to 'filename'."""
- cv2.imwrite(filename, np.asarray(self.im))
-
- @staticmethod
- def get_bbox_dimension(bbox: tuple | None = None):
- """
- Calculate the dimensions and area of a bounding box.
-
- Args:
- bbox (tuple): Bounding box coordinates in the format (x_min, y_min, x_max, y_max).
-
- Returns:
- width (float): Width of the bounding box.
- height (float): Height of the bounding box.
- area (float): Area enclosed by the bounding box.
-
- Examples:
- >>> from ultralytics.utils.plotting import Annotator
- >>> im0 = cv2.imread("test.png")
- >>> annotator = Annotator(im0, line_width=10)
- >>> annotator.get_bbox_dimension(bbox=[10, 20, 30, 40])
- """
- x_min, y_min, x_max, y_max = bbox
- width = x_max - x_min
- height = y_max - y_min
- return width, height, width * height
-
-
-@TryExcept()
-@plt_settings()
-def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
- """
- Plot training labels including class histograms and box statistics.
-
- Args:
- boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height].
- cls (np.ndarray): Class indices.
- names (dict, optional): Dictionary mapping class indices to class names.
- save_dir (Path, optional): Directory to save the plot.
- on_plot (Callable, optional): Function to call after plot is saved.
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
- import polars
- from matplotlib.colors import LinearSegmentedColormap
-
- # Filter matplotlib>=3.7.2 warning
- warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
- warnings.filterwarnings("ignore", category=FutureWarning)
-
- # Plot dataset labels
- LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
- nc = int(cls.max() + 1) # number of classes
- boxes = boxes[:1000000] # limit to 1M boxes
- x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"])
-
- # Matplotlib labels
- subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
- ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
- y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
- for i in range(nc):
- y[2].patches[i].set_color([x / 255 for x in colors(i)])
- ax[0].set_ylabel("instances")
- if 0 < len(names) < 30:
- ax[0].set_xticks(range(len(names)))
- ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
- ax[0].bar_label(y[2])
- else:
- ax[0].set_xlabel("classes")
- boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
- img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
- for cls, box in zip(cls[:500], boxes[:500]):
- ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(cls)) # plot
- ax[1].imshow(img)
- ax[1].axis("off")
-
- ax[2].hist2d(x["x"], x["y"], bins=50, cmap=subplot_3_4_color)
- ax[2].set_xlabel("x")
- ax[2].set_ylabel("y")
- ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color)
- ax[3].set_xlabel("width")
- ax[3].set_ylabel("height")
- for a in {0, 1, 2, 3}:
- for s in {"top", "right", "left", "bottom"}:
- ax[a].spines[s].set_visible(False)
-
- fname = save_dir / "labels.jpg"
- plt.savefig(fname, dpi=200)
- plt.close()
- if on_plot:
- on_plot(fname)
-
-
-def save_one_box(
- xyxy,
- im,
- file: Path = Path("im.jpg"),
- gain: float = 1.02,
- pad: int = 10,
- square: bool = False,
- BGR: bool = False,
- save: bool = True,
-):
- """
- Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop.
-
- This function takes a bounding box and an image, and then saves a cropped portion of the image according
- to the bounding box. Optionally, the crop can be squared, and the function allows for gain and padding
- adjustments to the bounding box.
-
- Args:
- xyxy (torch.Tensor | list): A tensor or list representing the bounding box in xyxy format.
- im (np.ndarray): The input image.
- file (Path, optional): The path where the cropped image will be saved.
- gain (float, optional): A multiplicative factor to increase the size of the bounding box.
- pad (int, optional): The number of pixels to add to the width and height of the bounding box.
- square (bool, optional): If True, the bounding box will be transformed into a square.
- BGR (bool, optional): If True, the image will be returned in BGR format, otherwise in RGB.
- save (bool, optional): If True, the cropped image will be saved to disk.
-
- Returns:
- (np.ndarray): The cropped image.
-
- Examples:
- >>> from ultralytics.utils.plotting import save_one_box
- >>> xyxy = [50, 50, 150, 150]
- >>> im = cv2.imread("image.jpg")
- >>> cropped_im = save_one_box(xyxy, im, file="cropped.jpg", square=True)
- """
- if not isinstance(xyxy, torch.Tensor): # may be list
- xyxy = torch.stack(xyxy)
- b = ops.xyxy2xywh(xyxy.view(-1, 4)) # boxes
- if square:
- b[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1) # attempt rectangle to square
- b[:, 2:] = b[:, 2:] * gain + pad # box wh * gain + pad
- xyxy = ops.xywh2xyxy(b).long()
- xyxy = ops.clip_boxes(xyxy, im.shape)
- grayscale = im.shape[2] == 1 # grayscale image
- crop = im[int(xyxy[0, 1]) : int(xyxy[0, 3]), int(xyxy[0, 0]) : int(xyxy[0, 2]), :: (1 if BGR or grayscale else -1)]
- if save:
- file.parent.mkdir(parents=True, exist_ok=True) # make directory
- f = str(increment_path(file).with_suffix(".jpg"))
- # cv2.imwrite(f, crop) # save BGR, https://github.com/ultralytics/yolov5/issues/7007 chroma subsampling issue
- crop = crop.squeeze(-1) if grayscale else crop[..., ::-1] if BGR else crop
- Image.fromarray(crop).save(f, quality=95, subsampling=0) # save RGB
- return crop
-
-
-@threaded
-def plot_images(
- labels: dict[str, Any],
- images: torch.Tensor | np.ndarray = np.zeros((0, 3, 640, 640), dtype=np.float32),
- paths: list[str] | None = None,
- fname: str = "images.jpg",
- names: dict[int, str] | None = None,
- on_plot: Callable | None = None,
- max_size: int = 1920,
- max_subplots: int = 16,
- save: bool = True,
- conf_thres: float = 0.25,
-) -> np.ndarray | None:
- """
- Plot image grid with labels, bounding boxes, masks, and keypoints.
-
- Args:
- labels (dict[str, Any]): Dictionary containing detection data with keys like 'cls', 'bboxes', 'conf', 'masks', 'keypoints', 'batch_idx', 'img'.
- images (torch.Tensor | np.ndarray]): Batch of images to plot. Shape: (batch_size, channels, height, width).
- paths (Optional[list[str]]): List of file paths for each image in the batch.
- fname (str): Output filename for the plotted image grid.
- names (Optional[dict[int, str]]): Dictionary mapping class indices to class names.
- on_plot (Optional[Callable]): Optional callback function to be called after saving the plot.
- max_size (int): Maximum size of the output image grid.
- max_subplots (int): Maximum number of subplots in the image grid.
- save (bool): Whether to save the plotted image grid to a file.
- conf_thres (float): Confidence threshold for displaying detections.
-
- Returns:
- (np.ndarray): Plotted image grid as a numpy array if save is False, None otherwise.
-
- Note:
- This function supports both tensor and numpy array inputs. It will automatically
- convert tensor inputs to numpy arrays for processing.
- """
- for k in {"cls", "bboxes", "conf", "masks", "keypoints", "batch_idx", "images"}:
- if k not in labels:
- continue
- if k == "cls" and labels[k].ndim == 2:
- labels[k] = labels[k].squeeze(1) # squeeze if shape is (n, 1)
- if isinstance(labels[k], torch.Tensor):
- labels[k] = labels[k].cpu().numpy()
-
- cls = labels.get("cls", np.zeros(0, dtype=np.int64))
- batch_idx = labels.get("batch_idx", np.zeros(cls.shape, dtype=np.int64))
- bboxes = labels.get("bboxes", np.zeros(0, dtype=np.float32))
- confs = labels.get("conf", None)
- masks = labels.get("masks", np.zeros(0, dtype=np.uint8))
- kpts = labels.get("keypoints", np.zeros(0, dtype=np.float32))
- images = labels.get("img", images) # default to input images
-
- if len(images) and isinstance(images, torch.Tensor):
- images = images.cpu().float().numpy()
- if images.shape[1] > 3:
- images = images[:, :3] # crop multispectral images to first 3 channels
-
- bs, _, h, w = images.shape # batch size, _, height, width
- bs = min(bs, max_subplots) # limit plot images
- ns = np.ceil(bs**0.5) # number of subplots (square)
- if np.max(images[0]) <= 1:
- images *= 255 # de-normalise (optional)
-
- # Build Image
- mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init
- for i in range(bs):
- x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
- mosaic[y : y + h, x : x + w, :] = images[i].transpose(1, 2, 0)
-
- # Resize (optional)
- scale = max_size / ns / max(h, w)
- if scale < 1:
- h = math.ceil(scale * h)
- w = math.ceil(scale * w)
- mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h)))
-
- # Annotate
- fs = int((h + w) * ns * 0.01) # font size
- fs = max(fs, 18) # ensure that the font size is large enough to be easily readable.
- annotator = Annotator(mosaic, line_width=round(fs / 10), font_size=fs, pil=True, example=str(names))
- for i in range(bs):
- x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
- annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
- if paths:
- annotator.text([x + 5, y + 5], text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
- if len(cls) > 0:
- idx = batch_idx == i
- classes = cls[idx].astype("int")
- labels = confs is None
-
- if len(bboxes):
- boxes = bboxes[idx]
- conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
- if len(boxes):
- if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
- boxes[..., [0, 2]] *= w # scale to pixels
- boxes[..., [1, 3]] *= h
- elif scale < 1: # absolute coords need scale if image scales
- boxes[..., :4] *= scale
- boxes[..., 0] += x
- boxes[..., 1] += y
- is_obb = boxes.shape[-1] == 5 # xywhr
- # TODO: this transformation might be unnecessary
- boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
- for j, box in enumerate(boxes.astype(np.int64).tolist()):
- c = classes[j]
- color = colors(c)
- c = names.get(c, c) if names else c
- if labels or conf[j] > conf_thres:
- label = f"{c}" if labels else f"{c} {conf[j]:.1f}"
- annotator.box_label(box, label, color=color)
-
- elif len(classes):
- for c in classes:
- color = colors(c)
- c = names.get(c, c) if names else c
- annotator.text([x, y], f"{c}", txt_color=color, box_color=(64, 64, 64, 128))
-
- # Plot keypoints
- if len(kpts):
- kpts_ = kpts[idx].copy()
- if len(kpts_):
- if kpts_[..., 0].max() <= 1.01 or kpts_[..., 1].max() <= 1.01: # if normalized with tolerance .01
- kpts_[..., 0] *= w # scale to pixels
- kpts_[..., 1] *= h
- elif scale < 1: # absolute coords need scale if image scales
- kpts_ *= scale
- kpts_[..., 0] += x
- kpts_[..., 1] += y
- for j in range(len(kpts_)):
- if labels or conf[j] > conf_thres:
- annotator.kpts(kpts_[j], conf_thres=conf_thres)
-
- # Plot masks
- if len(masks):
- if idx.shape[0] == masks.shape[0] and masks.max() <= 1: # overlap_mask=False
- image_masks = masks[idx]
- else: # overlap_mask=True
- image_masks = masks[[i]] # (1, 640, 640)
- nl = idx.sum()
- index = np.arange(1, nl + 1).reshape((nl, 1, 1))
- image_masks = (image_masks == index).astype(np.float32)
-
- im = np.asarray(annotator.im).copy()
- for j in range(len(image_masks)):
- if labels or conf[j] > conf_thres:
- color = colors(classes[j])
- mh, mw = image_masks[j].shape
- if mh != h or mw != w:
- mask = image_masks[j].astype(np.uint8)
- mask = cv2.resize(mask, (w, h))
- mask = mask.astype(bool)
- else:
- mask = image_masks[j].astype(bool)
- try:
- im[y : y + h, x : x + w, :][mask] = (
- im[y : y + h, x : x + w, :][mask] * 0.4 + np.array(color) * 0.6
- )
- except Exception:
- pass
- annotator.fromarray(im)
- if not save:
- return np.asarray(annotator.im)
- annotator.im.save(fname) # save
- if on_plot:
- on_plot(fname)
-
-
-@plt_settings()
-def plot_results(file: str = "path/to/results.csv", dir: str = "", on_plot: Callable | None = None):
- """
- Plot training results from a results CSV file. The function supports various types of data including segmentation,
- pose estimation, and classification. Plots are saved as 'results.png' in the directory where the CSV is located.
-
- Args:
- file (str, optional): Path to the CSV file containing the training results.
- dir (str, optional): Directory where the CSV file is located if 'file' is not provided.
- on_plot (callable, optional): Callback function to be executed after plotting. Takes filename as an argument.
-
- Examples:
- >>> from ultralytics.utils.plotting import plot_results
- >>> plot_results("path/to/results.csv", segment=True)
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
- import polars as pl
- from scipy.ndimage import gaussian_filter1d
-
- save_dir = Path(file).parent if file else Path(dir)
- files = list(save_dir.glob("results*.csv"))
- assert len(files), f"No results.csv files found in {save_dir.resolve()}, nothing to plot."
-
- loss_keys, metric_keys = [], []
- for i, f in enumerate(files):
- try:
- data = pl.read_csv(f, infer_schema_length=None)
- if i == 0:
- for c in data.columns:
- if "loss" in c:
- loss_keys.append(c)
- elif "metric" in c:
- metric_keys.append(c)
- loss_mid, metric_mid = len(loss_keys) // 2, len(metric_keys) // 2
- columns = (
- loss_keys[:loss_mid] + metric_keys[:metric_mid] + loss_keys[loss_mid:] + metric_keys[metric_mid:]
- )
- fig, ax = plt.subplots(2, len(columns) // 2, figsize=(len(columns) + 2, 6), tight_layout=True)
- ax = ax.ravel()
- x = data.select(data.columns[0]).to_numpy().flatten()
- for i, j in enumerate(columns):
- y = data.select(j).to_numpy().flatten().astype("float")
- ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
- ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
- ax[i].set_title(j, fontsize=12)
- except Exception as e:
- LOGGER.error(f"Plotting error for {f}: {e}")
- ax[1].legend()
- fname = save_dir / "results.png"
- fig.savefig(fname, dpi=200)
- plt.close()
- if on_plot:
- on_plot(fname)
-
-
-def plt_color_scatter(v, f, bins: int = 20, cmap: str = "viridis", alpha: float = 0.8, edgecolors: str = "none"):
- """
- Plot a scatter plot with points colored based on a 2D histogram.
-
- Args:
- v (array-like): Values for the x-axis.
- f (array-like): Values for the y-axis.
- bins (int, optional): Number of bins for the histogram.
- cmap (str, optional): Colormap for the scatter plot.
- alpha (float, optional): Alpha for the scatter plot.
- edgecolors (str, optional): Edge colors for the scatter plot.
-
- Examples:
- >>> v = np.random.rand(100)
- >>> f = np.random.rand(100)
- >>> plt_color_scatter(v, f)
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- # Calculate 2D histogram and corresponding colors
- hist, xedges, yedges = np.histogram2d(v, f, bins=bins)
- colors = [
- hist[
- min(np.digitize(v[i], xedges, right=True) - 1, hist.shape[0] - 1),
- min(np.digitize(f[i], yedges, right=True) - 1, hist.shape[1] - 1),
- ]
- for i in range(len(v))
- ]
-
- # Scatter plot
- plt.scatter(v, f, c=colors, cmap=cmap, alpha=alpha, edgecolors=edgecolors)
-
-
-@plt_settings()
-def plot_tune_results(csv_file: str = "tune_results.csv", exclude_zero_fitness_points: bool = True):
- """
- Plot the evolution results stored in a 'tune_results.csv' file. The function generates a scatter plot for each key
- in the CSV, color-coded based on fitness scores. The best-performing configurations are highlighted on the plots.
-
- Args:
- csv_file (str, optional): Path to the CSV file containing the tuning results.
- exclude_zero_fitness_points (bool, optional): Don't include points with zero fitness in tuning plots.
-
- Examples:
- >>> plot_tune_results("path/to/tune_results.csv")
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
- import polars as pl
- from scipy.ndimage import gaussian_filter1d
-
- def _save_one_file(file):
- """Save one matplotlib plot to 'file'."""
- plt.savefig(file, dpi=200)
- plt.close()
- LOGGER.info(f"Saved {file}")
-
- # Scatter plots for each hyperparameter
- csv_file = Path(csv_file)
- data = pl.read_csv(csv_file, infer_schema_length=None)
- num_metrics_columns = 1
- keys = [x.strip() for x in data.columns][num_metrics_columns:]
- x = data.to_numpy()
- fitness = x[:, 0] # fitness
- if exclude_zero_fitness_points:
- mask = fitness > 0 # exclude zero-fitness points
- x, fitness = x[mask], fitness[mask]
- j = np.argmax(fitness) # max fitness index
- n = math.ceil(len(keys) ** 0.5) # columns and rows in plot
- plt.figure(figsize=(10, 10), tight_layout=True)
- for i, k in enumerate(keys):
- v = x[:, i + num_metrics_columns]
- mu = v[j] # best single result
- plt.subplot(n, n, i + 1)
- plt_color_scatter(v, fitness, cmap="viridis", alpha=0.8, edgecolors="none")
- plt.plot(mu, fitness.max(), "k+", markersize=15)
- plt.title(f"{k} = {mu:.3g}", fontdict={"size": 9}) # limit to 40 characters
- plt.tick_params(axis="both", labelsize=8) # Set axis label size to 8
- if i % n != 0:
- plt.yticks([])
- _save_one_file(csv_file.with_name("tune_scatter_plots.png"))
-
- # Fitness vs iteration
- x = range(1, len(fitness) + 1)
- plt.figure(figsize=(10, 6), tight_layout=True)
- plt.plot(x, fitness, marker="o", linestyle="none", label="fitness")
- plt.plot(x, gaussian_filter1d(fitness, sigma=3), ":", label="smoothed", linewidth=2) # smoothing line
- plt.title("Fitness vs Iteration")
- plt.xlabel("Iteration")
- plt.ylabel("Fitness")
- plt.grid(True)
- plt.legend()
- _save_one_file(csv_file.with_name("tune_fitness.png"))
-
-
-@plt_settings()
-def feature_visualization(x, module_type: str, stage: int, n: int = 32, save_dir: Path = Path("runs/detect/exp")):
- """
- Visualize feature maps of a given model module during inference.
-
- Args:
- x (torch.Tensor): Features to be visualized.
- module_type (str): Module type.
- stage (int): Module stage within the model.
- n (int, optional): Maximum number of feature maps to plot.
- save_dir (Path, optional): Directory to save results.
- """
- import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
-
- for m in {"Detect", "Segment", "Pose", "Classify", "OBB", "RTDETRDecoder"}: # all model heads
- if m in module_type:
- return
- if isinstance(x, torch.Tensor):
- _, channels, height, width = x.shape # batch, channels, height, width
- if height > 1 and width > 1:
- f = save_dir / f"stage{stage}_{module_type.rsplit('.', 1)[-1]}_features.png" # filename
-
- blocks = torch.chunk(x[0].cpu(), channels, dim=0) # select batch index 0, block by channels
- n = min(n, channels) # number of plots
- _, ax = plt.subplots(math.ceil(n / 8), 8, tight_layout=True) # 8 rows x n/8 cols
- ax = ax.ravel()
- plt.subplots_adjust(wspace=0.05, hspace=0.05)
- for i in range(n):
- ax[i].imshow(blocks[i].squeeze()) # cmap='gray'
- ax[i].axis("off")
-
- LOGGER.info(f"Saving {f}... ({n}/{channels})")
- plt.savefig(f, dpi=300, bbox_inches="tight")
- plt.close()
- np.save(str(f.with_suffix(".npy")), x[0].cpu().numpy()) # npy save
diff --git a/ultralytics/utils/tal.py b/ultralytics/utils/tal.py
deleted file mode 100644
index 580ce2a..0000000
--- a/ultralytics/utils/tal.py
+++ /dev/null
@@ -1,417 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-import torch
-import torch.nn as nn
-
-from . import LOGGER
-from .metrics import bbox_iou, probiou
-from .ops import xywhr2xyxyxyxy
-from .torch_utils import TORCH_1_11
-
-
-class TaskAlignedAssigner(nn.Module):
- """
- A task-aligned assigner for object detection.
-
- This class assigns ground-truth (gt) objects to anchors based on the task-aligned metric, which combines both
- classification and localization information.
-
- Attributes:
- topk (int): The number of top candidates to consider.
- num_classes (int): The number of object classes.
- alpha (float): The alpha parameter for the classification component of the task-aligned metric.
- beta (float): The beta parameter for the localization component of the task-aligned metric.
- eps (float): A small value to prevent division by zero.
- """
-
- def __init__(self, topk: int = 13, num_classes: int = 80, alpha: float = 1.0, beta: float = 6.0, eps: float = 1e-9):
- """
- Initialize a TaskAlignedAssigner object with customizable hyperparameters.
-
- Args:
- topk (int, optional): The number of top candidates to consider.
- num_classes (int, optional): The number of object classes.
- alpha (float, optional): The alpha parameter for the classification component of the task-aligned metric.
- beta (float, optional): The beta parameter for the localization component of the task-aligned metric.
- eps (float, optional): A small value to prevent division by zero.
- """
- super().__init__()
- self.topk = topk
- self.num_classes = num_classes
- self.alpha = alpha
- self.beta = beta
- self.eps = eps
-
- @torch.no_grad()
- def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
- """
- Compute the task-aligned assignment.
-
- Args:
- pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
- pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
- anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
- gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
- gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
- mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
-
- Returns:
- target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
- target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
- target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
- fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
- target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
-
- References:
- https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
- """
- self.bs = pd_scores.shape[0]
- self.n_max_boxes = gt_bboxes.shape[1]
- device = gt_bboxes.device
-
- if self.n_max_boxes == 0:
- return (
- torch.full_like(pd_scores[..., 0], self.num_classes),
- torch.zeros_like(pd_bboxes),
- torch.zeros_like(pd_scores),
- torch.zeros_like(pd_scores[..., 0]),
- torch.zeros_like(pd_scores[..., 0]),
- )
-
- try:
- return self._forward(pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)
- except torch.cuda.OutOfMemoryError:
- # Move tensors to CPU, compute, then move back to original device
- LOGGER.warning("CUDA OutOfMemoryError in TaskAlignedAssigner, using CPU")
- cpu_tensors = [t.cpu() for t in (pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt)]
- result = self._forward(*cpu_tensors)
- return tuple(t.to(device) for t in result)
-
- def _forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
- """
- Compute the task-aligned assignment.
-
- Args:
- pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
- pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
- anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
- gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
- gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
- mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
-
- Returns:
- target_labels (torch.Tensor): Target labels with shape (bs, num_total_anchors).
- target_bboxes (torch.Tensor): Target bounding boxes with shape (bs, num_total_anchors, 4).
- target_scores (torch.Tensor): Target scores with shape (bs, num_total_anchors, num_classes).
- fg_mask (torch.Tensor): Foreground mask with shape (bs, num_total_anchors).
- target_gt_idx (torch.Tensor): Target ground truth indices with shape (bs, num_total_anchors).
- """
- mask_pos, align_metric, overlaps = self.get_pos_mask(
- pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt
- )
-
- target_gt_idx, fg_mask, mask_pos = self.select_highest_overlaps(mask_pos, overlaps, self.n_max_boxes)
-
- # Assigned target
- target_labels, target_bboxes, target_scores = self.get_targets(gt_labels, gt_bboxes, target_gt_idx, fg_mask)
-
- # Normalize
- align_metric *= mask_pos
- pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) # b, max_num_obj
- pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) # b, max_num_obj
- norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
- target_scores = target_scores * norm_align_metric
-
- return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
-
- def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points, mask_gt):
- """
- Get positive mask for each ground truth box.
-
- Args:
- pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
- pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
- gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
- gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
- anc_points (torch.Tensor): Anchor points with shape (num_total_anchors, 2).
- mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, 1).
-
- Returns:
- mask_pos (torch.Tensor): Positive mask with shape (bs, max_num_obj, h*w).
- align_metric (torch.Tensor): Alignment metric with shape (bs, max_num_obj, h*w).
- overlaps (torch.Tensor): Overlaps between predicted and ground truth boxes with shape (bs, max_num_obj, h*w).
- """
- mask_in_gts = self.select_candidates_in_gts(anc_points, gt_bboxes)
- # Get anchor_align metric, (b, max_num_obj, h*w)
- align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts * mask_gt)
- # Get topk_metric mask, (b, max_num_obj, h*w)
- mask_topk = self.select_topk_candidates(align_metric, topk_mask=mask_gt.expand(-1, -1, self.topk).bool())
- # Merge all mask to a final mask, (b, max_num_obj, h*w)
- mask_pos = mask_topk * mask_in_gts * mask_gt
-
- return mask_pos, align_metric, overlaps
-
- def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_gt):
- """
- Compute alignment metric given predicted and ground truth bounding boxes.
-
- Args:
- pd_scores (torch.Tensor): Predicted classification scores with shape (bs, num_total_anchors, num_classes).
- pd_bboxes (torch.Tensor): Predicted bounding boxes with shape (bs, num_total_anchors, 4).
- gt_labels (torch.Tensor): Ground truth labels with shape (bs, n_max_boxes, 1).
- gt_bboxes (torch.Tensor): Ground truth boxes with shape (bs, n_max_boxes, 4).
- mask_gt (torch.Tensor): Mask for valid ground truth boxes with shape (bs, n_max_boxes, h*w).
-
- Returns:
- align_metric (torch.Tensor): Alignment metric combining classification and localization.
- overlaps (torch.Tensor): IoU overlaps between predicted and ground truth boxes.
- """
- na = pd_bboxes.shape[-2]
- mask_gt = mask_gt.bool() # b, max_num_obj, h*w
- overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
- bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
-
- ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj
- ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes) # b, max_num_obj
- ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
- # Get the scores of each grid for each gt cls
- bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt] # b, max_num_obj, h*w
-
- # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
- pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
- gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
- overlaps[mask_gt] = self.iou_calculation(gt_boxes, pd_boxes)
-
- align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
- return align_metric, overlaps
-
- def iou_calculation(self, gt_bboxes, pd_bboxes):
- """
- Calculate IoU for horizontal bounding boxes.
-
- Args:
- gt_bboxes (torch.Tensor): Ground truth boxes.
- pd_bboxes (torch.Tensor): Predicted boxes.
-
- Returns:
- (torch.Tensor): IoU values between each pair of boxes.
- """
- return bbox_iou(gt_bboxes, pd_bboxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
-
- def select_topk_candidates(self, metrics, topk_mask=None):
- """
- Select the top-k candidates based on the given metrics.
-
- Args:
- metrics (torch.Tensor): A tensor of shape (b, max_num_obj, h*w), where b is the batch size, max_num_obj is
- the maximum number of objects, and h*w represents the total number of anchor points.
- topk_mask (torch.Tensor, optional): An optional boolean tensor of shape (b, max_num_obj, topk), where
- topk is the number of top candidates to consider. If not provided, the top-k values are automatically
- computed based on the given metrics.
-
- Returns:
- (torch.Tensor): A tensor of shape (b, max_num_obj, h*w) containing the selected top-k candidates.
- """
- # (b, max_num_obj, topk)
- topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=True)
- if topk_mask is None:
- topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
- # (b, max_num_obj, topk)
- topk_idxs.masked_fill_(~topk_mask, 0)
-
- # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
- count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
- ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
- for k in range(self.topk):
- # Expand topk_idxs for each value of k and add 1 at the specified positions
- count_tensor.scatter_add_(-1, topk_idxs[:, :, k : k + 1], ones)
- # Filter invalid bboxes
- count_tensor.masked_fill_(count_tensor > 1, 0)
-
- return count_tensor.to(metrics.dtype)
-
- def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
- """
- Compute target labels, target bounding boxes, and target scores for the positive anchor points.
-
- Args:
- gt_labels (torch.Tensor): Ground truth labels of shape (b, max_num_obj, 1), where b is the
- batch size and max_num_obj is the maximum number of objects.
- gt_bboxes (torch.Tensor): Ground truth bounding boxes of shape (b, max_num_obj, 4).
- target_gt_idx (torch.Tensor): Indices of the assigned ground truth objects for positive
- anchor points, with shape (b, h*w), where h*w is the total
- number of anchor points.
- fg_mask (torch.Tensor): A boolean tensor of shape (b, h*w) indicating the positive
- (foreground) anchor points.
-
- Returns:
- target_labels (torch.Tensor): Target labels for positive anchor points with shape (b, h*w).
- target_bboxes (torch.Tensor): Target bounding boxes for positive anchor points with shape (b, h*w, 4).
- target_scores (torch.Tensor): Target scores for positive anchor points with shape (b, h*w, num_classes).
- """
- # Assigned target labels, (b, 1)
- batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
- target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes # (b, h*w)
- target_labels = gt_labels.long().flatten()[target_gt_idx] # (b, h*w)
-
- # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
- target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_gt_idx]
-
- # Assigned target scores
- target_labels.clamp_(0)
-
- # 10x faster than F.one_hot()
- target_scores = torch.zeros(
- (target_labels.shape[0], target_labels.shape[1], self.num_classes),
- dtype=torch.int64,
- device=target_labels.device,
- ) # (b, h*w, 80)
- target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
-
- fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes) # (b, h*w, 80)
- target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
-
- return target_labels, target_bboxes, target_scores
-
- @staticmethod
- def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
- """
- Select positive anchor centers within ground truth bounding boxes.
-
- Args:
- xy_centers (torch.Tensor): Anchor center coordinates, shape (h*w, 2).
- gt_bboxes (torch.Tensor): Ground truth bounding boxes, shape (b, n_boxes, 4).
- eps (float, optional): Small value for numerical stability.
-
- Returns:
- (torch.Tensor): Boolean mask of positive anchors, shape (b, n_boxes, h*w).
-
- Note:
- b: batch size, n_boxes: number of ground truth boxes, h: height, w: width.
- Bounding box format: [x_min, y_min, x_max, y_max].
- """
- n_anchors = xy_centers.shape[0]
- bs, n_boxes, _ = gt_bboxes.shape
- lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) # left-top, right-bottom
- bbox_deltas = torch.cat((xy_centers[None] - lt, rb - xy_centers[None]), dim=2).view(bs, n_boxes, n_anchors, -1)
- return bbox_deltas.amin(3).gt_(eps)
-
- @staticmethod
- def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
- """
- Select anchor boxes with highest IoU when assigned to multiple ground truths.
-
- Args:
- mask_pos (torch.Tensor): Positive mask, shape (b, n_max_boxes, h*w).
- overlaps (torch.Tensor): IoU overlaps, shape (b, n_max_boxes, h*w).
- n_max_boxes (int): Maximum number of ground truth boxes.
-
- Returns:
- target_gt_idx (torch.Tensor): Indices of assigned ground truths, shape (b, h*w).
- fg_mask (torch.Tensor): Foreground mask, shape (b, h*w).
- mask_pos (torch.Tensor): Updated positive mask, shape (b, n_max_boxes, h*w).
- """
- # Convert (b, n_max_boxes, h*w) -> (b, h*w)
- fg_mask = mask_pos.sum(-2)
- if fg_mask.max() > 1: # one anchor is assigned to multiple gt_bboxes
- mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1) # (b, n_max_boxes, h*w)
- max_overlaps_idx = overlaps.argmax(1) # (b, h*w)
-
- is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
- is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
-
- mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() # (b, n_max_boxes, h*w)
- fg_mask = mask_pos.sum(-2)
- # Find each grid serve which gt(index)
- target_gt_idx = mask_pos.argmax(-2) # (b, h*w)
- return target_gt_idx, fg_mask, mask_pos
-
-
-class RotatedTaskAlignedAssigner(TaskAlignedAssigner):
- """Assigns ground-truth objects to rotated bounding boxes using a task-aligned metric."""
-
- def iou_calculation(self, gt_bboxes, pd_bboxes):
- """Calculate IoU for rotated bounding boxes."""
- return probiou(gt_bboxes, pd_bboxes).squeeze(-1).clamp_(0)
-
- @staticmethod
- def select_candidates_in_gts(xy_centers, gt_bboxes):
- """
- Select the positive anchor center in gt for rotated bounding boxes.
-
- Args:
- xy_centers (torch.Tensor): Anchor center coordinates with shape (h*w, 2).
- gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (b, n_boxes, 5).
-
- Returns:
- (torch.Tensor): Boolean mask of positive anchors with shape (b, n_boxes, h*w).
- """
- # (b, n_boxes, 5) --> (b, n_boxes, 4, 2)
- corners = xywhr2xyxyxyxy(gt_bboxes)
- # (b, n_boxes, 1, 2)
- a, b, _, d = corners.split(1, dim=-2)
- ab = b - a
- ad = d - a
-
- # (b, n_boxes, h*w, 2)
- ap = xy_centers - a
- norm_ab = (ab * ab).sum(dim=-1)
- norm_ad = (ad * ad).sum(dim=-1)
- ap_dot_ab = (ap * ab).sum(dim=-1)
- ap_dot_ad = (ap * ad).sum(dim=-1)
- return (ap_dot_ab >= 0) & (ap_dot_ab <= norm_ab) & (ap_dot_ad >= 0) & (ap_dot_ad <= norm_ad) # is_in_box
-
-
-def make_anchors(feats, strides, grid_cell_offset=0.5):
- """Generate anchors from features."""
- anchor_points, stride_tensor = [], []
- assert feats is not None
- dtype, device = feats[0].dtype, feats[0].device
- for i, stride in enumerate(strides):
- h, w = feats[i].shape[2:] if isinstance(feats, list) else (int(feats[i][0]), int(feats[i][1]))
- sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
- sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
- sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_11 else torch.meshgrid(sy, sx)
- anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
- stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
- return torch.cat(anchor_points), torch.cat(stride_tensor)
-
-
-def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
- """Transform distance(ltrb) to box(xywh or xyxy)."""
- lt, rb = distance.chunk(2, dim)
- x1y1 = anchor_points - lt
- x2y2 = anchor_points + rb
- if xywh:
- c_xy = (x1y1 + x2y2) / 2
- wh = x2y2 - x1y1
- return torch.cat([c_xy, wh], dim) # xywh bbox
- return torch.cat((x1y1, x2y2), dim) # xyxy bbox
-
-
-def bbox2dist(anchor_points, bbox, reg_max):
- """Transform bbox(xyxy) to dist(ltrb)."""
- x1y1, x2y2 = bbox.chunk(2, -1)
- return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01) # dist (lt, rb)
-
-
-def dist2rbox(pred_dist, pred_angle, anchor_points, dim=-1):
- """
- Decode predicted rotated bounding box coordinates from anchor points and distribution.
-
- Args:
- pred_dist (torch.Tensor): Predicted rotated distance with shape (bs, h*w, 4).
- pred_angle (torch.Tensor): Predicted angle with shape (bs, h*w, 1).
- anchor_points (torch.Tensor): Anchor points with shape (h*w, 2).
- dim (int, optional): Dimension along which to split.
-
- Returns:
- (torch.Tensor): Predicted rotated bounding boxes with shape (bs, h*w, 4).
- """
- lt, rb = pred_dist.split(2, dim=dim)
- cos, sin = torch.cos(pred_angle), torch.sin(pred_angle)
- # (bs, h*w, 1)
- xf, yf = ((rb - lt) / 2).split(1, dim=dim)
- x, y = xf * cos - yf * sin, xf * sin + yf * cos
- xy = torch.cat([x, y], dim=dim) + anchor_points
- return torch.cat([xy, lt + rb], dim=dim)
diff --git a/ultralytics/utils/torch_utils.py b/ultralytics/utils/torch_utils.py
deleted file mode 100644
index 2b757fb..0000000
--- a/ultralytics/utils/torch_utils.py
+++ /dev/null
@@ -1,1010 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import functools
-import gc
-import math
-import os
-import random
-import time
-from contextlib import contextmanager
-from copy import deepcopy
-from datetime import datetime
-from pathlib import Path
-from typing import Any
-
-import numpy as np
-import torch
-import torch.distributed as dist
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ultralytics import __version__
-from ultralytics.utils import (
- DEFAULT_CFG_DICT,
- DEFAULT_CFG_KEYS,
- LOGGER,
- NUM_THREADS,
- PYTHON_VERSION,
- TORCH_VERSION,
- TORCHVISION_VERSION,
- WINDOWS,
- colorstr,
-)
-from ultralytics.utils.checks import check_version
-from ultralytics.utils.cpu import CPUInfo
-from ultralytics.utils.patches import torch_load
-
-# Version checks (all default to version>=min_version)
-TORCH_1_9 = check_version(TORCH_VERSION, "1.9.0")
-TORCH_1_10 = check_version(TORCH_VERSION, "1.10.0")
-TORCH_1_11 = check_version(TORCH_VERSION, "1.11.0")
-TORCH_1_13 = check_version(TORCH_VERSION, "1.13.0")
-TORCH_2_0 = check_version(TORCH_VERSION, "2.0.0")
-TORCH_2_1 = check_version(TORCH_VERSION, "2.1.0")
-TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
-TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
-TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
-TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
-TORCHVISION_0_18 = check_version(TORCHVISION_VERSION, "0.18.0")
-if WINDOWS and check_version(TORCH_VERSION, "==2.4.0"): # reject version 2.4.0 on Windows
- LOGGER.warning(
- "Known issue with torch==2.4.0 on Windows with CPU, recommend upgrading to torch>=2.4.1 to resolve "
- "https://github.com/ultralytics/ultralytics/issues/15049"
- )
-
-
-@contextmanager
-def torch_distributed_zero_first(local_rank: int):
- """Ensure all processes in distributed training wait for the local master (rank 0) to complete a task first."""
- initialized = dist.is_available() and dist.is_initialized()
- use_ids = initialized and dist.get_backend() == "nccl"
-
- if initialized and local_rank not in {-1, 0}:
- dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
- yield
- if initialized and local_rank == 0:
- dist.barrier(device_ids=[local_rank]) if use_ids else dist.barrier()
-
-
-def smart_inference_mode():
- """Apply torch.inference_mode() decorator if torch>=1.9.0 else torch.no_grad() decorator."""
-
- def decorate(fn):
- """Apply appropriate torch decorator for inference mode based on torch version."""
- if TORCH_1_9 and torch.is_inference_mode_enabled():
- return fn # already in inference_mode, act as a pass-through
- else:
- return (torch.inference_mode if TORCH_1_9 else torch.no_grad)()(fn)
-
- return decorate
-
-
-def autocast(enabled: bool, device: str = "cuda"):
- """
- Get the appropriate autocast context manager based on PyTorch version and AMP setting.
-
- This function returns a context manager for automatic mixed precision (AMP) training that is compatible with both
- older and newer versions of PyTorch. It handles the differences in the autocast API between PyTorch versions.
-
- Args:
- enabled (bool): Whether to enable automatic mixed precision.
- device (str, optional): The device to use for autocast.
-
- Returns:
- (torch.amp.autocast): The appropriate autocast context manager.
-
- Notes:
- - For PyTorch versions 1.13 and newer, it uses `torch.amp.autocast`.
- - For older versions, it uses `torch.cuda.autocast`.
-
- Examples:
- >>> with autocast(enabled=True):
- ... # Your mixed precision operations here
- ... pass
- """
- if TORCH_1_13:
- return torch.amp.autocast(device, enabled=enabled)
- else:
- return torch.cuda.amp.autocast(enabled)
-
-
-@functools.lru_cache
-def get_cpu_info():
- """Return a string with system CPU information, i.e. 'Apple M2'."""
- from ultralytics.utils import PERSISTENT_CACHE # avoid circular import error
-
- if "cpu_info" not in PERSISTENT_CACHE:
- try:
- PERSISTENT_CACHE["cpu_info"] = CPUInfo.name()
- except Exception:
- pass
- return PERSISTENT_CACHE.get("cpu_info", "unknown")
-
-
-@functools.lru_cache
-def get_gpu_info(index):
- """Return a string with system GPU information, i.e. 'Tesla T4, 15102MiB'."""
- properties = torch.cuda.get_device_properties(index)
- return f"{properties.name}, {properties.total_memory / (1 << 20):.0f}MiB"
-
-
-def select_device(device="", newline=False, verbose=True):
- """
- Select the appropriate PyTorch device based on the provided arguments.
-
- The function takes a string specifying the device or a torch.device object and returns a torch.device object
- representing the selected device. The function also validates the number of available devices and raises an
- exception if the requested device(s) are not available.
-
- Args:
- device (str | torch.device, optional): Device string or torch.device object. Options are 'None', 'cpu', or
- 'cuda', or '0' or '0,1,2,3'. Auto-selects the first available GPU, or CPU if no GPU is available.
- newline (bool, optional): If True, adds a newline at the end of the log string.
- verbose (bool, optional): If True, logs the device information.
-
- Returns:
- (torch.device): Selected device.
-
- Examples:
- >>> select_device("cuda:0")
- device(type='cuda', index=0)
-
- >>> select_device("cpu")
- device(type='cpu')
-
- Notes:
- Sets the 'CUDA_VISIBLE_DEVICES' environment variable for specifying which GPUs to use.
- """
- if isinstance(device, torch.device) or str(device).startswith(("tpu", "intel")):
- return device
-
- s = f"Ultralytics {__version__} 🚀 Python-{PYTHON_VERSION} torch-{TORCH_VERSION} "
- device = str(device).lower()
- for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
- device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
-
- # Auto-select GPUs
- if "-1" in device:
- from ultralytics.utils.autodevice import GPUInfo
-
- # Replace each -1 with a selected GPU or remove it
- parts = device.split(",")
- selected = GPUInfo().select_idle_gpu(count=parts.count("-1"), min_memory_fraction=0.2)
- for i in range(len(parts)):
- if parts[i] == "-1":
- parts[i] = str(selected.pop(0)) if selected else ""
- device = ",".join(p for p in parts if p)
-
- cpu = device == "cpu"
- mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
- if cpu or mps:
- os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
- elif device: # non-cpu device requested
- if device == "cuda":
- device = "0"
- if "," in device:
- device = ",".join([x for x in device.split(",") if x]) # remove sequential commas, i.e. "0,,1" -> "0,1"
- visible = os.environ.get("CUDA_VISIBLE_DEVICES", None)
- os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
- if not (torch.cuda.is_available() and torch.cuda.device_count() >= len(device.split(","))):
- LOGGER.info(s)
- install = (
- "See https://pytorch.org/get-started/locally/ for up-to-date torch install instructions if no "
- "CUDA devices are seen by torch.\n"
- if torch.cuda.device_count() == 0
- else ""
- )
- raise ValueError(
- f"Invalid CUDA 'device={device}' requested."
- f" Use 'device=cpu' or pass valid CUDA device(s) if available,"
- f" i.e. 'device=0' or 'device=0,1,2,3' for Multi-GPU.\n"
- f"\ntorch.cuda.is_available(): {torch.cuda.is_available()}"
- f"\ntorch.cuda.device_count(): {torch.cuda.device_count()}"
- f"\nos.environ['CUDA_VISIBLE_DEVICES']: {visible}\n"
- f"{install}"
- )
-
- if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
- devices = device.split(",") if device else "0" # i.e. "0,1" -> ["0", "1"]
- space = " " * len(s)
- for i, d in enumerate(devices):
- s += f"{'' if i == 0 else space}CUDA:{d} ({get_gpu_info(i)})\n" # bytes to MB
- arg = "cuda:0"
- elif mps and TORCH_2_0 and torch.backends.mps.is_available():
- # Prefer MPS if available
- s += f"MPS ({get_cpu_info()})\n"
- arg = "mps"
- else: # revert to CPU
- s += f"CPU ({get_cpu_info()})\n"
- arg = "cpu"
-
- if arg in {"cpu", "mps"}:
- torch.set_num_threads(NUM_THREADS) # reset OMP_NUM_THREADS for cpu training
- if verbose:
- LOGGER.info(s if newline else s.rstrip())
- return torch.device(arg)
-
-
-def time_sync():
- """Return PyTorch-accurate time."""
- if torch.cuda.is_available():
- torch.cuda.synchronize()
- return time.time()
-
-
-def fuse_conv_and_bn(conv, bn):
- """
- Fuse Conv2d and BatchNorm2d layers for inference optimization.
-
- Args:
- conv (nn.Conv2d): Convolutional layer to fuse.
- bn (nn.BatchNorm2d): Batch normalization layer to fuse.
-
- Returns:
- (nn.Conv2d): The fused convolutional layer with gradients disabled.
-
- Example:
- >>> conv = nn.Conv2d(3, 16, 3)
- >>> bn = nn.BatchNorm2d(16)
- >>> fused_conv = fuse_conv_and_bn(conv, bn)
- """
- # Compute fused weights
- w_conv = conv.weight.view(conv.out_channels, -1)
- w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
- conv.weight.data = torch.mm(w_bn, w_conv).view(conv.weight.shape)
-
- # Compute fused bias
- b_conv = torch.zeros(conv.out_channels, device=conv.weight.device) if conv.bias is None else conv.bias
- b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
- fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
-
- if conv.bias is None:
- conv.register_parameter("bias", nn.Parameter(fused_bias))
- else:
- conv.bias.data = fused_bias
-
- return conv.requires_grad_(False)
-
-
-def fuse_deconv_and_bn(deconv, bn):
- """
- Fuse ConvTranspose2d and BatchNorm2d layers for inference optimization.
-
- Args:
- deconv (nn.ConvTranspose2d): Transposed convolutional layer to fuse.
- bn (nn.BatchNorm2d): Batch normalization layer to fuse.
-
- Returns:
- (nn.ConvTranspose2d): The fused transposed convolutional layer with gradients disabled.
-
- Example:
- >>> deconv = nn.ConvTranspose2d(16, 3, 3)
- >>> bn = nn.BatchNorm2d(3)
- >>> fused_deconv = fuse_deconv_and_bn(deconv, bn)
- """
- # Compute fused weights
- w_deconv = deconv.weight.view(deconv.out_channels, -1)
- w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
- deconv.weight.data = torch.mm(w_bn, w_deconv).view(deconv.weight.shape)
-
- # Compute fused bias
- b_conv = torch.zeros(deconv.out_channels, device=deconv.weight.device) if deconv.bias is None else deconv.bias
- b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
- fused_bias = torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn
-
- if deconv.bias is None:
- deconv.register_parameter("bias", nn.Parameter(fused_bias))
- else:
- deconv.bias.data = fused_bias
-
- return deconv.requires_grad_(False)
-
-
-def model_info(model, detailed=False, verbose=True, imgsz=640):
- """
- Print and return detailed model information layer by layer.
-
- Args:
- model (nn.Module): Model to analyze.
- detailed (bool, optional): Whether to print detailed layer information.
- verbose (bool, optional): Whether to print model information.
- imgsz (int | list, optional): Input image size.
-
- Returns:
- n_l (int): Number of layers.
- n_p (int): Number of parameters.
- n_g (int): Number of gradients.
- flops (float): GFLOPs.
- """
- if not verbose:
- return
- n_p = get_num_params(model) # number of parameters
- n_g = get_num_gradients(model) # number of gradients
- layers = __import__("collections").OrderedDict((n, m) for n, m in model.named_modules() if len(m._modules) == 0)
- n_l = len(layers) # number of layers
- if detailed:
- h = f"{'layer':>5}{'name':>40}{'type':>20}{'gradient':>10}{'parameters':>12}{'shape':>20}{'mu':>10}{'sigma':>10}"
- LOGGER.info(h)
- for i, (mn, m) in enumerate(layers.items()):
- mn = mn.replace("module_list.", "")
- mt = m.__class__.__name__
- if len(m._parameters):
- for pn, p in m.named_parameters():
- LOGGER.info(
- f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
- )
- else: # layers with no learnable params
- LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
-
- flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
- fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
- fs = f", {flops:.1f} GFLOPs" if flops else ""
- yaml_file = getattr(model, "yaml_file", "") or getattr(model, "yaml", {}).get("yaml_file", "")
- model_name = Path(yaml_file).stem.replace("yolo", "YOLO") or "Model"
- LOGGER.info(f"{model_name} summary{fused}: {n_l:,} layers, {n_p:,} parameters, {n_g:,} gradients{fs}")
- return n_l, n_p, n_g, flops
-
-
-def get_num_params(model):
- """Return the total number of parameters in a YOLO model."""
- return sum(x.numel() for x in model.parameters())
-
-
-def get_num_gradients(model):
- """Return the total number of parameters with gradients in a YOLO model."""
- return sum(x.numel() for x in model.parameters() if x.requires_grad)
-
-
-def model_info_for_loggers(trainer):
- """
- Return model info dict with useful model information.
-
- Args:
- trainer (ultralytics.engine.trainer.BaseTrainer): The trainer object containing model and validation data.
-
- Returns:
- (dict): Dictionary containing model parameters, GFLOPs, and inference speeds.
-
- Examples:
- YOLOv8n info for loggers
- >>> results = {
- ... "model/parameters": 3151904,
- ... "model/GFLOPs": 8.746,
- ... "model/speed_ONNX(ms)": 41.244,
- ... "model/speed_TensorRT(ms)": 3.211,
- ... "model/speed_PyTorch(ms)": 18.755,
- ...}
- """
- if trainer.args.profile: # profile ONNX and TensorRT times
- from ultralytics.utils.benchmarks import ProfileModels
-
- results = ProfileModels([trainer.last], device=trainer.device).run()[0]
- results.pop("model/name")
- else: # only return PyTorch times from most recent validation
- results = {
- "model/parameters": get_num_params(trainer.model),
- "model/GFLOPs": round(get_flops(trainer.model), 3),
- }
- results["model/speed_PyTorch(ms)"] = round(trainer.validator.speed["inference"], 3)
- return results
-
-
-def get_flops(model, imgsz=640):
- """
- Calculate FLOPs (floating point operations) for a model in billions.
-
- Attempts two calculation methods: first with a stride-based tensor for efficiency,
- then falls back to full image size if needed (e.g., for RTDETR models). Returns 0.0
- if thop library is unavailable or calculation fails.
-
- Args:
- model (nn.Module): The model to calculate FLOPs for.
- imgsz (int | list, optional): Input image size.
-
- Returns:
- (float): The model FLOPs in billions.
- """
- try:
- import thop
- except ImportError:
- thop = None # conda support without 'ultralytics-thop' installed
-
- if not thop:
- return 0.0 # if not installed return 0.0 GFLOPs
-
- try:
- model = unwrap_model(model)
- p = next(model.parameters())
- if not isinstance(imgsz, list):
- imgsz = [imgsz, imgsz] # expand if int/float
- try:
- # Method 1: Use stride-based input tensor
- stride = max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32 # max stride
- im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
- flops = thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # stride GFLOPs
- return flops * imgsz[0] / stride * imgsz[1] / stride # imgsz GFLOPs
- except Exception:
- # Method 2: Use actual image size (required for RTDETR models)
- im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
- return thop.profile(deepcopy(model), inputs=[im], verbose=False)[0] / 1e9 * 2 # imgsz GFLOPs
- except Exception:
- return 0.0
-
-
-def get_flops_with_torch_profiler(model, imgsz=640):
- """
- Compute model FLOPs using torch profiler (alternative to thop package, but 2-10x slower).
-
- Args:
- model (nn.Module): The model to calculate FLOPs for.
- imgsz (int | list, optional): Input image size.
-
- Returns:
- (float): The model's FLOPs in billions.
- """
- if not TORCH_2_0: # torch profiler implemented in torch>=2.0
- return 0.0
- model = unwrap_model(model)
- p = next(model.parameters())
- if not isinstance(imgsz, list):
- imgsz = [imgsz, imgsz] # expand if int/float
- try:
- # Use stride size for input tensor
- stride = (max(int(model.stride.max()), 32) if hasattr(model, "stride") else 32) * 2 # max stride
- im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
- with torch.profiler.profile(with_flops=True) as prof:
- model(im)
- flops = sum(x.flops for x in prof.key_averages()) / 1e9
- flops = flops * imgsz[0] / stride * imgsz[1] / stride # 640x640 GFLOPs
- except Exception:
- # Use actual image size for input tensor (i.e. required for RTDETR models)
- im = torch.empty((1, p.shape[1], *imgsz), device=p.device) # input image in BCHW format
- with torch.profiler.profile(with_flops=True) as prof:
- model(im)
- flops = sum(x.flops for x in prof.key_averages()) / 1e9
- return flops
-
-
-def initialize_weights(model):
- """Initialize model weights to random values."""
- for m in model.modules():
- t = type(m)
- if t is nn.Conv2d:
- pass # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
- elif t is nn.BatchNorm2d:
- m.eps = 1e-3
- m.momentum = 0.03
- elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
- m.inplace = True
-
-
-def scale_img(img, ratio=1.0, same_shape=False, gs=32):
- """
- Scale and pad an image tensor, optionally maintaining aspect ratio and padding to gs multiple.
-
- Args:
- img (torch.Tensor): Input image tensor.
- ratio (float, optional): Scaling ratio.
- same_shape (bool, optional): Whether to maintain the same shape.
- gs (int, optional): Grid size for padding.
-
- Returns:
- (torch.Tensor): Scaled and padded image tensor.
- """
- if ratio == 1.0:
- return img
- h, w = img.shape[2:]
- s = (int(h * ratio), int(w * ratio)) # new size
- img = F.interpolate(img, size=s, mode="bilinear", align_corners=False) # resize
- if not same_shape: # pad/crop img
- h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w))
- return F.pad(img, [0, w - s[1], 0, h - s[0]], value=0.447) # value = imagenet mean
-
-
-def copy_attr(a, b, include=(), exclude=()):
- """
- Copy attributes from object 'b' to object 'a', with options to include/exclude certain attributes.
-
- Args:
- a (Any): Destination object to copy attributes to.
- b (Any): Source object to copy attributes from.
- include (tuple, optional): Attributes to include. If empty, all attributes are included.
- exclude (tuple, optional): Attributes to exclude.
- """
- for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith("_") or k in exclude:
- continue
- else:
- setattr(a, k, v)
-
-
-def intersect_dicts(da, db, exclude=()):
- """
- Return a dictionary of intersecting keys with matching shapes, excluding 'exclude' keys, using da values.
-
- Args:
- da (dict): First dictionary.
- db (dict): Second dictionary.
- exclude (tuple, optional): Keys to exclude.
-
- Returns:
- (dict): Dictionary of intersecting keys with matching shapes.
- """
- return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
-
-
-def is_parallel(model):
- """
- Return True if model is of type DP or DDP.
-
- Args:
- model (nn.Module): Model to check.
-
- Returns:
- (bool): True if model is DataParallel or DistributedDataParallel.
- """
- return isinstance(model, (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel))
-
-
-def unwrap_model(m: nn.Module) -> nn.Module:
- """
- Unwrap compiled and parallel models to get the base model.
-
- Args:
- m (nn.Module): A model that may be wrapped by torch.compile (._orig_mod) or parallel wrappers such as
- DataParallel/DistributedDataParallel (.module).
-
- Returns:
- m (nn.Module): The unwrapped base model without compile or parallel wrappers.
- """
- while True:
- if hasattr(m, "_orig_mod") and isinstance(m._orig_mod, nn.Module):
- m = m._orig_mod
- elif hasattr(m, "module") and isinstance(m.module, nn.Module):
- m = m.module
- else:
- return m
-
-
-def one_cycle(y1=0.0, y2=1.0, steps=100):
- """
- Return a lambda function for sinusoidal ramp from y1 to y2 https://arxiv.org/pdf/1812.01187.pdf.
-
- Args:
- y1 (float, optional): Initial value.
- y2 (float, optional): Final value.
- steps (int, optional): Number of steps.
-
- Returns:
- (function): Lambda function for computing the sinusoidal ramp.
- """
- return lambda x: max((1 - math.cos(x * math.pi / steps)) / 2, 0) * (y2 - y1) + y1
-
-
-def init_seeds(seed=0, deterministic=False):
- """
- Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html.
-
- Args:
- seed (int, optional): Random seed.
- deterministic (bool, optional): Whether to set deterministic algorithms.
- """
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed(seed)
- torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
- # torch.backends.cudnn.benchmark = True # AutoBatch problem https://github.com/ultralytics/yolov5/issues/9287
- if deterministic:
- if TORCH_2_0:
- torch.use_deterministic_algorithms(True, warn_only=True) # warn if deterministic is not possible
- torch.backends.cudnn.deterministic = True
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
- os.environ["PYTHONHASHSEED"] = str(seed)
- else:
- LOGGER.warning("Upgrade to torch>=2.0.0 for deterministic training.")
- else:
- unset_deterministic()
-
-
-def unset_deterministic():
- """Unset all the configurations applied for deterministic training."""
- torch.use_deterministic_algorithms(False)
- torch.backends.cudnn.deterministic = False
- os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)
- os.environ.pop("PYTHONHASHSEED", None)
-
-
-class ModelEMA:
- """
- Updated Exponential Moving Average (EMA) implementation.
-
- Keeps a moving average of everything in the model state_dict (parameters and buffers).
- For EMA details see References.
-
- To disable EMA set the `enabled` attribute to `False`.
-
- Attributes:
- ema (nn.Module): Copy of the model in evaluation mode.
- updates (int): Number of EMA updates.
- decay (function): Decay function that determines the EMA weight.
- enabled (bool): Whether EMA is enabled.
-
- References:
- - https://github.com/rwightman/pytorch-image-models
- - https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- """
-
- def __init__(self, model, decay=0.9999, tau=2000, updates=0):
- """
- Initialize EMA for 'model' with given arguments.
-
- Args:
- model (nn.Module): Model to create EMA for.
- decay (float, optional): Maximum EMA decay rate.
- tau (int, optional): EMA decay time constant.
- updates (int, optional): Initial number of updates.
- """
- self.ema = deepcopy(unwrap_model(model)).eval() # FP32 EMA
- self.updates = updates # number of EMA updates
- self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
- for p in self.ema.parameters():
- p.requires_grad_(False)
- self.enabled = True
-
- def update(self, model):
- """
- Update EMA parameters.
-
- Args:
- model (nn.Module): Model to update EMA from.
- """
- if self.enabled:
- self.updates += 1
- d = self.decay(self.updates)
-
- msd = unwrap_model(model).state_dict() # model state_dict
- for k, v in self.ema.state_dict().items():
- if v.dtype.is_floating_point: # true for FP16 and FP32
- v *= d
- v += (1 - d) * msd[k].detach()
- # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype}, model {msd[k].dtype}'
-
- def update_attr(self, model, include=(), exclude=("process_group", "reducer")):
- """
- Update attributes and save stripped model with optimizer removed.
-
- Args:
- model (nn.Module): Model to update attributes from.
- include (tuple, optional): Attributes to include.
- exclude (tuple, optional): Attributes to exclude.
- """
- if self.enabled:
- copy_attr(self.ema, model, include, exclude)
-
-
-def strip_optimizer(f: str | Path = "best.pt", s: str = "", updates: dict[str, Any] = None) -> dict[str, Any]:
- """
- Strip optimizer from 'f' to finalize training, optionally save as 's'.
-
- Args:
- f (str | Path): File path to model to strip the optimizer from.
- s (str, optional): File path to save the model with stripped optimizer to. If not provided, 'f' will be
- overwritten.
- updates (dict, optional): A dictionary of updates to overlay onto the checkpoint before saving.
-
- Returns:
- (dict): The combined checkpoint dictionary.
-
- Examples:
- >>> from pathlib import Path
- >>> from ultralytics.utils.torch_utils import strip_optimizer
- >>> for f in Path("path/to/model/checkpoints").rglob("*.pt"):
- >>> strip_optimizer(f)
- """
- try:
- x = torch_load(f, map_location=torch.device("cpu"))
- assert isinstance(x, dict), "checkpoint is not a Python dictionary"
- assert "model" in x, "'model' missing from checkpoint"
- except Exception as e:
- LOGGER.warning(f"Skipping {f}, not a valid Ultralytics model: {e}")
- return {}
-
- metadata = {
- "date": datetime.now().isoformat(),
- "version": __version__,
- "license": "AGPL-3.0 License (https://ultralytics.com/license)",
- "docs": "https://docs.ultralytics.com",
- }
-
- # Update model
- if x.get("ema"):
- x["model"] = x["ema"] # replace model with EMA
- if hasattr(x["model"], "args"):
- x["model"].args = dict(x["model"].args) # convert from IterableSimpleNamespace to dict
- if hasattr(x["model"], "criterion"):
- x["model"].criterion = None # strip loss criterion
- x["model"].half() # to FP16
- for p in x["model"].parameters():
- p.requires_grad = False
-
- # Update other keys
- args = {**DEFAULT_CFG_DICT, **x.get("train_args", {})} # combine args
- for k in "optimizer", "best_fitness", "ema", "updates", "scaler": # keys
- x[k] = None
- x["epoch"] = -1
- x["train_args"] = {k: v for k, v in args.items() if k in DEFAULT_CFG_KEYS} # strip non-default keys
- # x['model'].args = x['train_args']
-
- # Save
- combined = {**metadata, **x, **(updates or {})}
- torch.save(combined, s or f) # combine dicts (prefer to the right)
- mb = os.path.getsize(s or f) / 1e6 # file size
- LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
- return combined
-
-
-def convert_optimizer_state_dict_to_fp16(state_dict):
- """
- Convert the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
-
- Args:
- state_dict (dict): Optimizer state dictionary.
-
- Returns:
- (dict): Converted optimizer state dictionary with FP16 tensors.
- """
- for state in state_dict["state"].values():
- for k, v in state.items():
- if k != "step" and isinstance(v, torch.Tensor) and v.dtype is torch.float32:
- state[k] = v.half()
-
- return state_dict
-
-
-@contextmanager
-def cuda_memory_usage(device=None):
- """
- Monitor and manage CUDA memory usage.
-
- This function checks if CUDA is available and, if so, empties the CUDA cache to free up unused memory.
- It then yields a dictionary containing memory usage information, which can be updated by the caller.
- Finally, it updates the dictionary with the amount of memory reserved by CUDA on the specified device.
-
- Args:
- device (torch.device, optional): The CUDA device to query memory usage for.
-
- Yields:
- (dict): A dictionary with a key 'memory' initialized to 0, which will be updated with the reserved memory.
- """
- cuda_info = dict(memory=0)
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- try:
- yield cuda_info
- finally:
- cuda_info["memory"] = torch.cuda.memory_reserved(device)
- else:
- yield cuda_info
-
-
-def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
- """
- Ultralytics speed, memory and FLOPs profiler.
-
- Args:
- input (torch.Tensor | list): Input tensor(s) to profile.
- ops (nn.Module | list): Model or list of operations to profile.
- n (int, optional): Number of iterations to average.
- device (str | torch.device, optional): Device to profile on.
- max_num_obj (int, optional): Maximum number of objects for simulation.
-
- Returns:
- (list): Profile results for each operation.
-
- Examples:
- >>> from ultralytics.utils.torch_utils import profile_ops
- >>> input = torch.randn(16, 3, 640, 640)
- >>> m1 = lambda x: x * torch.sigmoid(x)
- >>> m2 = nn.SiLU()
- >>> profile_ops(input, [m1, m2], n=100) # profile over 100 iterations
- """
- try:
- import thop
- except ImportError:
- thop = None # conda support without 'ultralytics-thop' installed
-
- results = []
- if not isinstance(device, torch.device):
- device = select_device(device)
- LOGGER.info(
- f"{'Params':>12s}{'GFLOPs':>12s}{'GPU_mem (GB)':>14s}{'forward (ms)':>14s}{'backward (ms)':>14s}"
- f"{'input':>24s}{'output':>24s}"
- )
- gc.collect() # attempt to free unused memory
- torch.cuda.empty_cache()
- for x in input if isinstance(input, list) else [input]:
- x = x.to(device)
- x.requires_grad = True
- for m in ops if isinstance(ops, list) else [ops]:
- m = m.to(device) if hasattr(m, "to") else m # device
- m = m.half() if hasattr(m, "half") and isinstance(x, torch.Tensor) and x.dtype is torch.float16 else m
- tf, tb, t = 0, 0, [0, 0, 0] # dt forward, backward
- try:
- flops = thop.profile(deepcopy(m), inputs=[x], verbose=False)[0] / 1e9 * 2 if thop else 0 # GFLOPs
- except Exception:
- flops = 0
-
- try:
- mem = 0
- for _ in range(n):
- with cuda_memory_usage(device) as cuda_info:
- t[0] = time_sync()
- y = m(x)
- t[1] = time_sync()
- try:
- (sum(yi.sum() for yi in y) if isinstance(y, list) else y).sum().backward()
- t[2] = time_sync()
- except Exception: # no backward method
- # print(e) # for debug
- t[2] = float("nan")
- mem += cuda_info["memory"] / 1e9 # (GB)
- tf += (t[1] - t[0]) * 1000 / n # ms per op forward
- tb += (t[2] - t[1]) * 1000 / n # ms per op backward
- if max_num_obj: # simulate training with predictions per image grid (for AutoBatch)
- with cuda_memory_usage(device) as cuda_info:
- torch.randn(
- x.shape[0],
- max_num_obj,
- int(sum((x.shape[-1] / s) * (x.shape[-2] / s) for s in m.stride.tolist())),
- device=device,
- dtype=torch.float32,
- )
- mem += cuda_info["memory"] / 1e9 # (GB)
- s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
- p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
- LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
- results.append([p, flops, mem, tf, tb, s_in, s_out])
- except Exception as e:
- LOGGER.info(e)
- results.append(None)
- finally:
- gc.collect() # attempt to free unused memory
- torch.cuda.empty_cache()
- return results
-
-
-class EarlyStopping:
- """
- Early stopping class that stops training when a specified number of epochs have passed without improvement.
-
- Attributes:
- best_fitness (float): Best fitness value observed.
- best_epoch (int): Epoch where best fitness was observed.
- patience (int): Number of epochs to wait after fitness stops improving before stopping.
- possible_stop (bool): Flag indicating if stopping may occur next epoch.
- """
-
- def __init__(self, patience=50):
- """
- Initialize early stopping object.
-
- Args:
- patience (int, optional): Number of epochs to wait after fitness stops improving before stopping.
- """
- self.best_fitness = 0.0 # i.e. mAP
- self.best_epoch = 0
- self.patience = patience or float("inf") # epochs to wait after fitness stops improving to stop
- self.possible_stop = False # possible stop may occur next epoch
-
- def __call__(self, epoch, fitness):
- """
- Check whether to stop training.
-
- Args:
- epoch (int): Current epoch of training
- fitness (float): Fitness value of current epoch
-
- Returns:
- (bool): True if training should stop, False otherwise
- """
- if fitness is None: # check if fitness=None (happens when val=False)
- return False
-
- if fitness > self.best_fitness or self.best_fitness == 0: # allow for early zero-fitness stage of training
- self.best_epoch = epoch
- self.best_fitness = fitness
- delta = epoch - self.best_epoch # epochs without improvement
- self.possible_stop = delta >= (self.patience - 1) # possible stop may occur next epoch
- stop = delta >= self.patience # stop training if patience exceeded
- if stop:
- prefix = colorstr("EarlyStopping: ")
- LOGGER.info(
- f"{prefix}Training stopped early as no improvement observed in last {self.patience} epochs. "
- f"Best results observed at epoch {self.best_epoch}, best model saved as best.pt.\n"
- f"To update EarlyStopping(patience={self.patience}) pass a new patience value, "
- f"i.e. `patience=300` or use `patience=0` to disable EarlyStopping."
- )
- return stop
-
-
-def attempt_compile(
- model: torch.nn.Module,
- device: torch.device,
- imgsz: int = 640,
- use_autocast: bool = False,
- warmup: bool = False,
- mode: bool | str = "default",
-) -> torch.nn.Module:
- """
- Compile a model with torch.compile and optionally warm up the graph to reduce first-iteration latency.
-
- This utility attempts to compile the provided model using the inductor backend with dynamic shapes enabled and an
- autotuning mode. If compilation is unavailable or fails, the original model is returned unchanged. An optional
- warmup performs a single forward pass on a dummy input to prime the compiled graph and measure compile/warmup time.
-
- Args:
- model (torch.nn.Module): Model to compile.
- device (torch.device): Inference device used for warmup and autocast decisions.
- imgsz (int, optional): Square input size to create a dummy tensor with shape (1, 3, imgsz, imgsz) for warmup.
- use_autocast (bool, optional): Whether to run warmup under autocast on CUDA or MPS devices.
- warmup (bool, optional): Whether to execute a single dummy forward pass to warm up the compiled model.
- mode (bool | str, optional): torch.compile mode. True → "default", False → no compile, or a string like
- "default", "reduce-overhead", "max-autotune-no-cudagraphs".
-
- Returns:
- model (torch.nn.Module): Compiled model if compilation succeeds, otherwise the original unmodified model.
-
- Notes:
- - If the current PyTorch build does not provide torch.compile, the function returns the input model immediately.
- - Warmup runs under torch.inference_mode and may use torch.autocast for CUDA/MPS to align compute precision.
- - CUDA devices are synchronized after warmup to account for asynchronous kernel execution.
-
- Examples:
- >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- >>> # Try to compile and warm up a model with a 640x640 input
- >>> model = attempt_compile(model, device=device, imgsz=640, use_autocast=True, warmup=True)
- """
- if not hasattr(torch, "compile") or not mode:
- return model
-
- if mode is True:
- mode = "default"
- prefix = colorstr("compile:")
- LOGGER.info(f"{prefix} starting torch.compile with '{mode}' mode...")
- if mode == "max-autotune":
- LOGGER.warning(f"{prefix} mode='{mode}' not recommended, using mode='max-autotune-no-cudagraphs' instead")
- mode = "max-autotune-no-cudagraphs"
- t0 = time.perf_counter()
- try:
- model = torch.compile(model, mode=mode, backend="inductor")
- except Exception as e:
- LOGGER.warning(f"{prefix} torch.compile failed, continuing uncompiled: {e}")
- return model
- t_compile = time.perf_counter() - t0
-
- t_warm = 0.0
- if warmup:
- # Use a single dummy tensor to build the graph shape state and reduce first-iteration latency
- dummy = torch.zeros(1, 3, imgsz, imgsz, device=device)
- if use_autocast and device.type == "cuda":
- dummy = dummy.half()
- t1 = time.perf_counter()
- with torch.inference_mode():
- if use_autocast and device.type in {"cuda", "mps"}:
- with torch.autocast(device.type):
- _ = model(dummy)
- else:
- _ = model(dummy)
- if device.type == "cuda":
- torch.cuda.synchronize(device)
- t_warm = time.perf_counter() - t1
-
- total = t_compile + t_warm
- if warmup:
- LOGGER.info(f"{prefix} complete in {total:.1f}s (compile {t_compile:.1f}s + warmup {t_warm:.1f}s)")
- else:
- LOGGER.info(f"{prefix} compile complete in {t_compile:.1f}s (no warmup)")
- return model
diff --git a/ultralytics/utils/tqdm.py b/ultralytics/utils/tqdm.py
deleted file mode 100644
index b6f1fc7..0000000
--- a/ultralytics/utils/tqdm.py
+++ /dev/null
@@ -1,440 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-import os
-import sys
-import time
-from functools import lru_cache
-from typing import IO, Any
-
-
-@lru_cache(maxsize=1)
-def is_noninteractive_console() -> bool:
- """Check for known non-interactive console environments."""
- return "GITHUB_ACTIONS" in os.environ or "RUNPOD_POD_ID" in os.environ
-
-
-class TQDM:
- """
- Lightweight zero-dependency progress bar for Ultralytics.
-
- Provides clean, rich-style progress bars suitable for various environments including Weights & Biases,
- console outputs, and other logging systems. Features zero external dependencies, clean single-line output,
- rich-style progress bars with Unicode block characters, context manager support, iterator protocol support,
- and dynamic description updates.
-
- Attributes:
- iterable (object): Iterable to wrap with progress bar.
- desc (str): Prefix description for the progress bar.
- total (int): Expected number of iterations.
- disable (bool): Whether to disable the progress bar.
- unit (str): String for units of iteration.
- unit_scale (bool): Auto-scale units flag.
- unit_divisor (int): Divisor for unit scaling.
- leave (bool): Whether to leave the progress bar after completion.
- mininterval (float): Minimum time interval between updates.
- initial (int): Initial counter value.
- n (int): Current iteration count.
- closed (bool): Whether the progress bar is closed.
- bar_format (str): Custom bar format string.
- file (object): Output file stream.
-
- Methods:
- update: Update progress by n steps.
- set_description: Set or update the description.
- set_postfix: Set postfix for the progress bar.
- close: Close the progress bar and clean up.
- refresh: Refresh the progress bar display.
- clear: Clear the progress bar from display.
- write: Write a message without breaking the progress bar.
-
- Examples:
- Basic usage with iterator:
- >>> for i in TQDM(range(100)):
- ... time.sleep(0.01)
-
- With custom description:
- >>> pbar = TQDM(range(100), desc="Processing")
- >>> for i in pbar:
- ... pbar.set_description(f"Processing item {i}")
-
- Context manager usage:
- >>> with TQDM(total=100, unit="B", unit_scale=True) as pbar:
- ... for i in range(100):
- ... pbar.update(1)
-
- Manual updates:
- >>> pbar = TQDM(total=100, desc="Training")
- >>> for epoch in range(100):
- ... # Do work
- ... pbar.update(1)
- >>> pbar.close()
- """
-
- # Constants
- MIN_RATE_CALC_INTERVAL = 0.01 # Minimum time interval for rate calculation
- RATE_SMOOTHING_FACTOR = 0.3 # Factor for exponential smoothing of rates
- MAX_SMOOTHED_RATE = 1000000 # Maximum rate to apply smoothing to
- NONINTERACTIVE_MIN_INTERVAL = 60.0 # Minimum interval for non-interactive environments
-
- def __init__(
- self,
- iterable: Any = None,
- desc: str | None = None,
- total: int | None = None,
- leave: bool = True,
- file: IO[str] | None = None,
- mininterval: float = 0.1,
- disable: bool | None = None,
- unit: str = "it",
- unit_scale: bool = True,
- unit_divisor: int = 1000,
- bar_format: str | None = None, # kept for API compatibility; not used for formatting
- initial: int = 0,
- **kwargs,
- ) -> None:
- """
- Initialize the TQDM progress bar with specified configuration options.
-
- Args:
- iterable (object, optional): Iterable to wrap with progress bar.
- desc (str, optional): Prefix description for the progress bar.
- total (int, optional): Expected number of iterations.
- leave (bool, optional): Whether to leave the progress bar after completion.
- file (object, optional): Output file stream for progress display.
- mininterval (float, optional): Minimum time interval between updates (default 0.1s, 60s in GitHub Actions).
- disable (bool, optional): Whether to disable the progress bar. Auto-detected if None.
- unit (str, optional): String for units of iteration (default "it" for items).
- unit_scale (bool, optional): Auto-scale units for bytes/data units.
- unit_divisor (int, optional): Divisor for unit scaling (default 1000).
- bar_format (str, optional): Custom bar format string.
- initial (int, optional): Initial counter value.
- **kwargs (Any): Additional keyword arguments for compatibility (ignored).
-
- Examples:
- >>> pbar = TQDM(range(100), desc="Processing")
- >>> with TQDM(total=1000, unit="B", unit_scale=True) as pbar:
- ... pbar.update(1024) # Updates by 1KB
- """
- # Disable if not verbose
- if disable is None:
- try:
- from ultralytics.utils import LOGGER, VERBOSE
-
- disable = not VERBOSE or LOGGER.getEffectiveLevel() > 20
- except ImportError:
- disable = False
-
- self.iterable = iterable
- self.desc = desc or ""
- self.total = total or (len(iterable) if hasattr(iterable, "__len__") else None) or None # prevent total=0
- self.disable = disable
- self.unit = unit
- self.unit_scale = unit_scale
- self.unit_divisor = unit_divisor
- self.leave = leave
- self.noninteractive = is_noninteractive_console()
- self.mininterval = max(mininterval, self.NONINTERACTIVE_MIN_INTERVAL) if self.noninteractive else mininterval
- self.initial = initial
-
- # Kept for API compatibility (unused for f-string formatting)
- self.bar_format = bar_format
-
- self.file = file or sys.stdout
-
- # Internal state
- self.n = self.initial
- self.last_print_n = self.initial
- self.last_print_t = time.time()
- self.start_t = time.time()
- self.last_rate = 0.0
- self.closed = False
- self.is_bytes = unit_scale and unit in ("B", "bytes")
- self.scales = (
- [(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]
- if self.is_bytes
- else [(1e9, f"G{self.unit}/s"), (1e6, f"M{self.unit}/s"), (1e3, f"K{self.unit}/s")]
- )
-
- if not self.disable and self.total and not self.noninteractive:
- self._display()
-
- def _format_rate(self, rate: float) -> str:
- """Format rate with units."""
- if rate <= 0:
- return ""
- fallback = f"{rate:.1f}B/s" if self.is_bytes else f"{rate:.1f}{self.unit}/s"
- return next((f"{rate / t:.1f}{u}" for t, u in self.scales if rate >= t), fallback)
-
- def _format_num(self, num: int | float) -> str:
- """Format number with optional unit scaling."""
- if not self.unit_scale or not self.is_bytes:
- return str(num)
-
- for unit in ("", "K", "M", "G", "T"):
- if abs(num) < self.unit_divisor:
- return f"{num:3.1f}{unit}B" if unit else f"{num:.0f}B"
- num /= self.unit_divisor
- return f"{num:.1f}PB"
-
- def _format_time(self, seconds: float) -> str:
- """Format time duration."""
- if seconds < 60:
- return f"{seconds:.1f}s"
- elif seconds < 3600:
- return f"{int(seconds // 60)}:{seconds % 60:02.0f}"
- else:
- h, m = int(seconds // 3600), int((seconds % 3600) // 60)
- return f"{h}:{m:02d}:{seconds % 60:02.0f}"
-
- def _generate_bar(self, width: int = 12) -> str:
- """Generate progress bar."""
- if self.total is None:
- return "━" * width if self.closed else "─" * width
-
- frac = min(1.0, self.n / self.total)
- filled = int(frac * width)
- bar = "━" * filled + "─" * (width - filled)
- if filled < width and frac * width - filled > 0.5:
- bar = f"{bar[:filled]}╸{bar[filled + 1 :]}"
- return bar
-
- def _should_update(self, dt: float, dn: int) -> bool:
- """Check if display should update."""
- if self.noninteractive:
- return False
- return (self.total is not None and self.n >= self.total) or (dt >= self.mininterval)
-
- def _display(self, final: bool = False) -> None:
- """Display progress bar."""
- if self.disable or (self.closed and not final):
- return
-
- current_time = time.time()
- dt = current_time - self.last_print_t
- dn = self.n - self.last_print_n
-
- if not final and not self._should_update(dt, dn):
- return
-
- # Calculate rate (avoid crazy numbers)
- if dt > self.MIN_RATE_CALC_INTERVAL:
- rate = dn / dt if dt else 0.0
- # Smooth rate for reasonable values, use raw rate for very high values
- if rate < self.MAX_SMOOTHED_RATE:
- self.last_rate = self.RATE_SMOOTHING_FACTOR * rate + (1 - self.RATE_SMOOTHING_FACTOR) * self.last_rate
- rate = self.last_rate
- else:
- rate = self.last_rate
-
- # At completion, use overall rate
- if self.total and self.n >= self.total:
- overall_elapsed = current_time - self.start_t
- if overall_elapsed > 0:
- rate = self.n / overall_elapsed
-
- # Update counters
- self.last_print_n = self.n
- self.last_print_t = current_time
- elapsed = current_time - self.start_t
-
- # Remaining time
- remaining_str = ""
- if self.total and 0 < self.n < self.total and elapsed > 0:
- est_rate = rate or (self.n / elapsed)
- remaining_str = f"<{self._format_time((self.total - self.n) / est_rate)}"
-
- # Numbers and percent
- if self.total:
- percent = (self.n / self.total) * 100
- n_str = self._format_num(self.n)
- t_str = self._format_num(self.total)
- if self.is_bytes:
- # Collapse suffix only when identical (e.g. "5.4/5.4MB")
- if n_str[-2] == t_str[-2]:
- n_str = n_str.rstrip("KMGTPB") # Remove unit suffix from current if different than total
- else:
- percent = 0.0
- n_str, t_str = self._format_num(self.n), "?"
-
- elapsed_str = self._format_time(elapsed)
- rate_str = self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else "")
-
- bar = self._generate_bar()
-
- # Compose progress line via f-strings (two shapes: with/without total)
- if self.total:
- if self.is_bytes and self.n >= self.total:
- # Completed bytes: show only final size
- progress_str = f"{self.desc}: {percent:.0f}% {bar} {t_str} {rate_str} {elapsed_str}"
- else:
- progress_str = (
- f"{self.desc}: {percent:.0f}% {bar} {n_str}/{t_str} {rate_str} {elapsed_str}{remaining_str}"
- )
- else:
- progress_str = f"{self.desc}: {bar} {n_str} {rate_str} {elapsed_str}"
-
- # Write to output
- try:
- if self.noninteractive:
- # In non-interactive environments, avoid carriage return which creates empty lines
- self.file.write(progress_str)
- else:
- # In interactive terminals, use carriage return and clear line for updating display
- self.file.write(f"\r\033[K{progress_str}")
- self.file.flush()
- except Exception:
- pass
-
- def update(self, n: int = 1) -> None:
- """Update progress by n steps."""
- if not self.disable and not self.closed:
- self.n += n
- self._display()
-
- def set_description(self, desc: str | None) -> None:
- """Set description."""
- self.desc = desc or ""
- if not self.disable:
- self._display()
-
- def set_postfix(self, **kwargs: Any) -> None:
- """Set postfix (appends to description)."""
- if kwargs:
- postfix = ", ".join(f"{k}={v}" for k, v in kwargs.items())
- base_desc = self.desc.split(" | ")[0] if " | " in self.desc else self.desc
- self.set_description(f"{base_desc} | {postfix}")
-
- def close(self) -> None:
- """Close progress bar."""
- if self.closed:
- return
-
- self.closed = True
-
- if not self.disable:
- # Final display
- if self.total and self.n >= self.total:
- self.n = self.total
- self._display(final=True)
-
- # Cleanup
- if self.leave:
- self.file.write("\n")
- else:
- self.file.write("\r\033[K")
-
- try:
- self.file.flush()
- except Exception:
- pass
-
- def __enter__(self) -> TQDM:
- """Enter context manager."""
- return self
-
- def __exit__(self, *args: Any) -> None:
- """Exit context manager and close progress bar."""
- self.close()
-
- def __iter__(self) -> Any:
- """Iterate over the wrapped iterable with progress updates."""
- if self.iterable is None:
- raise TypeError("'NoneType' object is not iterable")
-
- try:
- for item in self.iterable:
- yield item
- self.update(1)
- finally:
- self.close()
-
- def __del__(self) -> None:
- """Destructor to ensure cleanup."""
- try:
- self.close()
- except Exception:
- pass
-
- def refresh(self) -> None:
- """Refresh display."""
- if not self.disable:
- self._display()
-
- def clear(self) -> None:
- """Clear progress bar."""
- if not self.disable:
- try:
- self.file.write("\r\033[K")
- self.file.flush()
- except Exception:
- pass
-
- @staticmethod
- def write(s: str, file: IO[str] | None = None, end: str = "\n") -> None:
- """Static method to write without breaking progress bar."""
- file = file or sys.stdout
- try:
- file.write(s + end)
- file.flush()
- except Exception:
- pass
-
-
-if __name__ == "__main__":
- import time
-
- print("1. Basic progress bar with known total:")
- for i in TQDM(range(3), desc="Known total"):
- time.sleep(0.05)
-
- print("\n2. Manual updates with known total:")
- pbar = TQDM(total=300, desc="Manual updates", unit="files")
- for i in range(300):
- time.sleep(0.03)
- pbar.update(1)
- if i % 10 == 9:
- pbar.set_description(f"Processing batch {i // 10 + 1}")
- pbar.close()
-
- print("\n3. Progress bar with unknown total:")
- pbar = TQDM(desc="Unknown total", unit="items")
- for i in range(25):
- time.sleep(0.08)
- pbar.update(1)
- if i % 5 == 4:
- pbar.set_postfix(processed=i + 1, status="OK")
- pbar.close()
-
- print("\n4. Context manager with unknown total:")
- with TQDM(desc="Processing stream", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
- for i in range(30):
- time.sleep(0.1)
- pbar.update(1024 * 1024 * i) # Simulate processing MB of data
-
- print("\n5. Iterator with unknown length:")
-
- def data_stream():
- """Simulate a data stream of unknown length."""
- import random
-
- for i in range(random.randint(10, 20)):
- yield f"data_chunk_{i}"
-
- for chunk in TQDM(data_stream(), desc="Stream processing", unit="chunks"):
- time.sleep(0.1)
-
- print("\n6. File processing simulation (unknown size):")
-
- def process_files():
- """Simulate processing files of unknown count."""
- return [f"file_{i}.txt" for i in range(18)]
-
- pbar = TQDM(desc="Scanning files", unit="files")
- files = process_files()
- for i, filename in enumerate(files):
- time.sleep(0.06)
- pbar.update(1)
- pbar.set_description(f"Processing {filename}")
- pbar.close()
diff --git a/ultralytics/utils/triton.py b/ultralytics/utils/triton.py
deleted file mode 100644
index 6c122f3..0000000
--- a/ultralytics/utils/triton.py
+++ /dev/null
@@ -1,118 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from __future__ import annotations
-
-from urllib.parse import urlsplit
-
-import numpy as np
-
-
-class TritonRemoteModel:
- """
- Client for interacting with a remote Triton Inference Server model.
-
- This class provides a convenient interface for sending inference requests to a Triton Inference Server
- and processing the responses. Supports both HTTP and gRPC communication protocols.
-
- Attributes:
- endpoint (str): The name of the model on the Triton server.
- url (str): The URL of the Triton server.
- triton_client: The Triton client (either HTTP or gRPC).
- InferInput: The input class for the Triton client.
- InferRequestedOutput: The output request class for the Triton client.
- input_formats (list[str]): The data types of the model inputs.
- np_input_formats (list[type]): The numpy data types of the model inputs.
- input_names (list[str]): The names of the model inputs.
- output_names (list[str]): The names of the model outputs.
- metadata: The metadata associated with the model.
-
- Methods:
- __call__: Call the model with the given inputs and return the outputs.
-
- Examples:
- Initialize a Triton client with HTTP
- >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
-
- Make inference with numpy arrays
- >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
- """
-
- def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
- """
- Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
-
- Arguments may be provided individually or parsed from a collective 'url' argument of the form
- :////
-
- Args:
- url (str): The URL of the Triton server.
- endpoint (str, optional): The name of the model on the Triton server.
- scheme (str, optional): The communication scheme ('http' or 'grpc').
-
- Examples:
- >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
- >>> model = TritonRemoteModel(url="http://localhost:8000/yolov8")
- """
- if not endpoint and not scheme: # Parse all args from URL string
- splits = urlsplit(url)
- endpoint = splits.path.strip("/").split("/", 1)[0]
- scheme = splits.scheme
- url = splits.netloc
-
- self.endpoint = endpoint
- self.url = url
-
- # Choose the Triton client based on the communication scheme
- if scheme == "http":
- import tritonclient.http as client # noqa
-
- self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
- config = self.triton_client.get_model_config(endpoint)
- else:
- import tritonclient.grpc as client # noqa
-
- self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
- config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
-
- # Sort output names alphabetically, i.e. 'output0', 'output1', etc.
- config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
-
- # Define model attributes
- type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
- self.InferRequestedOutput = client.InferRequestedOutput
- self.InferInput = client.InferInput
- self.input_formats = [x["data_type"] for x in config["input"]]
- self.np_input_formats = [type_map[x] for x in self.input_formats]
- self.input_names = [x["name"] for x in config["input"]]
- self.output_names = [x["name"] for x in config["output"]]
- self.metadata = eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
-
- def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]:
- """
- Call the model with the given inputs and return inference results.
-
- Args:
- *inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type
- for the corresponding model input.
-
- Returns:
- (list[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list
- corresponds to one of the model's output tensors.
-
- Examples:
- >>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
- >>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
- """
- infer_inputs = []
- input_format = inputs[0].dtype
- for i, x in enumerate(inputs):
- if x.dtype != self.np_input_formats[i]:
- x = x.astype(self.np_input_formats[i])
- infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
- infer_input.set_data_from_numpy(x)
- infer_inputs.append(infer_input)
-
- infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
- outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
-
- return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
diff --git a/ultralytics/utils/tuner.py b/ultralytics/utils/tuner.py
deleted file mode 100644
index 6b025b5..0000000
--- a/ultralytics/utils/tuner.py
+++ /dev/null
@@ -1,159 +0,0 @@
-# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
-
-from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
-from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
-
-
-def run_ray_tune(
- model,
- space: dict = None,
- grace_period: int = 10,
- gpu_per_trial: int = None,
- max_samples: int = 10,
- **train_args,
-):
- """
- Run hyperparameter tuning using Ray Tune.
-
- Args:
- model (YOLO): Model to run the tuner on.
- space (dict, optional): The hyperparameter search space. If not provided, uses default space.
- grace_period (int, optional): The grace period in epochs of the ASHA scheduler.
- gpu_per_trial (int, optional): The number of GPUs to allocate per trial.
- max_samples (int, optional): The maximum number of trials to run.
- **train_args (Any): Additional arguments to pass to the `train()` method.
-
- Returns:
- (ray.tune.ResultGrid): A ResultGrid containing the results of the hyperparameter search.
-
- Examples:
- >>> from ultralytics import YOLO
- >>> model = YOLO("yolo11n.pt") # Load a YOLO11n model
-
- Start tuning hyperparameters for YOLO11n training on the COCO8 dataset
- >>> result_grid = model.tune(data="coco8.yaml", use_ray=True)
- """
- LOGGER.info("💡 Learn about RayTune at https://docs.ultralytics.com/integrations/ray-tune")
- if train_args is None:
- train_args = {}
-
- try:
- checks.check_requirements("ray[tune]")
-
- import ray
- from ray import tune
- from ray.air import RunConfig
- from ray.air.integrations.wandb import WandbLoggerCallback
- from ray.tune.schedulers import ASHAScheduler
- except ImportError:
- raise ModuleNotFoundError('Ray Tune required but not found. To install run: pip install "ray[tune]"')
-
- try:
- import wandb
-
- assert hasattr(wandb, "__version__")
- except (ImportError, AssertionError):
- wandb = False
-
- checks.check_version(ray.__version__, ">=2.0.0", "ray")
- default_space = {
- # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
- "lr0": tune.uniform(1e-5, 1e-1),
- "lrf": tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf)
- "momentum": tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1
- "weight_decay": tune.uniform(0.0, 0.001), # optimizer weight decay
- "warmup_epochs": tune.uniform(0.0, 5.0), # warmup epochs (fractions ok)
- "warmup_momentum": tune.uniform(0.0, 0.95), # warmup initial momentum
- "box": tune.uniform(0.02, 0.2), # box loss gain
- "cls": tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels)
- "hsv_h": tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction)
- "hsv_s": tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction)
- "hsv_v": tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction)
- "degrees": tune.uniform(0.0, 45.0), # image rotation (+/- deg)
- "translate": tune.uniform(0.0, 0.9), # image translation (+/- fraction)
- "scale": tune.uniform(0.0, 0.9), # image scale (+/- gain)
- "shear": tune.uniform(0.0, 10.0), # image shear (+/- deg)
- "perspective": tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
- "flipud": tune.uniform(0.0, 1.0), # image flip up-down (probability)
- "fliplr": tune.uniform(0.0, 1.0), # image flip left-right (probability)
- "bgr": tune.uniform(0.0, 1.0), # image channel BGR (probability)
- "mosaic": tune.uniform(0.0, 1.0), # image mosaic (probability)
- "mixup": tune.uniform(0.0, 1.0), # image mixup (probability)
- "cutmix": tune.uniform(0.0, 1.0), # image cutmix (probability)
- "copy_paste": tune.uniform(0.0, 1.0), # segment copy-paste (probability)
- }
-
- # Put the model in ray store
- task = model.task
- model_in_store = ray.put(model)
-
- def _tune(config):
- """Train the YOLO model with the specified hyperparameters and return results."""
- model_to_train = ray.get(model_in_store) # get the model from ray store for tuning
- model_to_train.reset_callbacks()
- config.update(train_args)
- results = model_to_train.train(**config)
- return results.results_dict
-
- # Get search space
- if not space and not train_args.get("resume"):
- space = default_space
- LOGGER.warning("Search space not provided, using default search space.")
-
- # Get dataset
- data = train_args.get("data", TASK2DATA[task])
- space["data"] = data
- if "data" not in train_args:
- LOGGER.warning(f'Data not provided, using default "data={data}".')
-
- # Define the trainable function with allocated resources
- trainable_with_resources = tune.with_resources(_tune, {"cpu": NUM_THREADS, "gpu": gpu_per_trial or 0})
-
- # Define the ASHA scheduler for hyperparameter search
- asha_scheduler = ASHAScheduler(
- time_attr="epoch",
- metric=TASK2METRIC[task],
- mode="max",
- max_t=train_args.get("epochs") or DEFAULT_CFG_DICT["epochs"] or 100,
- grace_period=grace_period,
- reduction_factor=3,
- )
-
- # Define the callbacks for the hyperparameter search
- tuner_callbacks = [WandbLoggerCallback(project="YOLOv8-tune")] if wandb else []
-
- # Create the Ray Tune hyperparameter search tuner
- tune_dir = get_save_dir(
- get_cfg(
- DEFAULT_CFG,
- {**train_args, **{"exist_ok": train_args.pop("resume", False)}}, # resume w/ same tune_dir
- ),
- name=train_args.pop("name", "tune"), # runs/{task}/{tune_dir}
- ) # must be absolute dir
- tune_dir.mkdir(parents=True, exist_ok=True)
- if tune.Tuner.can_restore(tune_dir):
- LOGGER.info(f"{colorstr('Tuner: ')} Resuming tuning run {tune_dir}...")
- tuner = tune.Tuner.restore(str(tune_dir), trainable=trainable_with_resources, resume_errored=True)
- else:
- tuner = tune.Tuner(
- trainable_with_resources,
- param_space=space,
- tune_config=tune.TuneConfig(
- scheduler=asha_scheduler,
- num_samples=max_samples,
- trial_name_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
- trial_dirname_creator=lambda trial: f"{trial.trainable_name}_{trial.trial_id}",
- ),
- run_config=RunConfig(callbacks=tuner_callbacks, storage_path=tune_dir.parent, name=tune_dir.name),
- )
-
- # Run the hyperparameter search
- tuner.fit()
-
- # Get the results of the hyperparameter search
- results = tuner.get_results()
-
- # Shut down Ray to clean up workers
- ray.shutdown()
-
- return results