init commit
This commit is contained in:
41
ultralytics/solutions/__init__.py
Normal file
41
ultralytics/solutions/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .ai_gym import AIGym
|
||||
from .analytics import Analytics
|
||||
from .distance_calculation import DistanceCalculation
|
||||
from .heatmap import Heatmap
|
||||
from .instance_segmentation import InstanceSegmentation
|
||||
from .object_blurrer import ObjectBlurrer
|
||||
from .object_counter import ObjectCounter
|
||||
from .object_cropper import ObjectCropper
|
||||
from .parking_management import ParkingManagement, ParkingPtsSelection
|
||||
from .queue_management import QueueManager
|
||||
from .region_counter import RegionCounter
|
||||
from .security_alarm import SecurityAlarm
|
||||
from .similarity_search import SearchApp, VisualAISearch
|
||||
from .speed_estimation import SpeedEstimator
|
||||
from .streamlit_inference import Inference
|
||||
from .trackzone import TrackZone
|
||||
from .vision_eye import VisionEye
|
||||
|
||||
__all__ = (
|
||||
"ObjectCounter",
|
||||
"ObjectCropper",
|
||||
"ObjectBlurrer",
|
||||
"AIGym",
|
||||
"RegionCounter",
|
||||
"SecurityAlarm",
|
||||
"Heatmap",
|
||||
"InstanceSegmentation",
|
||||
"VisionEye",
|
||||
"SpeedEstimator",
|
||||
"DistanceCalculation",
|
||||
"QueueManager",
|
||||
"ParkingManagement",
|
||||
"ParkingPtsSelection",
|
||||
"Analytics",
|
||||
"Inference",
|
||||
"TrackZone",
|
||||
"SearchApp",
|
||||
"VisualAISearch",
|
||||
)
|
||||
114
ultralytics/solutions/ai_gym.py
Normal file
114
ultralytics/solutions/ai_gym.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
|
||||
|
||||
class AIGym(BaseSolution):
|
||||
"""
|
||||
A class to manage gym steps of people in a real-time video stream based on their poses.
|
||||
|
||||
This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
|
||||
repetitions of exercises based on predefined angle thresholds for up and down positions.
|
||||
|
||||
Attributes:
|
||||
states (dict[float, int, str]): Stores per-track angle, count, and stage for workout monitoring.
|
||||
up_angle (float): Angle threshold for considering the 'up' position of an exercise.
|
||||
down_angle (float): Angle threshold for considering the 'down' position of an exercise.
|
||||
kpts (list[int]): Indices of keypoints used for angle calculation.
|
||||
|
||||
Methods:
|
||||
process: Process a frame to detect poses, calculate angles, and count repetitions.
|
||||
|
||||
Examples:
|
||||
>>> gym = AIGym(model="yolo11n-pose.pt")
|
||||
>>> image = cv2.imread("gym_scene.jpg")
|
||||
>>> results = gym.process(image)
|
||||
>>> processed_image = results.plot_im
|
||||
>>> cv2.imshow("Processed Image", processed_image)
|
||||
>>> cv2.waitKey(0)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent class constructor.
|
||||
model (str): Model name or path, defaults to "yolo11n-pose.pt".
|
||||
"""
|
||||
kwargs["model"] = kwargs.get("model", "yolo11n-pose.pt")
|
||||
super().__init__(**kwargs)
|
||||
self.states = defaultdict(lambda: {"angle": 0, "count": 0, "stage": "-"}) # Dict for count, angle and stage
|
||||
|
||||
# Extract details from CFG single time for usage later
|
||||
self.up_angle = float(self.CFG["up_angle"]) # Pose up predefined angle to consider up pose
|
||||
self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose
|
||||
self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Monitor workouts using Ultralytics YOLO Pose Model.
|
||||
|
||||
This function processes an input image to track and analyze human poses for workout monitoring. It uses
|
||||
the YOLO Pose model to detect keypoints, estimate angles, and count repetitions based on predefined
|
||||
angle thresholds.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image for processing.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`,
|
||||
'workout_count' (list of completed reps),
|
||||
'workout_stage' (list of current stages),
|
||||
'workout_angle' (list of angles), and
|
||||
'total_tracks' (total number of tracked individuals).
|
||||
|
||||
Examples:
|
||||
>>> gym = AIGym()
|
||||
>>> image = cv2.imread("workout.jpg")
|
||||
>>> results = gym.process(image)
|
||||
>>> processed_image = results.plot_im
|
||||
"""
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
|
||||
|
||||
if len(self.boxes):
|
||||
kpt_data = self.tracks.keypoints.data
|
||||
|
||||
for i, k in enumerate(kpt_data):
|
||||
state = self.states[self.track_ids[i]] # get state details
|
||||
# Get keypoints and estimate the angle
|
||||
state["angle"] = annotator.estimate_pose_angle(*[k[int(idx)] for idx in self.kpts])
|
||||
annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3)
|
||||
|
||||
# Determine stage and count logic based on angle thresholds
|
||||
if state["angle"] < self.down_angle:
|
||||
if state["stage"] == "up":
|
||||
state["count"] += 1
|
||||
state["stage"] = "down"
|
||||
elif state["angle"] > self.up_angle:
|
||||
state["stage"] = "up"
|
||||
|
||||
# Display angle, count, and stage text
|
||||
if self.show_labels:
|
||||
annotator.plot_angle_and_count_and_stage(
|
||||
angle_text=state["angle"], # angle text for display
|
||||
count_text=state["count"], # count text for workouts
|
||||
stage_text=state["stage"], # stage position text
|
||||
center_kpt=k[int(self.kpts[1])], # center keypoint for display
|
||||
)
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output image, if environment support display
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(
|
||||
plot_im=plot_im,
|
||||
workout_count=[v["count"] for v in self.states.values()],
|
||||
workout_stage=[v["stage"] for v in self.states.values()],
|
||||
workout_angle=[v["angle"] for v in self.states.values()],
|
||||
total_tracks=len(self.track_ids),
|
||||
)
|
||||
265
ultralytics/solutions/analytics.py
Normal file
265
ultralytics/solutions/analytics.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import cycle
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Import a parent class
|
||||
|
||||
|
||||
class Analytics(BaseSolution):
|
||||
"""
|
||||
A class for creating and updating various types of charts for visual analytics.
|
||||
|
||||
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
|
||||
based on object detection and tracking data.
|
||||
|
||||
Attributes:
|
||||
type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
|
||||
x_label (str): Label for the x-axis.
|
||||
y_label (str): Label for the y-axis.
|
||||
bg_color (str): Background color of the chart frame.
|
||||
fg_color (str): Foreground color of the chart frame.
|
||||
title (str): Title of the chart window.
|
||||
max_points (int): Maximum number of data points to display on the chart.
|
||||
fontsize (int): Font size for text display.
|
||||
color_cycle (cycle): Cyclic iterator for chart colors.
|
||||
total_counts (int): Total count of detected objects (used for line charts).
|
||||
clswise_count (dict[str, int]): Dictionary for class-wise object counts.
|
||||
fig (Figure): Matplotlib figure object for the chart.
|
||||
ax (Axes): Matplotlib axes object for the chart.
|
||||
canvas (FigureCanvasAgg): Canvas for rendering the chart.
|
||||
lines (dict): Dictionary to store line objects for area charts.
|
||||
color_mapping (dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.
|
||||
|
||||
Methods:
|
||||
process: Process image data and update the chart.
|
||||
update_graph: Update the chart with new data points.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics(analytics_type="line")
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> results = analytics.process(frame, frame_number=1)
|
||||
>>> cv2.imshow("Analytics", results.plot_im)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize Analytics class with various chart types for visual data representation."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
self.type = self.CFG["analytics_type"] # type of analytics i.e "line", "pie", "bar" or "area" charts.
|
||||
self.x_label = "Classes" if self.type in {"bar", "pie"} else "Frame#"
|
||||
self.y_label = "Total Counts"
|
||||
|
||||
# Predefined data
|
||||
self.bg_color = "#F3F3F3" # background color of frame
|
||||
self.fg_color = "#111E68" # foreground color of frame
|
||||
self.title = "Ultralytics Solutions" # window name
|
||||
self.max_points = 45 # maximum points to be drawn on window
|
||||
self.fontsize = 25 # text font size for display
|
||||
figsize = self.CFG["figsize"] # set output image size i.e (12.8, 7.2) -> w = 1280, h = 720
|
||||
self.color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
|
||||
|
||||
self.total_counts = 0 # count variable for storing total counts i.e. for line
|
||||
self.clswise_count = {} # dictionary for class-wise counts
|
||||
self.update_every = kwargs.get("update_every", 30) # Only update graph every 30 frames by default
|
||||
self.last_plot_im = None # Cache of the last rendered chart
|
||||
|
||||
# Ensure line and area chart
|
||||
if self.type in {"line", "area"}:
|
||||
self.lines = {}
|
||||
self.fig = Figure(facecolor=self.bg_color, figsize=figsize)
|
||||
self.canvas = FigureCanvasAgg(self.fig) # Set common axis properties
|
||||
self.ax = self.fig.add_subplot(111, facecolor=self.bg_color)
|
||||
if self.type == "line":
|
||||
(self.line,) = self.ax.plot([], [], color="cyan", linewidth=self.line_width)
|
||||
elif self.type in {"bar", "pie"}:
|
||||
# Initialize bar or pie plot
|
||||
self.fig, self.ax = plt.subplots(figsize=figsize, facecolor=self.bg_color)
|
||||
self.canvas = FigureCanvasAgg(self.fig) # Set common axis properties
|
||||
self.ax.set_facecolor(self.bg_color)
|
||||
self.color_mapping = {}
|
||||
|
||||
if self.type == "pie": # Ensure pie chart is circular
|
||||
self.ax.axis("equal")
|
||||
|
||||
def process(self, im0: np.ndarray, frame_number: int) -> SolutionResults:
|
||||
"""
|
||||
Process image data and run object tracking to update analytics charts.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image for processing.
|
||||
frame_number (int): Video frame number for plotting the data.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects)
|
||||
and 'classwise_count' (dict, per-class object count).
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If an unsupported chart type is specified.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics(analytics_type="line")
|
||||
>>> frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
>>> results = analytics.process(frame, frame_number=1)
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
if self.type == "line":
|
||||
for _ in self.boxes:
|
||||
self.total_counts += 1
|
||||
update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
|
||||
if update_required:
|
||||
self.last_plot_im = self.update_graph(frame_number=frame_number)
|
||||
plot_im = self.last_plot_im
|
||||
self.total_counts = 0
|
||||
elif self.type in {"pie", "bar", "area"}:
|
||||
from collections import Counter
|
||||
|
||||
self.clswise_count = Counter(self.names[int(cls)] for cls in self.clss)
|
||||
update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
|
||||
if update_required:
|
||||
self.last_plot_im = self.update_graph(
|
||||
frame_number=frame_number, count_dict=self.clswise_count, plot=self.type
|
||||
)
|
||||
plot_im = self.last_plot_im
|
||||
else:
|
||||
raise ModuleNotFoundError(f"{self.type} chart is not supported ❌")
|
||||
|
||||
# return output dictionary with summary for more usage
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), classwise_count=self.clswise_count)
|
||||
|
||||
def update_graph(
|
||||
self, frame_number: int, count_dict: dict[str, int] | None = None, plot: str = "line"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Update the graph with new data for single or multiple classes.
|
||||
|
||||
Args:
|
||||
frame_number (int): The current frame number.
|
||||
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for
|
||||
multiple classes. If None, updates a single line graph.
|
||||
plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Updated image containing the graph.
|
||||
|
||||
Examples:
|
||||
>>> analytics = Analytics(analytics_type="bar")
|
||||
>>> frame_num = 10
|
||||
>>> results_dict = {"person": 5, "car": 3}
|
||||
>>> updated_image = analytics.update_graph(frame_num, results_dict, plot="bar")
|
||||
"""
|
||||
if count_dict is None:
|
||||
# Single line update
|
||||
x_data = np.append(self.line.get_xdata(), float(frame_number))
|
||||
y_data = np.append(self.line.get_ydata(), float(self.total_counts))
|
||||
|
||||
if len(x_data) > self.max_points:
|
||||
x_data, y_data = x_data[-self.max_points :], y_data[-self.max_points :]
|
||||
|
||||
self.line.set_data(x_data, y_data)
|
||||
self.line.set_label("Counts")
|
||||
self.line.set_color("#7b0068") # Pink color
|
||||
self.line.set_marker("*")
|
||||
self.line.set_markersize(self.line_width * 5)
|
||||
else:
|
||||
labels = list(count_dict.keys())
|
||||
counts = list(count_dict.values())
|
||||
if plot == "area":
|
||||
color_cycle = cycle(["#DD00BA", "#042AFF", "#FF4447", "#7D24FF", "#BD00FF"])
|
||||
# Multiple lines or area update
|
||||
x_data = self.ax.lines[0].get_xdata() if self.ax.lines else np.array([])
|
||||
y_data_dict = {key: np.array([]) for key in count_dict.keys()}
|
||||
if self.ax.lines:
|
||||
for line, key in zip(self.ax.lines, count_dict.keys()):
|
||||
y_data_dict[key] = line.get_ydata()
|
||||
|
||||
x_data = np.append(x_data, float(frame_number))
|
||||
max_length = len(x_data)
|
||||
for key in count_dict.keys():
|
||||
y_data_dict[key] = np.append(y_data_dict[key], float(count_dict[key]))
|
||||
if len(y_data_dict[key]) < max_length:
|
||||
y_data_dict[key] = np.pad(y_data_dict[key], (0, max_length - len(y_data_dict[key])))
|
||||
if len(x_data) > self.max_points:
|
||||
x_data = x_data[1:]
|
||||
for key in count_dict.keys():
|
||||
y_data_dict[key] = y_data_dict[key][1:]
|
||||
|
||||
self.ax.clear()
|
||||
for key, y_data in y_data_dict.items():
|
||||
color = next(color_cycle)
|
||||
self.ax.fill_between(x_data, y_data, color=color, alpha=0.55)
|
||||
self.ax.plot(
|
||||
x_data,
|
||||
y_data,
|
||||
color=color,
|
||||
linewidth=self.line_width,
|
||||
marker="o",
|
||||
markersize=self.line_width * 5,
|
||||
label=f"{key} Data Points",
|
||||
)
|
||||
elif plot == "bar":
|
||||
self.ax.clear() # clear bar data
|
||||
for label in labels: # Map labels to colors
|
||||
if label not in self.color_mapping:
|
||||
self.color_mapping[label] = next(self.color_cycle)
|
||||
colors = [self.color_mapping[label] for label in labels]
|
||||
bars = self.ax.bar(labels, counts, color=colors)
|
||||
for bar, count in zip(bars, counts):
|
||||
self.ax.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height(),
|
||||
str(count),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
color=self.fg_color,
|
||||
)
|
||||
# Create the legend using labels from the bars
|
||||
for bar, label in zip(bars, labels):
|
||||
bar.set_label(label) # Assign label to each bar
|
||||
self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color)
|
||||
elif plot == "pie":
|
||||
total = sum(counts)
|
||||
percentages = [size / total * 100 for size in counts]
|
||||
self.ax.clear()
|
||||
|
||||
start_angle = 90
|
||||
# Create pie chart and create legend labels with percentages
|
||||
wedges, _ = self.ax.pie(
|
||||
counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None
|
||||
)
|
||||
legend_labels = [f"{label} ({percentage:.1f}%)" for label, percentage in zip(labels, percentages)]
|
||||
|
||||
# Assign the legend using the wedges and manually created labels
|
||||
self.ax.legend(wedges, legend_labels, title="Classes", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
|
||||
self.fig.subplots_adjust(left=0.1, right=0.75) # Adjust layout to fit the legend
|
||||
|
||||
# Common plot settings
|
||||
self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like
|
||||
self.ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) # Display grid for more data insights
|
||||
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
|
||||
self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
||||
|
||||
# Add and format legend
|
||||
legend = self.ax.legend(loc="upper left", fontsize=13, facecolor=self.bg_color, edgecolor=self.bg_color)
|
||||
for text in legend.get_texts():
|
||||
text.set_color(self.fg_color)
|
||||
|
||||
# Redraw graph, update view, capture, and display the updated plot
|
||||
self.ax.relim()
|
||||
self.ax.autoscale_view()
|
||||
self.canvas.draw()
|
||||
im0 = np.array(self.canvas.renderer.buffer_rgba())
|
||||
im0 = cv2.cvtColor(im0[:, :, :3], cv2.COLOR_RGBA2BGR)
|
||||
self.display_output(im0)
|
||||
|
||||
return im0 # Return the image
|
||||
108
ultralytics/solutions/config.py
Normal file
108
ultralytics/solutions/config.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolutionConfig:
|
||||
"""
|
||||
Manages configuration parameters for Ultralytics Vision AI solutions.
|
||||
|
||||
The SolutionConfig class serves as a centralized configuration container for all the
|
||||
Ultralytics solution modules: https://docs.ultralytics.com/solutions/#solutions.
|
||||
It leverages Python `dataclass` for clear, type-safe, and maintainable parameter definitions.
|
||||
|
||||
Attributes:
|
||||
source (str, optional): Path to the input source (video, RTSP, etc.). Only usable with Solutions CLI.
|
||||
model (str, optional): Path to the Ultralytics YOLO model to be used for inference.
|
||||
classes (list[int], optional): List of class indices to filter detections.
|
||||
show_conf (bool): Whether to show confidence scores on the visual output.
|
||||
show_labels (bool): Whether to display class labels on visual output.
|
||||
region (list[tuple[int, int]], optional): Polygonal region or line for object counting.
|
||||
colormap (int, optional): OpenCV colormap constant for visual overlays (e.g., cv2.COLORMAP_JET).
|
||||
show_in (bool): Whether to display count number for objects entering the region.
|
||||
show_out (bool): Whether to display count number for objects leaving the region.
|
||||
up_angle (float): Upper angle threshold used in pose-based workouts monitoring.
|
||||
down_angle (int): Lower angle threshold used in pose-based workouts monitoring.
|
||||
kpts (list[int]): Keypoint indices to monitor, e.g., for pose analytics.
|
||||
analytics_type (str): Type of analytics to perform ("line", "area", "bar", "pie", etc.).
|
||||
figsize (tuple[int, int], optional): Size of the matplotlib figure used for analytical plots (width, height).
|
||||
blur_ratio (float): Ratio used to blur objects in the video frames (0.0 to 1.0).
|
||||
vision_point (tuple[int, int]): Reference point for directional tracking or perspective drawing.
|
||||
crop_dir (str): Directory path to save cropped detection images.
|
||||
json_file (str): Path to a JSON file containing data for parking areas.
|
||||
line_width (int): Width for visual display i.e. bounding boxes, keypoints, counts.
|
||||
records (int): Number of detection records to send email alerts.
|
||||
fps (float): Frame rate (Frames Per Second) for speed estimation calculation.
|
||||
max_hist (int): Maximum number of historical points or states stored per tracked object for speed estimation.
|
||||
meter_per_pixel (float): Scale for real-world measurement, used in speed or distance calculations.
|
||||
max_speed (int): Maximum speed limit (e.g., km/h or mph) used in visual alerts or constraints.
|
||||
show (bool): Whether to display the visual output on screen.
|
||||
iou (float): Intersection-over-Union threshold for detection filtering.
|
||||
conf (float): Confidence threshold for keeping predictions.
|
||||
device (str, optional): Device to run inference on (e.g., 'cpu', '0' for CUDA GPU).
|
||||
max_det (int): Maximum number of detections allowed per video frame.
|
||||
half (bool): Whether to use FP16 precision (requires a supported CUDA device).
|
||||
tracker (str): Path to tracking configuration YAML file (e.g., 'botsort.yaml').
|
||||
verbose (bool): Enable verbose logging output for debugging or diagnostics.
|
||||
data (str): Path to image directory used for similarity search.
|
||||
|
||||
Methods:
|
||||
update: Update the configuration with user-defined keyword arguments and raise error on invalid keys.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.solutions.config import SolutionConfig
|
||||
>>> cfg = SolutionConfig(model="yolo11n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
|
||||
>>> cfg.update(show=False, conf=0.3)
|
||||
>>> print(cfg.model)
|
||||
"""
|
||||
|
||||
source: str | None = None
|
||||
model: str | None = None
|
||||
classes: list[int] | None = None
|
||||
show_conf: bool = True
|
||||
show_labels: bool = True
|
||||
region: list[tuple[int, int]] | None = None
|
||||
colormap: int | None = cv2.COLORMAP_DEEPGREEN
|
||||
show_in: bool = True
|
||||
show_out: bool = True
|
||||
up_angle: float = 145.0
|
||||
down_angle: int = 90
|
||||
kpts: list[int] = field(default_factory=lambda: [6, 8, 10])
|
||||
analytics_type: str = "line"
|
||||
figsize: tuple[int, int] | None = (12.8, 7.2)
|
||||
blur_ratio: float = 0.5
|
||||
vision_point: tuple[int, int] = (20, 20)
|
||||
crop_dir: str = "cropped-detections"
|
||||
json_file: str = None
|
||||
line_width: int = 2
|
||||
records: int = 5
|
||||
fps: float = 30.0
|
||||
max_hist: int = 5
|
||||
meter_per_pixel: float = 0.05
|
||||
max_speed: int = 120
|
||||
show: bool = False
|
||||
iou: float = 0.7
|
||||
conf: float = 0.25
|
||||
device: str | None = None
|
||||
max_det: int = 300
|
||||
half: bool = False
|
||||
tracker: str = "botsort.yaml"
|
||||
verbose: bool = True
|
||||
data: str = "images"
|
||||
|
||||
def update(self, **kwargs: Any):
|
||||
"""Update configuration parameters with new values provided as keyword arguments."""
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
url = "https://docs.ultralytics.com/solutions/#solutions-arguments"
|
||||
raise ValueError(f"{key} is not a valid solution argument, see {url}")
|
||||
|
||||
return self
|
||||
126
ultralytics/solutions/distance_calculation.py
Normal file
126
ultralytics/solutions/distance_calculation.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class DistanceCalculation(BaseSolution):
|
||||
"""
|
||||
A class to calculate distance between two objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends BaseSolution to provide functionality for selecting objects and calculating the distance
|
||||
between them in a video stream using YOLO object detection and tracking.
|
||||
|
||||
Attributes:
|
||||
left_mouse_count (int): Counter for left mouse button clicks.
|
||||
selected_boxes (dict[int, list[float]]): Dictionary to store selected bounding boxes and their track IDs.
|
||||
centroids (list[list[int]]): List to store centroids of selected bounding boxes.
|
||||
|
||||
Methods:
|
||||
mouse_event_for_distance: Handle mouse events for selecting objects in the video stream.
|
||||
process: Process video frames and calculate the distance between selected objects.
|
||||
|
||||
Examples:
|
||||
>>> distance_calc = DistanceCalculation()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = distance_calc.process(frame)
|
||||
>>> cv2.imshow("Distance Calculation", results.plot_im)
|
||||
>>> cv2.waitKey(0)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the DistanceCalculation class for measuring object distances in video streams."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Mouse event information
|
||||
self.left_mouse_count = 0
|
||||
self.selected_boxes: dict[int, list[float]] = {}
|
||||
self.centroids: list[list[int]] = [] # Store centroids of selected objects
|
||||
|
||||
def mouse_event_for_distance(self, event: int, x: int, y: int, flags: int, param: Any) -> None:
|
||||
"""
|
||||
Handle mouse events to select regions in a real-time video stream for distance calculation.
|
||||
|
||||
Args:
|
||||
event (int): Type of mouse event (e.g., cv2.EVENT_MOUSEMOVE, cv2.EVENT_LBUTTONDOWN).
|
||||
x (int): X-coordinate of the mouse pointer.
|
||||
y (int): Y-coordinate of the mouse pointer.
|
||||
flags (int): Flags associated with the event (e.g., cv2.EVENT_FLAG_CTRLKEY, cv2.EVENT_FLAG_SHIFTKEY).
|
||||
param (Any): Additional parameters passed to the function.
|
||||
|
||||
Examples:
|
||||
>>> # Assuming 'dc' is an instance of DistanceCalculation
|
||||
>>> cv2.setMouseCallback("window_name", dc.mouse_event_for_distance)
|
||||
"""
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
self.left_mouse_count += 1
|
||||
if self.left_mouse_count <= 2:
|
||||
for box, track_id in zip(self.boxes, self.track_ids):
|
||||
if box[0] < x < box[2] and box[1] < y < box[3] and track_id not in self.selected_boxes:
|
||||
self.selected_boxes[track_id] = box
|
||||
|
||||
elif event == cv2.EVENT_RBUTTONDOWN:
|
||||
self.selected_boxes = {}
|
||||
self.left_mouse_count = 0
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Process a video frame and calculate the distance between two selected bounding boxes.
|
||||
|
||||
This method extracts tracks from the input frame, annotates bounding boxes, and calculates the distance
|
||||
between two user-selected objects if they have been chosen.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image frame to process.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, `total_tracks` (int) representing the total number
|
||||
of tracked objects, and `pixels_distance` (float) representing the distance between selected objects
|
||||
in pixels.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from ultralytics.solutions import DistanceCalculation
|
||||
>>> dc = DistanceCalculation()
|
||||
>>> frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
>>> results = dc.process(frame)
|
||||
>>> print(f"Distance: {results.pixels_distance:.2f} pixels")
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
pixels_distance = 0
|
||||
# Iterate over bounding boxes, track ids and classes index
|
||||
for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||||
annotator.box_label(box, color=colors(int(cls), True), label=self.adjust_box_label(cls, conf, track_id))
|
||||
|
||||
# Update selected boxes if they're being tracked
|
||||
if len(self.selected_boxes) == 2:
|
||||
for trk_id in self.selected_boxes.keys():
|
||||
if trk_id == track_id:
|
||||
self.selected_boxes[track_id] = box
|
||||
|
||||
if len(self.selected_boxes) == 2:
|
||||
# Calculate centroids of selected boxes
|
||||
self.centroids.extend(
|
||||
[[int((box[0] + box[2]) // 2), int((box[1] + box[3]) // 2)] for box in self.selected_boxes.values()]
|
||||
)
|
||||
# Calculate Euclidean distance between centroids
|
||||
pixels_distance = math.sqrt(
|
||||
(self.centroids[0][0] - self.centroids[1][0]) ** 2 + (self.centroids[0][1] - self.centroids[1][1]) ** 2
|
||||
)
|
||||
annotator.plot_distance_and_line(pixels_distance, self.centroids)
|
||||
|
||||
self.centroids = [] # Reset centroids for next frame
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
if self.CFG.get("show") and self.env_check:
|
||||
cv2.setMouseCallback("Ultralytics Solutions", self.mouse_event_for_distance)
|
||||
|
||||
# Return SolutionResults with processed image and calculated metrics
|
||||
return SolutionResults(plot_im=plot_im, pixels_distance=pixels_distance, total_tracks=len(self.track_ids))
|
||||
131
ultralytics/solutions/heatmap.py
Normal file
131
ultralytics/solutions/heatmap.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.object_counter import ObjectCounter
|
||||
from ultralytics.solutions.solutions import SolutionAnnotator, SolutionResults
|
||||
|
||||
|
||||
class Heatmap(ObjectCounter):
|
||||
"""
|
||||
A class to draw heatmaps in real-time video streams based on object tracks.
|
||||
|
||||
This class extends the ObjectCounter class to generate and visualize heatmaps of object movements in video
|
||||
streams. It uses tracked object positions to create a cumulative heatmap effect over time.
|
||||
|
||||
Attributes:
|
||||
initialized (bool): Flag indicating whether the heatmap has been initialized.
|
||||
colormap (int): OpenCV colormap used for heatmap visualization.
|
||||
heatmap (np.ndarray): Array storing the cumulative heatmap data.
|
||||
annotator (SolutionAnnotator): Object for drawing annotations on the image.
|
||||
|
||||
Methods:
|
||||
heatmap_effect: Calculate and update the heatmap effect for a given bounding box.
|
||||
process: Generate and apply the heatmap effect to each frame.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.solutions import Heatmap
|
||||
>>> heatmap = Heatmap(model="yolo11n.pt", colormap=cv2.COLORMAP_JET)
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_frame = heatmap.process(frame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the Heatmap class for real-time video stream heatmap generation based on object tracks.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent ObjectCounter class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.initialized = False # Flag for heatmap initialization
|
||||
if self.region is not None: # Check if user provided the region coordinates
|
||||
self.initialize_region()
|
||||
|
||||
# Store colormap
|
||||
self.colormap = self.CFG["colormap"]
|
||||
self.heatmap = None
|
||||
|
||||
def heatmap_effect(self, box: list[float]) -> None:
|
||||
"""
|
||||
Efficiently calculate heatmap area and effect location for applying colormap.
|
||||
|
||||
Args:
|
||||
box (list[float]): Bounding box coordinates [x0, y0, x1, y1].
|
||||
"""
|
||||
x0, y0, x1, y1 = map(int, box)
|
||||
radius_squared = (min(x1 - x0, y1 - y0) // 2) ** 2
|
||||
|
||||
# Create a meshgrid with region of interest (ROI) for vectorized distance calculations
|
||||
xv, yv = np.meshgrid(np.arange(x0, x1), np.arange(y0, y1))
|
||||
|
||||
# Calculate squared distances from the center
|
||||
dist_squared = (xv - ((x0 + x1) // 2)) ** 2 + (yv - ((y0 + y1) // 2)) ** 2
|
||||
|
||||
# Create a mask of points within the radius
|
||||
within_radius = dist_squared <= radius_squared
|
||||
|
||||
# Update only the values within the bounding box in a single vectorized operation
|
||||
self.heatmap[y0:y1, x0:x1][within_radius] += 2
|
||||
|
||||
def process(self, im0: np.ndarray) -> SolutionResults:
|
||||
"""
|
||||
Generate heatmap for each frame using Ultralytics tracking.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image array for processing.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`,
|
||||
'in_count' (int, count of objects entering the region),
|
||||
'out_count' (int, count of objects exiting the region),
|
||||
'classwise_count' (dict, per-class object count), and
|
||||
'total_tracks' (int, total number of tracked objects).
|
||||
"""
|
||||
if not self.initialized:
|
||||
self.heatmap = np.zeros_like(im0, dtype=np.float32) * 0.99
|
||||
self.initialized = True # Initialize heatmap only once
|
||||
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
# Iterate over bounding boxes, track ids and classes index
|
||||
for box, track_id, cls in zip(self.boxes, self.track_ids, self.clss):
|
||||
# Apply heatmap effect for the bounding box
|
||||
self.heatmap_effect(box)
|
||||
|
||||
if self.region is not None:
|
||||
self.annotator.draw_region(reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2)
|
||||
self.store_tracking_history(track_id, box) # Store track history
|
||||
# Get previous position if available
|
||||
prev_position = None
|
||||
if len(self.track_history[track_id]) > 1:
|
||||
prev_position = self.track_history[track_id][-2]
|
||||
self.count_objects(self.track_history[track_id][-1], track_id, prev_position, cls) # object counting
|
||||
|
||||
plot_im = self.annotator.result()
|
||||
if self.region is not None:
|
||||
self.display_counts(plot_im) # Display the counts on the frame
|
||||
|
||||
# Normalize, apply colormap to heatmap and combine with original image
|
||||
if self.track_data.is_track:
|
||||
normalized_heatmap = cv2.normalize(self.heatmap, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
|
||||
colored_heatmap = cv2.applyColorMap(normalized_heatmap, self.colormap)
|
||||
plot_im = cv2.addWeighted(plot_im, 0.5, colored_heatmap, 0.5, 0)
|
||||
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(
|
||||
plot_im=plot_im,
|
||||
in_count=self.in_count,
|
||||
out_count=self.out_count,
|
||||
classwise_count=dict(self.classwise_count),
|
||||
total_tracks=len(self.track_ids),
|
||||
)
|
||||
89
ultralytics/solutions/instance_segmentation.py
Normal file
89
ultralytics/solutions/instance_segmentation.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionResults
|
||||
|
||||
|
||||
class InstanceSegmentation(BaseSolution):
|
||||
"""
|
||||
A class to manage instance segmentation in images or video streams.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for performing instance segmentation, including
|
||||
drawing segmented masks with bounding boxes and labels.
|
||||
|
||||
Attributes:
|
||||
model (str): The segmentation model to use for inference.
|
||||
line_width (int): Width of the bounding box and text lines.
|
||||
names (dict[int, str]): Dictionary mapping class indices to class names.
|
||||
clss (list[int]): List of detected class indices.
|
||||
track_ids (list[int]): List of track IDs for detected instances.
|
||||
masks (list[np.ndarray]): List of segmentation masks for detected instances.
|
||||
show_conf (bool): Whether to display confidence scores.
|
||||
show_labels (bool): Whether to display class labels.
|
||||
show_boxes (bool): Whether to display bounding boxes.
|
||||
|
||||
Methods:
|
||||
process: Process the input image to perform instance segmentation and annotate results.
|
||||
extract_tracks: Extract tracks including bounding boxes, classes, and masks from model predictions.
|
||||
|
||||
Examples:
|
||||
>>> segmenter = InstanceSegmentation()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = segmenter.process(frame)
|
||||
>>> print(f"Total segmented instances: {results.total_tracks}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the InstanceSegmentation class for detecting and annotating segmented instances.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the BaseSolution parent class.
|
||||
model (str): Model name or path, defaults to "yolo11n-seg.pt".
|
||||
"""
|
||||
kwargs["model"] = kwargs.get("model", "yolo11n-seg.pt")
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.show_conf = self.CFG.get("show_conf", True)
|
||||
self.show_labels = self.CFG.get("show_labels", True)
|
||||
self.show_boxes = self.CFG.get("show_boxes", True)
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Perform instance segmentation on the input image and annotate the results.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image for segmentation.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Object containing the annotated image and total number of tracked instances.
|
||||
|
||||
Examples:
|
||||
>>> segmenter = InstanceSegmentation()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> summary = segmenter.process(frame)
|
||||
>>> print(summary)
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
|
||||
self.masks = getattr(self.tracks, "masks", None)
|
||||
|
||||
# Iterate over detected classes, track IDs, and segmentation masks
|
||||
if self.masks is None:
|
||||
self.LOGGER.warning("No masks detected! Ensure you're using a supported Ultralytics segmentation model.")
|
||||
plot_im = im0
|
||||
else:
|
||||
results = Results(im0, path=None, names=self.names, boxes=self.track_data.data, masks=self.masks.data)
|
||||
plot_im = results.plot(
|
||||
line_width=self.line_width,
|
||||
boxes=self.show_boxes,
|
||||
conf=self.show_conf,
|
||||
labels=self.show_labels,
|
||||
color_mode="instance",
|
||||
)
|
||||
|
||||
self.display_output(plot_im) # Display the annotated output using the base class function
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
||||
92
ultralytics/solutions/object_blurrer.py
Normal file
92
ultralytics/solutions/object_blurrer.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class ObjectBlurrer(BaseSolution):
|
||||
"""
|
||||
A class to manage the blurring of detected objects in a real-time video stream.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for blurring objects based on detected bounding
|
||||
boxes. The blurred areas are updated directly in the input image, allowing for privacy preservation or other effects.
|
||||
|
||||
Attributes:
|
||||
blur_ratio (int): The intensity of the blur effect applied to detected objects (higher values create more blur).
|
||||
iou (float): Intersection over Union threshold for object detection.
|
||||
conf (float): Confidence threshold for object detection.
|
||||
|
||||
Methods:
|
||||
process: Apply a blurring effect to detected objects in the input image.
|
||||
extract_tracks: Extract tracking information from detected objects.
|
||||
display_output: Display the processed output image.
|
||||
|
||||
Examples:
|
||||
>>> blurrer = ObjectBlurrer()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_results = blurrer.process(frame)
|
||||
>>> print(f"Total blurred objects: {processed_results.total_tracks}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the ObjectBlurrer class for applying a blur effect to objects detected in video streams or images.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent class and for configuration.
|
||||
blur_ratio (float): Intensity of the blur effect (0.1-1.0, default=0.5).
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
blur_ratio = self.CFG["blur_ratio"]
|
||||
if blur_ratio < 0.1:
|
||||
LOGGER.warning("blur ratio cannot be less than 0.1, updating it to default value 0.5")
|
||||
blur_ratio = 0.5
|
||||
self.blur_ratio = int(blur_ratio * 100)
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Apply a blurring effect to detected objects in the input image.
|
||||
|
||||
This method extracts tracking information, applies blur to regions corresponding to detected objects,
|
||||
and annotates the image with bounding boxes.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image containing detected objects.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Object containing the processed image and number of tracked objects.
|
||||
- plot_im (np.ndarray): The annotated output image with blurred objects.
|
||||
- total_tracks (int): The total number of tracked objects in the frame.
|
||||
|
||||
Examples:
|
||||
>>> blurrer = ObjectBlurrer()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> results = blurrer.process(frame)
|
||||
>>> print(f"Blurred {results.total_tracks} objects")
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
annotator = SolutionAnnotator(im0, self.line_width)
|
||||
|
||||
# Iterate over bounding boxes and classes
|
||||
for box, cls, conf in zip(self.boxes, self.clss, self.confs):
|
||||
# Crop and blur the detected object
|
||||
blur_obj = cv2.blur(
|
||||
im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])],
|
||||
(self.blur_ratio, self.blur_ratio),
|
||||
)
|
||||
# Update the blurred area in the original image
|
||||
im0[int(box[1]) : int(box[3]), int(box[0]) : int(box[2])] = blur_obj
|
||||
annotator.box_label(
|
||||
box, label=self.adjust_box_label(cls, conf), color=colors(cls, True)
|
||||
) # Annotate bounding box
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display the output using the base class function
|
||||
|
||||
# Return a SolutionResults
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
||||
197
ultralytics/solutions/object_counter.py
Normal file
197
ultralytics/solutions/object_counter.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class ObjectCounter(BaseSolution):
|
||||
"""
|
||||
A class to manage the counting of objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for counting objects moving in and out of a
|
||||
specified region in a video stream. It supports both polygonal and linear regions for counting.
|
||||
|
||||
Attributes:
|
||||
in_count (int): Counter for objects moving inward.
|
||||
out_count (int): Counter for objects moving outward.
|
||||
counted_ids (list[int]): List of IDs of objects that have been counted.
|
||||
classwise_counts (dict[str, dict[str, int]]): Dictionary for counts, categorized by object class.
|
||||
region_initialized (bool): Flag indicating whether the counting region has been initialized.
|
||||
show_in (bool): Flag to control display of inward count.
|
||||
show_out (bool): Flag to control display of outward count.
|
||||
margin (int): Margin for background rectangle size to display counts properly.
|
||||
|
||||
Methods:
|
||||
count_objects: Count objects within a polygonal or linear region based on their tracks.
|
||||
display_counts: Display object counts on the frame.
|
||||
process: Process input data and update counts.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = counter.process(frame)
|
||||
>>> print(f"Inward count: {counter.in_count}, Outward count: {counter.out_count}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the ObjectCounter class for real-time object counting in video streams."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.in_count = 0 # Counter for objects moving inward
|
||||
self.out_count = 0 # Counter for objects moving outward
|
||||
self.counted_ids = [] # List of IDs of objects that have been counted
|
||||
self.classwise_count = defaultdict(lambda: {"IN": 0, "OUT": 0}) # Dictionary for counts, categorized by class
|
||||
self.region_initialized = False # Flag indicating whether the region has been initialized
|
||||
|
||||
self.show_in = self.CFG["show_in"]
|
||||
self.show_out = self.CFG["show_out"]
|
||||
self.margin = self.line_width * 2 # Scales the background rectangle size to display counts properly
|
||||
|
||||
def count_objects(
|
||||
self,
|
||||
current_centroid: tuple[float, float],
|
||||
track_id: int,
|
||||
prev_position: tuple[float, float] | None,
|
||||
cls: int,
|
||||
) -> None:
|
||||
"""
|
||||
Count objects within a polygonal or linear region based on their tracks.
|
||||
|
||||
Args:
|
||||
current_centroid (tuple[float, float]): Current centroid coordinates (x, y) in the current frame.
|
||||
track_id (int): Unique identifier for the tracked object.
|
||||
prev_position (tuple[float, float], optional): Last frame position coordinates (x, y) of the track.
|
||||
cls (int): Class index for classwise count updates.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> track_line = {1: [100, 200], 2: [110, 210], 3: [120, 220]}
|
||||
>>> box = [130, 230, 150, 250]
|
||||
>>> track_id_num = 1
|
||||
>>> previous_position = (120, 220)
|
||||
>>> class_to_count = 0 # In COCO model, class 0 = person
|
||||
>>> counter.count_objects((140, 240), track_id_num, previous_position, class_to_count)
|
||||
"""
|
||||
if prev_position is None or track_id in self.counted_ids:
|
||||
return
|
||||
|
||||
if len(self.region) == 2: # Linear region (defined as a line segment)
|
||||
if self.r_s.intersects(self.LineString([prev_position, current_centroid])):
|
||||
# Determine orientation of the region (vertical or horizontal)
|
||||
if abs(self.region[0][0] - self.region[1][0]) < abs(self.region[0][1] - self.region[1][1]):
|
||||
# Vertical region: Compare x-coordinates to determine direction
|
||||
if current_centroid[0] > prev_position[0]: # Moving right
|
||||
self.in_count += 1
|
||||
self.classwise_count[self.names[cls]]["IN"] += 1
|
||||
else: # Moving left
|
||||
self.out_count += 1
|
||||
self.classwise_count[self.names[cls]]["OUT"] += 1
|
||||
# Horizontal region: Compare y-coordinates to determine direction
|
||||
elif current_centroid[1] > prev_position[1]: # Moving downward
|
||||
self.in_count += 1
|
||||
self.classwise_count[self.names[cls]]["IN"] += 1
|
||||
else: # Moving upward
|
||||
self.out_count += 1
|
||||
self.classwise_count[self.names[cls]]["OUT"] += 1
|
||||
self.counted_ids.append(track_id)
|
||||
|
||||
elif len(self.region) > 2: # Polygonal region
|
||||
if self.r_s.contains(self.Point(current_centroid)):
|
||||
# Determine motion direction for vertical or horizontal polygons
|
||||
region_width = max(p[0] for p in self.region) - min(p[0] for p in self.region)
|
||||
region_height = max(p[1] for p in self.region) - min(p[1] for p in self.region)
|
||||
|
||||
if (
|
||||
region_width < region_height
|
||||
and current_centroid[0] > prev_position[0]
|
||||
or region_width >= region_height
|
||||
and current_centroid[1] > prev_position[1]
|
||||
): # Moving right or downward
|
||||
self.in_count += 1
|
||||
self.classwise_count[self.names[cls]]["IN"] += 1
|
||||
else: # Moving left or upward
|
||||
self.out_count += 1
|
||||
self.classwise_count[self.names[cls]]["OUT"] += 1
|
||||
self.counted_ids.append(track_id)
|
||||
|
||||
def display_counts(self, plot_im) -> None:
|
||||
"""
|
||||
Display object counts on the input image or frame.
|
||||
|
||||
Args:
|
||||
plot_im (np.ndarray): The image or frame to display counts on.
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> counter.display_counts(frame)
|
||||
"""
|
||||
labels_dict = {
|
||||
str.capitalize(key): f"{'IN ' + str(value['IN']) if self.show_in else ''} "
|
||||
f"{'OUT ' + str(value['OUT']) if self.show_out else ''}".strip()
|
||||
for key, value in self.classwise_count.items()
|
||||
if value["IN"] != 0 or value["OUT"] != 0 and (self.show_in or self.show_out)
|
||||
}
|
||||
if labels_dict:
|
||||
self.annotator.display_analytics(plot_im, labels_dict, (104, 31, 17), (255, 255, 255), self.margin)
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Process input data (frames or object tracks) and update object counts.
|
||||
|
||||
This method initializes the counting region, extracts tracks, draws bounding boxes and regions, updates
|
||||
object counts, and displays the results on the input image.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image or frame to be processed.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `im0`, 'in_count' (int, count of objects entering the region),
|
||||
'out_count' (int, count of objects exiting the region), 'classwise_count' (dict, per-class object count),
|
||||
and 'total_tracks' (int, total number of tracked objects).
|
||||
|
||||
Examples:
|
||||
>>> counter = ObjectCounter()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> results = counter.process(frame)
|
||||
"""
|
||||
if not self.region_initialized:
|
||||
self.initialize_region()
|
||||
self.region_initialized = True
|
||||
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
self.annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
self.annotator.draw_region(
|
||||
reg_pts=self.region, color=(104, 0, 123), thickness=self.line_width * 2
|
||||
) # Draw region
|
||||
|
||||
# Iterate over bounding boxes, track ids and classes index
|
||||
for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||||
# Draw bounding box and counting region
|
||||
self.annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(cls, True))
|
||||
self.store_tracking_history(track_id, box) # Store track history
|
||||
|
||||
# Store previous position of track for object counting
|
||||
prev_position = None
|
||||
if len(self.track_history[track_id]) > 1:
|
||||
prev_position = self.track_history[track_id][-2]
|
||||
self.count_objects(self.track_history[track_id][-1], track_id, prev_position, cls) # object counting
|
||||
|
||||
plot_im = self.annotator.result()
|
||||
self.display_counts(plot_im) # Display the counts on the frame
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(
|
||||
plot_im=plot_im,
|
||||
in_count=self.in_count,
|
||||
out_count=self.out_count,
|
||||
classwise_count=dict(self.classwise_count),
|
||||
total_tracks=len(self.track_ids),
|
||||
)
|
||||
93
ultralytics/solutions/object_cropper.py
Normal file
93
ultralytics/solutions/object_cropper.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionResults
|
||||
from ultralytics.utils.plotting import save_one_box
|
||||
|
||||
|
||||
class ObjectCropper(BaseSolution):
|
||||
"""
|
||||
A class to manage the cropping of detected objects in a real-time video stream or images.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for cropping objects based on detected bounding
|
||||
boxes. The cropped images are saved to a specified directory for further analysis or usage.
|
||||
|
||||
Attributes:
|
||||
crop_dir (str): Directory where cropped object images are stored.
|
||||
crop_idx (int): Counter for the total number of cropped objects.
|
||||
iou (float): IoU (Intersection over Union) threshold for non-maximum suppression.
|
||||
conf (float): Confidence threshold for filtering detections.
|
||||
|
||||
Methods:
|
||||
process: Crop detected objects from the input image and save them to the output directory.
|
||||
|
||||
Examples:
|
||||
>>> cropper = ObjectCropper()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> processed_results = cropper.process(frame)
|
||||
>>> print(f"Total cropped objects: {cropper.crop_idx}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the ObjectCropper class for cropping objects from detected bounding boxes.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent class and used for configuration.
|
||||
crop_dir (str): Path to the directory for saving cropped object images.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.crop_dir = self.CFG["crop_dir"] # Directory for storing cropped detections
|
||||
if not os.path.exists(self.crop_dir):
|
||||
os.mkdir(self.crop_dir) # Create directory if it does not exist
|
||||
if self.CFG["show"]:
|
||||
self.LOGGER.warning(
|
||||
f"show=True disabled for crop solution, results will be saved in the directory named: {self.crop_dir}"
|
||||
)
|
||||
self.crop_idx = 0 # Initialize counter for total cropped objects
|
||||
self.iou = self.CFG["iou"]
|
||||
self.conf = self.CFG["conf"]
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Crop detected objects from the input image and save them as separate images.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image containing detected objects.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): A SolutionResults object containing the total number of cropped objects and processed
|
||||
image.
|
||||
|
||||
Examples:
|
||||
>>> cropper = ObjectCropper()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> results = cropper.process(frame)
|
||||
>>> print(f"Total cropped objects: {results.total_crop_objects}")
|
||||
"""
|
||||
with self.profilers[0]:
|
||||
results = self.model.predict(
|
||||
im0,
|
||||
classes=self.classes,
|
||||
conf=self.conf,
|
||||
iou=self.iou,
|
||||
device=self.CFG["device"],
|
||||
verbose=False,
|
||||
)[0]
|
||||
self.clss = results.boxes.cls.tolist() # required for logging only.
|
||||
|
||||
for box in results.boxes:
|
||||
self.crop_idx += 1
|
||||
save_one_box(
|
||||
box.xyxy,
|
||||
im0,
|
||||
file=Path(self.crop_dir) / f"crop_{self.crop_idx}.jpg",
|
||||
BGR=True,
|
||||
)
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(plot_im=im0, total_crop_objects=self.crop_idx)
|
||||
278
ultralytics/solutions/parking_management.py
Normal file
278
ultralytics/solutions/parking_management.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.checks import check_imshow
|
||||
|
||||
|
||||
class ParkingPtsSelection:
|
||||
"""
|
||||
A class for selecting and managing parking zone points on images using a Tkinter-based UI.
|
||||
|
||||
This class provides functionality to upload an image, select points to define parking zones, and save the
|
||||
selected points to a JSON file. It uses Tkinter for the graphical user interface.
|
||||
|
||||
Attributes:
|
||||
tk (module): The Tkinter module for GUI operations.
|
||||
filedialog (module): Tkinter's filedialog module for file selection operations.
|
||||
messagebox (module): Tkinter's messagebox module for displaying message boxes.
|
||||
master (tk.Tk): The main Tkinter window.
|
||||
canvas (tk.Canvas): The canvas widget for displaying the image and drawing bounding boxes.
|
||||
image (PIL.Image.Image): The uploaded image.
|
||||
canvas_image (ImageTk.PhotoImage): The image displayed on the canvas.
|
||||
rg_data (list[list[tuple[int, int]]]): List of bounding boxes, each defined by 4 points.
|
||||
current_box (list[tuple[int, int]]): Temporary storage for the points of the current bounding box.
|
||||
imgw (int): Original width of the uploaded image.
|
||||
imgh (int): Original height of the uploaded image.
|
||||
canvas_max_width (int): Maximum width of the canvas.
|
||||
canvas_max_height (int): Maximum height of the canvas.
|
||||
|
||||
Methods:
|
||||
initialize_properties: Initialize properties for image, canvas, bounding boxes, and dimensions.
|
||||
upload_image: Upload and display an image on the canvas, resizing it to fit within specified dimensions.
|
||||
on_canvas_click: Handle mouse clicks to add points for bounding boxes on the canvas.
|
||||
draw_box: Draw a bounding box on the canvas using the provided coordinates.
|
||||
remove_last_bounding_box: Remove the last bounding box from the list and redraw the canvas.
|
||||
redraw_canvas: Redraw the canvas with the image and all bounding boxes.
|
||||
save_to_json: Save the selected parking zone points to a JSON file with scaled coordinates.
|
||||
|
||||
Examples:
|
||||
>>> parking_selector = ParkingPtsSelection()
|
||||
>>> # Use the GUI to upload an image, select parking zones, and save the data
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the ParkingPtsSelection class, setting up UI and properties for parking zone point selection."""
|
||||
try: # Check if tkinter is installed
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog, messagebox
|
||||
except ImportError: # Display error with recommendations
|
||||
import platform
|
||||
|
||||
install_cmd = {
|
||||
"Linux": "sudo apt install python3-tk (Debian/Ubuntu) | sudo dnf install python3-tkinter (Fedora) | "
|
||||
"sudo pacman -S tk (Arch)",
|
||||
"Windows": "reinstall Python and enable the checkbox `tcl/tk and IDLE` on **Optional Features** during installation",
|
||||
"Darwin": "reinstall Python from https://www.python.org/downloads/macos/ or `brew install python-tk`",
|
||||
}.get(platform.system(), "Unknown OS. Check your Python installation.")
|
||||
|
||||
LOGGER.warning(f" Tkinter is not configured or supported. Potential fix: {install_cmd}")
|
||||
return
|
||||
|
||||
if not check_imshow(warn=True):
|
||||
return
|
||||
|
||||
self.tk, self.filedialog, self.messagebox = tk, filedialog, messagebox
|
||||
self.master = self.tk.Tk() # Reference to the main application window
|
||||
self.master.title("Ultralytics Parking Zones Points Selector")
|
||||
self.master.resizable(False, False)
|
||||
|
||||
self.canvas = self.tk.Canvas(self.master, bg="white") # Canvas widget for displaying images
|
||||
self.canvas.pack(side=self.tk.BOTTOM)
|
||||
|
||||
self.image = None # Variable to store the loaded image
|
||||
self.canvas_image = None # Reference to the image displayed on the canvas
|
||||
self.canvas_max_width = None # Maximum allowed width for the canvas
|
||||
self.canvas_max_height = None # Maximum allowed height for the canvas
|
||||
self.rg_data = None # Data for region annotation management
|
||||
self.current_box = None # Stores the currently selected bounding box
|
||||
self.imgh = None # Height of the current image
|
||||
self.imgw = None # Width of the current image
|
||||
|
||||
# Button frame with buttons
|
||||
button_frame = self.tk.Frame(self.master)
|
||||
button_frame.pack(side=self.tk.TOP)
|
||||
|
||||
for text, cmd in [
|
||||
("Upload Image", self.upload_image),
|
||||
("Remove Last BBox", self.remove_last_bounding_box),
|
||||
("Save", self.save_to_json),
|
||||
]:
|
||||
self.tk.Button(button_frame, text=text, command=cmd).pack(side=self.tk.LEFT)
|
||||
|
||||
self.initialize_properties()
|
||||
self.master.mainloop()
|
||||
|
||||
def initialize_properties(self) -> None:
|
||||
"""Initialize properties for image, canvas, bounding boxes, and dimensions."""
|
||||
self.image = self.canvas_image = None
|
||||
self.rg_data, self.current_box = [], []
|
||||
self.imgw = self.imgh = 0
|
||||
self.canvas_max_width, self.canvas_max_height = 1280, 720
|
||||
|
||||
def upload_image(self) -> None:
|
||||
"""Upload and display an image on the canvas, resizing it to fit within specified dimensions."""
|
||||
from PIL import Image, ImageTk # Scoped import because ImageTk requires tkinter package
|
||||
|
||||
file = self.filedialog.askopenfilename(filetypes=[("Image Files", "*.png *.jpg *.jpeg")])
|
||||
if not file:
|
||||
LOGGER.info("No image selected.")
|
||||
return
|
||||
|
||||
self.image = Image.open(file)
|
||||
self.imgw, self.imgh = self.image.size
|
||||
aspect_ratio = self.imgw / self.imgh
|
||||
canvas_width = (
|
||||
min(self.canvas_max_width, self.imgw) if aspect_ratio > 1 else int(self.canvas_max_height * aspect_ratio)
|
||||
)
|
||||
canvas_height = (
|
||||
min(self.canvas_max_height, self.imgh) if aspect_ratio <= 1 else int(canvas_width / aspect_ratio)
|
||||
)
|
||||
|
||||
self.canvas.config(width=canvas_width, height=canvas_height)
|
||||
self.canvas_image = ImageTk.PhotoImage(self.image.resize((canvas_width, canvas_height)))
|
||||
self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)
|
||||
self.canvas.bind("<Button-1>", self.on_canvas_click)
|
||||
|
||||
self.rg_data.clear(), self.current_box.clear()
|
||||
|
||||
def on_canvas_click(self, event) -> None:
|
||||
"""Handle mouse clicks to add points for bounding boxes on the canvas."""
|
||||
self.current_box.append((event.x, event.y))
|
||||
self.canvas.create_oval(event.x - 3, event.y - 3, event.x + 3, event.y + 3, fill="red")
|
||||
if len(self.current_box) == 4:
|
||||
self.rg_data.append(self.current_box.copy())
|
||||
self.draw_box(self.current_box)
|
||||
self.current_box.clear()
|
||||
|
||||
def draw_box(self, box: list[tuple[int, int]]) -> None:
|
||||
"""Draw a bounding box on the canvas using the provided coordinates."""
|
||||
for i in range(4):
|
||||
self.canvas.create_line(box[i], box[(i + 1) % 4], fill="blue", width=2)
|
||||
|
||||
def remove_last_bounding_box(self) -> None:
|
||||
"""Remove the last bounding box from the list and redraw the canvas."""
|
||||
if not self.rg_data:
|
||||
self.messagebox.showwarning("Warning", "No bounding boxes to remove.")
|
||||
return
|
||||
self.rg_data.pop()
|
||||
self.redraw_canvas()
|
||||
|
||||
def redraw_canvas(self) -> None:
|
||||
"""Redraw the canvas with the image and all bounding boxes."""
|
||||
self.canvas.delete("all")
|
||||
self.canvas.create_image(0, 0, anchor=self.tk.NW, image=self.canvas_image)
|
||||
for box in self.rg_data:
|
||||
self.draw_box(box)
|
||||
|
||||
def save_to_json(self) -> None:
|
||||
"""Save the selected parking zone points to a JSON file with scaled coordinates."""
|
||||
scale_w, scale_h = self.imgw / self.canvas.winfo_width(), self.imgh / self.canvas.winfo_height()
|
||||
data = [{"points": [(int(x * scale_w), int(y * scale_h)) for x, y in box]} for box in self.rg_data]
|
||||
|
||||
from io import StringIO # Function level import, as it's only required to store coordinates
|
||||
|
||||
write_buffer = StringIO()
|
||||
json.dump(data, write_buffer, indent=4)
|
||||
with open("bounding_boxes.json", "w", encoding="utf-8") as f:
|
||||
f.write(write_buffer.getvalue())
|
||||
self.messagebox.showinfo("Success", "Bounding boxes saved to bounding_boxes.json")
|
||||
|
||||
|
||||
class ParkingManagement(BaseSolution):
|
||||
"""
|
||||
Manages parking occupancy and availability using YOLO model for real-time monitoring and visualization.
|
||||
|
||||
This class extends BaseSolution to provide functionality for parking lot management, including detection of
|
||||
occupied spaces, visualization of parking regions, and display of occupancy statistics.
|
||||
|
||||
Attributes:
|
||||
json_file (str): Path to the JSON file containing parking region details.
|
||||
json (list[dict]): Loaded JSON data containing parking region information.
|
||||
pr_info (dict[str, int]): Dictionary storing parking information (Occupancy and Available spaces).
|
||||
arc (tuple[int, int, int]): RGB color tuple for available region visualization.
|
||||
occ (tuple[int, int, int]): RGB color tuple for occupied region visualization.
|
||||
dc (tuple[int, int, int]): RGB color tuple for centroid visualization of detected objects.
|
||||
|
||||
Methods:
|
||||
process: Process the input image for parking lot management and visualization.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.solutions import ParkingManagement
|
||||
>>> parking_manager = ParkingManagement(model="yolo11n.pt", json_file="parking_regions.json")
|
||||
>>> print(f"Occupied spaces: {parking_manager.pr_info['Occupancy']}")
|
||||
>>> print(f"Available spaces: {parking_manager.pr_info['Available']}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the parking management system with a YOLO model and visualization settings."""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.json_file = self.CFG["json_file"] # Load parking regions JSON data
|
||||
if self.json_file is None:
|
||||
LOGGER.warning("json_file argument missing. Parking region details required.")
|
||||
raise ValueError("❌ Json file path can not be empty")
|
||||
|
||||
with open(self.json_file) as f:
|
||||
self.json = json.load(f)
|
||||
|
||||
self.pr_info = {"Occupancy": 0, "Available": 0} # Dictionary for parking information
|
||||
|
||||
self.arc = (0, 0, 255) # Available region color
|
||||
self.occ = (0, 255, 0) # Occupied region color
|
||||
self.dc = (255, 0, 189) # Centroid color for each box
|
||||
|
||||
def process(self, im0: np.ndarray) -> SolutionResults:
|
||||
"""
|
||||
Process the input image for parking lot management and visualization.
|
||||
|
||||
This function analyzes the input image, extracts tracks, and determines the occupancy status of parking
|
||||
regions defined in the JSON file. It annotates the image with occupied and available parking spots,
|
||||
and updates the parking information.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input inference image.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, 'filled_slots' (number of occupied parking slots),
|
||||
'available_slots' (number of available parking slots), and 'total_tracks' (total number of tracked objects).
|
||||
|
||||
Examples:
|
||||
>>> parking_manager = ParkingManagement(json_file="parking_regions.json")
|
||||
>>> image = cv2.imread("parking_lot.jpg")
|
||||
>>> results = parking_manager.process(image)
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks from im0
|
||||
es, fs = len(self.json), 0 # Empty slots, filled slots
|
||||
annotator = SolutionAnnotator(im0, self.line_width) # Initialize annotator
|
||||
|
||||
for region in self.json:
|
||||
# Convert points to a NumPy array with the correct dtype and reshape properly
|
||||
pts_array = np.array(region["points"], dtype=np.int32).reshape((-1, 1, 2))
|
||||
rg_occupied = False # Occupied region initialization
|
||||
for box, cls in zip(self.boxes, self.clss):
|
||||
xc, yc = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
||||
dist = cv2.pointPolygonTest(pts_array, (xc, yc), False)
|
||||
if dist >= 0:
|
||||
# cv2.circle(im0, (xc, yc), radius=self.line_width * 4, color=self.dc, thickness=-1)
|
||||
annotator.display_objects_labels(
|
||||
im0, self.model.names[int(cls)], (104, 31, 17), (255, 255, 255), xc, yc, 10
|
||||
)
|
||||
rg_occupied = True
|
||||
break
|
||||
fs, es = (fs + 1, es - 1) if rg_occupied else (fs, es)
|
||||
# Plot regions
|
||||
cv2.polylines(im0, [pts_array], isClosed=True, color=self.occ if rg_occupied else self.arc, thickness=2)
|
||||
|
||||
self.pr_info["Occupancy"], self.pr_info["Available"] = fs, es
|
||||
|
||||
annotator.display_analytics(im0, self.pr_info, (104, 31, 17), (255, 255, 255), 10)
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return SolutionResults
|
||||
return SolutionResults(
|
||||
plot_im=plot_im,
|
||||
filled_slots=self.pr_info["Occupancy"],
|
||||
available_slots=self.pr_info["Available"],
|
||||
total_tracks=len(self.track_ids),
|
||||
)
|
||||
95
ultralytics/solutions/queue_management.py
Normal file
95
ultralytics/solutions/queue_management.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class QueueManager(BaseSolution):
|
||||
"""
|
||||
Manages queue counting in real-time video streams based on object tracks.
|
||||
|
||||
This class extends BaseSolution to provide functionality for tracking and counting objects within a specified
|
||||
region in video frames.
|
||||
|
||||
Attributes:
|
||||
counts (int): The current count of objects in the queue.
|
||||
rect_color (tuple[int, int, int]): RGB color tuple for drawing the queue region rectangle.
|
||||
region_length (int): The number of points defining the queue region.
|
||||
track_line (list[tuple[int, int]]): List of track line coordinates.
|
||||
track_history (dict[int, list[tuple[int, int]]]): Dictionary storing tracking history for each object.
|
||||
|
||||
Methods:
|
||||
initialize_region: Initialize the queue region.
|
||||
process: Process a single frame for queue management.
|
||||
extract_tracks: Extract object tracks from the current frame.
|
||||
store_tracking_history: Store the tracking history for an object.
|
||||
display_output: Display the processed output.
|
||||
|
||||
Examples:
|
||||
>>> cap = cv2.VideoCapture("path/to/video.mp4")
|
||||
>>> queue_manager = QueueManager(region=[100, 100, 200, 200, 300, 300])
|
||||
>>> while cap.isOpened():
|
||||
>>> success, im0 = cap.read()
|
||||
>>> if not success:
|
||||
>>> break
|
||||
>>> results = queue_manager.process(im0)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the QueueManager with parameters for tracking and counting objects in a video stream."""
|
||||
super().__init__(**kwargs)
|
||||
self.initialize_region()
|
||||
self.counts = 0 # Queue counts information
|
||||
self.rect_color = (255, 255, 255) # Rectangle color for visualization
|
||||
self.region_length = len(self.region) # Store region length for further usage
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Process queue management for a single frame of video.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image for processing, typically a frame from a video stream.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `im0`, 'queue_count' (int, number of objects in the queue) and
|
||||
'total_tracks' (int, total number of tracked objects).
|
||||
|
||||
Examples:
|
||||
>>> queue_manager = QueueManager()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = queue_manager.process(frame)
|
||||
"""
|
||||
self.counts = 0 # Reset counts every frame
|
||||
self.extract_tracks(im0) # Extract tracks from the current frame
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
annotator.draw_region(reg_pts=self.region, color=self.rect_color, thickness=self.line_width * 2) # Draw region
|
||||
|
||||
for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||||
# Draw bounding box and counting region
|
||||
annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))
|
||||
self.store_tracking_history(track_id, box) # Store track history
|
||||
|
||||
# Cache frequently accessed attributes
|
||||
track_history = self.track_history.get(track_id, [])
|
||||
|
||||
# Store previous position of track and check if the object is inside the counting region
|
||||
prev_position = None
|
||||
if len(track_history) > 1:
|
||||
prev_position = track_history[-2]
|
||||
if self.region_length >= 3 and prev_position and self.r_s.contains(self.Point(self.track_line[-1])):
|
||||
self.counts += 1
|
||||
|
||||
# Display queue counts
|
||||
annotator.queue_counts_display(
|
||||
f"Queue Counts : {str(self.counts)}",
|
||||
points=self.region,
|
||||
region_color=self.rect_color,
|
||||
txt_color=(104, 31, 17),
|
||||
)
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return a SolutionResults object with processed data
|
||||
return SolutionResults(plot_im=plot_im, queue_count=self.counts, total_tracks=len(self.track_ids))
|
||||
136
ultralytics/solutions/region_counter.py
Normal file
136
ultralytics/solutions/region_counter.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class RegionCounter(BaseSolution):
|
||||
"""
|
||||
A class for real-time counting of objects within user-defined regions in a video stream.
|
||||
|
||||
This class inherits from `BaseSolution` and provides functionality to define polygonal regions in a video frame,
|
||||
track objects, and count those objects that pass through each defined region. Useful for applications requiring
|
||||
counting in specified areas, such as monitoring zones or segmented sections.
|
||||
|
||||
Attributes:
|
||||
region_template (dict): Template for creating new counting regions with default attributes including name,
|
||||
polygon coordinates, and display colors.
|
||||
counting_regions (list): List storing all defined regions, where each entry is based on `region_template`
|
||||
and includes specific region settings like name, coordinates, and color.
|
||||
region_counts (dict): Dictionary storing the count of objects for each named region.
|
||||
|
||||
Methods:
|
||||
add_region: Add a new counting region with specified attributes.
|
||||
process: Process video frames to count objects in each region.
|
||||
initialize_regions: Initialize zones to count the objects in each one. Zones could be multiple as well.
|
||||
|
||||
Examples:
|
||||
Initialize a RegionCounter and add a counting region
|
||||
>>> counter = RegionCounter()
|
||||
>>> counter.add_region("Zone1", [(100, 100), (200, 100), (200, 200), (100, 200)], (255, 0, 0), (255, 255, 255))
|
||||
>>> results = counter.process(frame)
|
||||
>>> print(f"Total tracks: {results.total_tracks}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the RegionCounter for real-time object counting in user-defined regions."""
|
||||
super().__init__(**kwargs)
|
||||
self.region_template = {
|
||||
"name": "Default Region",
|
||||
"polygon": None,
|
||||
"counts": 0,
|
||||
"region_color": (255, 255, 255),
|
||||
"text_color": (0, 0, 0),
|
||||
}
|
||||
self.region_counts = {}
|
||||
self.counting_regions = []
|
||||
self.initialize_regions()
|
||||
|
||||
def add_region(
|
||||
self,
|
||||
name: str,
|
||||
polygon_points: list[tuple],
|
||||
region_color: tuple[int, int, int],
|
||||
text_color: tuple[int, int, int],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Add a new region to the counting list based on the provided template with specific attributes.
|
||||
|
||||
Args:
|
||||
name (str): Name assigned to the new region.
|
||||
polygon_points (list[tuple]): List of (x, y) coordinates defining the region's polygon.
|
||||
region_color (tuple[int, int, int]): BGR color for region visualization.
|
||||
text_color (tuple[int, int, int]): BGR color for the text within the region.
|
||||
|
||||
Returns:
|
||||
(dict[str, any]): Returns a dictionary including the region information i.e. name, region_color etc.
|
||||
"""
|
||||
region = self.region_template.copy()
|
||||
region.update(
|
||||
{
|
||||
"name": name,
|
||||
"polygon": self.Polygon(polygon_points),
|
||||
"region_color": region_color,
|
||||
"text_color": text_color,
|
||||
}
|
||||
)
|
||||
self.counting_regions.append(region)
|
||||
return region
|
||||
|
||||
def initialize_regions(self):
|
||||
"""Initialize regions only once."""
|
||||
if self.region is None:
|
||||
self.initialize_region()
|
||||
if not isinstance(self.region, dict): # Ensure self.region is initialized and structured as a dictionary
|
||||
self.region = {"Region#01": self.region}
|
||||
for i, (name, pts) in enumerate(self.region.items()):
|
||||
region = self.add_region(name, pts, colors(i, True), (255, 255, 255))
|
||||
region["prepared_polygon"] = self.prep(region["polygon"])
|
||||
|
||||
def process(self, im0: np.ndarray) -> SolutionResults:
|
||||
"""
|
||||
Process the input frame to detect and count objects within each defined region.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image frame where objects and regions are annotated.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, 'total_tracks' (int, total number of tracked objects),
|
||||
and 'region_counts' (dict, counts of objects per region).
|
||||
"""
|
||||
self.extract_tracks(im0)
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width)
|
||||
|
||||
for box, cls, track_id, conf in zip(self.boxes, self.clss, self.track_ids, self.confs):
|
||||
annotator.box_label(box, label=self.adjust_box_label(cls, conf, track_id), color=colors(track_id, True))
|
||||
center = self.Point(((box[0] + box[2]) / 2, (box[1] + box[3]) / 2))
|
||||
for region in self.counting_regions:
|
||||
if region["prepared_polygon"].contains(center):
|
||||
region["counts"] += 1
|
||||
self.region_counts[region["name"]] = region["counts"]
|
||||
|
||||
# Display region counts
|
||||
for region in self.counting_regions:
|
||||
poly = region["polygon"]
|
||||
pts = list(map(tuple, np.array(poly.exterior.coords, dtype=np.int32)))
|
||||
(x1, y1), (x2, y2) = [(int(poly.centroid.x), int(poly.centroid.y))] * 2
|
||||
annotator.draw_region(pts, region["region_color"], self.line_width * 2)
|
||||
annotator.adaptive_label(
|
||||
[x1, y1, x2, y2],
|
||||
label=str(region["counts"]),
|
||||
color=region["region_color"],
|
||||
txt_color=region["text_color"],
|
||||
margin=self.line_width * 4,
|
||||
shape="rect",
|
||||
)
|
||||
region["counts"] = 0 # Reset for next frame
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im)
|
||||
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), region_counts=self.region_counts)
|
||||
156
ultralytics/solutions/security_alarm.py
Normal file
156
ultralytics/solutions/security_alarm.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class SecurityAlarm(BaseSolution):
|
||||
"""
|
||||
A class to manage security alarm functionalities for real-time monitoring.
|
||||
|
||||
This class extends the BaseSolution class and provides features to monitor objects in a frame, send email
|
||||
notifications when specific thresholds are exceeded for total detections, and annotate the output frame for
|
||||
visualization.
|
||||
|
||||
Attributes:
|
||||
email_sent (bool): Flag to track if an email has already been sent for the current event.
|
||||
records (int): Threshold for the number of detected objects to trigger an alert.
|
||||
server (smtplib.SMTP): SMTP server connection for sending email alerts.
|
||||
to_email (str): Recipient's email address for alerts.
|
||||
from_email (str): Sender's email address for alerts.
|
||||
|
||||
Methods:
|
||||
authenticate: Set up email server authentication for sending alerts.
|
||||
send_email: Send an email notification with details and an image attachment.
|
||||
process: Monitor the frame, process detections, and trigger alerts if thresholds are crossed.
|
||||
|
||||
Examples:
|
||||
>>> security = SecurityAlarm()
|
||||
>>> security.authenticate("abc@gmail.com", "1111222233334444", "xyz@gmail.com")
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = security.process(frame)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the SecurityAlarm class with parameters for real-time object monitoring.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.email_sent = False
|
||||
self.records = self.CFG["records"]
|
||||
self.server = None
|
||||
self.to_email = ""
|
||||
self.from_email = ""
|
||||
|
||||
def authenticate(self, from_email: str, password: str, to_email: str) -> None:
|
||||
"""
|
||||
Authenticate the email server for sending alert notifications.
|
||||
|
||||
Args:
|
||||
from_email (str): Sender's email address.
|
||||
password (str): Password for the sender's email account.
|
||||
to_email (str): Recipient's email address.
|
||||
|
||||
This method initializes a secure connection with the SMTP server and logs in using the provided credentials.
|
||||
|
||||
Examples:
|
||||
>>> alarm = SecurityAlarm()
|
||||
>>> alarm.authenticate("sender@example.com", "password123", "recipient@example.com")
|
||||
"""
|
||||
import smtplib
|
||||
|
||||
self.server = smtplib.SMTP("smtp.gmail.com: 587")
|
||||
self.server.starttls()
|
||||
self.server.login(from_email, password)
|
||||
self.to_email = to_email
|
||||
self.from_email = from_email
|
||||
|
||||
def send_email(self, im0, records: int = 5) -> None:
|
||||
"""
|
||||
Send an email notification with an image attachment indicating the number of objects detected.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image or frame to be attached to the email.
|
||||
records (int, optional): The number of detected objects to be included in the email message.
|
||||
|
||||
This method encodes the input image, composes the email message with details about the detection, and sends it
|
||||
to the specified recipient.
|
||||
|
||||
Examples:
|
||||
>>> alarm = SecurityAlarm()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> alarm.send_email(frame, records=10)
|
||||
"""
|
||||
from email.mime.image import MIMEImage
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
import cv2
|
||||
|
||||
img_bytes = cv2.imencode(".jpg", im0)[1].tobytes() # Encode the image as JPEG
|
||||
|
||||
# Create the email
|
||||
message = MIMEMultipart()
|
||||
message["From"] = self.from_email
|
||||
message["To"] = self.to_email
|
||||
message["Subject"] = "Security Alert"
|
||||
|
||||
# Add the text message body
|
||||
message_body = f"Ultralytics ALERT!!! {records} objects have been detected!!"
|
||||
message.attach(MIMEText(message_body))
|
||||
|
||||
# Attach the image
|
||||
image_attachment = MIMEImage(img_bytes, name="ultralytics.jpg")
|
||||
message.attach(image_attachment)
|
||||
|
||||
# Send the email
|
||||
try:
|
||||
self.server.send_message(message)
|
||||
LOGGER.info("Email sent successfully!")
|
||||
except Exception as e:
|
||||
LOGGER.error(f"Failed to send email: {e}")
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Monitor the frame, process object detections, and trigger alerts if thresholds are exceeded.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image or frame to be processed and annotated.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im`, 'total_tracks' (total number of tracked objects) and
|
||||
'email_sent' (whether an email alert was triggered).
|
||||
|
||||
This method processes the input frame, extracts detections, annotates the frame with bounding boxes, and sends
|
||||
an email notification if the number of detected objects surpasses the specified threshold and an alert has not
|
||||
already been sent.
|
||||
|
||||
Examples:
|
||||
>>> alarm = SecurityAlarm()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> results = alarm.process(frame)
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
# Iterate over bounding boxes and classes index
|
||||
for box, cls in zip(self.boxes, self.clss):
|
||||
# Draw bounding box
|
||||
annotator.box_label(box, label=self.names[cls], color=colors(cls, True))
|
||||
|
||||
total_det = len(self.clss)
|
||||
if total_det >= self.records and not self.email_sent: # Only send email if not sent before
|
||||
self.send_email(im0, total_det)
|
||||
self.email_sent = True
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return a SolutionResults
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), email_sent=self.email_sent)
|
||||
224
ultralytics/solutions/similarity_search.py
Normal file
224
ultralytics/solutions/similarity_search.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.utils import IMG_FORMATS
|
||||
from ultralytics.utils import LOGGER, TORCH_VERSION
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.torch_utils import TORCH_2_4, select_device
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Avoid OpenMP conflict on some systems
|
||||
|
||||
|
||||
class VisualAISearch:
|
||||
"""
|
||||
A semantic image search system that leverages OpenCLIP for generating high-quality image and text embeddings and
|
||||
FAISS for fast similarity-based retrieval.
|
||||
|
||||
This class aligns image and text embeddings in a shared semantic space, enabling users to search large collections
|
||||
of images using natural language queries with high accuracy and speed.
|
||||
|
||||
Attributes:
|
||||
data (str): Directory containing images.
|
||||
device (str): Computation device, e.g., 'cpu' or 'cuda'.
|
||||
faiss_index (str): Path to the FAISS index file.
|
||||
data_path_npy (str): Path to the numpy file storing image paths.
|
||||
data_dir (Path): Path object for the data directory.
|
||||
model: Loaded CLIP model.
|
||||
index: FAISS index for similarity search.
|
||||
image_paths (list[str]): List of image file paths.
|
||||
|
||||
Methods:
|
||||
extract_image_feature: Extract CLIP embedding from an image.
|
||||
extract_text_feature: Extract CLIP embedding from text.
|
||||
load_or_build_index: Load existing FAISS index or build new one.
|
||||
search: Perform semantic search for similar images.
|
||||
|
||||
Examples:
|
||||
Initialize and search for images
|
||||
>>> searcher = VisualAISearch(data="path/to/images", device="cuda")
|
||||
>>> results = searcher.search("a cat sitting on a chair", k=10)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the VisualAISearch class with FAISS index and CLIP model."""
|
||||
assert TORCH_2_4, f"VisualAISearch requires torch>=2.4 (found torch=={TORCH_VERSION})"
|
||||
from ultralytics.nn.text_model import build_text_model
|
||||
|
||||
check_requirements("faiss-cpu")
|
||||
|
||||
self.faiss = __import__("faiss")
|
||||
self.faiss_index = "faiss.index"
|
||||
self.data_path_npy = "paths.npy"
|
||||
self.data_dir = Path(kwargs.get("data", "images"))
|
||||
self.device = select_device(kwargs.get("device", "cpu"))
|
||||
|
||||
if not self.data_dir.exists():
|
||||
from ultralytics.utils import ASSETS_URL
|
||||
|
||||
LOGGER.warning(f"{self.data_dir} not found. Downloading images.zip from {ASSETS_URL}/images.zip")
|
||||
from ultralytics.utils.downloads import safe_download
|
||||
|
||||
safe_download(url=f"{ASSETS_URL}/images.zip", unzip=True, retry=3)
|
||||
self.data_dir = Path("images")
|
||||
|
||||
self.model = build_text_model("clip:ViT-B/32", device=self.device)
|
||||
|
||||
self.index = None
|
||||
self.image_paths = []
|
||||
|
||||
self.load_or_build_index()
|
||||
|
||||
def extract_image_feature(self, path: Path) -> np.ndarray:
|
||||
"""Extract CLIP image embedding from the given image path."""
|
||||
return self.model.encode_image(Image.open(path)).cpu().numpy()
|
||||
|
||||
def extract_text_feature(self, text: str) -> np.ndarray:
|
||||
"""Extract CLIP text embedding from the given text query."""
|
||||
return self.model.encode_text(self.model.tokenize([text])).cpu().numpy()
|
||||
|
||||
def load_or_build_index(self) -> None:
|
||||
"""
|
||||
Load existing FAISS index or build a new one from image features.
|
||||
|
||||
Checks if FAISS index and image paths exist on disk. If found, loads them directly. Otherwise, builds a new
|
||||
index by extracting features from all images in the data directory, normalizes the features, and saves both the
|
||||
index and image paths for future use.
|
||||
"""
|
||||
# Check if the FAISS index and corresponding image paths already exist
|
||||
if Path(self.faiss_index).exists() and Path(self.data_path_npy).exists():
|
||||
LOGGER.info("Loading existing FAISS index...")
|
||||
self.index = self.faiss.read_index(self.faiss_index) # Load the FAISS index from disk
|
||||
self.image_paths = np.load(self.data_path_npy) # Load the saved image path list
|
||||
return # Exit the function as the index is successfully loaded
|
||||
|
||||
# If the index doesn't exist, start building it from scratch
|
||||
LOGGER.info("Building FAISS index from images...")
|
||||
vectors = [] # List to store feature vectors of images
|
||||
|
||||
# Iterate over all image files in the data directory
|
||||
for file in self.data_dir.iterdir():
|
||||
# Skip files that are not valid image formats
|
||||
if file.suffix.lower().lstrip(".") not in IMG_FORMATS:
|
||||
continue
|
||||
try:
|
||||
# Extract feature vector for the image and add to the list
|
||||
vectors.append(self.extract_image_feature(file))
|
||||
self.image_paths.append(file.name) # Store the corresponding image name
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"Skipping {file.name}: {e}")
|
||||
|
||||
# If no vectors were successfully created, raise an error
|
||||
if not vectors:
|
||||
raise RuntimeError("No image embeddings could be generated.")
|
||||
|
||||
vectors = np.vstack(vectors).astype("float32") # Stack all vectors into a NumPy array and convert to float32
|
||||
self.faiss.normalize_L2(vectors) # Normalize vectors to unit length for cosine similarity
|
||||
|
||||
self.index = self.faiss.IndexFlatIP(vectors.shape[1]) # Create a new FAISS index using inner product
|
||||
self.index.add(vectors) # Add the normalized vectors to the FAISS index
|
||||
self.faiss.write_index(self.index, self.faiss_index) # Save the newly built FAISS index to disk
|
||||
np.save(self.data_path_npy, np.array(self.image_paths)) # Save the list of image paths to disk
|
||||
|
||||
LOGGER.info(f"Indexed {len(self.image_paths)} images.")
|
||||
|
||||
def search(self, query: str, k: int = 30, similarity_thresh: float = 0.1) -> list[str]:
|
||||
"""
|
||||
Return top-k semantically similar images to the given query.
|
||||
|
||||
Args:
|
||||
query (str): Natural language text query to search for.
|
||||
k (int, optional): Maximum number of results to return.
|
||||
similarity_thresh (float, optional): Minimum similarity threshold for filtering results.
|
||||
|
||||
Returns:
|
||||
(list[str]): List of image filenames ranked by similarity score.
|
||||
|
||||
Examples:
|
||||
Search for images matching a query
|
||||
>>> searcher = VisualAISearch(data="images")
|
||||
>>> results = searcher.search("red car", k=5, similarity_thresh=0.2)
|
||||
"""
|
||||
text_feat = self.extract_text_feature(query).astype("float32")
|
||||
self.faiss.normalize_L2(text_feat)
|
||||
|
||||
D, index = self.index.search(text_feat, k)
|
||||
results = [
|
||||
(self.image_paths[i], float(D[0][idx])) for idx, i in enumerate(index[0]) if D[0][idx] >= similarity_thresh
|
||||
]
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
LOGGER.info("\nRanked Results:")
|
||||
for name, score in results:
|
||||
LOGGER.info(f" - {name} | Similarity: {score:.4f}")
|
||||
|
||||
return [r[0] for r in results]
|
||||
|
||||
def __call__(self, query: str) -> list[str]:
|
||||
"""Direct call interface for the search function."""
|
||||
return self.search(query)
|
||||
|
||||
|
||||
class SearchApp:
|
||||
"""
|
||||
A Flask-based web interface for semantic image search with natural language queries.
|
||||
|
||||
This class provides a clean, responsive frontend that enables users to input natural language queries and
|
||||
instantly view the most relevant images retrieved from the indexed database.
|
||||
|
||||
Attributes:
|
||||
render_template: Flask template rendering function.
|
||||
request: Flask request object.
|
||||
searcher (VisualAISearch): Instance of the VisualAISearch class.
|
||||
app (Flask): Flask application instance.
|
||||
|
||||
Methods:
|
||||
index: Process user queries and display search results.
|
||||
run: Start the Flask web application.
|
||||
|
||||
Examples:
|
||||
Start a search application
|
||||
>>> app = SearchApp(data="path/to/images", device="cuda")
|
||||
>>> app.run(debug=True)
|
||||
"""
|
||||
|
||||
def __init__(self, data: str = "images", device: str = None) -> None:
|
||||
"""
|
||||
Initialize the SearchApp with VisualAISearch backend.
|
||||
|
||||
Args:
|
||||
data (str, optional): Path to directory containing images to index and search.
|
||||
device (str, optional): Device to run inference on (e.g. 'cpu', 'cuda').
|
||||
"""
|
||||
check_requirements("flask>=3.0.1")
|
||||
from flask import Flask, render_template, request
|
||||
|
||||
self.render_template = render_template
|
||||
self.request = request
|
||||
self.searcher = VisualAISearch(data=data, device=device)
|
||||
self.app = Flask(
|
||||
__name__,
|
||||
template_folder="templates",
|
||||
static_folder=Path(data).resolve(), # Absolute path to serve images
|
||||
static_url_path="/images", # URL prefix for images
|
||||
)
|
||||
self.app.add_url_rule("/", view_func=self.index, methods=["GET", "POST"])
|
||||
|
||||
def index(self) -> str:
|
||||
"""Process user query and display search results in the web interface."""
|
||||
results = []
|
||||
if self.request.method == "POST":
|
||||
query = self.request.form.get("query", "").strip()
|
||||
results = self.searcher(query)
|
||||
return self.render_template("similarity-search.html", results=results)
|
||||
|
||||
def run(self, debug: bool = False) -> None:
|
||||
"""Start the Flask web application server."""
|
||||
self.app.run(debug=debug)
|
||||
827
ultralytics/solutions/solutions.py
Normal file
827
ultralytics/solutions/solutions.py
Normal file
@@ -0,0 +1,827 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections import Counter, defaultdict
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.solutions.config import SolutionConfig
|
||||
from ultralytics.utils import ASSETS_URL, LOGGER, ops
|
||||
from ultralytics.utils.checks import check_imshow, check_requirements
|
||||
from ultralytics.utils.plotting import Annotator
|
||||
|
||||
|
||||
class BaseSolution:
|
||||
"""
|
||||
A base class for managing Ultralytics Solutions.
|
||||
|
||||
This class provides core functionality for various Ultralytics Solutions, including model loading, object tracking,
|
||||
and region initialization. It serves as the foundation for implementing specific computer vision solutions such as
|
||||
object counting, pose estimation, and analytics.
|
||||
|
||||
Attributes:
|
||||
LineString: Class for creating line string geometries from shapely.
|
||||
Polygon: Class for creating polygon geometries from shapely.
|
||||
Point: Class for creating point geometries from shapely.
|
||||
prep: Prepared geometry function from shapely for optimized spatial operations.
|
||||
CFG (dict[str, Any]): Configuration dictionary loaded from YAML file and updated with kwargs.
|
||||
LOGGER: Logger instance for solution-specific logging.
|
||||
annotator: Annotator instance for drawing on images.
|
||||
tracks: YOLO tracking results from the latest inference.
|
||||
track_data: Extracted tracking data (boxes or OBB) from tracks.
|
||||
boxes (list): Bounding box coordinates from tracking results.
|
||||
clss (list[int]): Class indices from tracking results.
|
||||
track_ids (list[int]): Track IDs from tracking results.
|
||||
confs (list[float]): Confidence scores from tracking results.
|
||||
track_line: Current track line for storing tracking history.
|
||||
masks: Segmentation masks from tracking results.
|
||||
r_s: Region or line geometry object for spatial operations.
|
||||
frame_no (int): Current frame number for logging purposes.
|
||||
region (list[tuple[int, int]]): List of coordinate tuples defining region of interest.
|
||||
line_width (int): Width of lines used in visualizations.
|
||||
model (YOLO): Loaded YOLO model instance.
|
||||
names (dict[int, str]): Dictionary mapping class indices to class names.
|
||||
classes (list[int]): List of class indices to track.
|
||||
show_conf (bool): Flag to show confidence scores in annotations.
|
||||
show_labels (bool): Flag to show class labels in annotations.
|
||||
device (str): Device for model inference.
|
||||
track_add_args (dict[str, Any]): Additional arguments for tracking configuration.
|
||||
env_check (bool): Flag indicating whether environment supports image display.
|
||||
track_history (defaultdict): Dictionary storing tracking history for each object.
|
||||
profilers (tuple): Profiler instances for performance monitoring.
|
||||
|
||||
Methods:
|
||||
adjust_box_label: Generate formatted label for bounding box.
|
||||
extract_tracks: Apply object tracking and extract tracks from input image.
|
||||
store_tracking_history: Store object tracking history for given track ID and bounding box.
|
||||
initialize_region: Initialize counting region and line segment based on configuration.
|
||||
display_output: Display processing results including frames or saved results.
|
||||
process: Process method to be implemented by each Solution subclass.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution(model="yolo11n.pt", region=[(0, 0), (100, 0), (100, 100), (0, 100)])
|
||||
>>> solution.initialize_region()
|
||||
>>> image = cv2.imread("image.jpg")
|
||||
>>> solution.extract_tracks(image)
|
||||
>>> solution.display_output(image)
|
||||
"""
|
||||
|
||||
def __init__(self, is_cli: bool = False, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the BaseSolution class with configuration settings and YOLO model.
|
||||
|
||||
Args:
|
||||
is_cli (bool): Enable CLI mode if set to True.
|
||||
**kwargs (Any): Additional configuration parameters that override defaults.
|
||||
"""
|
||||
self.CFG = vars(SolutionConfig().update(**kwargs))
|
||||
self.LOGGER = LOGGER # Store logger object to be used in multiple solution classes
|
||||
|
||||
check_requirements("shapely>=2.0.0")
|
||||
from shapely.geometry import LineString, Point, Polygon
|
||||
from shapely.prepared import prep
|
||||
|
||||
self.LineString = LineString
|
||||
self.Polygon = Polygon
|
||||
self.Point = Point
|
||||
self.prep = prep
|
||||
self.annotator = None # Initialize annotator
|
||||
self.tracks = None
|
||||
self.track_data = None
|
||||
self.boxes = []
|
||||
self.clss = []
|
||||
self.track_ids = []
|
||||
self.track_line = None
|
||||
self.masks = None
|
||||
self.r_s = None
|
||||
self.frame_no = -1 # Only for logging
|
||||
|
||||
self.LOGGER.info(f"Ultralytics Solutions: ✅ {self.CFG}")
|
||||
self.region = self.CFG["region"] # Store region data for other classes usage
|
||||
self.line_width = self.CFG["line_width"]
|
||||
|
||||
# Load Model and store additional information (classes, show_conf, show_label)
|
||||
if self.CFG["model"] is None:
|
||||
self.CFG["model"] = "yolo11n.pt"
|
||||
self.model = YOLO(self.CFG["model"])
|
||||
self.names = self.model.names
|
||||
self.classes = self.CFG["classes"]
|
||||
self.show_conf = self.CFG["show_conf"]
|
||||
self.show_labels = self.CFG["show_labels"]
|
||||
self.device = self.CFG["device"]
|
||||
|
||||
self.track_add_args = { # Tracker additional arguments for advance configuration
|
||||
k: self.CFG[k] for k in {"iou", "conf", "device", "max_det", "half", "tracker"}
|
||||
} # verbose must be passed to track method; setting it False in YOLO still logs the track information.
|
||||
|
||||
if is_cli and self.CFG["source"] is None:
|
||||
d_s = "solutions_ci_demo.mp4" if "-pose" not in self.CFG["model"] else "solution_ci_pose_demo.mp4"
|
||||
self.LOGGER.warning(f"source not provided. using default source {ASSETS_URL}/{d_s}")
|
||||
from ultralytics.utils.downloads import safe_download
|
||||
|
||||
safe_download(f"{ASSETS_URL}/{d_s}") # download source from ultralytics assets
|
||||
self.CFG["source"] = d_s # set default source
|
||||
|
||||
# Initialize environment and region setup
|
||||
self.env_check = check_imshow(warn=True)
|
||||
self.track_history = defaultdict(list)
|
||||
|
||||
self.profilers = (
|
||||
ops.Profile(device=self.device), # track
|
||||
ops.Profile(device=self.device), # solution
|
||||
)
|
||||
|
||||
def adjust_box_label(self, cls: int, conf: float, track_id: int | None = None) -> str | None:
|
||||
"""
|
||||
Generate a formatted label for a bounding box.
|
||||
|
||||
This method constructs a label string for a bounding box using the class index and confidence score.
|
||||
Optionally includes the track ID if provided. The label format adapts based on the display settings
|
||||
defined in `self.show_conf` and `self.show_labels`.
|
||||
|
||||
Args:
|
||||
cls (int): The class index of the detected object.
|
||||
conf (float): The confidence score of the detection.
|
||||
track_id (int, optional): The unique identifier for the tracked object.
|
||||
|
||||
Returns:
|
||||
(str | None): The formatted label string if `self.show_labels` is True; otherwise, None.
|
||||
"""
|
||||
name = ("" if track_id is None else f"{track_id} ") + self.names[cls]
|
||||
return (f"{name} {conf:.2f}" if self.show_conf else name) if self.show_labels else None
|
||||
|
||||
def extract_tracks(self, im0: np.ndarray) -> None:
|
||||
"""
|
||||
Apply object tracking and extract tracks from an input image or frame.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image or frame.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> solution.extract_tracks(frame)
|
||||
"""
|
||||
with self.profilers[0]:
|
||||
self.tracks = self.model.track(
|
||||
source=im0, persist=True, classes=self.classes, verbose=False, **self.track_add_args
|
||||
)[0]
|
||||
is_obb = self.tracks.obb is not None
|
||||
self.track_data = self.tracks.obb if is_obb else self.tracks.boxes # Extract tracks for OBB or object detection
|
||||
|
||||
if self.track_data and self.track_data.is_track:
|
||||
self.boxes = (self.track_data.xyxyxyxy if is_obb else self.track_data.xyxy).cpu()
|
||||
self.clss = self.track_data.cls.cpu().tolist()
|
||||
self.track_ids = self.track_data.id.int().cpu().tolist()
|
||||
self.confs = self.track_data.conf.cpu().tolist()
|
||||
else:
|
||||
self.LOGGER.warning("no tracks found!")
|
||||
self.boxes, self.clss, self.track_ids, self.confs = [], [], [], []
|
||||
|
||||
def store_tracking_history(self, track_id: int, box) -> None:
|
||||
"""
|
||||
Store the tracking history of an object.
|
||||
|
||||
This method updates the tracking history for a given object by appending the center point of its
|
||||
bounding box to the track line. It maintains a maximum of 30 points in the tracking history.
|
||||
|
||||
Args:
|
||||
track_id (int): The unique identifier for the tracked object.
|
||||
box (list[float]): The bounding box coordinates of the object in the format [x1, y1, x2, y2].
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> solution.store_tracking_history(1, [100, 200, 300, 400])
|
||||
"""
|
||||
# Store tracking history
|
||||
self.track_line = self.track_history[track_id]
|
||||
self.track_line.append(tuple(box.mean(dim=0)) if box.numel() > 4 else (box[:4:2].mean(), box[1:4:2].mean()))
|
||||
if len(self.track_line) > 30:
|
||||
self.track_line.pop(0)
|
||||
|
||||
def initialize_region(self) -> None:
|
||||
"""Initialize the counting region and line segment based on configuration settings."""
|
||||
if self.region is None:
|
||||
self.region = [(10, 200), (540, 200), (540, 180), (10, 180)]
|
||||
self.r_s = (
|
||||
self.Polygon(self.region) if len(self.region) >= 3 else self.LineString(self.region)
|
||||
) # region or line
|
||||
|
||||
def display_output(self, plot_im: np.ndarray) -> None:
|
||||
"""
|
||||
Display the results of the processing, which could involve showing frames, printing counts, or saving results.
|
||||
|
||||
This method is responsible for visualizing the output of the object detection and tracking process. It displays
|
||||
the processed frame with annotations, and allows for user interaction to close the display.
|
||||
|
||||
Args:
|
||||
plot_im (np.ndarray): The image or frame that has been processed and annotated.
|
||||
|
||||
Examples:
|
||||
>>> solution = BaseSolution()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> solution.display_output(frame)
|
||||
|
||||
Notes:
|
||||
- This method will only display output if the 'show' configuration is set to True and the environment
|
||||
supports image display.
|
||||
- The display can be closed by pressing the 'q' key.
|
||||
"""
|
||||
if self.CFG.get("show") and self.env_check:
|
||||
cv2.imshow("Ultralytics Solutions", plot_im)
|
||||
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||
cv2.destroyAllWindows() # Closes current frame window
|
||||
return
|
||||
|
||||
def process(self, *args: Any, **kwargs: Any):
|
||||
"""Process method should be implemented by each Solution subclass."""
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any):
|
||||
"""Allow instances to be called like a function with flexible arguments."""
|
||||
with self.profilers[1]:
|
||||
result = self.process(*args, **kwargs) # Call the subclass-specific process method
|
||||
track_or_predict = "predict" if type(self).__name__ == "ObjectCropper" else "track"
|
||||
track_or_predict_speed = self.profilers[0].dt * 1e3
|
||||
solution_speed = (self.profilers[1].dt - self.profilers[0].dt) * 1e3 # solution time = process - track
|
||||
result.speed = {track_or_predict: track_or_predict_speed, "solution": solution_speed}
|
||||
if self.CFG["verbose"]:
|
||||
self.frame_no += 1
|
||||
counts = Counter(self.clss) # Only for logging.
|
||||
LOGGER.info(
|
||||
f"{self.frame_no}: {result.plot_im.shape[0]}x{result.plot_im.shape[1]} {solution_speed:.1f}ms,"
|
||||
f" {', '.join([f'{v} {self.names[k]}' for k, v in counts.items()])}\n"
|
||||
f"Speed: {track_or_predict_speed:.1f}ms {track_or_predict}, "
|
||||
f"{solution_speed:.1f}ms solution per image at shape "
|
||||
f"(1, {getattr(self.model, 'ch', 3)}, {result.plot_im.shape[0]}, {result.plot_im.shape[1]})\n"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class SolutionAnnotator(Annotator):
|
||||
"""
|
||||
A specialized annotator class for visualizing and analyzing computer vision tasks.
|
||||
|
||||
This class extends the base Annotator class, providing additional methods for drawing regions, centroids, tracking
|
||||
trails, and visual annotations for Ultralytics Solutions. It offers comprehensive visualization capabilities for
|
||||
various computer vision applications including object detection, tracking, pose estimation, and analytics.
|
||||
|
||||
Attributes:
|
||||
im (np.ndarray): The image being annotated.
|
||||
line_width (int): Thickness of lines used in annotations.
|
||||
font_size (int): Size of the font used for text annotations.
|
||||
font (str): Path to the font file used for text rendering.
|
||||
pil (bool): Whether to use PIL for text rendering.
|
||||
example (str): An example attribute for demonstration purposes.
|
||||
|
||||
Methods:
|
||||
draw_region: Draw a region using specified points, colors, and thickness.
|
||||
queue_counts_display: Display queue counts in the specified region.
|
||||
display_analytics: Display overall statistics for parking lot management.
|
||||
estimate_pose_angle: Calculate the angle between three points in an object pose.
|
||||
draw_specific_kpts: Draw specific keypoints on the image.
|
||||
plot_workout_information: Draw a labeled text box on the image.
|
||||
plot_angle_and_count_and_stage: Visualize angle, step count, and stage for workout monitoring.
|
||||
plot_distance_and_line: Display the distance between centroids and connect them with a line.
|
||||
display_objects_labels: Annotate bounding boxes with object class labels.
|
||||
sweep_annotator: Visualize a vertical sweep line and optional label.
|
||||
visioneye: Map and connect object centroids to a visual "eye" point.
|
||||
adaptive_label: Draw a circular or rectangle background shape label in center of a bounding box.
|
||||
|
||||
Examples:
|
||||
>>> annotator = SolutionAnnotator(image)
|
||||
>>> annotator.draw_region([(0, 0), (100, 100)], color=(0, 255, 0), thickness=5)
|
||||
>>> annotator.display_analytics(
|
||||
... image, text={"Available Spots": 5}, txt_color=(0, 0, 0), bg_color=(255, 255, 255), margin=10
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
im: np.ndarray,
|
||||
line_width: int | None = None,
|
||||
font_size: int | None = None,
|
||||
font: str = "Arial.ttf",
|
||||
pil: bool = False,
|
||||
example: str = "abc",
|
||||
):
|
||||
"""
|
||||
Initialize the SolutionAnnotator class with an image for annotation.
|
||||
|
||||
Args:
|
||||
im (np.ndarray): The image to be annotated.
|
||||
line_width (int, optional): Line thickness for drawing on the image.
|
||||
font_size (int, optional): Font size for text annotations.
|
||||
font (str): Path to the font file.
|
||||
pil (bool): Indicates whether to use PIL for rendering text.
|
||||
example (str): An example parameter for demonstration purposes.
|
||||
"""
|
||||
super().__init__(im, line_width, font_size, font, pil, example)
|
||||
|
||||
def draw_region(
|
||||
self,
|
||||
reg_pts: list[tuple[int, int]] | None = None,
|
||||
color: tuple[int, int, int] = (0, 255, 0),
|
||||
thickness: int = 5,
|
||||
):
|
||||
"""
|
||||
Draw a region or line on the image.
|
||||
|
||||
Args:
|
||||
reg_pts (list[tuple[int, int]], optional): Region points (for line 2 points, for region 4+ points).
|
||||
color (tuple[int, int, int]): RGB color value for the region.
|
||||
thickness (int): Line thickness for drawing the region.
|
||||
"""
|
||||
cv2.polylines(self.im, [np.array(reg_pts, dtype=np.int32)], isClosed=True, color=color, thickness=thickness)
|
||||
|
||||
# Draw small circles at the corner points
|
||||
for point in reg_pts:
|
||||
cv2.circle(self.im, (point[0], point[1]), thickness * 2, color, -1) # -1 fills the circle
|
||||
|
||||
def queue_counts_display(
|
||||
self,
|
||||
label: str,
|
||||
points: list[tuple[int, int]] | None = None,
|
||||
region_color: tuple[int, int, int] = (255, 255, 255),
|
||||
txt_color: tuple[int, int, int] = (0, 0, 0),
|
||||
):
|
||||
"""
|
||||
Display queue counts on an image centered at the points with customizable font size and colors.
|
||||
|
||||
Args:
|
||||
label (str): Queue counts label.
|
||||
points (list[tuple[int, int]], optional): Region points for center point calculation to display text.
|
||||
region_color (tuple[int, int, int]): RGB queue region color.
|
||||
txt_color (tuple[int, int, int]): RGB text display color.
|
||||
"""
|
||||
x_values = [point[0] for point in points]
|
||||
y_values = [point[1] for point in points]
|
||||
center_x = sum(x_values) // len(points)
|
||||
center_y = sum(y_values) // len(points)
|
||||
|
||||
text_size = cv2.getTextSize(label, 0, fontScale=self.sf, thickness=self.tf)[0]
|
||||
text_width = text_size[0]
|
||||
text_height = text_size[1]
|
||||
|
||||
rect_width = text_width + 20
|
||||
rect_height = text_height + 20
|
||||
rect_top_left = (center_x - rect_width // 2, center_y - rect_height // 2)
|
||||
rect_bottom_right = (center_x + rect_width // 2, center_y + rect_height // 2)
|
||||
cv2.rectangle(self.im, rect_top_left, rect_bottom_right, region_color, -1)
|
||||
|
||||
text_x = center_x - text_width // 2
|
||||
text_y = center_y + text_height // 2
|
||||
|
||||
# Draw text
|
||||
cv2.putText(
|
||||
self.im,
|
||||
label,
|
||||
(text_x, text_y),
|
||||
0,
|
||||
fontScale=self.sf,
|
||||
color=txt_color,
|
||||
thickness=self.tf,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
def display_analytics(
|
||||
self,
|
||||
im0: np.ndarray,
|
||||
text: dict[str, Any],
|
||||
txt_color: tuple[int, int, int],
|
||||
bg_color: tuple[int, int, int],
|
||||
margin: int,
|
||||
):
|
||||
"""
|
||||
Display the overall statistics for parking lots, object counter etc.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Inference image.
|
||||
text (dict[str, Any]): Labels dictionary.
|
||||
txt_color (tuple[int, int, int]): Display color for text foreground.
|
||||
bg_color (tuple[int, int, int]): Display color for text background.
|
||||
margin (int): Gap between text and rectangle for better display.
|
||||
"""
|
||||
horizontal_gap = int(im0.shape[1] * 0.02)
|
||||
vertical_gap = int(im0.shape[0] * 0.01)
|
||||
text_y_offset = 0
|
||||
for label, value in text.items():
|
||||
txt = f"{label}: {value}"
|
||||
text_size = cv2.getTextSize(txt, 0, self.sf, self.tf)[0]
|
||||
if text_size[0] < 5 or text_size[1] < 5:
|
||||
text_size = (5, 5)
|
||||
text_x = im0.shape[1] - text_size[0] - margin * 2 - horizontal_gap
|
||||
text_y = text_y_offset + text_size[1] + margin * 2 + vertical_gap
|
||||
rect_x1 = text_x - margin * 2
|
||||
rect_y1 = text_y - text_size[1] - margin * 2
|
||||
rect_x2 = text_x + text_size[0] + margin * 2
|
||||
rect_y2 = text_y + margin * 2
|
||||
cv2.rectangle(im0, (rect_x1, rect_y1), (rect_x2, rect_y2), bg_color, -1)
|
||||
cv2.putText(im0, txt, (text_x, text_y), 0, self.sf, txt_color, self.tf, lineType=cv2.LINE_AA)
|
||||
text_y_offset = rect_y2
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=256)
|
||||
def estimate_pose_angle(a: list[float], b: list[float], c: list[float]) -> float:
|
||||
"""
|
||||
Calculate the angle between three points for workout monitoring.
|
||||
|
||||
Args:
|
||||
a (list[float]): The coordinates of the first point.
|
||||
b (list[float]): The coordinates of the second point (vertex).
|
||||
c (list[float]): The coordinates of the third point.
|
||||
|
||||
Returns:
|
||||
(float): The angle in degrees between the three points.
|
||||
"""
|
||||
radians = math.atan2(c[1] - b[1], c[0] - b[0]) - math.atan2(a[1] - b[1], a[0] - b[0])
|
||||
angle = abs(radians * 180.0 / math.pi)
|
||||
return angle if angle <= 180.0 else (360 - angle)
|
||||
|
||||
def draw_specific_kpts(
|
||||
self,
|
||||
keypoints: list[list[float]],
|
||||
indices: list[int] | None = None,
|
||||
radius: int = 2,
|
||||
conf_thresh: float = 0.25,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Draw specific keypoints for gym steps counting.
|
||||
|
||||
Args:
|
||||
keypoints (list[list[float]]): Keypoints data to be plotted, each in format [x, y, confidence].
|
||||
indices (list[int], optional): Keypoint indices to be plotted.
|
||||
radius (int): Keypoint radius.
|
||||
conf_thresh (float): Confidence threshold for keypoints.
|
||||
|
||||
Returns:
|
||||
(np.ndarray): Image with drawn keypoints.
|
||||
|
||||
Notes:
|
||||
Keypoint format: [x, y] or [x, y, confidence].
|
||||
Modifies self.im in-place.
|
||||
"""
|
||||
indices = indices or [2, 5, 7]
|
||||
points = [(int(k[0]), int(k[1])) for i, k in enumerate(keypoints) if i in indices and k[2] >= conf_thresh]
|
||||
|
||||
# Draw lines between consecutive points
|
||||
for start, end in zip(points[:-1], points[1:]):
|
||||
cv2.line(self.im, start, end, (0, 255, 0), 2, lineType=cv2.LINE_AA)
|
||||
|
||||
# Draw circles for keypoints
|
||||
for pt in points:
|
||||
cv2.circle(self.im, pt, radius, (0, 0, 255), -1, lineType=cv2.LINE_AA)
|
||||
|
||||
return self.im
|
||||
|
||||
def plot_workout_information(
|
||||
self,
|
||||
display_text: str,
|
||||
position: tuple[int, int],
|
||||
color: tuple[int, int, int] = (104, 31, 17),
|
||||
txt_color: tuple[int, int, int] = (255, 255, 255),
|
||||
) -> int:
|
||||
"""
|
||||
Draw workout text with a background on the image.
|
||||
|
||||
Args:
|
||||
display_text (str): The text to be displayed.
|
||||
position (tuple[int, int]): Coordinates (x, y) on the image where the text will be placed.
|
||||
color (tuple[int, int, int]): Text background color.
|
||||
txt_color (tuple[int, int, int]): Text foreground color.
|
||||
|
||||
Returns:
|
||||
(int): The height of the text.
|
||||
"""
|
||||
(text_width, text_height), _ = cv2.getTextSize(display_text, 0, fontScale=self.sf, thickness=self.tf)
|
||||
|
||||
# Draw background rectangle
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
(position[0], position[1] - text_height - 5),
|
||||
(position[0] + text_width + 10, position[1] - text_height - 5 + text_height + 10 + self.tf),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
# Draw text
|
||||
cv2.putText(self.im, display_text, position, 0, self.sf, txt_color, self.tf)
|
||||
|
||||
return text_height
|
||||
|
||||
def plot_angle_and_count_and_stage(
|
||||
self,
|
||||
angle_text: str,
|
||||
count_text: str,
|
||||
stage_text: str,
|
||||
center_kpt: list[int],
|
||||
color: tuple[int, int, int] = (104, 31, 17),
|
||||
txt_color: tuple[int, int, int] = (255, 255, 255),
|
||||
):
|
||||
"""
|
||||
Plot the pose angle, count value, and step stage for workout monitoring.
|
||||
|
||||
Args:
|
||||
angle_text (str): Angle value for workout monitoring.
|
||||
count_text (str): Counts value for workout monitoring.
|
||||
stage_text (str): Stage decision for workout monitoring.
|
||||
center_kpt (list[int]): Centroid pose index for workout monitoring.
|
||||
color (tuple[int, int, int]): Text background color.
|
||||
txt_color (tuple[int, int, int]): Text foreground color.
|
||||
"""
|
||||
# Format text
|
||||
angle_text, count_text, stage_text = f" {angle_text:.2f}", f"Steps : {count_text}", f" {stage_text}"
|
||||
|
||||
# Draw angle, count and stage text
|
||||
angle_height = self.plot_workout_information(
|
||||
angle_text, (int(center_kpt[0]), int(center_kpt[1])), color, txt_color
|
||||
)
|
||||
count_height = self.plot_workout_information(
|
||||
count_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + 20), color, txt_color
|
||||
)
|
||||
self.plot_workout_information(
|
||||
stage_text, (int(center_kpt[0]), int(center_kpt[1]) + angle_height + count_height + 40), color, txt_color
|
||||
)
|
||||
|
||||
def plot_distance_and_line(
|
||||
self,
|
||||
pixels_distance: float,
|
||||
centroids: list[tuple[int, int]],
|
||||
line_color: tuple[int, int, int] = (104, 31, 17),
|
||||
centroid_color: tuple[int, int, int] = (255, 0, 255),
|
||||
):
|
||||
"""
|
||||
Plot the distance and line between two centroids on the frame.
|
||||
|
||||
Args:
|
||||
pixels_distance (float): Pixels distance between two bbox centroids.
|
||||
centroids (list[tuple[int, int]]): Bounding box centroids data.
|
||||
line_color (tuple[int, int, int]): Distance line color.
|
||||
centroid_color (tuple[int, int, int]): Bounding box centroid color.
|
||||
"""
|
||||
# Get the text size
|
||||
text = f"Pixels Distance: {pixels_distance:.2f}"
|
||||
(text_width_m, text_height_m), _ = cv2.getTextSize(text, 0, self.sf, self.tf)
|
||||
|
||||
# Define corners with 10-pixel margin and draw rectangle
|
||||
cv2.rectangle(self.im, (15, 25), (15 + text_width_m + 20, 25 + text_height_m + 20), line_color, -1)
|
||||
|
||||
# Calculate the position for the text with a 10-pixel margin and draw text
|
||||
text_position = (25, 25 + text_height_m + 10)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
text,
|
||||
text_position,
|
||||
0,
|
||||
self.sf,
|
||||
(255, 255, 255),
|
||||
self.tf,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
cv2.line(self.im, centroids[0], centroids[1], line_color, 3)
|
||||
cv2.circle(self.im, centroids[0], 6, centroid_color, -1)
|
||||
cv2.circle(self.im, centroids[1], 6, centroid_color, -1)
|
||||
|
||||
def display_objects_labels(
|
||||
self,
|
||||
im0: np.ndarray,
|
||||
text: str,
|
||||
txt_color: tuple[int, int, int],
|
||||
bg_color: tuple[int, int, int],
|
||||
x_center: float,
|
||||
y_center: float,
|
||||
margin: int,
|
||||
):
|
||||
"""
|
||||
Display the bounding boxes labels in parking management app.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Inference image.
|
||||
text (str): Object/class name.
|
||||
txt_color (tuple[int, int, int]): Display color for text foreground.
|
||||
bg_color (tuple[int, int, int]): Display color for text background.
|
||||
x_center (float): The x position center point for bounding box.
|
||||
y_center (float): The y position center point for bounding box.
|
||||
margin (int): The gap between text and rectangle for better display.
|
||||
"""
|
||||
text_size = cv2.getTextSize(text, 0, fontScale=self.sf, thickness=self.tf)[0]
|
||||
text_x = x_center - text_size[0] // 2
|
||||
text_y = y_center + text_size[1] // 2
|
||||
|
||||
rect_x1 = text_x - margin
|
||||
rect_y1 = text_y - text_size[1] - margin
|
||||
rect_x2 = text_x + text_size[0] + margin
|
||||
rect_y2 = text_y + margin
|
||||
cv2.rectangle(
|
||||
im0,
|
||||
(int(rect_x1), int(rect_y1)),
|
||||
(int(rect_x2), int(rect_y2)),
|
||||
tuple(map(int, bg_color)), # Ensure color values are int
|
||||
-1,
|
||||
)
|
||||
|
||||
cv2.putText(
|
||||
im0,
|
||||
text,
|
||||
(int(text_x), int(text_y)),
|
||||
0,
|
||||
self.sf,
|
||||
tuple(map(int, txt_color)), # Ensure color values are int
|
||||
self.tf,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
def sweep_annotator(
|
||||
self,
|
||||
line_x: int = 0,
|
||||
line_y: int = 0,
|
||||
label: str | None = None,
|
||||
color: tuple[int, int, int] = (221, 0, 186),
|
||||
txt_color: tuple[int, int, int] = (255, 255, 255),
|
||||
):
|
||||
"""
|
||||
Draw a sweep annotation line and an optional label.
|
||||
|
||||
Args:
|
||||
line_x (int): The x-coordinate of the sweep line.
|
||||
line_y (int): The y-coordinate limit of the sweep line.
|
||||
label (str, optional): Text label to be drawn in center of sweep line. If None, no label is drawn.
|
||||
color (tuple[int, int, int]): RGB color for the line and label background.
|
||||
txt_color (tuple[int, int, int]): RGB color for the label text.
|
||||
"""
|
||||
# Draw the sweep line
|
||||
cv2.line(self.im, (line_x, 0), (line_x, line_y), color, self.tf * 2)
|
||||
|
||||
# Draw label, if provided
|
||||
if label:
|
||||
(text_width, text_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf, self.tf)
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
(line_x - text_width // 2 - 10, line_y // 2 - text_height // 2 - 10),
|
||||
(line_x + text_width // 2 + 10, line_y // 2 + text_height // 2 + 10),
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
cv2.putText(
|
||||
self.im,
|
||||
label,
|
||||
(line_x - text_width // 2, line_y // 2 + text_height // 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
self.sf,
|
||||
txt_color,
|
||||
self.tf,
|
||||
)
|
||||
|
||||
def visioneye(
|
||||
self,
|
||||
box: list[float],
|
||||
center_point: tuple[int, int],
|
||||
color: tuple[int, int, int] = (235, 219, 11),
|
||||
pin_color: tuple[int, int, int] = (255, 0, 255),
|
||||
):
|
||||
"""
|
||||
Perform pinpoint human-vision eye mapping and plotting.
|
||||
|
||||
Args:
|
||||
box (list[float]): Bounding box coordinates in format [x1, y1, x2, y2].
|
||||
center_point (tuple[int, int]): Center point for vision eye view.
|
||||
color (tuple[int, int, int]): Object centroid and line color.
|
||||
pin_color (tuple[int, int, int]): Visioneye point color.
|
||||
"""
|
||||
center_bbox = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2)
|
||||
cv2.circle(self.im, center_point, self.tf * 2, pin_color, -1)
|
||||
cv2.circle(self.im, center_bbox, self.tf * 2, color, -1)
|
||||
cv2.line(self.im, center_point, center_bbox, color, self.tf)
|
||||
|
||||
def adaptive_label(
|
||||
self,
|
||||
box: tuple[float, float, float, float],
|
||||
label: str = "",
|
||||
color: tuple[int, int, int] = (128, 128, 128),
|
||||
txt_color: tuple[int, int, int] = (255, 255, 255),
|
||||
shape: str = "rect",
|
||||
margin: int = 5,
|
||||
):
|
||||
"""
|
||||
Draw a label with a background rectangle or circle centered within a given bounding box.
|
||||
|
||||
Args:
|
||||
box (tuple[float, float, float, float]): The bounding box coordinates (x1, y1, x2, y2).
|
||||
label (str): The text label to be displayed.
|
||||
color (tuple[int, int, int]): The background color of the rectangle (B, G, R).
|
||||
txt_color (tuple[int, int, int]): The color of the text (R, G, B).
|
||||
shape (str): The shape of the label i.e "circle" or "rect"
|
||||
margin (int): The margin between the text and the rectangle border.
|
||||
"""
|
||||
if shape == "circle" and len(label) > 3:
|
||||
LOGGER.warning(f"Length of label is {len(label)}, only first 3 letters will be used for circle annotation.")
|
||||
label = label[:3]
|
||||
|
||||
x_center, y_center = int((box[0] + box[2]) / 2), int((box[1] + box[3]) / 2) # Calculate center of the bbox
|
||||
text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, self.sf - 0.15, self.tf)[0] # Get size of the text
|
||||
text_x, text_y = x_center - text_size[0] // 2, y_center + text_size[1] // 2 # Calculate top-left corner of text
|
||||
|
||||
if shape == "circle":
|
||||
cv2.circle(
|
||||
self.im,
|
||||
(x_center, y_center),
|
||||
int(((text_size[0] ** 2 + text_size[1] ** 2) ** 0.5) / 2) + margin, # Calculate the radius
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
else:
|
||||
cv2.rectangle(
|
||||
self.im,
|
||||
(text_x - margin, text_y - text_size[1] - margin), # Calculate coordinates of the rectangle
|
||||
(text_x + text_size[0] + margin, text_y + margin), # Calculate coordinates of the rectangle
|
||||
color,
|
||||
-1,
|
||||
)
|
||||
|
||||
# Draw the text on top of the rectangle
|
||||
cv2.putText(
|
||||
self.im,
|
||||
label,
|
||||
(text_x, text_y), # Calculate top-left corner of the text
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
self.sf - 0.15,
|
||||
self.get_txt_color(color, txt_color),
|
||||
self.tf,
|
||||
lineType=cv2.LINE_AA,
|
||||
)
|
||||
|
||||
|
||||
class SolutionResults:
|
||||
"""
|
||||
A class to encapsulate the results of Ultralytics Solutions.
|
||||
|
||||
This class is designed to store and manage various outputs generated by the solution pipeline, including counts,
|
||||
angles, workout stages, and other analytics data. It provides a structured way to access and manipulate results
|
||||
from different computer vision solutions such as object counting, pose estimation, and tracking analytics.
|
||||
|
||||
Attributes:
|
||||
plot_im (np.ndarray): Processed image with counts, blurred, or other effects from solutions.
|
||||
in_count (int): The total number of "in" counts in a video stream.
|
||||
out_count (int): The total number of "out" counts in a video stream.
|
||||
classwise_count (dict[str, int]): A dictionary containing counts of objects categorized by class.
|
||||
queue_count (int): The count of objects in a queue or waiting area.
|
||||
workout_count (int): The count of workout repetitions.
|
||||
workout_angle (float): The angle calculated during a workout exercise.
|
||||
workout_stage (str): The current stage of the workout.
|
||||
pixels_distance (float): The calculated distance in pixels between two points or objects.
|
||||
available_slots (int): The number of available slots in a monitored area.
|
||||
filled_slots (int): The number of filled slots in a monitored area.
|
||||
email_sent (bool): A flag indicating whether an email notification was sent.
|
||||
total_tracks (int): The total number of tracked objects.
|
||||
region_counts (dict[str, int]): The count of objects within a specific region.
|
||||
speed_dict (dict[str, float]): A dictionary containing speed information for tracked objects.
|
||||
total_crop_objects (int): Total number of cropped objects using ObjectCropper class.
|
||||
speed (dict[str, float]): Performance timing information for tracking and solution processing.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Initialize a SolutionResults object with default or user-specified values.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Optional arguments to override default attribute values.
|
||||
"""
|
||||
self.plot_im = None
|
||||
self.in_count = 0
|
||||
self.out_count = 0
|
||||
self.classwise_count = {}
|
||||
self.queue_count = 0
|
||||
self.workout_count = 0
|
||||
self.workout_angle = 0.0
|
||||
self.workout_stage = None
|
||||
self.pixels_distance = 0.0
|
||||
self.available_slots = 0
|
||||
self.filled_slots = 0
|
||||
self.email_sent = False
|
||||
self.total_tracks = 0
|
||||
self.region_counts = {}
|
||||
self.speed_dict = {} # for speed estimation
|
||||
self.total_crop_objects = 0
|
||||
self.speed = {}
|
||||
|
||||
# Override with user-defined values
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return a formatted string representation of the SolutionResults object.
|
||||
|
||||
Returns:
|
||||
(str): A string representation listing non-null attributes.
|
||||
"""
|
||||
attrs = {
|
||||
k: v
|
||||
for k, v in self.__dict__.items()
|
||||
if k != "plot_im" and v not in [None, {}, 0, 0.0, False] # Exclude `plot_im` explicitly
|
||||
}
|
||||
return ", ".join(f"{k}={v}" for k, v in attrs.items())
|
||||
117
ultralytics/solutions/speed_estimation.py
Normal file
117
ultralytics/solutions/speed_estimation.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from collections import deque
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class SpeedEstimator(BaseSolution):
|
||||
"""
|
||||
A class to estimate the speed of objects in a real-time video stream based on their tracks.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for estimating object speeds using
|
||||
tracking data in video streams. Speed is calculated based on pixel displacement over time and converted
|
||||
to real-world units using a configurable meters-per-pixel scale factor.
|
||||
|
||||
Attributes:
|
||||
fps (float): Video frame rate for time calculations.
|
||||
frame_count (int): Global frame counter for tracking temporal information.
|
||||
trk_frame_ids (dict): Maps track IDs to their first frame index.
|
||||
spd (dict): Final speed per object in km/h once locked.
|
||||
trk_hist (dict): Maps track IDs to deque of position history.
|
||||
locked_ids (set): Track IDs whose speed has been finalized.
|
||||
max_hist (int): Required frame history before computing speed.
|
||||
meter_per_pixel (float): Real-world meters represented by one pixel for scene scale conversion.
|
||||
max_speed (int): Maximum allowed object speed; values above this will be capped.
|
||||
|
||||
Methods:
|
||||
process: Process input frames to estimate object speeds based on tracking data.
|
||||
store_tracking_history: Store the tracking history for an object.
|
||||
extract_tracks: Extract tracks from the current frame.
|
||||
display_output: Display the output with annotations.
|
||||
|
||||
Examples:
|
||||
Initialize speed estimator and process a frame
|
||||
>>> estimator = SpeedEstimator(meter_per_pixel=0.04, max_speed=120)
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = estimator.process(frame)
|
||||
>>> cv2.imshow("Speed Estimation", results.plot_im)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the SpeedEstimator object with speed estimation parameters and data structures.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.fps = self.CFG["fps"] # Video frame rate for time calculations
|
||||
self.frame_count = 0 # Global frame counter
|
||||
self.trk_frame_ids = {} # Track ID → first frame index
|
||||
self.spd = {} # Final speed per object (km/h), once locked
|
||||
self.trk_hist = {} # Track ID → deque of (time, position)
|
||||
self.locked_ids = set() # Track IDs whose speed has been finalized
|
||||
self.max_hist = self.CFG["max_hist"] # Required frame history before computing speed
|
||||
self.meter_per_pixel = self.CFG["meter_per_pixel"] # Scene scale, depends on camera details
|
||||
self.max_speed = self.CFG["max_speed"] # Maximum speed adjustment
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Process an input frame to estimate object speeds based on tracking data.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): Input image for processing with shape (H, W, C) for RGB images.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im` and `total_tracks` (number of tracked objects).
|
||||
|
||||
Examples:
|
||||
Process a frame for speed estimation
|
||||
>>> estimator = SpeedEstimator()
|
||||
>>> image = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8)
|
||||
>>> results = estimator.process(image)
|
||||
"""
|
||||
self.frame_count += 1
|
||||
self.extract_tracks(im0)
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width)
|
||||
|
||||
for box, track_id, _, _ in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||||
self.store_tracking_history(track_id, box)
|
||||
|
||||
if track_id not in self.trk_hist: # Initialize history if new track found
|
||||
self.trk_hist[track_id] = deque(maxlen=self.max_hist)
|
||||
self.trk_frame_ids[track_id] = self.frame_count
|
||||
|
||||
if track_id not in self.locked_ids: # Update history until speed is locked
|
||||
trk_hist = self.trk_hist[track_id]
|
||||
trk_hist.append(self.track_line[-1])
|
||||
|
||||
# Compute and lock speed once enough history is collected
|
||||
if len(trk_hist) == self.max_hist:
|
||||
p0, p1 = trk_hist[0], trk_hist[-1] # First and last points of track
|
||||
dt = (self.frame_count - self.trk_frame_ids[track_id]) / self.fps # Time in seconds
|
||||
if dt > 0:
|
||||
dx, dy = p1[0] - p0[0], p1[1] - p0[1] # Pixel displacement
|
||||
pixel_distance = sqrt(dx * dx + dy * dy) # Calculate pixel distance
|
||||
meters = pixel_distance * self.meter_per_pixel # Convert to meters
|
||||
self.spd[track_id] = int(
|
||||
min((meters / dt) * 3.6, self.max_speed)
|
||||
) # Convert to km/h and store final speed
|
||||
self.locked_ids.add(track_id) # Prevent further updates
|
||||
self.trk_hist.pop(track_id, None) # Free memory
|
||||
self.trk_frame_ids.pop(track_id, None) # Remove frame start reference
|
||||
|
||||
if track_id in self.spd:
|
||||
speed_label = f"{self.spd[track_id]} km/h"
|
||||
annotator.box_label(box, label=speed_label, color=colors(track_id, True)) # Draw bounding box
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return results with processed image and tracking summary
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
||||
262
ultralytics/solutions/streamlit_inference.py
Normal file
262
ultralytics/solutions/streamlit_inference.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
from ultralytics import YOLO
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.downloads import GITHUB_ASSETS_STEMS
|
||||
|
||||
torch.classes.__path__ = [] # Torch module __path__._path issue: https://github.com/datalab-to/marker/issues/442
|
||||
|
||||
|
||||
class Inference:
|
||||
"""
|
||||
A class to perform object detection, image classification, image segmentation and pose estimation inference.
|
||||
|
||||
This class provides functionalities for loading models, configuring settings, uploading video files, and performing
|
||||
real-time inference using Streamlit and Ultralytics YOLO models.
|
||||
|
||||
Attributes:
|
||||
st (module): Streamlit module for UI creation.
|
||||
temp_dict (dict): Temporary dictionary to store the model path and other configuration.
|
||||
model_path (str): Path to the loaded model.
|
||||
model (YOLO): The YOLO model instance.
|
||||
source (str): Selected video source (webcam or video file).
|
||||
enable_trk (bool): Enable tracking option.
|
||||
conf (float): Confidence threshold for detection.
|
||||
iou (float): IoU threshold for non-maximum suppression.
|
||||
org_frame (Any): Container for the original frame to be displayed.
|
||||
ann_frame (Any): Container for the annotated frame to be displayed.
|
||||
vid_file_name (str | int): Name of the uploaded video file or webcam index.
|
||||
selected_ind (list[int]): List of selected class indices for detection.
|
||||
|
||||
Methods:
|
||||
web_ui: Set up the Streamlit web interface with custom HTML elements.
|
||||
sidebar: Configure the Streamlit sidebar for model and inference settings.
|
||||
source_upload: Handle video file uploads through the Streamlit interface.
|
||||
configure: Configure the model and load selected classes for inference.
|
||||
inference: Perform real-time object detection inference.
|
||||
|
||||
Examples:
|
||||
Create an Inference instance with a custom model
|
||||
>>> inf = Inference(model="path/to/model.pt")
|
||||
>>> inf.inference()
|
||||
|
||||
Create an Inference instance with default settings
|
||||
>>> inf = Inference()
|
||||
>>> inf.inference()
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the Inference class, checking Streamlit requirements and setting up the model path.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments for model configuration.
|
||||
"""
|
||||
check_requirements("streamlit>=1.29.0") # scope imports for faster ultralytics package load speeds
|
||||
import streamlit as st
|
||||
|
||||
self.st = st # Reference to the Streamlit module
|
||||
self.source = None # Video source selection (webcam or video file)
|
||||
self.img_file_names = [] # List of image file names
|
||||
self.enable_trk = False # Flag to toggle object tracking
|
||||
self.conf = 0.25 # Confidence threshold for detection
|
||||
self.iou = 0.45 # Intersection-over-Union (IoU) threshold for non-maximum suppression
|
||||
self.org_frame = None # Container for the original frame display
|
||||
self.ann_frame = None # Container for the annotated frame display
|
||||
self.vid_file_name = None # Video file name or webcam index
|
||||
self.selected_ind: list[int] = [] # List of selected class indices for detection
|
||||
self.model = None # YOLO model instance
|
||||
|
||||
self.temp_dict = {"model": None, **kwargs}
|
||||
self.model_path = None # Model file path
|
||||
if self.temp_dict["model"] is not None:
|
||||
self.model_path = self.temp_dict["model"]
|
||||
|
||||
LOGGER.info(f"Ultralytics Solutions: ✅ {self.temp_dict}")
|
||||
|
||||
def web_ui(self) -> None:
|
||||
"""Set up the Streamlit web interface with custom HTML elements."""
|
||||
menu_style_cfg = """<style>MainMenu {visibility: hidden;}</style>""" # Hide main menu style
|
||||
|
||||
# Main title of streamlit application
|
||||
main_title_cfg = """<div><h1 style="color:#111F68; text-align:center; font-size:40px; margin-top:-50px;
|
||||
font-family: 'Archivo', sans-serif; margin-bottom:20px;">Ultralytics YOLO Streamlit Application</h1></div>"""
|
||||
|
||||
# Subtitle of streamlit application
|
||||
sub_title_cfg = """<div><h5 style="color:#042AFF; text-align:center; font-family: 'Archivo', sans-serif;
|
||||
margin-top:-15px; margin-bottom:50px;">Experience real-time object detection on your webcam, videos, and images
|
||||
with the power of Ultralytics YOLO! 🚀</h5></div>"""
|
||||
|
||||
# Set html page configuration and append custom HTML
|
||||
self.st.set_page_config(page_title="Ultralytics Streamlit App", layout="wide")
|
||||
self.st.markdown(menu_style_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(main_title_cfg, unsafe_allow_html=True)
|
||||
self.st.markdown(sub_title_cfg, unsafe_allow_html=True)
|
||||
|
||||
def sidebar(self) -> None:
|
||||
"""Configure the Streamlit sidebar for model and inference settings."""
|
||||
with self.st.sidebar: # Add Ultralytics LOGO
|
||||
logo = "https://raw.githubusercontent.com/ultralytics/assets/main/logo/Ultralytics_Logotype_Original.svg"
|
||||
self.st.image(logo, width=250)
|
||||
|
||||
self.st.sidebar.title("User Configuration") # Add elements to vertical setting menu
|
||||
self.source = self.st.sidebar.selectbox(
|
||||
"Source",
|
||||
("webcam", "video", "image"),
|
||||
) # Add source selection dropdown
|
||||
if self.source in ["webcam", "video"]:
|
||||
self.enable_trk = self.st.sidebar.radio("Enable Tracking", ("Yes", "No")) == "Yes" # Enable object tracking
|
||||
self.conf = float(
|
||||
self.st.sidebar.slider("Confidence Threshold", 0.0, 1.0, self.conf, 0.01)
|
||||
) # Slider for confidence
|
||||
self.iou = float(self.st.sidebar.slider("IoU Threshold", 0.0, 1.0, self.iou, 0.01)) # Slider for NMS threshold
|
||||
|
||||
if self.source != "image": # Only create columns for video/webcam
|
||||
col1, col2 = self.st.columns(2) # Create two columns for displaying frames
|
||||
self.org_frame = col1.empty() # Container for original frame
|
||||
self.ann_frame = col2.empty() # Container for annotated frame
|
||||
|
||||
def source_upload(self) -> None:
|
||||
"""Handle video file uploads through the Streamlit interface."""
|
||||
from ultralytics.data.utils import IMG_FORMATS, VID_FORMATS # scope import
|
||||
|
||||
self.vid_file_name = ""
|
||||
if self.source == "video":
|
||||
vid_file = self.st.sidebar.file_uploader("Upload Video File", type=VID_FORMATS)
|
||||
if vid_file is not None:
|
||||
g = io.BytesIO(vid_file.read()) # BytesIO Object
|
||||
with open("ultralytics.mp4", "wb") as out: # Open temporary file as bytes
|
||||
out.write(g.read()) # Read bytes into file
|
||||
self.vid_file_name = "ultralytics.mp4"
|
||||
elif self.source == "webcam":
|
||||
self.vid_file_name = 0 # Use webcam index 0
|
||||
elif self.source == "image":
|
||||
import tempfile # scope import
|
||||
|
||||
if imgfiles := self.st.sidebar.file_uploader(
|
||||
"Upload Image Files", type=IMG_FORMATS, accept_multiple_files=True
|
||||
):
|
||||
for imgfile in imgfiles: # Save each uploaded image to a temporary file
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=f".{imgfile.name.split('.')[-1]}") as tf:
|
||||
tf.write(imgfile.read())
|
||||
self.img_file_names.append({"path": tf.name, "name": imgfile.name})
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure the model and load selected classes for inference."""
|
||||
# Add dropdown menu for model selection
|
||||
M_ORD, T_ORD = ["yolo11n", "yolo11s", "yolo11m", "yolo11l", "yolo11x"], ["", "-seg", "-pose", "-obb", "-cls"]
|
||||
available_models = sorted(
|
||||
[
|
||||
x.replace("yolo", "YOLO")
|
||||
for x in GITHUB_ASSETS_STEMS
|
||||
if any(x.startswith(b) for b in M_ORD) and "grayscale" not in x
|
||||
],
|
||||
key=lambda x: (M_ORD.index(x[:7].lower()), T_ORD.index(x[7:].lower() or "")),
|
||||
)
|
||||
if self.model_path: # Insert user provided custom model in available_models
|
||||
available_models.insert(0, self.model_path)
|
||||
selected_model = self.st.sidebar.selectbox("Model", available_models)
|
||||
|
||||
with self.st.spinner("Model is downloading..."):
|
||||
if selected_model.endswith((".pt", ".onnx", ".torchscript", ".mlpackage", ".engine")) or any(
|
||||
fmt in selected_model for fmt in ("openvino_model", "rknn_model")
|
||||
):
|
||||
model_path = selected_model
|
||||
else:
|
||||
model_path = f"{selected_model.lower()}.pt" # Default to .pt if no model provided during function call.
|
||||
self.model = YOLO(model_path) # Load the YOLO model
|
||||
class_names = list(self.model.names.values()) # Convert dictionary to list of class names
|
||||
self.st.success("Model loaded successfully!")
|
||||
|
||||
# Multiselect box with class names and get indices of selected classes
|
||||
selected_classes = self.st.sidebar.multiselect("Classes", class_names, default=class_names[:3])
|
||||
self.selected_ind = [class_names.index(option) for option in selected_classes]
|
||||
|
||||
if not isinstance(self.selected_ind, list): # Ensure selected_options is a list
|
||||
self.selected_ind = list(self.selected_ind)
|
||||
|
||||
def image_inference(self) -> None:
|
||||
"""Perform inference on uploaded images."""
|
||||
for img_info in self.img_file_names:
|
||||
img_path = img_info["path"]
|
||||
image = cv2.imread(img_path) # Load and display the original image
|
||||
if image is not None:
|
||||
self.st.markdown(f"#### Processed: {img_info['name']}")
|
||||
col1, col2 = self.st.columns(2)
|
||||
with col1:
|
||||
self.st.image(image, channels="BGR", caption="Original Image")
|
||||
results = self.model(image, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||
annotated_image = results[0].plot()
|
||||
with col2:
|
||||
self.st.image(annotated_image, channels="BGR", caption="Predicted Image")
|
||||
try: # Clean up temporary file
|
||||
os.unlink(img_path)
|
||||
except FileNotFoundError:
|
||||
pass # File doesn't exist, ignore
|
||||
else:
|
||||
self.st.error("Could not load the uploaded image.")
|
||||
|
||||
def inference(self) -> None:
|
||||
"""Perform real-time object detection inference on video or webcam feed."""
|
||||
self.web_ui() # Initialize the web interface
|
||||
self.sidebar() # Create the sidebar
|
||||
self.source_upload() # Upload the video source
|
||||
self.configure() # Configure the app
|
||||
|
||||
if self.st.sidebar.button("Start"):
|
||||
if self.source == "image":
|
||||
if self.img_file_names:
|
||||
self.image_inference()
|
||||
else:
|
||||
self.st.info("Please upload an image file to perform inference.")
|
||||
return
|
||||
|
||||
stop_button = self.st.sidebar.button("Stop") # Button to stop the inference
|
||||
cap = cv2.VideoCapture(self.vid_file_name) # Capture the video
|
||||
if not cap.isOpened():
|
||||
self.st.error("Could not open webcam or video source.")
|
||||
return
|
||||
|
||||
while cap.isOpened():
|
||||
success, frame = cap.read()
|
||||
if not success:
|
||||
self.st.warning("Failed to read frame from webcam. Please verify the webcam is connected properly.")
|
||||
break
|
||||
|
||||
# Process frame with model
|
||||
if self.enable_trk:
|
||||
results = self.model.track(
|
||||
frame, conf=self.conf, iou=self.iou, classes=self.selected_ind, persist=True
|
||||
)
|
||||
else:
|
||||
results = self.model(frame, conf=self.conf, iou=self.iou, classes=self.selected_ind)
|
||||
|
||||
annotated_frame = results[0].plot() # Add annotations on frame
|
||||
|
||||
if stop_button:
|
||||
cap.release() # Release the capture
|
||||
self.st.stop() # Stop streamlit app
|
||||
|
||||
self.org_frame.image(frame, channels="BGR", caption="Original Frame") # Display original frame
|
||||
self.ann_frame.image(annotated_frame, channels="BGR", caption="Predicted Frame") # Display processed
|
||||
|
||||
cap.release() # Release the capture
|
||||
cv2.destroyAllWindows() # Destroy all OpenCV windows
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys # Import the sys module for accessing command-line arguments
|
||||
|
||||
# Check if a model name is provided as a command-line argument
|
||||
args = len(sys.argv)
|
||||
model = sys.argv[1] if args > 1 else None # Assign first argument as the model name if provided
|
||||
# Create an instance of the Inference class and run inference
|
||||
Inference(model=model).inference()
|
||||
167
ultralytics/solutions/templates/similarity-search.html
Normal file
167
ultralytics/solutions/templates/similarity-search.html
Normal file
@@ -0,0 +1,167 @@
|
||||
<!-- Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license -->
|
||||
|
||||
<!--Similarity search webpage-->
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Semantic Image Search</title>
|
||||
<link
|
||||
href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap"
|
||||
rel="stylesheet"
|
||||
/>
|
||||
<style>
|
||||
body {
|
||||
background: linear-gradient(135deg, #f0f4ff, #f9fbff);
|
||||
font-family: "Inter", sans-serif;
|
||||
color: #111e68;
|
||||
padding: 2rem;
|
||||
margin: 0;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
h1 {
|
||||
text-align: center;
|
||||
margin-bottom: 2rem;
|
||||
font-size: 2.5rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
form {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
margin-bottom: 3rem;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 300px;
|
||||
padding: 0.75rem 1rem;
|
||||
font-size: 1rem;
|
||||
border-radius: 10px;
|
||||
border: 1px solid #ccc;
|
||||
box-shadow: 0 2px 6px rgba(0, 0, 0, 0.05);
|
||||
transition: box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
input[type="text"]:focus {
|
||||
outline: none;
|
||||
box-shadow: 0 0 0 3px rgba(17, 30, 104, 0.2);
|
||||
}
|
||||
|
||||
button {
|
||||
background-color: #111e68;
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
font-size: 1rem;
|
||||
padding: 0.75rem 1.5rem;
|
||||
border-radius: 10px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
transition:
|
||||
background-color 0.3s ease,
|
||||
transform 0.2s ease;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
background-color: #1f2e9f;
|
||||
transform: translateY(-2px);
|
||||
}
|
||||
|
||||
.grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
|
||||
gap: 1.5rem;
|
||||
max-width: 1600px;
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
overflow: hidden;
|
||||
box-shadow: 0 6px 14px rgba(0, 0, 0, 0.08);
|
||||
transition:
|
||||
transform 0.3s ease,
|
||||
box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
.card:hover {
|
||||
transform: translateY(-6px);
|
||||
box-shadow: 0 10px 20px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.card img {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
display: block;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<script>
|
||||
function filterResults(k) {
|
||||
const cards = document.querySelectorAll(".grid .card");
|
||||
cards.forEach((card, idx) => {
|
||||
card.style.display = idx < k ? "block" : "none";
|
||||
});
|
||||
const buttons = document.querySelectorAll(".topk-btn");
|
||||
buttons.forEach((btn) => btn.classList.remove("active"));
|
||||
event.target.classList.add("active");
|
||||
}
|
||||
document.addEventListener("DOMContentLoaded", () => {
|
||||
filterResults(10);
|
||||
});
|
||||
</script>
|
||||
<body>
|
||||
<div style="text-align: center; margin-bottom: 1rem">
|
||||
<img
|
||||
src="https://raw.githubusercontent.com/ultralytics/assets/main/logo/favicon.png"
|
||||
alt="Ultralytics Logo"
|
||||
style="height: 40px"
|
||||
/>
|
||||
</div>
|
||||
<h1>Semantic Image Search with AI</h1>
|
||||
|
||||
<!-- Search box -->
|
||||
<form method="POST">
|
||||
<input
|
||||
type="text"
|
||||
name="query"
|
||||
placeholder="Describe the scene (e.g., man walking)"
|
||||
value="{{ request.form['query'] }}"
|
||||
required
|
||||
/>
|
||||
<button type="submit">Search</button>
|
||||
{% if results %}
|
||||
<div class="top-k-buttons">
|
||||
<button type="button" class="topk-btn" onclick="filterResults(5)">
|
||||
Top 5
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
class="topk-btn active"
|
||||
onclick="filterResults(10)"
|
||||
>
|
||||
Top 10
|
||||
</button>
|
||||
<button type="button" class="topk-btn" onclick="filterResults(30)">
|
||||
Top 30
|
||||
</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
</form>
|
||||
|
||||
<!-- Search results grid -->
|
||||
<div class="grid">
|
||||
{% for img in results %}
|
||||
<div class="card">
|
||||
<img src="{{ url_for('static', filename=img) }}" alt="Result Image" />
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
91
ultralytics/solutions/trackzone.py
Normal file
91
ultralytics/solutions/trackzone.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class TrackZone(BaseSolution):
|
||||
"""
|
||||
A class to manage region-based object tracking in a video stream.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for tracking objects within a specific region
|
||||
defined by a polygonal area. Objects outside the region are excluded from tracking.
|
||||
|
||||
Attributes:
|
||||
region (np.ndarray): The polygonal region for tracking, represented as a convex hull of points.
|
||||
line_width (int): Width of the lines used for drawing bounding boxes and region boundaries.
|
||||
names (list[str]): List of class names that the model can detect.
|
||||
boxes (list[np.ndarray]): Bounding boxes of tracked objects.
|
||||
track_ids (list[int]): Unique identifiers for each tracked object.
|
||||
clss (list[int]): Class indices of tracked objects.
|
||||
|
||||
Methods:
|
||||
process: Process each frame of the video, applying region-based tracking.
|
||||
extract_tracks: Extract tracking information from the input frame.
|
||||
display_output: Display the processed output.
|
||||
|
||||
Examples:
|
||||
>>> tracker = TrackZone()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = tracker.process(frame)
|
||||
>>> cv2.imshow("Tracked Frame", results.plot_im)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the TrackZone class for tracking objects within a defined region in video streams.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Additional keyword arguments passed to the parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
default_region = [(75, 75), (565, 75), (565, 285), (75, 285)]
|
||||
self.region = cv2.convexHull(np.array(self.region or default_region, dtype=np.int32))
|
||||
self.mask = None
|
||||
|
||||
def process(self, im0: np.ndarray) -> SolutionResults:
|
||||
"""
|
||||
Process the input frame to track objects within a defined region.
|
||||
|
||||
This method initializes the annotator, creates a mask for the specified region, extracts tracks
|
||||
only from the masked area, and updates tracking information. Objects outside the region are ignored.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image or frame to be processed.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Contains processed image `plot_im` and `total_tracks` (int) representing the
|
||||
total number of tracked objects within the defined region.
|
||||
|
||||
Examples:
|
||||
>>> tracker = TrackZone()
|
||||
>>> frame = cv2.imread("path/to/image.jpg")
|
||||
>>> results = tracker.process(frame)
|
||||
"""
|
||||
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
||||
|
||||
if self.mask is None: # Create a mask for the region
|
||||
self.mask = np.zeros_like(im0[:, :, 0])
|
||||
cv2.fillPoly(self.mask, [self.region], 255)
|
||||
masked_frame = cv2.bitwise_and(im0, im0, mask=self.mask)
|
||||
self.extract_tracks(masked_frame)
|
||||
|
||||
# Draw the region boundary
|
||||
cv2.polylines(im0, [self.region], isClosed=True, color=(255, 255, 255), thickness=self.line_width * 2)
|
||||
|
||||
# Iterate over boxes, track ids, classes indexes list and draw bounding boxes
|
||||
for box, track_id, cls, conf in zip(self.boxes, self.track_ids, self.clss, self.confs):
|
||||
annotator.box_label(
|
||||
box, label=self.adjust_box_label(cls, conf, track_id=track_id), color=colors(track_id, True)
|
||||
)
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display output with base class function
|
||||
|
||||
# Return a SolutionResults
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
||||
70
ultralytics/solutions/vision_eye.py
Normal file
70
ultralytics/solutions/vision_eye.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
||||
from ultralytics.utils.plotting import colors
|
||||
|
||||
|
||||
class VisionEye(BaseSolution):
|
||||
"""
|
||||
A class to manage object detection and vision mapping in images or video streams.
|
||||
|
||||
This class extends the BaseSolution class and provides functionality for detecting objects,
|
||||
mapping vision points, and annotating results with bounding boxes and labels.
|
||||
|
||||
Attributes:
|
||||
vision_point (tuple[int, int]): Coordinates (x, y) where vision will view objects and draw tracks.
|
||||
|
||||
Methods:
|
||||
process: Process the input image to detect objects, annotate them, and apply vision mapping.
|
||||
|
||||
Examples:
|
||||
>>> vision_eye = VisionEye()
|
||||
>>> frame = cv2.imread("frame.jpg")
|
||||
>>> results = vision_eye.process(frame)
|
||||
>>> print(f"Total detected instances: {results.total_tracks}")
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""
|
||||
Initialize the VisionEye class for detecting objects and applying vision mapping.
|
||||
|
||||
Args:
|
||||
**kwargs (Any): Keyword arguments passed to the parent class and for configuring vision_point.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
# Set the vision point where the system will view objects and draw tracks
|
||||
self.vision_point = self.CFG["vision_point"]
|
||||
|
||||
def process(self, im0) -> SolutionResults:
|
||||
"""
|
||||
Perform object detection, vision mapping, and annotation on the input image.
|
||||
|
||||
Args:
|
||||
im0 (np.ndarray): The input image for detection and annotation.
|
||||
|
||||
Returns:
|
||||
(SolutionResults): Object containing the annotated image and tracking statistics.
|
||||
- plot_im: Annotated output image with bounding boxes and vision mapping
|
||||
- total_tracks: Number of tracked objects in the frame
|
||||
|
||||
Examples:
|
||||
>>> vision_eye = VisionEye()
|
||||
>>> frame = cv2.imread("image.jpg")
|
||||
>>> results = vision_eye.process(frame)
|
||||
>>> print(f"Detected {results.total_tracks} objects")
|
||||
"""
|
||||
self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
|
||||
annotator = SolutionAnnotator(im0, self.line_width)
|
||||
|
||||
for cls, t_id, box, conf in zip(self.clss, self.track_ids, self.boxes, self.confs):
|
||||
# Annotate the image with bounding boxes, labels, and vision mapping
|
||||
annotator.box_label(box, label=self.adjust_box_label(cls, conf, t_id), color=colors(int(t_id), True))
|
||||
annotator.visioneye(box, self.vision_point)
|
||||
|
||||
plot_im = annotator.result()
|
||||
self.display_output(plot_im) # Display the annotated output using the base class function
|
||||
|
||||
# Return a SolutionResults object with the annotated image and tracking statistics
|
||||
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids))
|
||||
Reference in New Issue
Block a user