init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .fastsam import FastSAM
from .nas import NAS
from .rtdetr import RTDETR
from .sam import SAM
from .yolo import YOLO, YOLOE, YOLOWorld
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import FastSAM
from .predict import FastSAMPredictor
from .val import FastSAMValidator
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"

View File

@@ -0,0 +1,81 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(Model):
"""
FastSAM model interface for segment anything tasks.
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
Attributes:
model (str): Path to the pre-trained FastSAM model file.
task (str): The task type, set to "segment" for FastSAM models.
Methods:
predict: Perform segmentation prediction on image or video source with optional prompts.
task_map: Returns mapping of segment task to predictor and validator classes.
Examples:
Initialize FastSAM model and run prediction
>>> from ultralytics import FastSAM
>>> model = FastSAM("FastSAM-x.pt")
>>> results = model.predict("ultralytics/assets/bus.jpg")
Run prediction with bounding box prompts
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
"""
def __init__(self, model: str = "FastSAM-x.pt"):
"""Initialize the FastSAM model with the specified pre-trained weights."""
if str(model) == "FastSAM.pt":
model = "FastSAM-x.pt"
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
super().__init__(model=model, task="segment")
def predict(
self,
source,
stream: bool = False,
bboxes: list | None = None,
points: list | None = None,
labels: list | None = None,
texts: list | None = None,
**kwargs: Any,
):
"""
Perform segmentation prediction on image or video source.
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
prompts and passes them to the parent class predict method for processing.
Args:
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
or numpy array.
stream (bool): Whether to enable real-time streaming mode for video inputs.
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
labels (list, optional): Class labels for prompted segmentation.
texts (list, optional): Text prompts for segmentation guidance.
**kwargs (Any): Additional keyword arguments passed to the predictor.
Returns:
(list): List of Results objects containing the prediction results.
"""
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
return super().predict(source, stream, prompts=prompts, **kwargs)
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}

View File

@@ -0,0 +1,181 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from PIL import Image
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils import DEFAULT_CFG, checks
from ultralytics.utils.metrics import box_iou
from ultralytics.utils.ops import scale_masks
from ultralytics.utils.torch_utils import TORCH_1_10
from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
single-class segmentation.
Attributes:
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
device (torch.device): Device on which model and tensors are processed.
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
Methods:
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
prompt: Perform image segmentation inference based on various prompt types.
set_prompts: Set prompts to be used during inference.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the FastSAMPredictor with configuration and callbacks.
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
optimized for single-class segmentation.
Args:
cfg (dict): Configuration for the predictor.
overrides (dict, optional): Configuration overrides.
_callbacks (list, optional): List of callback functions.
"""
super().__init__(cfg, overrides, _callbacks)
self.prompts = {}
def postprocess(self, preds, img, orig_imgs):
"""
Apply postprocessing to FastSAM predictions and handle prompts.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
img (torch.Tensor): Input image tensor that was fed to the model.
orig_imgs (list[np.ndarray]): Original images before preprocessing.
Returns:
(list[Results]): Processed results with prompts applied.
"""
bboxes = self.prompts.pop("bboxes", None)
points = self.prompts.pop("points", None)
labels = self.prompts.pop("labels", None)
texts = self.prompts.pop("texts", None)
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
"""
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
Args:
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
texts (str | list[str], optional): Textual prompts, a list containing string objects.
Returns:
(list[Results]): Output results filtered and determined by the provided prompts.
"""
if bboxes is None and points is None and texts is None:
return results
prompt_results = []
if not isinstance(results, list):
results = [results]
for result in results:
if len(result) == 0:
prompt_results.append(result)
continue
masks = result.masks.data
if masks.shape[1:] != result.orig_shape:
masks = scale_masks(masks[None], result.orig_shape)[0]
# bboxes prompt
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
full_mask_areas = torch.sum(masks, dim=(1, 2))
union = bbox_areas[:, None] + full_mask_areas - mask_areas
idx[torch.argmax(mask_areas / union, dim=1)] = True
if points is not None:
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
points = points[None] if points.ndim == 1 else points
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
assert len(labels) == len(points), (
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
)
point_idx = (
torch.ones(len(result), dtype=torch.bool, device=self.device)
if labels.sum() == 0 # all negative points
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
)
for point, label in zip(points, labels):
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
idx |= point_idx
if texts is not None:
if isinstance(texts, str):
texts = [texts]
crop_ims, filter_idx = [], []
for i, b in enumerate(result.boxes.xyxy.tolist()):
x1, y1, x2, y2 = (int(x) for x in b)
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
filter_idx.append(i)
continue
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
similarity = self._clip_inference(crop_ims, texts)
text_idx = torch.argmax(similarity, dim=-1) # (M, )
if len(filter_idx):
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
idx[text_idx] = True
prompt_results.append(result[idx])
return prompt_results
def _clip_inference(self, images, texts):
"""
Perform CLIP inference to calculate similarity between images and text prompts.
Args:
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
texts (list[str]): List of prompt texts, each should be a string object.
Returns:
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
"""
try:
import clip
except ImportError:
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
tokenized_text = clip.tokenize(texts).to(self.device)
image_features = self.clip_model.encode_image(images)
text_features = self.clip_model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
return (image_features * text_features[:, None]).sum(-1) # (M, N)
def set_prompts(self, prompts):
"""Set prompts to be used during inference."""
self.prompts = prompts

View File

@@ -0,0 +1,24 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
image_shape (tuple): Image dimensions as (height, width).
threshold (int): Pixel threshold for considering a box close to the border.
Returns:
(torch.Tensor): Adjusted bounding boxes with shape (N, 4).
"""
# Image dimensions
h, w = image_shape
# Adjust boxes that are close to image borders
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes

View File

@@ -0,0 +1,40 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.segment import SegmentationValidator
class FastSAMValidator(SegmentationValidator):
"""
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
to avoid errors during validation.
Attributes:
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
save_dir (Path): The directory where validation results will be saved.
args (SimpleNamespace): Additional arguments for customization of the validation process.
_callbacks (list): List of callback functions to be invoked during validation.
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
Methods:
__init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
"""
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
args (SimpleNamespace, optional): Configuration for the validator.
_callbacks (list, optional): List of callback functions to be invoked during validation.
Notes:
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
__all__ = "NASPredictor", "NASValidator", "NAS"

Binary file not shown.

View File

@@ -0,0 +1,101 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.engine.model import Model
from ultralytics.utils import DEFAULT_CFG_DICT
from ultralytics.utils.downloads import attempt_download_asset
from ultralytics.utils.patches import torch_load
from ultralytics.utils.torch_utils import model_info
from .predict import NASPredictor
from .val import NASValidator
class NAS(Model):
"""
YOLO-NAS model for object detection.
This class provides an interface for the YOLO-NAS models and extends the `Model` class from ultralytics engine.
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
Attributes:
model (torch.nn.Module): The loaded YOLO-NAS model.
task (str): The task type for the model, defaults to 'detect'.
predictor (NASPredictor): The predictor instance for making predictions.
validator (NASValidator): The validator instance for model validation.
Methods:
info: Log model information and return model details.
Examples:
>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> results = model.predict("ultralytics/assets/bus.jpg")
Notes:
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
"""
def __init__(self, model: str = "yolo_nas_s.pt") -> None:
"""Initialize the NAS model with the provided or default model."""
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
super().__init__(model, task="detect")
def _load(self, weights: str, task=None) -> None:
"""
Load an existing NAS model weights or create a new NAS model with pretrained weights.
Args:
weights (str): Path to the model weights file or model name.
task (str, optional): Task type for the model.
"""
import super_gradients
suffix = Path(weights).suffix
if suffix == ".pt":
self.model = torch_load(attempt_download_asset(weights))
elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Override the forward method to ignore additional arguments
def new_forward(x, *args, **kwargs):
"""Ignore additional __call__ arguments."""
return self.model._original_forward(x)
self.model._original_forward = self.model.forward
self.model.forward = new_forward
# Standardize model attributes for compatibility
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = "detect" # for export()
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
self.model.eval()
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
"""
Log model information.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
Returns:
(dict[str, Any]): Model information dictionary.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}

View File

@@ -0,0 +1,58 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import ops
class NASPredictor(DetectionPredictor):
"""
Ultralytics YOLO NAS Predictor for object detection.
This class extends the DetectionPredictor from ultralytics engine and is responsible for post-processing the
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
scaling the bounding boxes to fit the original image dimensions.
Attributes:
args (Namespace): Namespace containing various configurations for post-processing including confidence
threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
model (torch.nn.Module): The YOLO NAS model used for inference.
batch (list): Batch of inputs for processing.
Examples:
>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> predictor = model.predictor
Assume that raw_preds, img, orig_imgs are available
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
Notes:
Typically, this class is not instantiated directly. It is used internally within the NAS class.
"""
def postprocess(self, preds_in, img, orig_imgs):
"""
Postprocess NAS model predictions to generate final detection results.
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
post-processing operations to generate the final detection results compatible with Ultralytics
result visualization and analysis tools.
Args:
preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).
orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling
coordinates back to original dimensions.
Returns:
(list): List of Results objects containing the processed predictions for each image in the batch.
Examples:
>>> predictor = NAS("yolo_nas_s").predictor
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
"""
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
return super().postprocess(preds, img, orig_imgs)

View File

@@ -0,0 +1,39 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
__all__ = ["NASValidator"]
class NASValidator(DetectionValidator):
"""
Ultralytics YOLO NAS Validator for object detection.
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
ultimately producing the final detections.
Attributes:
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
thresholds.
lb (torch.Tensor): Optional tensor for multilabel NMS.
Examples:
>>> from ultralytics import NAS
>>> model = NAS("yolo_nas_s")
>>> validator = model.validator
>>> # Assumes that raw_preds are available
>>> final_preds = validator.postprocess(raw_preds)
Notes:
This class is generally not instantiated directly but is used internally within the NAS class.
"""
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
return super().postprocess(preds)

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"

View File

@@ -0,0 +1,66 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
References:
https://arxiv.org/pdf/2304.08069.pdf
"""
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils.torch_utils import TORCH_1_11
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR(Model):
"""
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
query selection, and adaptable inference speed.
Attributes:
model (str): Path to the pre-trained model.
Methods:
task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
Examples:
Initialize RT-DETR with a pre-trained model
>>> from ultralytics import RTDETR
>>> model = RTDETR("rtdetr-l.pt")
>>> results = model("image.jpg")
"""
def __init__(self, model: str = "rtdetr-l.pt") -> None:
"""
Initialize the RT-DETR model with the given pre-trained model file.
Args:
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
"""
assert TORCH_1_11, "RTDETR requires torch>=1.11"
super().__init__(model=model, task="detect")
@property
def task_map(self) -> dict:
"""
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
Returns:
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
"""
return {
"detect": {
"predictor": RTDETRPredictor,
"validator": RTDETRValidator,
"trainer": RTDETRTrainer,
"model": RTDETRDetectionModel,
}
}

View File

@@ -0,0 +1,92 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
It supports key features like efficient hybrid encoding and IoU-aware query selection.
Attributes:
imgsz (int): Image size for inference (must be square and scale-filled).
args (dict): Argument overrides for the predictor.
model (torch.nn.Module): The loaded RT-DETR model.
batch (list): Current batch of processed inputs.
Methods:
postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
pre_transform: Pre-transform input images before feeding them into the model for inference.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.rtdetr import RTDETRPredictor
>>> args = dict(model="rtdetr-l.pt", source=ASSETS)
>>> predictor = RTDETRPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def postprocess(self, preds, img, orig_imgs):
"""
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
The method filters detections based on confidence and class if specified in `self.args`. It converts
model predictions to Results objects containing properly scaled bounding boxes.
Args:
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
bounding boxes and scores.
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
orig_imgs (list | torch.Tensor): Original, unprocessed images.
Returns:
results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
confidence scores, and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
max_score, cls = score.max(-1, keepdim=True) # (300, 1)
idx = max_score.squeeze(-1) > self.args.conf # (300, )
if self.args.classes is not None:
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
oh, ow = orig_img.shape[:2]
pred[..., [0, 2]] *= ow # scale x coordinates to original width
pred[..., [1, 3]] *= oh # scale y coordinates to original height
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def pre_transform(self, im):
"""
Pre-transform input images before feeding them into the model for inference.
The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
(640) and scale_filled.
Args:
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
[(H, W, 3) x N] for list.
Returns:
(list): List of pre-transformed images ready for model inference.
"""
letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)
return [letterbox(image=x) for x in im]

View File

@@ -0,0 +1,92 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
speed.
Attributes:
loss_names (tuple): Names of the loss components used for training.
data (dict): Dataset configuration containing class count and other parameters.
args (dict): Training arguments and hyperparameters.
save_dir (Path): Directory to save training results.
test_loader (DataLoader): DataLoader for validation/testing data.
Methods:
get_model: Initialize and return an RT-DETR model for object detection tasks.
build_dataset: Build and return an RT-DETR dataset for training or validation.
get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
Notes:
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Examples:
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
>>> trainer = RTDETRTrainer(overrides=args)
>>> trainer.train()
"""
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
"""
Initialize and return an RT-DETR model for object detection tasks.
Args:
cfg (dict, optional): Model configuration.
weights (str, optional): Path to pre-trained model weights.
verbose (bool): Verbose logging if True.
Returns:
(RTDETRDetectionModel): Initialized model.
"""
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
"""
Build and return an RT-DETR dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): Dataset mode, either 'train' or 'val'.
batch (int, optional): Batch size for rectangle training.
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == "train",
hyp=self.args,
rect=False,
cache=self.args.cache or None,
single_cls=self.args.single_cls or False,
prefix=colorstr(f"{mode}: "),
classes=self.args.classes,
data=self.data,
fraction=self.args.fraction if mode == "train" else 1.0,
)
def get_validator(self):
"""Return a DetectionValidator suitable for RT-DETR model validation."""
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))

View File

@@ -0,0 +1,218 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import YOLODataset
from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = ("RTDETRValidator",) # tuple or list
class RTDETRDataset(YOLODataset):
"""
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
real-time detection and tracking tasks.
Attributes:
augment (bool): Whether to apply data augmentation.
rect (bool): Whether to use rectangular training.
use_segments (bool): Whether to use segmentation masks.
use_keypoints (bool): Whether to use keypoint annotations.
imgsz (int): Target image size for training.
Methods:
load_image: Load one image from dataset index.
build_transforms: Build transformation pipeline for the dataset.
Examples:
Initialize an RT-DETR dataset
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
>>> image, hw = dataset.load_image(0)
"""
def __init__(self, *args, data=None, **kwargs):
"""
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
model, building upon the base YOLODataset functionality.
Args:
*args (Any): Variable length argument list passed to the parent YOLODataset class.
data (dict | None): Dictionary containing dataset information. If None, default values will be used.
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
"""
super().__init__(*args, data=data, **kwargs)
def load_image(self, i, rect_mode=False):
"""
Load one image from dataset index 'i'.
Args:
i (int): Index of the image to load.
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
Returns:
im (torch.Tensor): The loaded image.
resized_hw (tuple): Height and width of the resized image with shape (2,).
Examples:
Load an image from the dataset
>>> dataset = RTDETRDataset(img_path="path/to/images")
>>> image, hw = dataset.load_image(0)
"""
return super().load_image(i=i, rect_mode=rect_mode)
def build_transforms(self, hyp=None):
"""
Build transformation pipeline for the dataset.
Args:
hyp (dict, optional): Hyperparameters for transformations.
Returns:
(Compose): Composition of transformation functions.
"""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
else:
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
transforms = Compose([])
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
)
return transforms
class RTDETRValidator(DetectionValidator):
"""
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
the RT-DETR (Real-Time DETR) object detection model.
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
post-processing, and updates evaluation metrics accordingly.
Attributes:
args (Namespace): Configuration arguments for validation.
data (dict): Dataset configuration dictionary.
Methods:
build_dataset: Build an RTDETR Dataset for validation.
postprocess: Apply Non-maximum suppression to prediction outputs.
Examples:
Initialize and run RT-DETR validation
>>> from ultralytics.models.rtdetr import RTDETRValidator
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
>>> validator = RTDETRValidator(args=args)
>>> validator()
Notes:
For further details on the attributes and methods, refer to the parent DetectionValidator class.
"""
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
each mode.
batch (int, optional): Size of batches, this is for `rect`.
Returns:
(RTDETRDataset): Dataset configured for RT-DETR validation.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=False, # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f"{mode}: "),
data=self.data,
)
def postprocess(
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
) -> list[dict[str, torch.Tensor]]:
"""
Apply Non-maximum suppression to prediction outputs.
Args:
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
Returns:
(list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
- 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
- 'conf': Tensor of shape (N,) with confidence scores
- 'cls': Tensor of shape (N,) with class indices
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1) # (300, )
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
# Sort by confidence to correctly get internal metrics
pred = pred[score.argsort(descending=True)]
outputs[i] = pred[score > self.args.conf]
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Serialize YOLO predictions to COCO json format.
Args:
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
"""
path = Path(pbatch["im_file"])
stem = path.stem
image_id = int(stem) if stem.isnumeric() else stem
box = predn["bboxes"].clone()
box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
box = ops.xyxy2xywh(box) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
self.jdict.append(
{
"image_id": image_id,
"file_name": path.name,
"category_id": self.class_map[int(c)],
"bbox": [round(x, 3) for x in b],
"score": round(s, 5),
}
)

View File

@@ -0,0 +1,12 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import SAM
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
__all__ = (
"SAM",
"Predictor",
"SAM2Predictor",
"SAM2VideoPredictor",
"SAM2DynamicInteractivePredictor",
) # tuple or list of exportable items

Binary file not shown.

View File

@@ -0,0 +1,281 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import math
from collections.abc import Generator
from itertools import product
from typing import Any
import numpy as np
import torch
def is_box_near_crop_edge(
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
) -> torch.Tensor:
"""
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
Args:
boxes (torch.Tensor): Bounding boxes in XYXY format.
crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
atol (float, optional): Absolute tolerance for edge proximity detection.
Returns:
(torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
Examples:
>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
>>> crop_box = [0, 0, 200, 200]
>>> orig_box = [0, 0, 300, 300]
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
"""
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
return torch.any(near_crop_edge, dim=1)
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
"""
Yield batches of data from input arguments with specified batch size for efficient processing.
This function takes a batch size and any number of iterables, then yields batches of elements from those
iterables. All input iterables must have the same length.
Args:
batch_size (int): Size of each batch to yield.
*args (Any): Variable length input iterables to batch. All iterables must have the same length.
Yields:
(list[Any]): A list of batched elements from each input iterable.
Examples:
>>> data = [1, 2, 3, 4, 5]
>>> labels = ["a", "b", "c", "d", "e"]
>>> for batch in batch_iterator(2, data, labels):
... print(batch)
[[1, 2], ['a', 'b']]
[[3, 4], ['c', 'd']]
[[5], ['e']]
"""
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
for b in range(n_batches):
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
"""
Compute the stability score for a batch of masks.
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
high and low values.
Args:
masks (torch.Tensor): Batch of predicted mask logits.
mask_threshold (float): Threshold value for creating binary masks.
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
Returns:
(torch.Tensor): Stability scores for each mask in the batch.
Notes:
- One mask is always contained inside the other.
- Memory is saved by preventing unnecessary cast to torch.int64.
Examples:
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
>>> mask_threshold = 0.5
>>> threshold_offset = 0.1
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
"""
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
return intersections / unions
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
"""Generate point grids for multiple crop layers with varying scales and densities."""
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
def generate_crop_boxes(
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
) -> tuple[list[list[int]], list[int]]:
"""
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
Args:
im_size (tuple[int, ...]): Height and width of the input image.
n_layers (int): Number of layers to generate crop boxes for.
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
Returns:
crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
layer_idxs (list[int]): List of layer indices corresponding to each crop box.
Examples:
>>> im_size = (800, 1200) # Height, width
>>> n_layers = 3
>>> overlap_ratio = 0.25
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
"""
crop_boxes, layer_idxs = [], []
im_h, im_w = im_size
short_side = min(im_h, im_w)
# Original image
crop_boxes.append([0, 0, im_w, im_h])
layer_idxs.append(0)
def crop_len(orig_len, n_crops, overlap):
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
for i_layer in range(n_layers):
n_crops_per_side = 2 ** (i_layer + 1)
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
crop_w = crop_len(im_w, n_crops_per_side, overlap)
crop_h = crop_len(im_h, n_crops_per_side, overlap)
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
# Crops in XYWH format
for x0, y0 in product(crop_box_x0, crop_box_y0):
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
crop_boxes.append(box)
layer_idxs.append(i_layer + 1)
return crop_boxes, layer_idxs
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
# Check if boxes has a channel dimension
if len(boxes.shape) == 3:
offset = offset.unsqueeze(1)
return boxes + offset
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
"""Uncrop points by adding the crop box offset to their coordinates."""
x0, y0, _, _ = crop_box
offset = torch.tensor([[x0, y0]], device=points.device)
# Check if points has a channel dimension
if len(points.shape) == 3:
offset = offset.unsqueeze(1)
return points + offset
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
x0, y0, x1, y1 = crop_box
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
return masks
# Coordinate transform masks
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
pad = (x0, pad_x - x0, y0, pad_y - y0)
return torch.nn.functional.pad(masks, pad, value=0)
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
"""
Remove small disconnected regions or holes in a mask based on area threshold and mode.
Args:
mask (np.ndarray): Binary mask to process.
area_thresh (float): Area threshold below which regions will be removed.
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
regions.
Returns:
processed_mask (np.ndarray): Processed binary mask with small regions removed.
modified (bool): Whether any regions were modified.
Examples:
>>> mask = np.zeros((100, 100), dtype=np.bool_)
>>> mask[40:60, 40:60] = True # Create a square
>>> mask[45:55, 45:55] = False # Create a hole
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
"""
import cv2 # type: ignore
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if not small_regions:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
# If every region is below threshold, keep largest
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
"""
Calculate bounding boxes in XYXY format around binary masks.
Args:
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
Returns:
(torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
Notes:
- Handles empty masks by returning zero boxes.
- Preserves input tensor dimensions in the output.
"""
# torch.max below raises an error on empty inputs, just skip in this case
if torch.numel(masks) == 0:
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
# Normalize shape to CxHxW
shape = masks.shape
h, w = shape[-2:]
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
# Get top and bottom edges
in_height, _ = torch.max(masks, dim=-1)
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
in_height_coords = in_height_coords + h * (~in_height)
top_edges, _ = torch.min(in_height_coords, dim=-1)
# Get left and right edges
in_width, _ = torch.max(masks, dim=-2)
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
right_edges, _ = torch.max(in_width_coords, dim=-1)
in_width_coords = in_width_coords + w * (~in_width)
left_edges, _ = torch.min(in_width_coords, dim=-1)
# If the mask is empty the right edge will be to the left of the left edge.
# Replace these boxes with [0, 0, 0, 0]
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
out = out * (~empty_filter).unsqueeze(-1)
# Return to original shape
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]

View File

@@ -0,0 +1,358 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from functools import partial
import torch
from ultralytics.utils.downloads import attempt_download_asset
from .modules.decoders import MaskDecoder
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
from .modules.sam import SAM2Model, SAMModel
from .modules.tiny_encoder import TinyViT
from .modules.transformer import TwoWayTransformer
def build_sam_vit_h(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_global_attn_indexes=[7, 15, 23, 31],
checkpoint=checkpoint,
)
def build_sam_vit_l(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_global_attn_indexes=[5, 11, 17, 23],
checkpoint=checkpoint,
)
def build_sam_vit_b(checkpoint=None):
"""Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
return _build_sam(
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_global_attn_indexes=[2, 5, 8, 11],
checkpoint=checkpoint,
)
def build_mobile_sam(checkpoint=None):
"""Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
return _build_sam(
encoder_embed_dim=[64, 128, 160, 320],
encoder_depth=[2, 2, 6, 2],
encoder_num_heads=[2, 4, 5, 10],
encoder_global_attn_indexes=None,
mobile_sam=True,
checkpoint=checkpoint,
)
def build_sam2_t(checkpoint=None):
"""Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 7, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[5, 7, 9],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_s(checkpoint=None):
"""Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=96,
encoder_stages=[1, 2, 11, 2],
encoder_num_heads=1,
encoder_global_att_blocks=[7, 10, 13],
encoder_window_spec=[8, 4, 14, 7],
encoder_backbone_channel_list=[768, 384, 192, 96],
checkpoint=checkpoint,
)
def build_sam2_b(checkpoint=None):
"""Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=112,
encoder_stages=[2, 3, 16, 3],
encoder_num_heads=2,
encoder_global_att_blocks=[12, 16, 20],
encoder_window_spec=[8, 4, 14, 7],
encoder_window_spatial_size=[14, 14],
encoder_backbone_channel_list=[896, 448, 224, 112],
checkpoint=checkpoint,
)
def build_sam2_l(checkpoint=None):
"""Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
return _build_sam2(
encoder_embed_dim=144,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[23, 33, 43],
encoder_window_spec=[8, 4, 16, 8],
encoder_backbone_channel_list=[1152, 576, 288, 144],
checkpoint=checkpoint,
)
def _build_sam(
encoder_embed_dim,
encoder_depth,
encoder_num_heads,
encoder_global_attn_indexes,
checkpoint=None,
mobile_sam=False,
):
"""
Build a Segment Anything Model (SAM) with specified encoder parameters.
Args:
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
encoder_depth (int | list[int]): Depth of the encoder.
encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
checkpoint (str | None, optional): Path to the model checkpoint file.
mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
Returns:
(SAMModel): A Segment Anything Model instance with the specified architecture.
Examples:
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
"""
prompt_embed_dim = 256
image_size = 1024
vit_patch_size = 16
image_embedding_size = image_size // vit_patch_size
image_encoder = (
TinyViT(
img_size=1024,
in_chans=3,
num_classes=1000,
embed_dims=encoder_embed_dim,
depths=encoder_depth,
num_heads=encoder_num_heads,
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.0,
drop_rate=0.0,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8,
)
if mobile_sam
else ImageEncoderViT(
depth=encoder_depth,
embed_dim=encoder_embed_dim,
img_size=image_size,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
num_heads=encoder_num_heads,
patch_size=vit_patch_size,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=encoder_global_attn_indexes,
window_size=14,
out_chans=prompt_embed_dim,
)
)
sam = SAMModel(
image_encoder=image_encoder,
prompt_encoder=PromptEncoder(
embed_dim=prompt_embed_dim,
image_embedding_size=(image_embedding_size, image_embedding_size),
input_image_size=(image_size, image_size),
mask_in_chans=16,
),
mask_decoder=MaskDecoder(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=prompt_embed_dim,
mlp_dim=2048,
num_heads=8,
),
transformer_dim=prompt_embed_dim,
iou_head_depth=3,
iou_head_hidden_dim=256,
),
pixel_mean=[123.675, 116.28, 103.53],
pixel_std=[58.395, 57.12, 57.375],
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)
sam.load_state_dict(state_dict)
sam.eval()
return sam
def _build_sam2(
encoder_embed_dim=1280,
encoder_stages=[2, 6, 36, 4],
encoder_num_heads=2,
encoder_global_att_blocks=[7, 15, 23, 31],
encoder_backbone_channel_list=[1152, 576, 288, 144],
encoder_window_spatial_size=[7, 7],
encoder_window_spec=[8, 4, 16, 8],
checkpoint=None,
):
"""
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
Args:
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
encoder_num_heads (int, optional): Number of attention heads in the encoder.
encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
Returns:
(SAM2Model): A configured and initialized SAM2 model.
Examples:
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
>>> sam2_model.eval()
"""
image_encoder = ImageEncoder(
trunk=Hiera(
embed_dim=encoder_embed_dim,
num_heads=encoder_num_heads,
stages=encoder_stages,
global_att_blocks=encoder_global_att_blocks,
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
window_spec=encoder_window_spec,
),
neck=FpnNeck(
d_model=256,
backbone_channel_list=encoder_backbone_channel_list,
fpn_top_down_levels=[2, 3],
fpn_interp_model="nearest",
),
scalp=1,
)
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
memory_encoder = MemoryEncoder(out_dim=64)
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
sam2 = SAM2Model(
image_encoder=image_encoder,
memory_attention=memory_attention,
memory_encoder=memory_encoder,
num_maskmem=7,
image_size=1024,
sigmoid_scale_for_mem_enc=20.0,
sigmoid_bias_for_mem_enc=-10.0,
use_mask_input_as_output_without_sam=True,
directly_add_no_mem_embed=True,
use_high_res_features_in_sam=True,
multimask_output_in_sam=True,
iou_prediction_use_sigmoid=True,
use_obj_ptrs_in_encoder=True,
add_tpos_enc_to_obj_ptrs=True,
only_obj_ptrs_in_the_past_for_eval=True,
pred_obj_scores=True,
pred_obj_scores_mlp=True,
fixed_no_obj_ptr=True,
multimask_output_for_tracking=True,
use_multimask_token_for_obj_ptr=True,
multimask_min_pt_num=0,
multimask_max_pt_num=1,
use_mlp_for_obj_ptr_proj=True,
compile_image_encoder=False,
no_obj_embed_spatial=is_sam2_1,
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
sam_mask_decoder_extra_args=dict(
dynamic_multimask_via_stability=True,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
),
)
if checkpoint is not None:
checkpoint = attempt_download_asset(checkpoint)
with open(checkpoint, "rb") as f:
state_dict = torch.load(f)["model"]
sam2.load_state_dict(state_dict)
sam2.eval()
return sam2
sam_model_map = {
"sam_h.pt": build_sam_vit_h,
"sam_l.pt": build_sam_vit_l,
"sam_b.pt": build_sam_vit_b,
"mobile_sam.pt": build_mobile_sam,
"sam2_t.pt": build_sam2_t,
"sam2_s.pt": build_sam2_s,
"sam2_b.pt": build_sam2_b,
"sam2_l.pt": build_sam2_l,
"sam2.1_t.pt": build_sam2_t,
"sam2.1_s.pt": build_sam2_s,
"sam2.1_b.pt": build_sam2_b,
"sam2.1_l.pt": build_sam2_l,
}
def build_sam(ckpt="sam_b.pt"):
"""
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
Args:
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
Returns:
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
Raises:
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
Examples:
>>> sam_model = build_sam("sam_b.pt")
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
Notes:
Supported pre-defined models include:
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
"""
model_builder = None
ckpt = str(ckpt) # to allow Path ckpt types
for k in sam_model_map.keys():
if ckpt.endswith(k):
model_builder = sam_model_map.get(k)
if not model_builder:
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
return model_builder(ckpt)

View File

@@ -0,0 +1,172 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
SAM model interface.
This module provides an interface to the Segment Anything Model (SAM) from ultralytics, designed for real-time image
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
image distributions and tasks without prior knowledge.
Key Features:
- Promptable segmentation
- Real-time performance
- Zero-shot transfer capabilities
- Trained on SA-1B dataset
"""
from __future__ import annotations
from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .predict import Predictor, SAM2Predictor
class SAM(Model):
"""
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
This class provides an interface to the Segment Anything Model (SAM) from ultralytics, designed for
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
boxes, points, or labels, and features zero-shot performance capabilities.
Attributes:
model (torch.nn.Module): The loaded SAM model.
is_sam2 (bool): Indicates whether the model is SAM2 variant.
task (str): The task type, set to "segment" for SAM models.
Methods:
predict: Perform segmentation prediction on the given image or video source.
info: Log information about the SAM model.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
>>> print(f"Detected {len(r.masks)} masks")
"""
def __init__(self, model: str = "sam_b.pt") -> None:
"""
Initialize the SAM (Segment Anything Model) instance.
Args:
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
Raises:
NotImplementedError: If the model file extension is not .pt or .pth.
Examples:
>>> sam = SAM("sam_b.pt")
>>> print(sam.is_sam2)
"""
if model and Path(model).suffix not in {".pt", ".pth"}:
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
self.is_sam2 = "sam2" in Path(model).stem
super().__init__(model=model, task="segment")
def _load(self, weights: str, task=None):
"""
Load the specified weights into the SAM model.
Args:
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
Examples:
>>> sam = SAM("sam_b.pt")
>>> sam._load("path/to/custom_weights.pt")
"""
from .build import build_sam # slow import
self.model = build_sam(weights)
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""
Perform segmentation prediction on the given image or video source.
Args:
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
a np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
labels (list[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments for prediction.
Returns:
(list): The model predictions.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam.predict("image.jpg", points=[[500, 375]])
>>> for r in results:
... print(f"Detected {len(r.masks)} masks")
"""
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
kwargs = {**overrides, **kwargs}
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
"""
Perform segmentation prediction on the given image or video source.
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
for segmentation tasks.
Args:
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
object, or a np.ndarray object.
stream (bool): If True, enables real-time streaming.
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
points (list[list[float]] | None): List of points for prompted segmentation.
labels (list[int] | None): List of labels for prompted segmentation.
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
Returns:
(list): The model predictions, typically containing segmentation masks and other relevant information.
Examples:
>>> sam = SAM("sam_b.pt")
>>> results = sam("image.jpg", points=[[500, 375]])
>>> print(f"Detected {len(results[0].masks)} masks")
"""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def info(self, detailed: bool = False, verbose: bool = True):
"""
Log information about the SAM model.
Args:
detailed (bool): If True, displays detailed information about the model layers and operations.
verbose (bool): If True, prints the information to the console.
Returns:
(tuple): A tuple containing the model's information (string representations of the model).
Examples:
>>> sam = SAM("sam_b.pt")
>>> info = sam.info()
>>> print(info[0]) # Print summary information
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
"""
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
Returns:
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
Examples:
>>> sam = SAM("sam_b.pt")
>>> task_map = sam.task_map
>>> print(task_map)
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
"""
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}

View File

@@ -0,0 +1 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,513 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import torch
from torch import nn
from ultralytics.nn.modules import MLP, LayerNorm2d
class MaskDecoder(nn.Module):
"""
Decoder module for generating masks and their associated quality scores using a transformer architecture.
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
generate mask predictions along with their quality scores.
Attributes:
transformer_dim (int): Channel dimension for the transformer module.
transformer (nn.Module): Transformer module used for mask prediction.
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
iou_token (nn.Embedding): Embedding for the IoU token.
num_mask_tokens (int): Number of mask tokens.
mask_tokens (nn.Embedding): Embedding for the mask tokens.
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
iou_prediction_head (nn.Module): MLP for predicting mask quality.
Methods:
forward: Predict masks given image and prompt embeddings.
predict_masks: Internal method for mask prediction.
Examples:
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> masks, iou_pred = decoder(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
... )
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
def __init__(
self,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
) -> None:
"""
Initialize the MaskDecoder module for generating masks and their associated quality scores.
Args:
transformer_dim (int): Channel dimension for the transformer module.
transformer (nn.Module): Transformer module used for mask prediction.
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
Examples:
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder.
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
multimask_output (bool): Whether to return multiple masks or a single mask.
Returns:
masks (torch.Tensor): Batched predicted masks.
iou_pred (torch.Tensor): Batched predictions of mask quality.
Examples:
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
>>> image_emb = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_emb = torch.rand(1, 2, 256)
>>> dense_emb = torch.rand(1, 256, 64, 64)
>>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
>>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
"""
masks, iou_pred = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
)
# Select the correct mask or masks for output
mask_slice = slice(1, None) if multimask_output else slice(0, 1)
masks = masks[:, mask_slice, :, :]
iou_pred = iou_pred[:, mask_slice]
return masks, iou_pred
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
# Concatenate output tokens
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, 0, :]
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: list[torch.Tensor] = [
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
return masks, iou_pred
class SAM2MaskDecoder(nn.Module):
"""
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
This class extends the functionality of the MaskDecoder, incorporating additional features such as
high-resolution feature processing, dynamic multimask output, and object score prediction.
Attributes:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
iou_token (nn.Embedding): Embedding for IOU token.
num_mask_tokens (int): Total number of mask tokens.
mask_tokens (nn.Embedding): Embedding for mask tokens.
pred_obj_scores (bool): Whether to predict object scores.
obj_score_token (nn.Embedding): Embedding for object score token.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
output_upscaling (nn.Sequential): Upscaling layers for output.
use_high_res_features (bool): Whether to use high-resolution features.
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
iou_prediction_head (MLP): MLP for IOU prediction.
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
Methods:
forward: Predict masks given image and prompt embeddings.
predict_masks: Predict instance segmentation masks from image and prompt embeddings.
_get_stability_scores: Compute mask stability scores based on IoU between thresholds.
_dynamic_multimask_via_stability: Dynamically select the most stable mask output.
Examples:
>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )
"""
def __init__(
self,
transformer_dim: int,
transformer: nn.Module,
num_multimask_outputs: int = 3,
activation: type[nn.Module] = nn.GELU,
iou_head_depth: int = 3,
iou_head_hidden_dim: int = 256,
use_high_res_features: bool = False,
iou_prediction_use_sigmoid=False,
dynamic_multimask_via_stability=False,
dynamic_multimask_stability_delta=0.05,
dynamic_multimask_stability_thresh=0.98,
pred_obj_scores: bool = False,
pred_obj_scores_mlp: bool = False,
use_multimask_token_for_obj_ptr: bool = False,
) -> None:
"""
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
high-resolution feature processing, dynamic multimask output, and object score prediction.
Args:
transformer_dim (int): Channel dimension of the transformer.
transformer (nn.Module): Transformer used to predict masks.
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
iou_head_depth (int): Depth of the MLP used to predict mask quality.
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
use_high_res_features (bool): Whether to use high-resolution features.
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
pred_obj_scores (bool): Whether to predict object scores.
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
Examples:
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
>>> print(decoder)
"""
super().__init__()
self.transformer_dim = transformer_dim
self.transformer = transformer
self.num_multimask_outputs = num_multimask_outputs
self.iou_token = nn.Embedding(1, transformer_dim)
self.num_mask_tokens = num_multimask_outputs + 1
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.pred_obj_scores = pred_obj_scores
if self.pred_obj_scores:
self.obj_score_token = nn.Embedding(1, transformer_dim)
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
self.output_upscaling = nn.Sequential(
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
LayerNorm2d(transformer_dim // 4),
activation(),
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
activation(),
)
self.use_high_res_features = use_high_res_features
if use_high_res_features:
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
self.output_hypernetworks_mlps = nn.ModuleList(
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
)
self.iou_prediction_head = MLP(
transformer_dim,
iou_head_hidden_dim,
self.num_mask_tokens,
iou_head_depth,
sigmoid=iou_prediction_use_sigmoid,
)
if self.pred_obj_scores:
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
if pred_obj_scores_mlp:
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
# When outputting a single mask, optionally we can dynamically fall back to the best
# multimask output token if the single mask output token gives low stability scores.
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Args:
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
multimask_output (bool): Whether to return multiple masks or a single mask.
repeat_image (bool): Flag to repeat the image embeddings.
high_res_features (list[torch.Tensor] | None, optional): Optional high-resolution features.
Returns:
masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
Examples:
>>> image_embeddings = torch.rand(1, 256, 64, 64)
>>> image_pe = torch.rand(1, 256, 64, 64)
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
>>> decoder = SAM2MaskDecoder(256, transformer)
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
... )
"""
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# Select the correct mask or masks for output
if multimask_output:
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
elif self.dynamic_multimask_via_stability and not self.training:
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
else:
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
if multimask_output and self.use_multimask_token_for_obj_ptr:
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# Take the mask output token. Here we *always* use the token for single mask output.
# At test time, even if we track after 1-click (and using multimask_output=True),
# we still take the single mask token here. The rationale is that we always track
# after multiple clicks during training, so the past tokens seen during training
# are always the single mask token (and we'll let it be the object-memory token).
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
return masks, iou_pred, sam_tokens_out, object_score_logits
def predict_masks(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
repeat_image: bool,
high_res_features: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Predict instance segmentation masks from image and prompt embeddings using a transformer."""
# Concatenate output tokens
s = 0
if self.pred_obj_scores:
output_tokens = torch.cat(
[
self.obj_score_token.weight,
self.iou_token.weight,
self.mask_tokens.weight,
],
dim=0,
)
s = 1
else:
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
# Expand per-image data in batch direction to be per-mask
if repeat_image:
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
else:
assert image_embeddings.shape[0] == tokens.shape[0]
src = image_embeddings
src = src + dense_prompt_embeddings
assert image_pe.shape[0] == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
iou_token_out = hs[:, s, :]
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
if not self.use_high_res_features or high_res_features is None:
upscaled_embedding = self.output_upscaling(src)
else:
dc1, ln1, act1, dc2, act2 = self.output_upscaling
feat_s0, feat_s1 = high_res_features
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
hyper_in_list: list[torch.Tensor] = [
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
]
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
if self.pred_obj_scores:
assert s == 1
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
else:
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
return masks, iou_pred, mask_tokens_out, object_score_logits
def _get_stability_scores(self, mask_logits):
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
mask_logits = mask_logits.flatten(-2)
stability_delta = self.dynamic_multimask_stability_delta
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
return torch.where(area_u > 0, area_i / area_u, 1.0)
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
"""
Dynamically select the most stable mask output based on stability scores and IoU predictions.
This method is used when outputting a single mask. If the stability score from the current single-mask
output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
(based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
for both clicking and tracking scenarios.
Args:
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
Returns:
mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
Examples:
>>> decoder = SAM2MaskDecoder(...)
>>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
>>> all_iou_scores = torch.rand(2, 4)
>>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
>>> print(mask_logits.shape, iou_scores.shape)
torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
"""
# The best mask from multimask output tokens (1~3)
multimask_logits = all_mask_logits[:, 1:, :, :]
multimask_iou_scores = all_iou_scores[:, 1:]
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
batch_inds = torch.arange(multimask_iou_scores.shape[0], device=all_iou_scores.device)
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
best_multimask_logits = best_multimask_logits.unsqueeze(1)
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
# The mask from singlemask output token 0 and its stability score
singlemask_logits = all_mask_logits[:, 0:1, :, :]
singlemask_iou_scores = all_iou_scores[:, 0:1]
stability_scores = self._get_stability_scores(singlemask_logits)
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
# Dynamically fall back to best multimask output upon low stability scores.
mask_logits_out = torch.where(
is_stable[..., None, None].expand_as(singlemask_logits),
singlemask_logits,
best_multimask_logits,
)
iou_scores_out = torch.where(
is_stable.expand_as(singlemask_iou_scores),
singlemask_iou_scores,
best_multimask_iou_scores,
)
return mask_logits_out, iou_scores_out

View File

@@ -0,0 +1,851 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import LayerNorm2d
from .blocks import (
Block,
CXBlock,
Fuser,
MaskDownSampler,
MultiScaleBlock,
PatchEmbed,
PositionEmbeddingRandom,
PositionEmbeddingSine,
)
class ImageEncoderViT(nn.Module):
"""
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
encoded representation through a neck module.
Attributes:
img_size (int): Dimension of input images, assumed to be square.
patch_embed (PatchEmbed): Module for patch embedding.
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
neck (nn.Sequential): Neck module to further process the output.
Methods:
forward: Process input through patch embedding, positional embedding, blocks, and neck.
Examples:
>>> import torch
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
>>> input_image = torch.randn(1, 3, 224, 224)
>>> output = encoder(input_image)
>>> print(output.shape)
"""
def __init__(
self,
img_size: int = 1024,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 256,
qkv_bias: bool = True,
norm_layer: type[nn.Module] = nn.LayerNorm,
act_layer: type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
window_size: int = 0,
global_attn_indexes: tuple[int, ...] = (),
) -> None:
"""
Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
Args:
img_size (int): Input image size, assumed to be square.
patch_size (int): Size of image patches.
in_chans (int): Number of input image channels.
embed_dim (int): Dimension of patch embeddings.
depth (int): Number of transformer blocks.
num_heads (int): Number of attention heads in each block.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
out_chans (int): Number of output channels from the neck module.
qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
norm_layer (Type[nn.Module]): Type of normalization layer to use.
act_layer (Type[nn.Module]): Type of activation layer to use.
use_abs_pos (bool): If True, uses absolute positional embeddings.
use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
window_size (int): Size of attention window for windowed attention blocks.
global_attn_indexes (tuple[int, ...]): Indices of blocks that use global attention.
Examples:
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
>>> input_image = torch.randn(1, 3, 224, 224)
>>> output = encoder(input_image)
>>> print(output.shape)
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed(
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
self.pos_embed: nn.Parameter | None = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=(img_size // patch_size, img_size // patch_size),
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_chans),
nn.Conv2d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through patch embedding, positional embedding, transformer blocks, and neck module."""
x = self.patch_embed(x)
if self.pos_embed is not None:
pos_embed = (
F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
if self.img_size != 1024
else self.pos_embed
)
x = x + pos_embed
for blk in self.blocks:
x = blk(x)
return self.neck(x.permute(0, 3, 1, 2))
class PromptEncoder(nn.Module):
"""
Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
Attributes:
embed_dim (int): Dimension of the embeddings.
input_image_size (tuple[int, int]): Size of the input image as (H, W).
image_embedding_size (tuple[int, int]): Spatial size of the image embedding as (H, W).
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
num_point_embeddings (int): Number of point embeddings for different types of points.
point_embeddings (nn.ModuleList): List of point embeddings.
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
mask_input_size (tuple[int, int]): Size of the input mask.
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
Methods:
get_dense_pe: Return the positional encoding used to encode point prompts.
forward: Embed different types of prompts, returning both sparse and dense embeddings.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
def __init__(
self,
embed_dim: int,
image_embedding_size: tuple[int, int],
input_image_size: tuple[int, int],
mask_in_chans: int,
activation: type[nn.Module] = nn.GELU,
) -> None:
"""
Initialize the PromptEncoder module for encoding various types of prompts.
Args:
embed_dim (int): The dimension of the embeddings.
image_embedding_size (tuple[int, int]): The spatial size of the image embedding as (H, W).
input_image_size (tuple[int, int]): The padded size of the input image as (H, W).
mask_in_chans (int): The number of hidden channels used for encoding input masks.
activation (Type[nn.Module]): The activation function to use when encoding input masks.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
super().__init__()
self.embed_dim = embed_dim
self.input_image_size = input_image_size
self.image_embedding_size = image_embedding_size
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
activation(),
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans),
activation(),
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
)
self.no_mask_embed = nn.Embedding(1, embed_dim)
def get_dense_pe(self) -> torch.Tensor:
"""
Return the dense positional encoding used for encoding point prompts.
Generate a positional encoding for a dense set of points matching the shape of the image
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
Returns:
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
height and width of the image embedding size, respectively.
Examples:
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> dense_pe = prompt_encoder.get_dense_pe()
>>> print(dense_pe.shape)
torch.Size([1, 256, 64, 64])
"""
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
"""Embed point prompts by applying positional encoding and label-specific embeddings."""
points = points + 0.5 # Shift to center of pixel
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), dtype=points.dtype, device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), dtype=labels.dtype, device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
point_embedding[labels == 2] += self.point_embeddings[2].weight
point_embedding[labels == 3] += self.point_embeddings[3].weight
return point_embedding
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embed box prompts by applying positional encoding and adding corner embeddings."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
"""Embed mask inputs by downscaling and processing through convolutional layers."""
return self.mask_downscaling(masks)
@staticmethod
def _get_batch_size(
points: tuple[torch.Tensor, torch.Tensor] | None,
boxes: torch.Tensor | None,
masks: torch.Tensor | None,
) -> int:
"""Get the batch size of the output given the batch size of the input prompts."""
if points is not None:
return points[0].shape[0]
elif boxes is not None:
return boxes.shape[0]
elif masks is not None:
return masks.shape[0]
else:
return 1
def forward(
self,
points: tuple[torch.Tensor, torch.Tensor] | None,
boxes: torch.Tensor | None,
masks: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Embed different types of prompts, returning both sparse and dense embeddings.
Args:
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
shape (B, N).
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
Returns:
sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
Examples:
>>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
>>> boxes = torch.rand(1, 2, 2, 2)
>>> masks = torch.rand(1, 1, 256, 256)
>>> sparse_emb, dense_emb = encoder(points, boxes, masks)
>>> print(sparse_emb.shape, dense_emb.shape)
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
"""
bs = self._get_batch_size(points, boxes, masks)
sparse_embeddings = torch.empty(
(bs, 0, self.embed_dim),
dtype=self.point_embeddings[0].weight.dtype,
device=self.point_embeddings[0].weight.device,
)
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
if boxes is not None:
box_embeddings = self._embed_boxes(boxes)
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
if masks is not None:
dense_embeddings = self._embed_masks(masks)
else:
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)
return sparse_embeddings, dense_embeddings
class MemoryEncoder(nn.Module):
"""
Encode pixel features and masks into a memory representation for efficient image segmentation.
This class processes pixel-level features and masks, fusing them to generate encoded memory representations
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
Attributes:
mask_downsampler (MaskDownSampler): Module for downsampling input masks.
pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
fuser (Fuser): Module for fusing pixel features and masks.
position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
Methods:
forward: Process input pixel features and masks to generate encoded memory representations.
Examples:
>>> import torch
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
>>> pix_feat = torch.randn(1, 256, 64, 64)
>>> masks = torch.randn(1, 1, 64, 64)
>>> encoded_feat, pos = encoder(pix_feat, masks)
>>> print(encoded_feat.shape, pos.shape)
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
"""
def __init__(
self,
out_dim,
in_dim=256, # in_dim of pix_feats
):
"""
Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
Args:
out_dim (int): Output dimension of the encoded features.
in_dim (int): Input dimension of the pixel features.
Examples:
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
>>> pix_feat = torch.randn(1, 256, 64, 64)
>>> masks = torch.randn(1, 1, 64, 64)
>>> encoded_feat, pos = encoder(pix_feat, masks)
>>> print(encoded_feat.shape, pos.shape)
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
"""
super().__init__()
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
self.out_proj = nn.Identity()
if out_dim != in_dim:
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
def forward(
self,
pix_feat: torch.Tensor,
masks: torch.Tensor,
skip_mask_sigmoid: bool = False,
) -> dict:
"""Process pixel features and masks to generate encoded memory representations for segmentation."""
if not skip_mask_sigmoid:
masks = F.sigmoid(masks)
masks = self.mask_downsampler(masks)
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
pix_feat = pix_feat.to(masks.device)
x = self.pix_feat_proj(pix_feat)
x = x + masks
x = self.fuser(x)
x = self.out_proj(x)
pos = self.position_encoding(x).to(x.dtype)
return {"vision_features": x, "vision_pos_enc": [pos]}
class ImageEncoder(nn.Module):
"""
Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
This class combines a trunk network for feature extraction with a neck network for feature refinement
and positional encoding generation. It can optionally discard the lowest resolution features.
Attributes:
trunk (nn.Module): The trunk network for initial feature extraction.
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
scalp (int): Number of lowest resolution feature levels to discard.
Methods:
forward: Process the input image through the trunk and neck networks.
Examples:
>>> trunk = SomeTrunkNetwork()
>>> neck = SomeNeckNetwork()
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
>>> image = torch.randn(1, 3, 224, 224)
>>> output = encoder(image)
>>> print(output.keys())
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
"""
def __init__(
self,
trunk: nn.Module,
neck: nn.Module,
scalp: int = 0,
):
"""
Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
This encoder combines a trunk network for feature extraction with a neck network for feature refinement
and positional encoding generation. It can optionally discard the lowest resolution features.
Args:
trunk (nn.Module): The trunk network for initial feature extraction.
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
scalp (int): Number of lowest resolution feature levels to discard.
Examples:
>>> trunk = SomeTrunkNetwork()
>>> neck = SomeNeckNetwork()
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
>>> image = torch.randn(1, 3, 224, 224)
>>> output = encoder(image)
>>> print(output.keys())
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
"""
super().__init__()
self.trunk = trunk
self.neck = neck
self.scalp = scalp
assert self.trunk.channel_list == self.neck.backbone_channel_list, (
f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
)
def forward(self, sample: torch.Tensor):
"""Encode input through trunk and neck networks, returning multiscale features and positional encodings."""
features, pos = self.neck(self.trunk(sample))
if self.scalp > 0:
# Discard the lowest resolution features
features, pos = features[: -self.scalp], pos[: -self.scalp]
src = features[-1]
return {
"vision_features": src,
"vision_pos_enc": pos,
"backbone_fpn": features,
}
class FpnNeck(nn.Module):
"""
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
similar to ViT positional embedding interpolation.
Attributes:
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
convs (nn.ModuleList): List of convolutional layers for each backbone level.
backbone_channel_list (list[int]): List of channel dimensions from the backbone.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
fpn_top_down_levels (list[int]): Levels to have top-down features in outputs.
Methods:
forward: Perform forward pass through the FPN neck.
Examples:
>>> backbone_channels = [64, 128, 256, 512]
>>> fpn_neck = FpnNeck(256, backbone_channels)
>>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
>>> outputs, positions = fpn_neck(inputs)
>>> print(len(outputs), len(positions))
4 4
"""
def __init__(
self,
d_model: int,
backbone_channel_list: list[int],
kernel_size: int = 1,
stride: int = 1,
padding: int = 0,
fpn_interp_model: str = "bilinear",
fuse_type: str = "sum",
fpn_top_down_levels: list[int] | None = None,
):
"""
Initialize a modified Feature Pyramid Network (FPN) neck.
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
similar to ViT positional embedding interpolation.
Args:
d_model (int): Dimension of the model.
backbone_channel_list (list[int]): List of channel dimensions from the backbone.
kernel_size (int): Kernel size for the convolutional layers.
stride (int): Stride for the convolutional layers.
padding (int): Padding for the convolutional layers.
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
fpn_top_down_levels (Optional[list[int]]): Levels to have top-down features in outputs.
Examples:
>>> backbone_channels = [64, 128, 256, 512]
>>> fpn_neck = FpnNeck(256, backbone_channels)
>>> print(fpn_neck)
"""
super().__init__()
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
self.convs = nn.ModuleList()
self.backbone_channel_list = backbone_channel_list
for dim in backbone_channel_list:
current = nn.Sequential()
current.add_module(
"conv",
nn.Conv2d(
in_channels=dim,
out_channels=d_model,
kernel_size=kernel_size,
stride=stride,
padding=padding,
),
)
self.convs.append(current)
self.fpn_interp_model = fpn_interp_model
assert fuse_type in {"sum", "avg"}
self.fuse_type = fuse_type
# Levels to have top-down features in its outputs
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
# have top-down propagation, while outputs of level 0 and level 1 have only
# lateral features from the same backbone level
if fpn_top_down_levels is None:
# Default is to have top-down features on all levels
fpn_top_down_levels = range(len(self.convs))
self.fpn_top_down_levels = list(fpn_top_down_levels)
def forward(self, xs: list[torch.Tensor]):
"""
Perform forward pass through the Feature Pyramid Network (FPN) neck.
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
Args:
xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
Returns:
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape
(B, d_model, H, W).
pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
Examples:
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
>>> outputs, positions = fpn_neck(inputs)
>>> print(len(outputs), len(positions))
4 4
"""
out = [None] * len(self.convs)
pos = [None] * len(self.convs)
assert len(xs) == len(self.convs)
# FPN forward pass
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
prev_features = None
# Forward in top-down order (from low to high resolution)
n = len(self.convs) - 1
for i in range(n, -1, -1):
x = xs[i]
lateral_features = self.convs[n - i](x)
if i in self.fpn_top_down_levels and prev_features is not None:
top_down_features = F.interpolate(
prev_features.to(dtype=x.dtype),
scale_factor=2.0,
mode=self.fpn_interp_model,
align_corners=(None if self.fpn_interp_model == "nearest" else False),
antialias=False,
)
prev_features = lateral_features + top_down_features
if self.fuse_type == "avg":
prev_features /= 2
else:
prev_features = lateral_features
x_out = prev_features
out[i] = x_out
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
return out, pos
class Hiera(nn.Module):
"""
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
with optional pooling and global attention mechanisms.
Attributes:
window_spec (tuple[int, ...]): Window sizes for each stage.
q_stride (tuple[int, int]): Downsampling stride between stages.
stage_ends (list[int]): Indices of the last block in each stage.
q_pool_blocks (list[int]): Indices of blocks where pooling is applied.
return_interm_layers (bool): Whether to return intermediate layer outputs.
patch_embed (PatchEmbed): Module for patch embedding.
global_att_blocks (tuple[int, ...]): Indices of blocks with global attention.
window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
pos_embed (nn.Parameter): Positional embedding for the background.
pos_embed_window (nn.Parameter): Positional embedding for the window.
blocks (nn.ModuleList): List of MultiScaleBlock modules.
channel_list (list[int]): List of output channel dimensions for each stage.
Methods:
_get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.
forward: Perform the forward pass through the Hiera model.
Examples:
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output_features = model(input_tensor)
>>> for feat in output_features:
... print(feat.shape)
"""
def __init__(
self,
embed_dim: int = 96, # initial embed dim
num_heads: int = 1, # initial number of heads
drop_path_rate: float = 0.0, # stochastic depth
q_pool: int = 3, # number of q_pool stages
q_stride: tuple[int, int] = (2, 2), # downsample stride bet. stages
stages: tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
dim_mul: float = 2.0, # dim_mul factor at stage shift
head_mul: float = 2.0, # head_mul factor at stage shift
window_pos_embed_bkg_spatial_size: tuple[int, int] = (14, 14),
# window size per stage, when not using global att.
window_spec: tuple[int, ...] = (
8,
4,
14,
7,
),
# global attn in these blocks
global_att_blocks: tuple[int, ...] = (
12,
16,
20,
),
return_interm_layers=True, # return feats from every stage
):
"""
Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction
in image processing tasks. It uses a series of transformer blocks organized into stages, with optional
pooling and global attention mechanisms.
Args:
embed_dim (int): Initial embedding dimension for the model.
num_heads (int): Initial number of attention heads.
drop_path_rate (float): Stochastic depth rate.
q_pool (int): Number of query pooling stages.
q_stride (tuple[int, int]): Downsampling stride between stages.
stages (tuple[int, ...]): Number of blocks per stage.
dim_mul (float): Dimension multiplier factor at stage transitions.
head_mul (float): Head multiplier factor at stage transitions.
window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
window_spec (tuple[int, ...]): Window sizes for each stage when not using global attention.
global_att_blocks (tuple[int, ...]): Indices of blocks that use global attention.
return_interm_layers (bool): Whether to return intermediate layer outputs.
Examples:
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output_features = model(input_tensor)
>>> for feat in output_features:
... print(feat.shape)
"""
super().__init__()
assert len(stages) == len(window_spec)
self.window_spec = window_spec
depth = sum(stages)
self.q_stride = q_stride
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
assert 0 <= q_pool <= len(self.stage_ends[:-1])
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
self.return_interm_layers = return_interm_layers
self.patch_embed = PatchEmbed(
embed_dim=embed_dim,
kernel_size=(7, 7),
stride=(4, 4),
padding=(3, 3),
)
# Which blocks have global attention?
self.global_att_blocks = global_att_blocks
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
cur_stage = 1
self.blocks = nn.ModuleList()
for i in range(depth):
dim_out = embed_dim
# Lags by a block, so first block of next stage uses an initial window size
# of previous stage and final window size of current stage
window_size = self.window_spec[cur_stage - 1]
if self.global_att_blocks is not None:
window_size = 0 if i in self.global_att_blocks else window_size
if i - 1 in self.stage_ends:
dim_out = int(embed_dim * dim_mul)
num_heads = int(num_heads * head_mul)
cur_stage += 1
block = MultiScaleBlock(
dim=embed_dim,
dim_out=dim_out,
num_heads=num_heads,
drop_path=dpr[i],
q_stride=self.q_stride if i in self.q_pool_blocks else None,
window_size=window_size,
)
embed_dim = dim_out
self.blocks.append(block)
self.channel_list = (
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
if return_interm_layers
else [self.blocks[-1].dim_out]
)
def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
h, w = hw
window_embed = self.pos_embed_window
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
pos_embed = pos_embed.permute(0, 2, 3, 1)
return pos_embed
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
"""
Perform forward pass through Hiera model, extracting multiscale features from input images.
Args:
x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
Returns:
(list[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where
C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered
from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers
is True, otherwise contains only the final output.
Examples:
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output_features = model(input_tensor)
>>> for feat in output_features:
... print(feat.shape)
"""
x = self.patch_embed(x)
# x: (B, H, W, C)
# Add positional embedding
x = x + self._get_pos_embed(x.shape[1:3])
outputs = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
feats = x.permute(0, 3, 1, 2)
outputs.append(feats)
return outputs

View File

@@ -0,0 +1,312 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import copy
import torch
from torch import nn
from .blocks import RoPEAttention
class MemoryAttentionLayer(nn.Module):
"""
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
generate memory-based attention outputs.
Attributes:
d_model (int): Dimensionality of the model.
dim_feedforward (int): Dimensionality of the feedforward network.
dropout_value (float): Dropout rate for regularization.
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
linear1 (nn.Linear): First linear layer of the feedforward network.
linear2 (nn.Linear): Second linear layer of the feedforward network.
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
dropout1 (nn.Dropout): Dropout layer after self-attention.
dropout2 (nn.Dropout): Dropout layer after cross-attention.
dropout3 (nn.Dropout): Dropout layer after feedforward network.
activation (nn.ReLU): Activation function for the feedforward network.
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
Methods:
forward: Performs the full memory attention operation on input tensors.
_forward_sa: Performs self-attention on input tensor.
_forward_ca: Performs cross-attention between target and memory tensors.
Examples:
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
>>> tgt = torch.randn(1, 100, 256)
>>> memory = torch.randn(1, 100, 64)
>>> pos = torch.randn(1, 100, 256)
>>> query_pos = torch.randn(1, 100, 256)
>>> output = layer(tgt, memory, pos, query_pos)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
self,
d_model: int = 256,
dim_feedforward: int = 2048,
dropout: float = 0.1,
pos_enc_at_attn: bool = False,
pos_enc_at_cross_attn_keys: bool = True,
pos_enc_at_cross_attn_queries: bool = False,
):
"""
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
Args:
d_model (int): Dimensionality of the model.
dim_feedforward (int): Dimensionality of the feedforward network.
dropout (float): Dropout rate for regularization.
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
"""
super().__init__()
self.d_model = d_model
self.dim_feedforward = dim_feedforward
self.dropout_value = dropout
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
self.cross_attn_image = RoPEAttention(
rope_k_repeat=True,
embedding_dim=256,
num_heads=1,
downsample_rate=1,
kv_in_dim=64,
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = nn.ReLU()
# Where to add pos enc
self.pos_enc_at_attn = pos_enc_at_attn
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
def _forward_sa(self, tgt: torch.Tensor, query_pos: torch.Tensor | None) -> torch.Tensor:
"""Perform self-attention on input tensor using positional encoding and RoPE attention mechanism."""
tgt2 = self.norm1(tgt)
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
tgt2 = self.self_attn(q, k, v=tgt2)
tgt = tgt + self.dropout1(tgt2)
return tgt
def _forward_ca(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
query_pos: torch.Tensor | None,
pos: torch.Tensor | None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
"""Perform cross-attention between target and memory tensors using RoPEAttention mechanism."""
kwds = {}
if num_k_exclude_rope > 0:
assert isinstance(self.cross_attn_image, RoPEAttention)
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
# Cross-Attention
tgt2 = self.norm2(tgt)
tgt2 = self.cross_attn_image(
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
v=memory,
**kwds,
)
tgt = tgt + self.dropout2(tgt2)
return tgt
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
pos: torch.Tensor | None = None,
query_pos: torch.Tensor | None = None,
num_k_exclude_rope: int = 0,
) -> torch.Tensor:
"""
Process input tensors through self-attention, cross-attention, and feedforward network layers.
Args:
tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).
pos (Optional[torch.Tensor]): Positional encoding for memory tensor.
query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.
num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.
Returns:
(torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).
"""
tgt = self._forward_sa(tgt, query_pos)
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
# MLP
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
class MemoryAttention(nn.Module):
"""
Memory attention module for processing sequential data with self and cross-attention mechanisms.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
for processing sequential data, particularly useful in transformer-like architectures.
Attributes:
d_model (int): The dimension of the model's hidden state.
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
num_layers (int): The number of attention layers.
norm (nn.LayerNorm): Layer normalization applied to the output.
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
batch_first (bool): Whether the input tensors are in batch-first format.
Methods:
forward: Processes input tensors through the attention layers.
Examples:
>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
"""
def __init__(
self,
d_model: int,
pos_enc_at_input: bool,
layer: nn.Module,
num_layers: int,
batch_first: bool = True, # Do layers expect batch first input?
):
"""
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
for processing sequential data, particularly useful in transformer-like architectures.
Args:
d_model (int): The dimension of the model's hidden state.
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
layer (nn.Module): The attention layer to be used in the module.
num_layers (int): The number of attention layers.
batch_first (bool): Whether the input tensors are in batch-first format.
Examples:
>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
"""
super().__init__()
self.d_model = d_model
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
self.num_layers = num_layers
self.norm = nn.LayerNorm(d_model)
self.pos_enc_at_input = pos_enc_at_input
self.batch_first = batch_first
def forward(
self,
curr: torch.Tensor, # self-attention inputs
memory: torch.Tensor, # cross-attention inputs
curr_pos: torch.Tensor | None = None, # pos_enc for self-attention inputs
memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
) -> torch.Tensor:
"""
Process inputs through attention layers, applying self and cross-attention with positional encoding.
Args:
curr (torch.Tensor): Self-attention input tensor, representing the current state.
memory (torch.Tensor): Cross-attention input tensor, representing memory information.
curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.
memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.
num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
Returns:
(torch.Tensor): Processed output tensor after applying attention layers and normalization.
Examples:
>>> d_model = 256
>>> layer = MemoryAttentionLayer(d_model)
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
>>> curr_pos = torch.randn(10, 32, d_model)
>>> memory_pos = torch.randn(20, 32, d_model)
>>> output = attention(curr, memory, curr_pos, memory_pos)
>>> print(output.shape)
torch.Size([10, 32, 256])
"""
if isinstance(curr, list):
assert isinstance(curr_pos, list)
assert len(curr) == len(curr_pos) == 1
curr, curr_pos = curr[0], curr_pos[0]
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
output = curr
if self.pos_enc_at_input and curr_pos is not None:
output = output + 0.1 * curr_pos
if self.batch_first:
# Convert to batch first
output = output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
memory = memory.transpose(0, 1)
memory_pos = memory_pos.transpose(0, 1)
for layer in self.layers:
kwds = {}
if isinstance(layer.cross_attn_image, RoPEAttention):
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
output = layer(
tgt=output,
memory=memory,
pos=memory_pos,
query_pos=curr_pos,
**kwds,
)
normed_output = self.norm(output)
if self.batch_first:
# Convert back to seq first
normed_output = normed_output.transpose(0, 1)
curr_pos = curr_pos.transpose(0, 1)
return normed_output

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,998 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
# --------------------------------------------------------
# TinyViT Model Architecture
# Copyright (c) 2022 Microsoft
# Adapted from LeViT and Swin Transformer
# LeViT: (https://github.com/facebookresearch/levit)
# Swin: (https://github.com/microsoft/swin-transformer)
# Build the TinyViT Model
# --------------------------------------------------------
from __future__ import annotations
import itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.nn.modules import LayerNorm2d
from ultralytics.utils.instance import to_2tuple
class Conv2d_BN(torch.nn.Sequential):
"""
A sequential container that performs 2D convolution followed by batch normalization.
This module combines a 2D convolution layer with batch normalization, providing a common building block
for convolutional neural networks. The batch normalization weights and biases are initialized to specific
values for optimal training performance.
Attributes:
c (torch.nn.Conv2d): 2D convolution layer.
bn (torch.nn.BatchNorm2d): Batch normalization layer.
Examples:
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
>>> input_tensor = torch.randn(1, 3, 224, 224)
>>> output = conv_bn(input_tensor)
>>> print(output.shape)
torch.Size([1, 64, 224, 224])
"""
def __init__(
self,
a: int,
b: int,
ks: int = 1,
stride: int = 1,
pad: int = 0,
dilation: int = 1,
groups: int = 1,
bn_weight_init: float = 1,
):
"""
Initialize a sequential container with 2D convolution followed by batch normalization.
Args:
a (int): Number of input channels.
b (int): Number of output channels.
ks (int, optional): Kernel size for the convolution.
stride (int, optional): Stride for the convolution.
pad (int, optional): Padding for the convolution.
dilation (int, optional): Dilation factor for the convolution.
groups (int, optional): Number of groups for the convolution.
bn_weight_init (float, optional): Initial value for batch normalization weight.
"""
super().__init__()
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = torch.nn.BatchNorm2d(b)
torch.nn.init.constant_(bn.weight, bn_weight_init)
torch.nn.init.constant_(bn.bias, 0)
self.add_module("bn", bn)
class PatchEmbed(nn.Module):
"""
Embed images into patches and project them into a specified embedding dimension.
This module converts input images into patch embeddings using a sequence of convolutional layers,
effectively downsampling the spatial dimensions while increasing the channel dimension.
Attributes:
patches_resolution (tuple[int, int]): Resolution of the patches after embedding.
num_patches (int): Total number of patches.
in_chans (int): Number of input channels.
embed_dim (int): Dimension of the embedding.
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
Examples:
>>> import torch
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
>>> x = torch.randn(1, 3, 224, 224)
>>> output = patch_embed(x)
>>> print(output.shape)
torch.Size([1, 96, 56, 56])
"""
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
"""
Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
Args:
in_chans (int): Number of input channels.
embed_dim (int): Dimension of the embedding.
resolution (int): Input image resolution.
activation (nn.Module): Activation function to use between convolutions.
"""
super().__init__()
img_size: tuple[int, int] = to_2tuple(resolution)
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
n = embed_dim
self.seq = nn.Sequential(
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
activation(),
Conv2d_BN(n // 2, n, 3, 2, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor through patch embedding sequence, converting images to patch embeddings."""
return self.seq(x)
class MBConv(nn.Module):
"""
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,
and projection phases, along with residual connections for improved gradient flow.
Attributes:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels after expansion.
out_chans (int): Number of output channels.
conv1 (Conv2d_BN): First convolutional layer for channel expansion.
act1 (nn.Module): First activation function.
conv2 (Conv2d_BN): Depthwise convolutional layer.
act2 (nn.Module): Second activation function.
conv3 (Conv2d_BN): Final convolutional layer for projection.
act3 (nn.Module): Third activation function.
drop_path (nn.Module): Drop path layer (Identity for inference).
Examples:
>>> in_chans, out_chans = 32, 64
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
>>> x = torch.randn(1, in_chans, 56, 56)
>>> output = mbconv(x)
>>> print(output.shape)
torch.Size([1, 64, 56, 56])
"""
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
"""
Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
Args:
in_chans (int): Number of input channels.
out_chans (int): Number of output channels.
expand_ratio (float): Channel expansion ratio for the hidden layer.
activation (nn.Module): Activation function to use.
drop_path (float): Drop path rate for stochastic depth.
"""
super().__init__()
self.in_chans = in_chans
self.hidden_chans = int(in_chans * expand_ratio)
self.out_chans = out_chans
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
self.act1 = activation()
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
self.act2 = activation()
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
self.act3 = activation()
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Implement the forward pass of MBConv, applying convolutions and skip connection."""
shortcut = x
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop_path(x)
x += shortcut
return self.act3(x)
class PatchMerging(nn.Module):
"""
Merge neighboring patches in the feature map and project to a new dimension.
This class implements a patch merging operation that combines spatial information and adjusts the feature
dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial
resolution while potentially increasing channel dimensions.
Attributes:
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
dim (int): The input dimension of the feature map.
out_dim (int): The output dimension after merging and projection.
act (nn.Module): The activation function used between convolutions.
conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
conv3 (Conv2d_BN): The third convolutional layer for final projection.
Examples:
>>> input_resolution = (56, 56)
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
>>> x = torch.randn(4, 64, 56, 56)
>>> output = patch_merging(x)
>>> print(output.shape)
torch.Size([4, 3136, 128])
"""
def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
"""
Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
Args:
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
dim (int): The input dimension of the feature map.
out_dim (int): The output dimension after merging and projection.
activation (nn.Module): The activation function used between convolutions.
"""
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.out_dim = out_dim
self.act = activation()
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
stride_c = 1 if out_dim in {320, 448, 576} else 2
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply patch merging and dimension projection to the input feature map."""
if x.ndim == 3:
H, W = self.input_resolution
B = len(x)
# (B, C, H, W)
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
x = self.conv1(x)
x = self.act(x)
x = self.conv2(x)
x = self.act(x)
x = self.conv3(x)
return x.flatten(2).transpose(1, 2)
class ConvLayer(nn.Module):
"""
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
This layer optionally applies downsample operations to the output and supports gradient checkpointing
for memory efficiency during training.
Attributes:
dim (int): Dimensionality of the input and output.
input_resolution (tuple[int, int]): Resolution of the input image.
depth (int): Number of MBConv layers in the block.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
blocks (nn.ModuleList): List of MBConv layers.
downsample (Optional[nn.Module]): Function for downsampling the output.
Examples:
>>> input_tensor = torch.randn(1, 64, 56, 56)
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
>>> output = conv_layer(input_tensor)
>>> print(output.shape)
torch.Size([1, 3136, 128])
"""
def __init__(
self,
dim: int,
input_resolution: tuple[int, int],
depth: int,
activation,
drop_path: float | list[float] = 0.0,
downsample: nn.Module | None = None,
use_checkpoint: bool = False,
out_dim: int | None = None,
conv_expand_ratio: float = 4.0,
):
"""
Initialize the ConvLayer with the given dimensions and settings.
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
optionally applies downsampling to the output.
Args:
dim (int): The dimensionality of the input and output.
input_resolution (tuple[int, int]): The resolution of the input image.
depth (int): The number of MBConv layers in the block.
activation (nn.Module): Activation function applied after each convolution.
drop_path (float | list[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.
conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# Build blocks
self.blocks = nn.ModuleList(
[
MBConv(
dim,
dim,
conv_expand_ratio,
activation,
drop_path[i] if isinstance(drop_path, list) else drop_path,
)
for i in range(depth)
]
)
# Patch merging layer
self.downsample = (
None
if downsample is None
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through convolutional layers, applying MBConv blocks and optional downsampling."""
for blk in self.blocks:
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
return x if self.downsample is None else self.downsample(x)
class MLP(nn.Module):
"""
Multi-layer Perceptron (MLP) module for transformer architectures.
This module applies layer normalization, two fully-connected layers with an activation function in between,
and dropout. It is commonly used in transformer-based architectures for processing token embeddings.
Attributes:
norm (nn.LayerNorm): Layer normalization applied to the input.
fc1 (nn.Linear): First fully-connected layer.
fc2 (nn.Linear): Second fully-connected layer.
act (nn.Module): Activation function applied after the first fully-connected layer.
drop (nn.Dropout): Dropout layer applied after the activation function.
Examples:
>>> import torch
>>> from torch import nn
>>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)
>>> x = torch.randn(32, 100, 256)
>>> output = mlp(x)
>>> print(output.shape)
torch.Size([32, 100, 256])
"""
def __init__(
self,
in_features: int,
hidden_features: int | None = None,
out_features: int | None = None,
activation=nn.GELU,
drop: float = 0.0,
):
"""
Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
Args:
in_features (int): Number of input features.
hidden_features (Optional[int], optional): Number of hidden features.
out_features (Optional[int], optional): Number of output features.
activation (nn.Module): Activation function applied after the first fully-connected layer.
drop (float, optional): Dropout probability.
"""
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.norm = nn.LayerNorm(in_features)
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.act = activation()
self.drop = nn.Dropout(drop)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
x = self.norm(x)
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return self.drop(x)
class Attention(torch.nn.Module):
"""
Multi-head attention module with spatial awareness and trainable attention biases.
This module implements a multi-head attention mechanism with support for spatial awareness, applying
attention biases based on spatial resolution. It includes trainable attention biases for each unique
offset between spatial positions in the resolution grid.
Attributes:
num_heads (int): Number of attention heads.
scale (float): Scaling factor for attention scores.
key_dim (int): Dimensionality of the keys and queries.
nh_kd (int): Product of num_heads and key_dim.
d (int): Dimensionality of the value vectors.
dh (int): Product of d and num_heads.
attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
norm (nn.LayerNorm): Layer normalization applied to input.
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
proj (nn.Linear): Linear layer for final projection.
attention_biases (nn.Parameter): Learnable attention biases.
attention_bias_idxs (torch.Tensor): Indices for attention biases.
ab (torch.Tensor): Cached attention biases for inference, deleted during training.
Examples:
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
>>> x = torch.randn(1, 196, 256)
>>> output = attn(x)
>>> print(output.shape)
torch.Size([1, 196, 256])
"""
def __init__(
self,
dim: int,
key_dim: int,
num_heads: int = 8,
attn_ratio: float = 4,
resolution: tuple[int, int] = (14, 14),
):
"""
Initialize the Attention module for multi-head attention with spatial awareness.
This module implements a multi-head attention mechanism with support for spatial awareness, applying
attention biases based on spatial resolution. It includes trainable attention biases for each unique
offset between spatial positions in the resolution grid.
Args:
dim (int): The dimensionality of the input and output.
key_dim (int): The dimensionality of the keys and queries.
num_heads (int, optional): Number of attention heads.
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.
resolution (tuple[int, int], optional): Spatial resolution of the input feature map.
"""
super().__init__()
assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
self.num_heads = num_heads
self.scale = key_dim**-0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
h = self.dh + nh_kd * 2
self.norm = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, h)
self.proj = nn.Linear(self.dh, dim)
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
@torch.no_grad()
def train(self, mode: bool = True):
"""Set the module in training mode and handle the 'ab' attribute for cached attention biases."""
super().train(mode)
if mode and hasattr(self, "ab"):
del self.ab
else:
self.ab = self.attention_biases[:, self.attention_bias_idxs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply multi-head attention with spatial awareness and trainable attention biases."""
B, N, _ = x.shape # B, N, C
# Normalization
x = self.norm(x)
qkv = self.qkv(x)
# (B, N, num_heads, d)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
# (B, num_heads, N, d)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
self.ab = self.ab.to(self.attention_biases.device)
attn = (q @ k.transpose(-2, -1)) * self.scale + (
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
return self.proj(x)
class TinyViTBlock(nn.Module):
"""
TinyViT Block that applies self-attention and a local convolution to the input.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
local convolutions to process input features efficiently. It supports windowed attention for
computational efficiency and includes residual connections.
Attributes:
dim (int): The dimensionality of the input and output.
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
num_heads (int): Number of attention heads.
window_size (int): Size of the attention window.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
attn (Attention): Self-attention module.
mlp (MLP): Multi-layer perceptron module.
local_conv (Conv2d_BN): Depth-wise local convolution layer.
Examples:
>>> input_tensor = torch.randn(1, 196, 192)
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
>>> output = block(input_tensor)
>>> print(output.shape)
torch.Size([1, 196, 192])
"""
def __init__(
self,
dim: int,
input_resolution: tuple[int, int],
num_heads: int,
window_size: int = 7,
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float = 0.0,
local_conv_size: int = 3,
activation=nn.GELU,
):
"""
Initialize a TinyViT block with self-attention and local convolution.
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
local convolutions to process input features efficiently.
Args:
dim (int): Dimensionality of the input and output features.
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
num_heads (int): Number of attention heads.
window_size (int, optional): Size of the attention window. Must be greater than 0.
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
drop (float, optional): Dropout rate.
drop_path (float, optional): Stochastic depth rate.
local_conv_size (int, optional): Kernel size of the local convolution.
activation (nn.Module): Activation function for MLP.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
assert window_size > 0, "window_size must be greater than 0"
self.window_size = window_size
self.mlp_ratio = mlp_ratio
# NOTE: `DropPath` is needed only for training.
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.drop_path = nn.Identity()
assert dim % num_heads == 0, "dim must be divisible by num_heads"
head_dim = dim // num_heads
window_resolution = (window_size, window_size)
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
mlp_hidden_dim = int(dim * mlp_ratio)
mlp_activation = activation
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)
pad = local_conv_size // 2
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply self-attention, local convolution, and MLP operations to the input tensor."""
h, w = self.input_resolution
b, hw, c = x.shape # batch, height*width, channels
assert hw == h * w, "input feature has wrong size"
res_x = x
if h == self.window_size and w == self.window_size:
x = self.attn(x)
else:
x = x.view(b, h, w, c)
pad_b = (self.window_size - h % self.window_size) % self.window_size
pad_r = (self.window_size - w % self.window_size) % self.window_size
padding = pad_b > 0 or pad_r > 0
if padding:
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
pH, pW = h + pad_b, w + pad_r
nH = pH // self.window_size
nW = pW // self.window_size
# Window partition
x = (
x.view(b, nH, self.window_size, nW, self.window_size, c)
.transpose(2, 3)
.reshape(b * nH * nW, self.window_size * self.window_size, c)
)
x = self.attn(x)
# Window reverse
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
if padding:
x = x[:, :h, :w].contiguous()
x = x.view(b, hw, c)
x = res_x + self.drop_path(x)
x = x.transpose(1, 2).reshape(b, c, h, w)
x = self.local_conv(x)
x = x.view(b, c, hw).transpose(1, 2)
return x + self.drop_path(self.mlp(x))
def extra_repr(self) -> str:
"""
Return a string representation of the TinyViTBlock's parameters.
This method provides a formatted string containing key information about the TinyViTBlock, including its
dimension, input resolution, number of attention heads, window size, and MLP ratio.
Returns:
(str): A formatted string containing the block's parameters.
Examples:
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
>>> print(block.extra_repr())
dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
"""
return (
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
)
class BasicLayer(nn.Module):
"""
A basic TinyViT layer for one stage in a TinyViT architecture.
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
and an optional downsampling operation. It processes features at a specific resolution and
dimensionality within the overall architecture.
Attributes:
dim (int): The dimensionality of the input and output features.
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
depth (int): Number of TinyViT blocks in this layer.
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
Examples:
>>> input_tensor = torch.randn(1, 3136, 192)
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
>>> output = layer(input_tensor)
>>> print(output.shape)
torch.Size([1, 784, 384])
"""
def __init__(
self,
dim: int,
input_resolution: tuple[int, int],
depth: int,
num_heads: int,
window_size: int,
mlp_ratio: float = 4.0,
drop: float = 0.0,
drop_path: float | list[float] = 0.0,
downsample: nn.Module | None = None,
use_checkpoint: bool = False,
local_conv_size: int = 3,
activation=nn.GELU,
out_dim: int | None = None,
):
"""
Initialize a BasicLayer in the TinyViT architecture.
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
process feature maps at a specific resolution and dimensionality within the TinyViT model.
Args:
dim (int): Dimensionality of the input and output features.
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
depth (int): Number of TinyViT blocks in this layer.
num_heads (int): Number of attention heads in each TinyViT block.
window_size (int): Size of the local window for attention computation.
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
drop (float, optional): Dropout rate.
drop_path (float | list[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.
downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
activation (nn.Module): Activation function used in the MLP.
out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.
"""
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# Build blocks
self.blocks = nn.ModuleList(
[
TinyViTBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
drop=drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
local_conv_size=local_conv_size,
activation=activation,
)
for i in range(depth)
]
)
# Patch merging layer
self.downsample = (
None
if downsample is None
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through TinyViT blocks and optional downsampling."""
for blk in self.blocks:
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
return x if self.downsample is None else self.downsample(x)
def extra_repr(self) -> str:
"""Return a string with the layer's parameters for printing."""
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
class TinyViT(nn.Module):
"""
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing
with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
Attributes:
img_size (int): Input image size.
num_classes (int): Number of classification classes.
depths (tuple[int, int, int, int]): Number of blocks in each stage.
num_layers (int): Total number of layers in the network.
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
patch_embed (PatchEmbed): Module for patch embedding.
patches_resolution (tuple[int, int]): Resolution of embedded patches.
layers (nn.ModuleList): List of network layers.
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
head (nn.Linear): Linear layer for final classification.
neck (nn.Sequential): Neck module for feature refinement.
Examples:
>>> model = TinyViT(img_size=224, num_classes=1000)
>>> x = torch.randn(1, 3, 224, 224)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 256, 56, 56])
"""
def __init__(
self,
img_size: int = 224,
in_chans: int = 3,
num_classes: int = 1000,
embed_dims: tuple[int, int, int, int] = (96, 192, 384, 768),
depths: tuple[int, int, int, int] = (2, 2, 6, 2),
num_heads: tuple[int, int, int, int] = (3, 6, 12, 24),
window_sizes: tuple[int, int, int, int] = (7, 7, 14, 7),
mlp_ratio: float = 4.0,
drop_rate: float = 0.0,
drop_path_rate: float = 0.1,
use_checkpoint: bool = False,
mbconv_expand_ratio: float = 4.0,
local_conv_size: int = 3,
layer_lr_decay: float = 1.0,
):
"""
Initialize the TinyViT model.
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
attention and convolution blocks, and a classification head.
Args:
img_size (int, optional): Size of the input image.
in_chans (int, optional): Number of input channels.
num_classes (int, optional): Number of classes for classification.
embed_dims (tuple[int, int, int, int], optional): Embedding dimensions for each stage.
depths (tuple[int, int, int, int], optional): Number of blocks in each stage.
num_heads (tuple[int, int, int, int], optional): Number of attention heads in each stage.
window_sizes (tuple[int, int, int, int], optional): Window sizes for each stage.
mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.
drop_rate (float, optional): Dropout rate.
drop_path_rate (float, optional): Stochastic depth rate.
use_checkpoint (bool, optional): Whether to use checkpointing to save memory.
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.
local_conv_size (int, optional): Kernel size for local convolutions.
layer_lr_decay (float, optional): Layer-wise learning rate decay factor.
"""
super().__init__()
self.img_size = img_size
self.num_classes = num_classes
self.depths = depths
self.num_layers = len(depths)
self.mlp_ratio = mlp_ratio
activation = nn.GELU
self.patch_embed = PatchEmbed(
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
)
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# Stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# Build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
kwargs = dict(
dim=embed_dims[i_layer],
input_resolution=(
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
),
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
# patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
activation=activation,
)
if i_layer == 0:
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
else:
layer = BasicLayer(
num_heads=num_heads[i_layer],
window_size=window_sizes[i_layer],
mlp_ratio=self.mlp_ratio,
drop=drop_rate,
local_conv_size=local_conv_size,
**kwargs,
)
self.layers.append(layer)
# Classifier head
self.norm_head = nn.LayerNorm(embed_dims[-1])
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
# Init weights
self.apply(self._init_weights)
self.set_layer_lr_decay(layer_lr_decay)
self.neck = nn.Sequential(
nn.Conv2d(
embed_dims[-1],
256,
kernel_size=1,
bias=False,
),
LayerNorm2d(256),
nn.Conv2d(
256,
256,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(256),
)
def set_layer_lr_decay(self, layer_lr_decay: float):
"""Set layer-wise learning rate decay for the TinyViT model based on depth."""
decay_rate = layer_lr_decay
# Layers -> blocks (depth)
depth = sum(self.depths)
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
def _set_lr_scale(m, scale):
"""Set the learning rate scale for each layer in the model based on the layer's depth."""
for p in m.parameters():
p.lr_scale = scale
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
i = 0
for layer in self.layers:
for block in layer.blocks:
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
i += 1
if layer.downsample is not None:
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
assert i == depth
for m in {self.norm_head, self.head}:
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
for k, p in self.named_parameters():
p.param_name = k
def _check_lr_scale(m):
"""Check if the learning rate scale attribute is present in module's parameters."""
for p in m.parameters():
assert hasattr(p, "lr_scale"), p.param_name
self.apply(_check_lr_scale)
@staticmethod
def _init_weights(m):
"""Initialize weights for linear and normalization layers in the TinyViT model."""
if isinstance(m, nn.Linear):
# NOTE: This initialization is needed only for training.
# trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
"""Return a set of keywords for parameters that should not use weight decay."""
return {"attention_biases"}
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
"""Process input through feature extraction layers, returning spatial features."""
x = self.patch_embed(x) # x input is (N, C, H, W)
x = self.layers[0](x)
start_i = 1
for i in range(start_i, len(self.layers)):
layer = self.layers[i]
x = layer(x)
batch, _, channel = x.shape
x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
x = x.permute(0, 3, 1, 2)
return self.neck(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass through the TinyViT model, extracting features from the input image."""
return self.forward_features(x)
def set_imgsz(self, imgsz: list[int] = [1024, 1024]):
"""Set image size to make model compatible with different image sizes."""
imgsz = [s // 4 for s in imgsz]
self.patches_resolution = imgsz
for i, layer in enumerate(self.layers):
input_resolution = (
imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
)
layer.input_resolution = input_resolution
if layer.downsample is not None:
layer.downsample.input_resolution = input_resolution
if isinstance(layer, BasicLayer):
for b in layer.blocks:
b.input_resolution = input_resolution

View File

@@ -0,0 +1,354 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import math
import torch
from torch import Tensor, nn
from ultralytics.nn.modules import MLPBlock
class TwoWayTransformer(nn.Module):
"""
A Two-Way Transformer module for simultaneous attention to image and query points.
This class implements a specialized transformer decoder that attends to an input image using queries with
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
cloud processing.
Attributes:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention.
mlp_dim (int): Internal channel dimension for the MLP block.
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
final_attn_token_to_image (Attention): Final attention layer from queries to image.
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
Methods:
forward: Process image and point embeddings through the transformer.
Examples:
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
>>> image_embedding = torch.randn(1, 256, 32, 32)
>>> image_pe = torch.randn(1, 256, 32, 32)
>>> point_embedding = torch.randn(1, 100, 256)
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
>>> print(output_queries.shape, output_image.shape)
"""
def __init__(
self,
depth: int,
embedding_dim: int,
num_heads: int,
mlp_dim: int,
activation: type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
) -> None:
"""
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
Args:
depth (int): Number of layers in the transformer.
embedding_dim (int): Channel dimension for input embeddings.
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
mlp_dim (int): Internal channel dimension for the MLP block.
activation (Type[nn.Module], optional): Activation function to use in the MLP block.
attention_downsample_rate (int, optional): Downsampling rate for attention mechanism.
"""
super().__init__()
self.depth = depth
self.embedding_dim = embedding_dim
self.num_heads = num_heads
self.mlp_dim = mlp_dim
self.layers = nn.ModuleList()
for i in range(depth):
self.layers.append(
TwoWayAttentionBlock(
embedding_dim=embedding_dim,
num_heads=num_heads,
mlp_dim=mlp_dim,
activation=activation,
attention_downsample_rate=attention_downsample_rate,
skip_first_layer_pe=(i == 0),
)
)
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm_final_attn = nn.LayerNorm(embedding_dim)
def forward(
self,
image_embedding: torch.Tensor,
image_pe: torch.Tensor,
point_embedding: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Process image and point embeddings through the Two-Way Transformer.
Args:
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
Returns:
queries (torch.Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
keys (torch.Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
"""
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
image_pe = image_pe.flatten(2).permute(0, 2, 1)
# Prepare queries
queries = point_embedding
keys = image_embedding
# Apply transformer blocks and final layernorm
for layer in self.layers:
queries, keys = layer(
queries=queries,
keys=keys,
query_pe=point_embedding,
key_pe=image_pe,
)
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
return queries, keys
class TwoWayAttentionBlock(nn.Module):
"""
A two-way attention block for simultaneous attention to image and query points.
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
inputs to sparse inputs.
Attributes:
self_attn (Attention): Self-attention layer for queries.
norm1 (nn.LayerNorm): Layer normalization after self-attention.
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
mlp (MLPBlock): MLP block for transforming query embeddings.
norm3 (nn.LayerNorm): Layer normalization after MLP block.
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
Methods:
forward: Apply self-attention and cross-attention to queries and keys.
Examples:
>>> embedding_dim, num_heads = 256, 8
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
>>> queries = torch.randn(1, 100, embedding_dim)
>>> keys = torch.randn(1, 1000, embedding_dim)
>>> query_pe = torch.randn(1, 100, embedding_dim)
>>> key_pe = torch.randn(1, 1000, embedding_dim)
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
mlp_dim: int = 2048,
activation: type[nn.Module] = nn.ReLU,
attention_downsample_rate: int = 2,
skip_first_layer_pe: bool = False,
) -> None:
"""
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
This block implements a specialized transformer layer with four main components: self-attention on sparse
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
of dense inputs to sparse inputs.
Args:
embedding_dim (int): Channel dimension of the embeddings.
num_heads (int): Number of attention heads in the attention layers.
mlp_dim (int, optional): Hidden dimension of the MLP block.
activation (Type[nn.Module], optional): Activation function for the MLP block.
attention_downsample_rate (int, optional): Downsampling rate for the attention mechanism.
skip_first_layer_pe (bool, optional): Whether to skip positional encoding in the first layer.
"""
super().__init__()
self.self_attn = Attention(embedding_dim, num_heads)
self.norm1 = nn.LayerNorm(embedding_dim)
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.norm2 = nn.LayerNorm(embedding_dim)
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
self.norm3 = nn.LayerNorm(embedding_dim)
self.norm4 = nn.LayerNorm(embedding_dim)
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
self.skip_first_layer_pe = skip_first_layer_pe
def forward(
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply two-way attention to process query and key embeddings in a transformer block.
Args:
queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
keys (torch.Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
query_pe (torch.Tensor): Positional encodings for queries with same shape as queries.
key_pe (torch.Tensor): Positional encodings for keys with same shape as keys.
Returns:
queries (torch.Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
keys (torch.Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
"""
# Self attention block
if self.skip_first_layer_pe:
queries = self.self_attn(q=queries, k=queries, v=queries)
else:
q = queries + query_pe
attn_out = self.self_attn(q=q, k=q, v=queries)
queries = queries + attn_out
queries = self.norm1(queries)
# Cross attention block, tokens attending to image embedding
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm2(queries)
# MLP block
mlp_out = self.mlp(queries)
queries = queries + mlp_out
queries = self.norm3(queries)
# Cross attention block, image embedding attending to tokens
q = queries + query_pe
k = keys + key_pe
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
keys = keys + attn_out
keys = self.norm4(keys)
return queries, keys
class Attention(nn.Module):
"""
An attention layer with downscaling capability for embedding size after projection.
This class implements a multi-head attention mechanism with the option to downsample the internal
dimension of queries, keys, and values.
Attributes:
embedding_dim (int): Dimensionality of input embeddings.
kv_in_dim (int): Dimensionality of key and value inputs.
internal_dim (int): Internal dimension after downsampling.
num_heads (int): Number of attention heads.
q_proj (nn.Linear): Linear projection for queries.
k_proj (nn.Linear): Linear projection for keys.
v_proj (nn.Linear): Linear projection for values.
out_proj (nn.Linear): Linear projection for output.
Methods:
_separate_heads: Separate input tensor into attention heads.
_recombine_heads: Recombine separated attention heads.
forward: Compute attention output for given query, key, and value tensors.
Examples:
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
>>> q = torch.randn(1, 100, 256)
>>> k = v = torch.randn(1, 50, 256)
>>> output = attn(q, k, v)
>>> print(output.shape)
torch.Size([1, 100, 256])
"""
def __init__(
self,
embedding_dim: int,
num_heads: int,
downsample_rate: int = 1,
kv_in_dim: int = None,
) -> None:
"""
Initialize the Attention module with specified dimensions and settings.
Args:
embedding_dim (int): Dimensionality of input embeddings.
num_heads (int): Number of attention heads.
downsample_rate (int, optional): Factor by which internal dimensions are downsampled.
kv_in_dim (int | None, optional): Dimensionality of key and value inputs. If None, uses embedding_dim.
Raises:
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
"""
super().__init__()
self.embedding_dim = embedding_dim
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
@staticmethod
def _separate_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
"""Separate the input tensor into the specified number of attention heads."""
b, n, c = x.shape
x = x.reshape(b, n, num_heads, c // num_heads)
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
@staticmethod
def _recombine_heads(x: Tensor) -> Tensor:
"""Recombine separated attention heads into a single tensor."""
b, n_heads, n_tokens, c_per_head = x.shape
x = x.transpose(1, 2)
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""
Apply multi-head attention to query, key, and value tensors with optional downsampling.
Args:
q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
k (torch.Tensor): Key tensor with shape (B, N_k, embedding_dim).
v (torch.Tensor): Value tensor with shape (B, N_k, embedding_dim).
Returns:
(torch.Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
"""
# Input projections
q = self.q_proj(q)
k = self.k_proj(k)
v = self.v_proj(v)
# Separate into heads
q = self._separate_heads(q, self.num_heads)
k = self._separate_heads(k, self.num_heads)
v = self._separate_heads(v, self.num_heads)
# Attention
_, _, _, c_per_head = q.shape
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
attn = attn / math.sqrt(c_per_head)
attn = torch.softmax(attn, dim=-1)
# Get output
out = attn @ v
out = self._recombine_heads(out)
return self.out_proj(out)

View File

@@ -0,0 +1,388 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from typing import Any
import torch
import torch.nn.functional as F
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
"""
Select the closest conditioning frames to a given frame index.
Args:
frame_idx (int): Current frame index.
cond_frame_outputs (dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
max_cond_frame_num (int): Maximum number of conditioning frames to select.
Returns:
selected_outputs (dict[int, Any]): Selected items from cond_frame_outputs.
unselected_outputs (dict[int, Any]): Items not selected from cond_frame_outputs.
Examples:
>>> frame_idx = 5
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
>>> max_cond_frame_num = 2
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
>>> print(selected)
{3: 'b', 7: 'c'}
>>> print(unselected)
{1: 'a', 9: 'd'}
"""
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
selected_outputs = cond_frame_outputs
unselected_outputs = {}
else:
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
selected_outputs = {}
# The closest conditioning frame before `frame_idx` (if any)
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
if idx_before is not None:
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
# The closest conditioning frame after `frame_idx` (if any)
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
if idx_after is not None:
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
# Add other temporally closest conditioning frames until reaching a total
# of `max_cond_frame_num` conditioning frames.
num_remain = max_cond_frame_num - len(selected_outputs)
inds_remain = sorted(
(t for t in cond_frame_outputs if t not in selected_outputs),
key=lambda x: abs(x - frame_idx),
)[:num_remain]
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
return selected_outputs, unselected_outputs
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
"""
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
Args:
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
dim (int): Dimension of the positional embeddings. Should be an even number.
temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.
Returns:
(torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
Examples:
>>> pos = torch.tensor([0, 1, 2, 3])
>>> embeddings = get_1d_sine_pe(pos, 128)
>>> embeddings.shape
torch.Size([4, 128])
"""
pe_dim = dim // 2
dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
pos_embed = pos_inds.unsqueeze(-1) / dim_t
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
return pos_embed
def init_t_xy(end_x: int, end_y: int):
"""
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
and corresponding x and y coordinate tensors.
Args:
end_x (int): Width of the grid (number of columns).
end_y (int): Height of the grid (number of rows).
Returns:
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
Examples:
>>> t_x, t_y = init_t_xy(3, 2)
>>> print(t_x)
tensor([0., 1., 2., 0., 1., 2.])
>>> print(t_y)
tensor([0., 0., 0., 1., 1., 1.])
"""
t = torch.arange(end_x * end_y, dtype=torch.float32)
t_x = (t % end_x).float()
t_y = torch.div(t, end_x, rounding_mode="floor").float()
return t_x, t_y
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
"""
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
This function generates complex exponential positional encodings for a 2D grid of spatial positions,
using separate frequency components for the x and y dimensions.
Args:
dim (int): Dimension of the positional encoding.
end_x (int): Width of the 2D grid.
end_y (int): Height of the 2D grid.
theta (float, optional): Scaling factor for frequency computation.
Returns:
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
Examples:
>>> dim, end_x, end_y = 128, 8, 8
>>> freqs_cis = compute_axial_cis(dim, end_x, end_y)
>>> freqs_cis.shape
torch.Size([64, 64])
"""
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
t_x, t_y = init_t_xy(end_x, end_y)
freqs_x = torch.outer(t_x, freqs_x)
freqs_y = torch.outer(t_y, freqs_y)
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
"""
Reshape frequency tensor for broadcasting with input tensor.
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
This function is typically used in positional encoding operations.
Args:
freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
x (torch.Tensor): Input tensor to broadcast with.
Returns:
(torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.
Raises:
AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_enc(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
repeat_freqs_k: bool = False,
):
"""
Apply rotary positional encoding to query and key tensors.
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
Args:
xq (torch.Tensor): Query tensor to encode with positional information.
xk (torch.Tensor): Key tensor to encode with positional information.
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
last two dimensions of xq.
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
to match key sequence length.
Returns:
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.
Examples:
>>> import torch
>>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]
>>> xk = torch.randn(2, 8, 16, 64)
>>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64
>>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
if xk_ is None:
# No keys to rotate, due to dropout
return xq_out.type_as(xq).to(xq.device), xk
# Repeat freqs along seq_len dim to match k seq_len
if repeat_freqs_k:
r = xk_.shape[-2] // xq_.shape[-2]
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
def window_partition(x: torch.Tensor, window_size: int):
"""
Partition input tensor into non-overlapping windows with padding if needed.
Args:
x (torch.Tensor): Input tensor with shape (B, H, W, C).
window_size (int): Size of each window.
Returns:
windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
padded_h_w (tuple[int, int]): Padded height and width before partition.
Examples:
>>> x = torch.randn(1, 16, 16, 3)
>>> windows, (Hp, Wp) = window_partition(x, window_size=4)
>>> print(windows.shape, Hp, Wp)
torch.Size([16, 4, 4, 3]) 16 16
"""
B, H, W, C = x.shape
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0:
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows, (Hp, Wp)
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
"""
Unpartition windowed sequences into original sequences and remove padding.
This function reverses the windowing process, reconstructing the original input from windowed segments
and removing any padding that was added during the windowing process.
Args:
windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
the size of each window, and C is the number of channels.
window_size (int): Size of each window.
pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
Returns:
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
are the original height and width, and C is the number of channels.
Examples:
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
>>> pad_hw = (16, 16) # Padded height and width
>>> hw = (15, 14) # Original height and width
>>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
>>> print(x.shape)
torch.Size([1, 15, 14, 64])
"""
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
x = x[:, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Extract relative positional embeddings based on query and key sizes.
Args:
q_size (int): Size of the query.
k_size (int): Size of the key.
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
distance and C is the embedding dimension.
Returns:
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
k_size, C).
Examples:
>>> q_size, k_size = 8, 16
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
>>> print(extracted_pos.shape)
torch.Size([8, 16, 64])
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: tuple[int, int],
k_size: tuple[int, int],
) -> torch.Tensor:
"""
Add decomposed Relative Positional Embeddings to the attention map.
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
positions.
Args:
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
q_size (tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
Returns:
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
(B, q_h * q_w, k_h * k_w).
Examples:
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
>>> q = torch.rand(B, q_h * q_w, C)
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
>>> print(updated_attn.shape)
torch.Size([1, 64, 64])
References:
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
"""
q_h, q_w = q_size
k_h, k_w = k_size
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
B, q_h * q_w, k_h * k_w
)
return attn

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

View File

@@ -0,0 +1,478 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
from ultralytics.utils.metrics import bbox_iou
from .ops import HungarianMatcher
class DETRLoss(nn.Module):
"""
DETR (DEtection TRansformer) Loss class for calculating various loss components.
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
DETR object detection model.
Attributes:
nc (int): Number of classes.
loss_gain (dict[str, float]): Coefficients for different loss components.
aux_loss (bool): Whether to compute auxiliary losses.
use_fl (bool): Whether to use FocalLoss.
use_vfl (bool): Whether to use VarifocalLoss.
use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
matcher (HungarianMatcher): Object to compute matching cost and indices.
fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
device (torch.device): Device on which tensors are stored.
"""
def __init__(
self,
nc: int = 80,
loss_gain: dict[str, float] | None = None,
aux_loss: bool = True,
use_fl: bool = True,
use_vfl: bool = False,
use_uni_match: bool = False,
uni_match_ind: int = 0,
gamma: float = 1.5,
alpha: float = 0.25,
):
"""
Initialize DETR loss function with customizable components and gains.
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
losses and various loss types.
Args:
nc (int): Number of classes.
loss_gain (dict[str, float], optional): Coefficients for different loss components.
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
use_fl (bool): Whether to use FocalLoss.
use_vfl (bool): Whether to use VarifocalLoss.
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
uni_match_ind (int): Index of fixed layer for uni_match.
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
alpha (float): The balancing factor used to address class imbalance.
"""
super().__init__()
if loss_gain is None:
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
self.nc = nc
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
self.loss_gain = loss_gain
self.aux_loss = aux_loss
self.fl = FocalLoss(gamma, alpha) if use_fl else None
self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None
self.use_uni_match = use_uni_match
self.uni_match_ind = uni_match_ind
self.device = None
def _get_loss_class(
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute classification loss based on predictions, target values, and ground truth scores.
Args:
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
targets (torch.Tensor): Target class indices with shape (B, N).
gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
num_gts (int): Number of ground truth objects.
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
Returns:
(dict[str, torch.Tensor]): Dictionary containing classification loss value.
Notes:
The function supports different classification loss types:
- Varifocal Loss (if self.vfl is True and num_gts > 0)
- Focal Loss (if self.fl is True)
- BCE Loss (default fallback)
"""
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
name_class = f"loss_class{postfix}"
bs, nq = pred_scores.shape[:2]
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
one_hot = one_hot[..., :-1]
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
if self.fl:
if num_gts and self.vfl:
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
else:
loss_cls = self.fl(pred_scores, one_hot.float())
loss_cls /= max(num_gts, 1) / nq
else:
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
def _get_loss_bbox(
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
) -> dict[str, torch.Tensor]:
"""
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
Returns:
(dict[str, torch.Tensor]): Dictionary containing:
- loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
- loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
Notes:
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
"""
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
name_bbox = f"loss_bbox{postfix}"
name_giou = f"loss_giou{postfix}"
loss = {}
if len(gt_bboxes) == 0:
loss[name_bbox] = torch.tensor(0.0, device=self.device)
loss[name_giou] = torch.tensor(0.0, device=self.device)
return loss
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
return {k: v.squeeze() for k, v in loss.items()}
# This function is for future RT-DETR Segment models
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
# name_mask = f'loss_mask{postfix}'
# name_dice = f'loss_dice{postfix}'
#
# loss = {}
# if sum(len(a) for a in gt_mask) == 0:
# loss[name_mask] = torch.tensor(0., device=self.device)
# loss[name_dice] = torch.tensor(0., device=self.device)
# return loss
#
# num_gts = len(gt_mask)
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
# torch.tensor([num_gts], dtype=torch.float32))
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
# return loss
# This function is for future RT-DETR Segment models
# @staticmethod
# def _dice_loss(inputs, targets, num_gts):
# inputs = F.sigmoid(inputs).flatten(1)
# targets = targets.flatten(1)
# numerator = 2 * (inputs * targets).sum(1)
# denominator = inputs.sum(-1) + targets.sum(-1)
# loss = 1 - (numerator + 1) / (denominator + 1)
# return loss.sum() / num_gts
def _get_loss_aux(
self,
pred_bboxes: torch.Tensor,
pred_scores: torch.Tensor,
gt_bboxes: torch.Tensor,
gt_cls: torch.Tensor,
gt_groups: list[int],
match_indices: list[tuple] | None = None,
postfix: str = "",
masks: torch.Tensor | None = None,
gt_mask: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""
Get auxiliary losses for intermediate decoder layers.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
gt_cls (torch.Tensor): Ground truth classes.
gt_groups (list[int]): Number of ground truths per image.
match_indices (list[tuple], optional): Pre-computed matching indices.
postfix (str, optional): String to append to loss names.
masks (torch.Tensor, optional): Predicted masks if using segmentation.
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
Returns:
(dict[str, torch.Tensor]): Dictionary of auxiliary losses.
"""
# NOTE: loss class, bbox, giou, mask, dice
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
if match_indices is None and self.use_uni_match:
match_indices = self.matcher(
pred_bboxes[self.uni_match_ind],
pred_scores[self.uni_match_ind],
gt_bboxes,
gt_cls,
gt_groups,
masks=masks[self.uni_match_ind] if masks is not None else None,
gt_mask=gt_mask,
)
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
aux_masks = masks[i] if masks is not None else None
loss_ = self._get_loss(
aux_bboxes,
aux_scores,
gt_bboxes,
gt_cls,
gt_groups,
masks=aux_masks,
gt_mask=gt_mask,
postfix=postfix,
match_indices=match_indices,
)
loss[0] += loss_[f"loss_class{postfix}"]
loss[1] += loss_[f"loss_bbox{postfix}"]
loss[2] += loss_[f"loss_giou{postfix}"]
# if masks is not None and gt_mask is not None:
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
# loss[3] += loss_[f'loss_mask{postfix}']
# loss[4] += loss_[f'loss_dice{postfix}']
loss = {
f"loss_class_aux{postfix}": loss[0],
f"loss_bbox_aux{postfix}": loss[1],
f"loss_giou_aux{postfix}": loss[2],
}
# if masks is not None and gt_mask is not None:
# loss[f'loss_mask_aux{postfix}'] = loss[3]
# loss[f'loss_dice_aux{postfix}'] = loss[4]
return loss
@staticmethod
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Extract batch indices, source indices, and destination indices from match indices.
Args:
match_indices (list[tuple]): List of tuples containing matched indices.
Returns:
batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
dst_idx (torch.Tensor): Destination indices.
"""
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
src_idx = torch.cat([src for (src, _) in match_indices])
dst_idx = torch.cat([dst for (_, dst) in match_indices])
return (batch_idx, src_idx), dst_idx
def _get_assigned_bboxes(
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes.
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
match_indices (list[tuple]): List of tuples containing matched indices.
Returns:
pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
"""
pred_assigned = torch.cat(
[
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (i, _) in zip(pred_bboxes, match_indices)
]
)
gt_assigned = torch.cat(
[
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
for t, (_, j) in zip(gt_bboxes, match_indices)
]
)
return pred_assigned, gt_assigned
def _get_loss(
self,
pred_bboxes: torch.Tensor,
pred_scores: torch.Tensor,
gt_bboxes: torch.Tensor,
gt_cls: torch.Tensor,
gt_groups: list[int],
masks: torch.Tensor | None = None,
gt_mask: torch.Tensor | None = None,
postfix: str = "",
match_indices: list[tuple] | None = None,
) -> dict[str, torch.Tensor]:
"""
Calculate losses for a single prediction layer.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes.
pred_scores (torch.Tensor): Predicted class scores.
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
gt_cls (torch.Tensor): Ground truth classes.
gt_groups (list[int]): Number of ground truths per image.
masks (torch.Tensor, optional): Predicted masks if using segmentation.
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
postfix (str, optional): String to append to loss names.
match_indices (list[tuple], optional): Pre-computed matching indices.
Returns:
(dict[str, torch.Tensor]): Dictionary of losses.
"""
if match_indices is None:
match_indices = self.matcher(
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
)
idx, gt_idx = self._get_index(match_indices)
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
bs, nq = pred_scores.shape[:2]
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
targets[idx] = gt_cls[gt_idx]
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
if len(gt_bboxes):
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
return {
**self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
**self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
}
def forward(
self,
pred_bboxes: torch.Tensor,
pred_scores: torch.Tensor,
batch: dict[str, Any],
postfix: str = "",
**kwargs: Any,
) -> dict[str, torch.Tensor]:
"""
Calculate loss for predicted bounding boxes and scores.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
postfix (str, optional): Postfix for loss names.
**kwargs (Any): Additional arguments, may include 'match_indices'.
Returns:
(dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
Notes:
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
self.aux_loss is True.
"""
self.device = pred_bboxes.device
match_indices = kwargs.get("match_indices", None)
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
total_loss = self._get_loss(
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
)
if self.aux_loss:
total_loss.update(
self._get_loss_aux(
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
)
)
return total_loss
class RTDETRDetectionLoss(DETRLoss):
"""
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
an additional denoising training loss when provided with denoising metadata.
"""
def forward(
self,
preds: tuple[torch.Tensor, torch.Tensor],
batch: dict[str, Any],
dn_bboxes: torch.Tensor | None = None,
dn_scores: torch.Tensor | None = None,
dn_meta: dict[str, Any] | None = None,
) -> dict[str, torch.Tensor]:
"""
Forward pass to compute detection loss with optional denoising loss.
Args:
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
batch (dict[str, Any]): Batch data containing ground truth information.
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
dn_scores (torch.Tensor, optional): Denoising scores.
dn_meta (dict[str, Any], optional): Metadata for denoising.
Returns:
(dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
"""
pred_bboxes, pred_scores = preds
total_loss = super().forward(pred_bboxes, pred_scores, batch)
# Check for denoising metadata to compute denoising training loss
if dn_meta is not None:
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
assert len(batch["gt_groups"]) == len(dn_pos_idx)
# Get the match indices for denoising
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
# Compute the denoising training loss
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
total_loss.update(dn_loss)
else:
# If no denoising metadata is provided, set denoising loss to zero
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
return total_loss
@staticmethod
def get_dn_match_indices(
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""
Get match indices for denoising.
Args:
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
dn_num_group (int): Number of denoising groups.
gt_groups (list[int]): List of integers representing number of ground truths per image.
Returns:
(list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
"""
dn_match_indices = []
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
for i, num_gt in enumerate(gt_groups):
if num_gt > 0:
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
gt_idx = gt_idx.repeat(dn_num_group)
assert len(dn_pos_idx[i]) == len(gt_idx), (
f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
)
dn_match_indices.append((dn_pos_idx[i], gt_idx))
else:
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
return dn_match_indices

View File

@@ -0,0 +1,319 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.optimize import linear_sum_assignment
from ultralytics.utils.metrics import bbox_iou
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
class HungarianMatcher(nn.Module):
"""
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
used in end-to-end object detection models like DETR.
Attributes:
cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
components.
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
with_mask (bool): Whether the model makes mask predictions.
num_sample_points (int): Number of sample points used in mask cost calculation.
alpha (float): Alpha factor in Focal Loss calculation.
gamma (float): Gamma factor in Focal Loss calculation.
Methods:
forward: Compute optimal assignment between predictions and ground truths for a batch.
_cost_mask: Compute mask cost and dice cost if masks are predicted.
Examples:
Initialize a HungarianMatcher with custom cost gains
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
Perform matching between predictions and ground truth
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
>>> gt_classes = torch.randint(0, 80, (10,))
>>> gt_groups = [5, 5] # 5 GT boxes per image
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
"""
def __init__(
self,
cost_gain: dict[str, float] | None = None,
use_fl: bool = True,
with_mask: bool = False,
num_sample_points: int = 12544,
alpha: float = 0.25,
gamma: float = 2.0,
):
"""
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
Args:
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
with_mask (bool): Whether the model makes mask predictions.
num_sample_points (int): Number of sample points used in mask cost calculation.
alpha (float): Alpha factor in Focal Loss calculation.
gamma (float): Gamma factor in Focal Loss calculation.
"""
super().__init__()
if cost_gain is None:
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
self.cost_gain = cost_gain
self.use_fl = use_fl
self.with_mask = with_mask
self.num_sample_points = num_sample_points
self.alpha = alpha
self.gamma = gamma
def forward(
self,
pred_bboxes: torch.Tensor,
pred_scores: torch.Tensor,
gt_bboxes: torch.Tensor,
gt_cls: torch.Tensor,
gt_groups: list[int],
masks: torch.Tensor | None = None,
gt_mask: list[torch.Tensor] | None = None,
) -> list[tuple[torch.Tensor, torch.Tensor]]:
"""
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
Args:
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
num_classes).
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
Returns:
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
"""
bs, nq, nc = pred_scores.shape
if sum(gt_groups) == 0:
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
# Flatten to compute cost matrices in batch format
pred_scores = pred_scores.detach().view(-1, nc)
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
pred_bboxes = pred_bboxes.detach().view(-1, 4)
# Compute classification cost
pred_scores = pred_scores[:, gt_cls]
if self.use_fl:
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
cost_class = pos_cost_class - neg_cost_class
else:
cost_class = -pred_scores
# Compute L1 cost between boxes
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
# Combine costs into final cost matrix
C = (
self.cost_gain["class"] * cost_class
+ self.cost_gain["bbox"] * cost_bbox
+ self.cost_gain["giou"] * cost_giou
)
# Add mask costs if available
if self.with_mask:
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
# Set invalid values (NaNs and infinities) to 0
C[C.isnan() | C.isinf()] = 0.0
C = C.view(bs, nq, -1).cpu()
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
return [
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
for k, (i, j) in enumerate(indices)
]
# This function is for future RT-DETR Segment models
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
# # all masks share the same set of points for efficient matching
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
# sample_points = 2.0 * sample_points - 1.0
#
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
# out_mask = out_mask.flatten(0, 1)
#
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
#
# with torch.amp.autocast("cuda", enabled=False):
# # binary cross entropy cost
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
# cost_mask /= self.num_sample_points
#
# # dice cost
# out_mask = F.sigmoid(out_mask)
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
#
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
# return C
def get_cdn_group(
batch: dict[str, Any],
num_classes: int,
num_queries: int,
class_embed: torch.Tensor,
num_dn: int = 100,
cls_noise_ratio: float = 0.5,
box_noise_scale: float = 1.0,
training: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
"""
Generate contrastive denoising training group with positive and negative samples from ground truths.
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
Args:
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of
ground truths per image.
num_classes (int): Total number of object classes.
num_queries (int): Number of object queries.
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
num_dn (int): Number of denoising queries to generate.
cls_noise_ratio (float): Noise ratio for class labels.
box_noise_scale (float): Noise scale for bounding box coordinates.
training (bool): Whether model is in training mode.
Returns:
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
Examples:
Generate denoising group for training
>>> batch = {
... "cls": torch.tensor([0, 1, 2]),
... "bboxes": torch.rand(3, 4),
... "batch_idx": torch.tensor([0, 0, 1]),
... "gt_groups": [2, 1],
... }
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
"""
if (not training) or num_dn <= 0 or batch is None:
return None, None, None, None
gt_groups = batch["gt_groups"]
total_num = sum(gt_groups)
max_nums = max(gt_groups)
if max_nums == 0:
return None, None, None, None
num_group = num_dn // max_nums
num_group = 1 if num_group == 0 else num_group
# Pad gt to max_num of a batch
bs = len(gt_groups)
gt_cls = batch["cls"] # (bs*num, )
gt_bbox = batch["bboxes"] # bs*num, 4
b_idx = batch["batch_idx"]
# Each group has positive and negative queries
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
# Positive and negative mask
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
if cls_noise_ratio > 0:
# Apply class label noise to half of the samples
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
idx = torch.nonzero(mask).squeeze(-1)
# Randomly assign new class labels
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
dn_cls[idx] = new_label
if box_noise_scale > 0:
known_bbox = xywh2xyxy(dn_bbox)
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
rand_part = torch.rand_like(dn_bbox)
rand_part[neg_idx] += 1.0
rand_part *= rand_sign
known_bbox += rand_part * diff
known_bbox.clip_(min=0.0, max=1.0)
dn_bbox = xyxy2xywh(known_bbox)
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
num_dn = int(max_nums * 2 * num_group) # total denoising queries
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
tgt_size = num_dn + num_queries
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
# Match query cannot see the reconstruct
attn_mask[num_dn:, :num_dn] = True
# Reconstruct cannot see each other
for i in range(num_group):
if i == 0:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
if i == num_group - 1:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
else:
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
dn_meta = {
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
"dn_num_group": num_group,
"dn_num_split": [num_dn, num_queries],
}
return (
padding_cls.to(class_embed.device),
padding_bbox.to(class_embed.device),
attn_mask.to(class_embed.device),
dn_meta,
)

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
from .model import YOLO, YOLOE, YOLOWorld
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
from ultralytics.models.yolo.classify.train import ClassificationTrainer
from ultralytics.models.yolo.classify.val import ClassificationValidator
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"

View File

@@ -0,0 +1,93 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import cv2
import torch
from PIL import Image
from ultralytics.data.augment import classify_transforms
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import DEFAULT_CFG, ops
class ClassificationPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a classification model.
This predictor handles the specific requirements of classification models, including preprocessing images
and postprocessing predictions to generate classification results.
Attributes:
args (dict): Configuration arguments for the predictor.
Methods:
preprocess: Convert input images to model-compatible format.
postprocess: Process model predictions into Results objects.
Notes:
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
>>> predictor = ClassificationPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
tasks. It ensures the task is set to 'classify' regardless of input configuration.
Args:
cfg (dict): Default configuration dictionary containing prediction settings.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be executed during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "classify"
def setup_source(self, source):
"""Set up source and inference mode and classify transforms."""
super().setup_source(source)
updated = (
self.model.model.transforms.transforms[0].size != max(self.imgsz)
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
else False
)
self.transforms = (
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
)
def preprocess(self, img):
"""Convert input images to model-compatible tensor format with appropriate normalization."""
if not isinstance(img, torch.Tensor):
img = torch.stack(
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
)
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
def postprocess(self, preds, img, orig_imgs):
"""
Process predictions to return Results objects with classification probabilities.
Args:
preds (torch.Tensor): Raw predictions from the model.
img (torch.Tensor): Input images after preprocessing.
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
Returns:
(list[Results]): List of Results objects containing classification results for each image.
"""
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
return [
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]

View File

@@ -0,0 +1,223 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import ClassificationModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.plotting import plot_images
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
class ClassificationTrainer(BaseTrainer):
"""
A trainer class extending BaseTrainer for training image classification models.
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
and torchvision models with comprehensive dataset handling and validation.
Attributes:
model (ClassificationModel): The classification model to be trained.
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
loss_names (list[str]): Names of the loss functions used during training.
validator (ClassificationValidator): Validator instance for model evaluation.
Methods:
set_model_attributes: Set the model's class names from the loaded dataset.
get_model: Return a modified PyTorch model configured for training.
setup_model: Load, create or download model for classification.
build_dataset: Create a ClassificationDataset instance.
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
preprocess_batch: Preprocess a batch of images and classes.
progress_string: Return a formatted string showing training progress.
get_validator: Return an instance of ClassificationValidator.
label_loss_items: Return a loss dict with labelled training loss items.
final_eval: Evaluate trained model and save validation results.
plot_training_samples: Plot training samples with their annotations.
Examples:
Initialize and train a classification model
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
>>> trainer = ClassificationTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a ClassificationTrainer object.
Args:
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list[Any], optional): List of callback functions to be executed during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "classify"
if overrides.get("imgsz") is None:
overrides["imgsz"] = 224
super().__init__(cfg, overrides, _callbacks)
def set_model_attributes(self):
"""Set the YOLO model's class names from the loaded dataset."""
self.model.names = self.data["names"]
def get_model(self, cfg=None, weights=None, verbose: bool = True):
"""
Return a modified PyTorch model configured for training YOLO classification.
Args:
cfg (Any, optional): Model configuration.
weights (Any, optional): Pre-trained model weights.
verbose (bool, optional): Whether to display model information.
Returns:
(ClassificationModel): Configured PyTorch model for classification.
"""
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
for m in model.modules():
if not self.args.pretrained and hasattr(m, "reset_parameters"):
m.reset_parameters()
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
m.p = self.args.dropout # set dropout
for p in model.parameters():
p.requires_grad = True # for training
return model
def setup_model(self):
"""
Load, create or download model for classification tasks.
Returns:
(Any): Model checkpoint if applicable, otherwise None.
"""
import torchvision # scope for faster 'import ultralytics'
if str(self.model) in torchvision.models.__dict__:
self.model = torchvision.models.__dict__[self.model](
weights="IMAGENET1K_V1" if self.args.pretrained else None
)
ckpt = None
else:
ckpt = super().setup_model()
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
return ckpt
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
"""
Create a ClassificationDataset instance given an image path and mode.
Args:
img_path (str): Path to the dataset images.
mode (str, optional): Dataset mode ('train', 'val', or 'test').
batch (Any, optional): Batch information (unused in this implementation).
Returns:
(ClassificationDataset): Dataset for the specified mode.
"""
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Return PyTorch DataLoader with transforms to preprocess images.
Args:
dataset_path (str): Path to the dataset.
batch_size (int, optional): Number of images per batch.
rank (int, optional): Process rank for distributed training.
mode (str, optional): 'train', 'val', or 'test' mode.
Returns:
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
"""
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode)
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
# Attach inference transforms
if mode != "train":
if is_parallel(self.model):
self.model.module.transforms = loader.dataset.torch_transforms
else:
self.model.transforms = loader.dataset.torch_transforms
return loader
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Preprocess a batch of images and classes."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def progress_string(self) -> str:
"""Return a formatted string showing training progress."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def get_validator(self):
"""Return an instance of ClassificationValidator for validation."""
self.loss_names = ["loss"]
return yolo.classify.ClassificationValidator(
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
"""
Return a loss dict with labelled training loss items tensor.
Args:
loss_items (torch.Tensor, optional): Loss tensor items.
prefix (str, optional): Prefix to prepend to loss names.
Returns:
keys (list[str]): List of loss keys if loss_items is None.
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is None:
return keys
loss_items = [round(float(loss_items), 5)]
return dict(zip(keys, loss_items))
def final_eval(self):
"""Evaluate trained model and save validation results."""
for f in self.last, self.best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is self.best:
LOGGER.info(f"\nValidating {f}...")
self.validator.args.data = self.args.data
self.validator.args.plots = self.args.plots
self.metrics = self.validator(model=f)
self.metrics.pop("fitness", None)
self.run_callbacks("on_fit_epoch_end")
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
"""
Plot training samples with their annotations.
Args:
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
ni (int): Number of iterations.
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)

View File

@@ -0,0 +1,214 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import ClassificationDataset, build_dataloader
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
from ultralytics.utils.plotting import plot_images
class ClassificationValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a classification model.
This validator handles the validation process for classification models, including metrics calculation,
confusion matrix generation, and visualization of results.
Attributes:
targets (list[torch.Tensor]): Ground truth class labels.
pred (list[torch.Tensor]): Model predictions.
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
names (dict): Mapping of class indices to class names.
nc (int): Number of classes.
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
Methods:
get_desc: Return a formatted string summarizing classification metrics.
init_metrics: Initialize confusion matrix, class names, and tracking containers.
preprocess: Preprocess input batch by moving data to device.
update_metrics: Update running metrics with model predictions and batch targets.
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
postprocess: Extract the primary prediction from model output.
get_stats: Calculate and return a dictionary of metrics.
build_dataset: Create a ClassificationDataset instance for validation.
get_dataloader: Build and return a data loader for classification validation.
print_results: Print evaluation metrics for the classification model.
plot_val_samples: Plot validation image samples with their ground truth labels.
plot_predictions: Plot images with their predicted class labels.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
Notes:
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (str | Path, optional): Directory to save results.
args (dict, optional): Arguments containing model and validation configuration.
_callbacks (list, optional): List of callback functions to be called during validation.
Examples:
>>> from ultralytics.models.yolo.classify import ClassificationValidator
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
>>> validator = ClassificationValidator(args=args)
>>> validator()
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.targets = None
self.pred = None
self.args.task = "classify"
self.metrics = ClassifyMetrics()
def get_desc(self) -> str:
"""Return a formatted string summarizing classification metrics."""
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
def init_metrics(self, model: torch.nn.Module) -> None:
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
self.names = model.names
self.nc = len(model.names)
self.pred = []
self.targets = []
self.confusion_matrix = ConfusionMatrix(names=model.names)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
return batch
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
"""
Update running metrics with model predictions and batch targets.
Args:
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
batch (dict): Batch data containing images and class labels.
Notes:
This method appends the top-N predictions (sorted by confidence in descending order) to the
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
"""
n5 = min(len(self.names), 5)
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
self.targets.append(batch["cls"].type(torch.int32).cpu())
def finalize_metrics(self) -> None:
"""
Finalize metrics including confusion matrix and processing speed.
Notes:
This method processes the accumulated predictions and targets to generate the confusion matrix,
optionally plots it, and updates the metrics object with speed information.
Examples:
>>> validator = ClassificationValidator()
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
>>> validator.targets = [torch.tensor([0])] # Ground truth class
>>> validator.finalize_metrics()
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
"""
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.save_dir = self.save_dir
self.metrics.confusion_matrix = self.confusion_matrix
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
"""Extract the primary prediction from model output if it's in a list or tuple format."""
return preds[0] if isinstance(preds, (list, tuple)) else preds
def get_stats(self) -> dict[str, float]:
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
self.metrics.process(self.targets, self.pred)
return self.metrics.results_dict
def build_dataset(self, img_path: str) -> ClassificationDataset:
"""Create a ClassificationDataset instance for validation."""
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Build and return a data loader for classification validation.
Args:
dataset_path (str | Path): Path to the dataset directory.
batch_size (int): Number of samples per batch.
Returns:
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
"""
dataset = self.build_dataset(dataset_path)
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
def print_results(self) -> None:
"""Print evaluation metrics for the classification model."""
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples with their ground truth labels.
Args:
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
>>> validator.plot_val_samples(batch, 0)
"""
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
plot_images(
labels=batch,
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
"""
Plot images with their predicted class labels and save the visualization.
Args:
batch (dict[str, Any]): Batch data containing images and other information.
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = ClassificationValidator()
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
>>> validator.plot_predictions(batch, preds, 0)
"""
batched_preds = dict(
img=batch["img"],
batch_idx=torch.arange(batch["img"].shape[0]),
cls=torch.argmax(preds, dim=1),
)
plot_images(
batched_preds,
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import DetectionPredictor
from .train import DetectionTrainer
from .val import DetectionValidator
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"

View File

@@ -0,0 +1,125 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import nms, ops
class DetectionPredictor(BasePredictor):
"""
A class extending the BasePredictor class for prediction based on a detection model.
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
with bounding boxes and class predictions.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (nn.Module): The detection model used for inference.
batch (list): Batch of images and metadata for processing.
Methods:
postprocess: Process raw model predictions into detection results.
construct_results: Build Results objects from processed predictions.
construct_result: Create a single Result object from a prediction.
get_obj_feats: Extract object features from the feature maps.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.detect import DetectionPredictor
>>> args = dict(model="yolo11n.pt", source=ASSETS)
>>> predictor = DetectionPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def postprocess(self, preds, img, orig_imgs, **kwargs):
"""
Post-process predictions and return a list of Results objects.
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
further analysis.
Args:
preds (torch.Tensor): Raw predictions from the model.
img (torch.Tensor): Processed input image tensor in model input format.
orig_imgs (torch.Tensor | list): Original input images before preprocessing.
**kwargs (Any): Additional keyword arguments.
Returns:
(list): List of Results objects containing the post-processed predictions.
Examples:
>>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
>>> results = predictor.predict("path/to/image.jpg")
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
"""
save_feats = getattr(self, "_feats", None) is not None
preds = nms.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
self.args.classes,
self.args.agnostic_nms,
max_det=self.args.max_det,
nc=0 if self.args.task == "detect" else len(self.model.names),
end2end=getattr(self.model, "end2end", False),
rotated=self.args.task == "obb",
return_idxs=save_feats,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
if save_feats:
obj_feats = self.get_obj_feats(self._feats, preds[1])
preds = preds[0]
results = self.construct_results(preds, img, orig_imgs, **kwargs)
if save_feats:
for r, f in zip(results, obj_feats):
r.feats = f # add object features to results
return results
def get_obj_feats(self, feat_maps, idxs):
"""Extract object features from the feature maps."""
import torch
s = min(x.shape[1] for x in feat_maps) # find shortest vector length
obj_feats = torch.cat(
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
) # mean reduce all vectors to same length
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
def construct_results(self, preds, img, orig_imgs):
"""
Construct a list of Results objects from model predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
img (torch.Tensor): Batch of preprocessed images used for inference.
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
Returns:
(list[Results]): List of Results objects containing detection information for each image.
"""
return [
self.construct_result(pred, img, orig_img, img_path)
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
]
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct a single Results object from one image prediction.
Args:
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
img (torch.Tensor): Preprocessed image tensor used for inference.
orig_img (np.ndarray): Original image before preprocessing.
img_path (str): Path to the original image file.
Returns:
(Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
"""
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])

View File

@@ -0,0 +1,236 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import math
import random
from copy import copy
from typing import Any
import numpy as np
import torch
import torch.nn as nn
from ultralytics.data import build_dataloader, build_yolo_dataset
from ultralytics.engine.trainer import BaseTrainer
from ultralytics.models import yolo
from ultralytics.nn.tasks import DetectionModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.patches import override_configs
from ultralytics.utils.plotting import plot_images, plot_labels
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
class DetectionTrainer(BaseTrainer):
"""
A class extending the BaseTrainer class for training based on a detection model.
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
for object detection including dataset building, data loading, preprocessing, and model configuration.
Attributes:
model (DetectionModel): The YOLO detection model being trained.
data (dict): Dictionary containing dataset information including class names and number of classes.
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
Methods:
build_dataset: Build YOLO dataset for training or validation.
get_dataloader: Construct and return dataloader for the specified mode.
preprocess_batch: Preprocess a batch of images by scaling and converting to float.
set_model_attributes: Set model attributes based on dataset information.
get_model: Return a YOLO detection model.
get_validator: Return a validator for model evaluation.
label_loss_items: Return a loss dictionary with labeled training loss items.
progress_string: Return a formatted string of training progress.
plot_training_samples: Plot training samples with their annotations.
plot_training_labels: Create a labeled training plot of the YOLO model.
auto_batch: Calculate optimal batch size based on model memory requirements.
Examples:
>>> from ultralytics.models.yolo.detect import DetectionTrainer
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
>>> trainer = DetectionTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a DetectionTrainer object for training YOLO object detection model training.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
"""
super().__init__(cfg, overrides, _callbacks)
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for 'rect' mode.
Returns:
(Dataset): YOLO dataset object configured for the specified mode.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
"""
Construct and return dataloader for the specified mode.
Args:
dataset_path (str): Path to the dataset.
batch_size (int): Number of images per batch.
rank (int): Process rank for distributed training.
mode (str): 'train' for training dataloader, 'val' for validation dataloader.
Returns:
(DataLoader): PyTorch dataloader object.
"""
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
dataset = self.build_dataset(dataset_path, mode, batch_size)
shuffle = mode == "train"
if getattr(dataset, "rect", False) and shuffle:
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
shuffle = False
return build_dataloader(
dataset,
batch=batch_size,
workers=self.args.workers if mode == "train" else self.args.workers * 2,
shuffle=shuffle,
rank=rank,
drop_last=self.args.compile and mode == "train",
)
def preprocess_batch(self, batch: dict) -> dict:
"""
Preprocess a batch of images by scaling and converting to float.
Args:
batch (dict): Dictionary containing batch data with 'img' tensor.
Returns:
(dict): Preprocessed batch with normalized images.
"""
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = batch["img"].float() / 255
if self.args.multi_scale:
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
) # size
sf = sz / max(imgs.shape[2:]) # scale factor
if sf != 1:
ns = [
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
] # new shape (stretched to gs-multiple)
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
def set_model_attributes(self):
"""Set model attributes based on dataset information."""
# Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
# self.args.box *= 3 / nl # scale to layers
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
self.model.nc = self.data["nc"] # attach number of classes to model
self.model.names = self.data["names"] # attach class names to model
self.model.args = self.args # attach hyperparameters to model
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
"""
Return a YOLO detection model.
Args:
cfg (str, optional): Path to model configuration file.
weights (str, optional): Path to model weights.
verbose (bool): Whether to display model information.
Returns:
(DetectionModel): YOLO detection model.
"""
model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return a DetectionValidator for YOLO model validation."""
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.detect.DetectionValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
"""
Return a loss dict with labeled training loss items tensor.
Args:
loss_items (list[float], optional): List of loss values.
prefix (str): Prefix for keys in the returned dictionary.
Returns:
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
"""
keys = [f"{prefix}/{x}" for x in self.loss_names]
if loss_items is not None:
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
return dict(zip(keys, loss_items))
else:
return keys
def progress_string(self):
"""Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
"Epoch",
"GPU_mem",
*self.loss_names,
"Instances",
"Size",
)
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot training samples with their annotations.
Args:
batch (dict[str, Any]): Dictionary containing batch data.
ni (int): Number of iterations.
"""
plot_images(
labels=batch,
paths=batch["im_file"],
fname=self.save_dir / f"train_batch{ni}.jpg",
on_plot=self.on_plot,
)
def plot_training_labels(self):
"""Create a labeled training plot of the YOLO model."""
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
def auto_batch(self):
"""
Get optimal batch size by calculating memory occupation of model.
Returns:
(int): Optimal batch size.
"""
with override_configs(self.args, overrides={"cache": False}) as self.args:
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
del train_dataset # free memory
return super().auto_batch(max_num_obj)

View File

@@ -0,0 +1,495 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
from ultralytics.engine.validator import BaseValidator
from ultralytics.utils import LOGGER, nms, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
from ultralytics.utils.plotting import plot_images
class DetectionValidator(BaseValidator):
"""
A class extending the BaseValidator class for validation based on a detection model.
This class implements validation functionality specific to object detection tasks, including metrics calculation,
prediction processing, and visualization of results.
Attributes:
is_coco (bool): Whether the dataset is COCO.
is_lvis (bool): Whether the dataset is LVIS.
class_map (list[int]): Mapping from model class indices to dataset class indices.
metrics (DetMetrics): Object detection metrics calculator.
iouv (torch.Tensor): IoU thresholds for mAP calculation.
niou (int): Number of IoU thresholds.
lb (list[Any]): List for storing ground truth labels for hybrid saving.
jdict (list[dict[str, Any]]): List for storing JSON detection results.
stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
Examples:
>>> from ultralytics.models.yolo.detect import DetectionValidator
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
>>> validator = DetectionValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize detection validator with necessary variables and settings.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (Path, optional): Directory to save results.
args (dict[str, Any], optional): Arguments for the validator.
_callbacks (list[Any], optional): List of callback functions.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.is_coco = False
self.is_lvis = False
self.class_map = None
self.args.task = "detect"
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
self.niou = self.iouv.numel()
self.metrics = DetMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
Returns:
(dict[str, Any]): Preprocessed batch.
"""
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO detection validation.
Args:
model (torch.nn.Module): Model to validate.
"""
val = self.data.get(self.args.split, "") # validation path
self.is_coco = (
isinstance(val, str)
and "coco" in val
and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
) # is COCO
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
self.names = model.names
self.nc = len(model.names)
self.end2end = getattr(model, "end2end", False)
self.seen = 0
self.jdict = []
self.metrics.names = model.names
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
def get_desc(self) -> str:
"""Return a formatted string summarizing class metrics of YOLO model."""
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
Apply Non-maximum suppression to prediction outputs.
Args:
preds (torch.Tensor): Raw predictions from the model.
Returns:
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
'bboxes', 'conf', 'cls', and 'extra' tensors.
"""
outputs = nms.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
nc=0 if self.args.task == "detect" else self.nc,
multi_label=True,
agnostic=self.args.single_cls or self.args.agnostic_nms,
max_det=self.args.max_det,
end2end=self.end2end,
rotated=self.args.task == "obb",
)
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch of images and annotations for validation.
Args:
si (int): Batch index.
batch (dict[str, Any]): Batch data containing images and annotations.
Returns:
(dict[str, Any]): Prepared batch with processed annotations.
"""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if cls.shape[0]:
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
return {
"cls": cls,
"bboxes": bbox,
"ori_shape": ori_shape,
"imgsz": imgsz,
"ratio_pad": ratio_pad,
"im_file": batch["im_file"][si],
}
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Prepare predictions for evaluation against ground truth.
Args:
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
Returns:
(dict[str, torch.Tensor]): Prepared predictions in native space.
"""
if self.args.single_cls:
pred["cls"] *= 0
return pred
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
"""
Update metrics with new predictions and ground truth.
Args:
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
batch (dict[str, Any]): Batch data containing ground truth.
"""
for si, pred in enumerate(preds):
self.seen += 1
pbatch = self._prepare_batch(si, batch)
predn = self._prepare_pred(pred)
cls = pbatch["cls"].cpu().numpy()
no_pred = predn["cls"].shape[0] == 0
self.metrics.update_stats(
{
**self._process_batch(predn, pbatch),
"target_cls": cls,
"target_img": np.unique(cls),
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
}
)
# Evaluate
if self.args.plots:
self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
if self.args.visualize:
self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
if no_pred:
continue
# Save
if self.args.save_json or self.args.save_txt:
predn_scaled = self.scale_preds(predn, pbatch)
if self.args.save_json:
self.pred_to_json(predn_scaled, pbatch)
if self.args.save_txt:
self.save_one_txt(
predn_scaled,
self.args.save_conf,
pbatch["ori_shape"],
self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
)
def finalize_metrics(self) -> None:
"""Set final values for metrics speed and confusion matrix."""
if self.args.plots:
for normalize in True, False:
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
self.metrics.speed = self.speed
self.metrics.confusion_matrix = self.confusion_matrix
self.metrics.save_dir = self.save_dir
def get_stats(self) -> dict[str, Any]:
"""
Calculate and return metrics statistics.
Returns:
(dict[str, Any]): Dictionary containing metrics results.
"""
self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
self.metrics.clear_stats()
return self.metrics.results_dict
def print_results(self) -> None:
"""Print training/validation set metrics per class."""
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
if self.metrics.nt_per_class.sum() == 0:
LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
# Print results per class
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
for i, c in enumerate(self.metrics.ap_class_index):
LOGGER.info(
pf
% (
self.names[c],
self.metrics.nt_per_image[c],
self.metrics.nt_per_class[c],
*self.metrics.class_result(i),
)
)
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
Returns:
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
"""
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
iou = box_iou(batch["bboxes"], preds["bboxes"])
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
"""
Build YOLO Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`.
Returns:
(Dataset): YOLO dataset.
"""
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
"""
Construct and return dataloader.
Args:
dataset_path (str): Path to the dataset.
batch_size (int): Size of each batch.
Returns:
(torch.utils.data.DataLoader): Dataloader for validation.
"""
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
return build_dataloader(
dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
)
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
"""
Plot validation image samples.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
ni (int): Batch index.
"""
plot_images(
labels=batch,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
names=self.names,
on_plot=self.on_plot,
)
def plot_predictions(
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
ni (int): Batch index.
max_det (Optional[int]): Maximum number of detections to plot.
"""
# TODO: optimize this
for i, pred in enumerate(preds):
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
keys = preds[0].keys()
max_det = max_det or self.args.max_det
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
# TODO: fix this
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
plot_images(
images=batch["img"],
labels=batched_preds,
paths=batch["im_file"],
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
names=self.names,
on_plot=self.on_plot,
) # pred
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image (height, width).
file (Path): File path to save the detections.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Serialize YOLO predictions to COCO json format.
Args:
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Examples:
>>> result = {
... "image_id": 42,
... "file_name": "42.jpg",
... "category_id": 18,
... "bbox": [258.15, 41.29, 348.26, 243.78],
... "score": 0.236,
... }
"""
path = Path(pbatch["im_file"])
stem = path.stem
image_id = int(stem) if stem.isnumeric() else stem
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
self.jdict.append(
{
"image_id": image_id,
"file_name": path.name,
"category_id": self.class_map[int(c)],
"bbox": [round(x, 3) for x in b],
"score": round(s, 5),
}
)
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**predn,
"bboxes": ops.scale_boxes(
pbatch["imgsz"],
predn["bboxes"].clone(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and return performance statistics.
Args:
stats (dict[str, Any]): Current statistics dictionary.
Returns:
(dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
"""
pred_json = self.save_dir / "predictions.json" # predictions
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
return self.coco_evaluate(stats, pred_json, anno_json)
def coco_evaluate(
self,
stats: dict[str, Any],
pred_json: str,
anno_json: str,
iou_types: str | list[str] = "bbox",
suffix: str | list[str] = "Box",
) -> dict[str, Any]:
"""
Evaluate COCO/LVIS metrics using faster-coco-eval library.
Performs evaluation using the faster-coco-eval library to compute mAP metrics
for object detection. Updates the provided stats dictionary with computed metrics
including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
Args:
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
to iou_types if multiple types provided. Defaults to "Box".
Returns:
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
"""
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
try:
for x in pred_json, anno_json:
assert x.is_file(), f"{x} file not found"
iou_types = [iou_types] if isinstance(iou_types, str) else iou_types
suffix = [suffix] if isinstance(suffix, str) else suffix
check_requirements("faster-coco-eval>=1.6.7")
from faster_coco_eval import COCO, COCOeval_faster
anno = COCO(anno_json)
pred = anno.loadRes(pred_json)
for i, iou_type in enumerate(iou_types):
val = COCOeval_faster(
anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info
)
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
val.evaluate()
val.accumulate()
val.summarize()
# update mAP50-95 and mAP50
stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
if self.is_lvis:
stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]
stats[f"metrics/APc({suffix[i][0]})"] = val.stats_as_dict["APc"]
stats[f"metrics/APf({suffix[i][0]})"] = val.stats_as_dict["APf"]
if self.is_lvis:
stats["fitness"] = stats["metrics/mAP50-95(B)"] # always use box mAP50-95 for fitness
except Exception as e:
LOGGER.warning(f"faster-coco-eval unable to run: {e}")
return stats

View File

@@ -0,0 +1,447 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import torch
from ultralytics.data.build import load_inference_source
from ultralytics.engine.model import Model
from ultralytics.models import yolo
from ultralytics.nn.tasks import (
ClassificationModel,
DetectionModel,
OBBModel,
PoseModel,
SegmentationModel,
WorldModel,
YOLOEModel,
YOLOESegModel,
)
from ultralytics.utils import ROOT, YAML
class YOLO(Model):
"""
YOLO (You Only Look Once) object detection model.
This class provides a unified interface for YOLO models, automatically switching to specialized model types
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
Attributes:
model: The loaded YOLO model instance.
task: The task type (detect, segment, classify, pose, obb).
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize a YOLO model with automatic type detection.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
Examples:
Load a pretrained YOLOv11n detection model
>>> model = YOLO("yolo11n.pt")
Load a pretrained YOLO11n segmentation model
>>> model = YOLO("yolo11n-seg.pt")
Initialize from a YAML configuration
>>> model = YOLO("yolo11n.yaml")
"""
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
"""
Initialize a YOLO model.
This constructor initializes a YOLO model, automatically switching to specialized model types
(YOLOWorld or YOLOE) based on the model filename.
Args:
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
Defaults to auto-detection based on model.
verbose (bool): Display model info on load.
Examples:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
"""
path = Path(model if isinstance(model, (str, Path)) else "")
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
new_instance = YOLOWorld(path, verbose=verbose)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
new_instance = YOLOE(path, task=task, verbose=verbose)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
else:
# Continue with default YOLO initialization
super().__init__(model=model, task=task, verbose=verbose)
if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
from ultralytics import RTDETR
new_instance = RTDETR(self)
self.__class__ = type(new_instance)
self.__dict__ = new_instance.__dict__
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, trainer, validator, and predictor classes."""
return {
"classify": {
"model": ClassificationModel,
"trainer": yolo.classify.ClassificationTrainer,
"validator": yolo.classify.ClassificationValidator,
"predictor": yolo.classify.ClassificationPredictor,
},
"detect": {
"model": DetectionModel,
"trainer": yolo.detect.DetectionTrainer,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
},
"segment": {
"model": SegmentationModel,
"trainer": yolo.segment.SegmentationTrainer,
"validator": yolo.segment.SegmentationValidator,
"predictor": yolo.segment.SegmentationPredictor,
},
"pose": {
"model": PoseModel,
"trainer": yolo.pose.PoseTrainer,
"validator": yolo.pose.PoseValidator,
"predictor": yolo.pose.PosePredictor,
},
"obb": {
"model": OBBModel,
"trainer": yolo.obb.OBBTrainer,
"validator": yolo.obb.OBBValidator,
"predictor": yolo.obb.OBBPredictor,
},
}
class YOLOWorld(Model):
"""
YOLO-World object detection model.
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
without requiring training on specific classes. It extends the YOLO architecture to support real-time
open-vocabulary detection.
Attributes:
model: The loaded YOLO-World model instance.
task: Always set to 'detect' for object detection.
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize YOLOv8-World model with a pre-trained model file.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
set_classes: Set the model's class names for detection.
Examples:
Load a YOLOv8-World model
>>> model = YOLOWorld("yolov8s-world.pt")
Set custom classes for detection
>>> model.set_classes(["person", "car", "bicycle"])
"""
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
"""
Initialize YOLOv8-World model with a pre-trained model file.
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
COCO class names.
Args:
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
verbose (bool): If True, prints additional information during initialization.
"""
super().__init__(model=model, task="detect", verbose=verbose)
# Assign default COCO class names when there are no custom names
if not hasattr(self.model, "names"):
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": WorldModel,
"validator": yolo.detect.DetectionValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.world.WorldTrainer,
}
}
def set_classes(self, classes: list[str]) -> None:
"""
Set the model's class names for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
"""
self.model.set_classes(classes)
# Remove background if it's given
background = " "
if background in classes:
classes.remove(background)
self.model.names = classes
# Reset method class names
if self.predictor:
self.predictor.model.names = classes
class YOLOE(Model):
"""
YOLOE object detection and segmentation model.
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
improved performance and additional features like visual and text positional embeddings.
Attributes:
model: The loaded YOLOE model instance.
task: The task type (detect or segment).
overrides: Configuration overrides for the model.
Methods:
__init__: Initialize YOLOE model with a pre-trained model file.
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
get_text_pe: Get text positional embeddings for the given texts.
get_visual_pe: Get visual positional embeddings for the given image and visual features.
set_vocab: Set vocabulary and class names for the YOLOE model.
get_vocab: Get vocabulary for the given class names.
set_classes: Set the model's class names and embeddings for detection.
val: Validate the model using text or visual prompts.
predict: Run prediction on images, videos, directories, streams, etc.
Examples:
Load a YOLOE detection model
>>> model = YOLOE("yoloe-11s-seg.pt")
Set vocabulary and class names
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
Predict with visual prompts
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
>>> results = model.predict("image.jpg", visual_prompts=prompts)
"""
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
"""
Initialize YOLOE model with a pre-trained model file.
Args:
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
task (str, optional): Task type for the model. Auto-detected if None.
verbose (bool): If True, prints additional information during initialization.
"""
super().__init__(model=model, task=task, verbose=verbose)
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Map head to model, validator, and predictor classes."""
return {
"detect": {
"model": YOLOEModel,
"validator": yolo.yoloe.YOLOEDetectValidator,
"predictor": yolo.detect.DetectionPredictor,
"trainer": yolo.yoloe.YOLOETrainer,
},
"segment": {
"model": YOLOESegModel,
"validator": yolo.yoloe.YOLOESegValidator,
"predictor": yolo.segment.SegmentationPredictor,
"trainer": yolo.yoloe.YOLOESegTrainer,
},
}
def get_text_pe(self, texts):
"""Get text positional embeddings for the given texts."""
assert isinstance(self.model, YOLOEModel)
return self.model.get_text_pe(texts)
def get_visual_pe(self, img, visual):
"""
Get visual positional embeddings for the given image and visual features.
This method extracts positional embeddings from visual features based on the input image. It requires
that the model is an instance of YOLOEModel.
Args:
img (torch.Tensor): Input image tensor.
visual (torch.Tensor): Visual features extracted from the image.
Returns:
(torch.Tensor): Visual positional embeddings.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> img = torch.rand(1, 3, 640, 640)
>>> visual_features = torch.rand(1, 1, 80, 80)
>>> pe = model.get_visual_pe(img, visual_features)
"""
assert isinstance(self.model, YOLOEModel)
return self.model.get_visual_pe(img, visual)
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
"""
Set vocabulary and class names for the YOLOE model.
This method configures the vocabulary and class names used by the model for text processing and
classification tasks. The model must be an instance of YOLOEModel.
Args:
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
names (list[str]): List of class names that the model can detect or classify.
Raises:
AssertionError: If the model is not an instance of YOLOEModel.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
"""
assert isinstance(self.model, YOLOEModel)
self.model.set_vocab(vocab, names=names)
def get_vocab(self, names):
"""Get vocabulary for the given class names."""
assert isinstance(self.model, YOLOEModel)
return self.model.get_vocab(names)
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
"""
Set the model's class names and embeddings for detection.
Args:
classes (list[str]): A list of categories i.e. ["person"].
embeddings (torch.Tensor): Embeddings corresponding to the classes.
"""
assert isinstance(self.model, YOLOEModel)
if embeddings is None:
embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
self.model.set_classes(classes, embeddings)
# Verify no background class is present
assert " " not in classes
self.model.names = classes
# Reset method class names
if self.predictor:
self.predictor.model.names = classes
def val(
self,
validator=None,
load_vp: bool = False,
refer_data: str | None = None,
**kwargs,
):
"""
Validate the model using text or visual prompts.
Args:
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
refer_data (str, optional): Path to the reference data for visual prompts.
**kwargs (Any): Additional keyword arguments to override default settings.
Returns:
(dict): Validation statistics containing metrics computed during validation.
"""
custom = {"rect": not load_vp} # method defaults
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
self.metrics = validator.metrics
return validator.metrics
def predict(
self,
source=None,
stream: bool = False,
visual_prompts: dict[str, list] = {},
refer_image=None,
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
**kwargs,
):
"""
Run prediction on images, videos, directories, streams, etc.
Args:
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
generator as they are computed.
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
'bboxes' and 'cls' keys when non-empty.
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
loaded based on the task.
**kwargs (Any): Additional keyword arguments passed to the predictor.
Returns:
(list | generator): List of Results objects or generator of Results objects if stream=True.
Examples:
>>> model = YOLOE("yoloe-11s-seg.pt")
>>> results = model.predict("path/to/image.jpg")
>>> # With visual prompts
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
"""
if len(visual_prompts):
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
)
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
f"{len(visual_prompts['cls'])} respectively"
)
if type(self.predictor) is not predictor:
self.predictor = predictor(
overrides={
"task": self.model.task,
"mode": "predict",
"save": False,
"verbose": refer_image is None,
"batch": 1,
"device": kwargs.get("device", None),
"half": kwargs.get("half", False),
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
},
_callbacks=self.callbacks,
)
num_cls = (
max(len(set(c)) for c in visual_prompts["cls"])
if isinstance(source, list) and refer_image is None # means multiple images
else len(set(visual_prompts["cls"]))
)
self.model.model[-1].nc = num_cls
self.model.names = [f"object{i}" for i in range(num_cls)]
self.predictor.set_prompts(visual_prompts.copy())
self.predictor.setup_model(model=self.model)
if refer_image is None and source is not None:
dataset = load_inference_source(source)
if dataset.mode in {"video", "stream"}:
# NOTE: set the first frame as refer image for videos/streams inference
refer_image = next(iter(dataset))[1][0]
if refer_image is not None:
vpe = self.predictor.get_vpe(refer_image)
self.model.set_classes(self.model.names, vpe)
self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
self.predictor = None # reset predictor
elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
self.predictor = None # reset predictor if no visual prompts
return super().predict(source, stream, **kwargs)

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import OBBPredictor
from .train import OBBTrainer
from .val import OBBValidator
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"

View File

@@ -0,0 +1,65 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class OBBPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
bounding boxes.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO OBB model.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.obb import OBBPredictor
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
>>> predictor = OBBPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize OBBPredictor with optional model and data configuration overrides.
Args:
cfg (dict, optional): Default configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over the default config.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.obb import OBBPredictor
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
>>> predictor = OBBPredictor(overrides=args)
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "obb"
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
the last dimension contains [x, y, w, h, confidence, class_id, angle].
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
orig_img (np.ndarray): The original image before preprocessing.
img_path (str): The path to the original image.
Returns:
(Results): The result object containing the original image, image path, class names, and oriented bounding
boxes.
"""
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)

View File

@@ -0,0 +1,82 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from typing import Any
from ultralytics.models import yolo
from ultralytics.nn.tasks import OBBModel
from ultralytics.utils import DEFAULT_CFG, RANK
class OBBTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
Attributes:
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
and dfl_loss.
Methods:
get_model: Return OBBModel initialized with specified config and weights.
get_validator: Return an instance of OBBValidator for validation of YOLO model.
Examples:
>>> from ultralytics.models.yolo.obb import OBBTrainer
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
>>> trainer = OBBTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
"""
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
Args:
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
model configuration.
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
will take precedence over those in cfg.
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "obb"
super().__init__(cfg, overrides, _callbacks)
def get_model(
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
) -> OBBModel:
"""
Return OBBModel initialized with specified config and weights.
Args:
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
containing configuration parameters, or None to use default configuration.
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
verbose (bool): Whether to display model information during initialization.
Returns:
(OBBModel): Initialized OBBModel with the specified configuration and weights.
Examples:
>>> trainer = OBBTrainer()
>>> model = trainer.get_model(cfg="yolo11n-obb.yaml", weights="yolo11n-obb.pt")
"""
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of OBBValidator for validation of YOLO model."""
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
return yolo.obb.OBBValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)

View File

@@ -0,0 +1,299 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
from ultralytics.utils.nms import TorchNMS
class OBBValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
satellite imagery where objects can appear at various orientations.
Attributes:
args (dict): Configuration arguments for the validator.
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
Methods:
init_metrics: Initialize evaluation metrics for YOLO.
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
_prepare_batch: Prepare batch data for OBB validation.
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
plot_predictions: Plot predicted bounding boxes on input images.
pred_to_json: Serialize YOLO predictions to COCO json format.
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
Examples:
>>> from ultralytics.models.yolo.obb import OBBValidator
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
>>> validator = OBBValidator(args=args)
>>> validator(model=args["model"])
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
It extends the DetectionValidator class and configures it specifically for the OBB task.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (str | Path, optional): Directory to save results.
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
_callbacks (list, optional): List of callback functions to be called during validation.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.args.task = "obb"
self.metrics = OBBMetrics()
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO obb validation.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
val = self.data.get(self.args.split, "") # validation path
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
"""
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
Args:
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
class labels and bounding boxes.
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
class labels and bounding boxes.
Returns:
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
of predictions compared to the ground truth.
Examples:
>>> detections = torch.rand(100, 7) # 100 sample detections
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
"""
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
"""
Args:
preds (torch.Tensor): Raw predictions from the model.
Returns:
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
"""
preds = super().postprocess(preds)
for pred in preds:
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare batch data for OBB validation with proper scaling and formatting.
Args:
si (int): Batch index to process.
batch (dict[str, Any]): Dictionary containing batch data with keys:
- batch_idx: Tensor of batch indices
- cls: Tensor of class labels
- bboxes: Tensor of bounding boxes
- ori_shape: Original image shapes
- img: Batch of images
- ratio_pad: Ratio and padding information
Returns:
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
"""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if cls.shape[0]:
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
return {
"cls": cls,
"bboxes": bbox,
"ori_shape": ori_shape,
"imgsz": imgsz,
"ratio_pad": ratio_pad,
"im_file": batch["im_file"][si],
}
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
"""
Plot predicted bounding boxes on input images and save the result.
Args:
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
ni (int): Batch index used for naming the output file.
Examples:
>>> validator = OBBValidator()
>>> batch = {"img": images, "im_file": paths}
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
>>> validator.plot_predictions(batch, preds, 0)
"""
for p in preds:
# TODO: fix this duplicated `xywh2xyxy`
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
super().plot_predictions(batch, preds, ni) # plot bboxes
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
with bounding box coordinates, confidence scores, and class predictions.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
This method processes rotated bounding box predictions and converts them to both rbox format
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
to the JSON dictionary.
"""
path = Path(pbatch["im_file"])
stem = path.stem
image_id = int(stem) if stem.isnumeric() else stem
rbox = predn["bboxes"]
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
self.jdict.append(
{
"image_id": image_id,
"file_name": path.name,
"category_id": self.class_map[int(c)],
"score": round(s, 5),
"rbox": [round(x, 3) for x in r],
"poly": [round(x, 3) for x in b],
}
)
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO OBB detections to a text file in normalized coordinates.
Args:
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
save_conf (bool): Whether to save confidence scores in the text file.
shape (tuple[int, int]): Original image shape in format (height, width).
file (Path): Output file path to save detections.
Examples:
>>> validator = OBBValidator()
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
"""
import numpy as np
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
).save_txt(file, save_conf=save_conf)
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**predn,
"bboxes": ops.scale_boxes(
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""
Evaluate YOLO output in JSON format and save predictions in DOTA format.
Args:
stats (dict[str, Any]): Performance statistics dictionary.
Returns:
(dict[str, Any]): Updated performance statistics.
"""
if self.args.save_json and self.is_dota and len(self.jdict):
import json
import re
from collections import defaultdict
pred_json = self.save_dir / "predictions.json" # predictions
pred_txt = self.save_dir / "predictions_txt" # predictions
pred_txt.mkdir(parents=True, exist_ok=True)
data = json.load(open(pred_json))
# Save split results
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
for d in data:
image_id = d["image_id"]
score = d["score"]
classname = self.names[d["category_id"] - 1].replace(" ", "-")
p = d["poly"]
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
# Save merged results, this could result slightly lower map than using official merging script,
# because of the probiou calculation.
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
pred_merged_txt.mkdir(parents=True, exist_ok=True)
merged_results = defaultdict(list)
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
for d in data:
image_id = d["image_id"].split("__", 1)[0]
pattern = re.compile(r"\d+___\d+")
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
bbox[0] += x
bbox[1] += y
bbox.extend([score, cls])
merged_results[image_id].append(bbox)
for image_id, bbox in merged_results.items():
bbox = torch.tensor(bbox)
max_wh = torch.max(bbox[:, :2]).item() * 2
c = bbox[:, 6:7] * max_wh # classes
scores = bbox[:, 5] # scores
b = bbox[:, :5].clone()
b[:, :2] += c
# 0.3 could get results close to the ones from official merging script, even slightly better.
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
bbox = bbox[i]
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
classname = self.names[int(x[-1])].replace(" ", "-")
p = [round(i, 3) for i in x[:-2]] # poly
score = round(x[-2], 3)
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
return stats

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"

View File

@@ -0,0 +1,80 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
capabilities inherited from DetectionPredictor.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
Methods:
construct_result: Construct the result object from the prediction, including keypoints.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize PosePredictor for pose estimation tasks.
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
warnings for Apple MPS.
Args:
cfg (Any): Configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "pose"
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction, including keypoints.
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
result object.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
orig_img (np.ndarray): The original unprocessed image as a numpy array.
img_path (str): The path to the original image file.
Returns:
(Results): The result object containing the original image, image path, class names, bounding boxes, and
keypoints.
"""
result = super().construct_result(pred, img, orig_img, img_path)
# Extract keypoints from prediction and reshape according to model's keypoint shape
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
# Scale keypoints coordinates to match the original image dimensions
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
result.update(keypoints=pred_kpts)
return result

View File

@@ -0,0 +1,115 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from typing import Any
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training YOLO pose estimation models.
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
of pose keypoints alongside bounding boxes.
Attributes:
args (dict): Configuration arguments for training.
model (PoseModel): The pose estimation model being trained.
data (dict): Dataset configuration including keypoint shape information.
loss_names (tuple): Names of the loss components used in training.
Methods:
get_model: Retrieve a pose estimation model with specified configuration.
set_model_attributes: Set keypoints shape attribute on the model.
get_validator: Create a validator instance for model evaluation.
plot_training_samples: Visualize training samples with keypoints.
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
Examples:
>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a PoseTrainer object for training YOLO pose estimation models.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
Notes:
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
A warning is issued when using Apple MPS device due to known bugs with pose models.
"""
if overrides is None:
overrides = {}
overrides["task"] = "pose"
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def get_model(
self,
cfg: str | Path | dict[str, Any] | None = None,
weights: str | Path | None = None,
verbose: bool = True,
) -> PoseModel:
"""
Get pose estimation model with specified configuration and weights.
Args:
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
weights (str | Path, optional): Path to the model weights file.
verbose (bool): Whether to display model information.
Returns:
(PoseModel): Initialized pose estimation model.
"""
model = PoseModel(
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
)
if weights:
model.load(weights)
return model
def set_model_attributes(self):
"""Set keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data["kpt_shape"]
def get_validator(self):
"""Return an instance of the PoseValidator class for validation."""
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
return yolo.pose.PoseValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def get_dataset(self) -> dict[str, Any]:
"""
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
Raises:
KeyError: If the `kpt_shape` key is not present in the dataset.
"""
data = super().get_dataset()
if "kpt_shape" not in data:
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
return data

View File

@@ -0,0 +1,267 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Attributes:
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
args (dict): Arguments for the validator including task set to "pose".
metrics (PoseMetrics): Metrics object for pose evaluation.
Methods:
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
get_desc: Return description of evaluation metrics in string format.
init_metrics: Initialize pose estimation metrics for YOLO model.
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
dimensions.
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
detections and ground truth.
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
pred_to_json: Convert YOLO predictions to COCO JSON format.
eval_json: Evaluate object detection model using COCO JSON format.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize a PoseValidator object for pose estimation validation.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path | str, optional): Directory to save results.
args (dict, optional): Arguments for the validator including task set to "pose".
_callbacks (list, optional): List of callback functions to be executed during validation.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
Notes:
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
due to a known bug with pose models.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = "pose"
self.metrics = PoseMetrics()
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
batch = super().preprocess(batch)
batch["keypoints"] = batch["keypoints"].float()
return batch
def get_desc(self) -> str:
"""Return description of evaluation metrics in string format."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Pose(P",
"R",
"mAP50",
"mAP50-95)",
)
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO pose validation.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
field of predictions and reshaping them according to the keypoint shape configuration.
The keypoints are reshaped from a flattened format to the proper dimensional structure
(typically [N, 17, 3] for COCO pose format).
Args:
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
bounding boxes, confidence scores, class predictions, and keypoint data.
Returns:
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
- 'bboxes': Bounding box coordinates
- 'conf': Confidence scores
- 'cls': Class predictions
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
Note:
If no keypoints are present in a prediction (empty keypoints), that prediction
is skipped and continues to the next one. The keypoints are extracted from the
'extra' field which contains additional task-specific data beyond basic detection.
"""
preds = super().postprocess(preds)
for pred in preds:
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
Args:
si (int): Batch index.
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
Returns:
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
Notes:
This method extends the parent class's _prepare_batch method by adding keypoint processing.
Keypoints are scaled from normalized coordinates to original image dimensions.
"""
pbatch = super()._prepare_batch(si, batch)
kpts = batch["keypoints"][batch["batch_idx"] == si]
h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
pbatch["keypoints"] = kpts
return pbatch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
and 'keypoints' for keypoint predictions.
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
Returns:
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
true positives across 10 IoU levels.
Notes:
`0.53` scale factor used in area computation is referenced from
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
"""
tp = super()._process_batch(preds, batch)
gt_cls = batch["cls"]
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
else:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
return tp
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO pose detections to a text file in normalized coordinates.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image (height, width).
file (Path): Output file path to save detections.
Notes:
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
normalized (x, y, visibility) values for each point.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
keypoints=predn["keypoints"],
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format.
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
and 'keypoints' tensors.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
before saving to the JSON dictionary.
"""
super().pred_to_json(predn, pbatch)
kpts = predn["kpts"]
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**super().scale_preds(predn, pbatch),
"kpts": ops.scale_coords(
pbatch["imgsz"],
predn["keypoints"].clone(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""Evaluate object detection model using COCO JSON format."""
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import SegmentationPredictor
from .train import SegmentationTrainer
from .val import SegmentationValidator
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"

View File

@@ -0,0 +1,113 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
class SegmentationPredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a segmentation model.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
Attributes:
args (dict): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO segmentation model.
batch (list): Current batch of images being processed.
Methods:
postprocess: Apply non-max suppression and process segmentation detections.
construct_results: Construct a list of result objects from predictions.
construct_result: Construct a single result object from a prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
>>> predictor = SegmentationPredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
prediction results.
Args:
cfg (dict): Configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "segment"
def postprocess(self, preds, img, orig_imgs):
"""
Apply non-max suppression and process segmentation detections for each image in the input batch.
Args:
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
Returns:
(list): List of Results objects containing the segmentation predictions for each image in the batch.
Each Results object includes both bounding boxes and segmentation masks.
Examples:
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
>>> results = predictor.postprocess(preds, img, orig_img)
"""
# Extract protos - tuple if PyTorch model or array if exported
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
def construct_results(self, preds, img, orig_imgs, protos):
"""
Construct a list of result objects from the predictions.
Args:
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
img (torch.Tensor): The image after preprocessing.
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
protos (list[torch.Tensor]): List of prototype masks.
Returns:
(list[Results]): List of result objects containing the original images, image paths, class names,
bounding boxes, and masks.
"""
return [
self.construct_result(pred, img, orig_img, img_path, proto)
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
]
def construct_result(self, pred, img, orig_img, img_path, proto):
"""
Construct a single result object from the prediction.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
img (torch.Tensor): The image after preprocessing.
orig_img (np.ndarray): The original image before preprocessing.
img_path (str): The path to the original image.
proto (torch.Tensor): The prototype masks.
Returns:
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
"""
if pred.shape[0] == 0: # save empty boxes
masks = None
elif self.args.retina_masks:
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
else:
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
if masks is not None:
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
pred, masks = pred[keep], masks[keep]
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)

View File

@@ -0,0 +1,72 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from ultralytics.models import yolo
from ultralytics.nn.tasks import SegmentationModel
from ultralytics.utils import DEFAULT_CFG, RANK
class SegmentationTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training based on a segmentation model.
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
functionality including model initialization, validation, and visualization.
Attributes:
loss_names (tuple[str]): Names of the loss components used during training.
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
>>> trainer = SegmentationTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
"""
Initialize a SegmentationTrainer object.
Args:
cfg (dict): Configuration dictionary with default training settings.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
"""
if overrides is None:
overrides = {}
overrides["task"] = "segment"
super().__init__(cfg, overrides, _callbacks)
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
"""
Initialize and return a SegmentationModel with specified configuration and weights.
Args:
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
weights (str | Path, optional): Path to pretrained weights file.
verbose (bool): Whether to display model information during initialization.
Returns:
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
Examples:
>>> trainer = SegmentationTrainer()
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
"""
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def get_validator(self):
"""Return an instance of SegmentationValidator for validation of YOLO model."""
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
return yolo.segment.SegmentationValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)

View File

@@ -0,0 +1,259 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, NUM_THREADS, ops
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
class SegmentationValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a segmentation model.
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
to compute metrics such as mAP for both detection and segmentation tasks.
Attributes:
plot_masks (list): List to store masks for plotting.
process (callable): Function to process masks based on save_json and save_txt flags.
args (namespace): Arguments for the validator.
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
stats (dict): Dictionary to store statistics during validation.
Examples:
>>> from ultralytics.models.yolo.segment import SegmentationValidator
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
>>> validator = SegmentationValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
save_dir (Path, optional): Directory to save results.
args (namespace, optional): Arguments for the validator.
_callbacks (list, optional): List of callback functions.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.process = None
self.args.task = "segment"
self.metrics = SegmentMetrics()
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""
Preprocess batch of images for YOLO segmentation validation.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
Returns:
(dict[str, Any]): Preprocessed batch.
"""
batch = super().preprocess(batch)
batch["masks"] = batch["masks"].float()
return batch
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize metrics and select mask processing function based on save_json flag.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
if self.args.save_json:
check_requirements("faster-coco-eval>=1.6.7")
# More accurate vs faster
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
def get_desc(self) -> str:
"""Return a formatted description of evaluation metrics."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Mask(P",
"R",
"mAP50",
"mAP50-95)",
)
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
"""
Post-process YOLO predictions and return output detections with proto.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
Returns:
list[dict[str, torch.Tensor]]: Processed detection predictions with masks.
"""
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
preds = super().postprocess(preds[0])
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
for i, pred in enumerate(preds):
coefficient = pred.pop("extra")
pred["masks"] = (
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
if coefficient.shape[0]
else torch.zeros(
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
dtype=torch.uint8,
device=pred["bboxes"].device,
)
)
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for training or inference by processing images and targets.
Args:
si (int): Batch index.
batch (dict[str, Any]): Batch data containing images and annotations.
Returns:
(dict[str, Any]): Prepared batch with processed annotations.
"""
prepared_batch = super()._prepare_batch(si, batch)
nl = prepared_batch["cls"].shape[0]
if self.args.overlap_mask:
masks = batch["masks"][si]
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
masks = (masks == index).float()
else:
masks = batch["masks"][batch["batch_idx"] == si]
if nl:
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
if masks.shape[1:] != mask_size:
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
masks = masks.gt_(0.5)
prepared_batch["masks"] = masks
return prepared_batch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
batch (dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
Returns:
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
Notes:
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
Examples:
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
>>> correct_preds = validator._process_batch(preds, batch)
"""
tp = super()._process_batch(preds, batch)
gt_cls = batch["cls"]
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
else:
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
tp.update({"tp_m": tp_m}) # update tp with mask IoU
return tp
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
"""
Plot batch predictions with masks and bounding boxes.
Args:
batch (dict[str, Any]): Batch containing images and annotations.
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
ni (int): Batch index.
"""
for p in preds:
masks = p["masks"]
if masks.shape[0] > self.args.max_det:
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO detections to a txt file in normalized coordinates in a specific format.
Args:
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image.
file (Path): File path to save the detections.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Save one JSON result for COCO evaluation.
Args:
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
"""
from faster_coco_eval.core.mask import encode # noqa
def single_encode(x):
"""Encode predicted masks as RLE and append results to jdict."""
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
rle["counts"] = rle["counts"].decode("utf-8")
return rle
pred_masks = np.transpose(predn["masks"], (2, 0, 1))
with ThreadPool(NUM_THREADS) as pool:
rles = pool.map(single_encode, pred_masks)
super().pred_to_json(predn, pbatch)
for i, r in enumerate(rles):
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**super().scale_preds(predn, pbatch),
"masks": ops.scale_image(
torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""Return COCO-style instance segmentation evaluation metrics."""
pred_json = self.save_dir / "predictions.json" # predictions
anno_json = (
self.data["path"]
/ "annotations"
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
) # annotations
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "segm"], suffix=["Box", "Mask"])

View File

@@ -0,0 +1,5 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .train import WorldTrainer
__all__ = ["WorldTrainer"]

View File

@@ -0,0 +1,179 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
import itertools
from pathlib import Path
from typing import Any
import torch
from ultralytics.data import build_yolo_dataset
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import WorldModel
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
from ultralytics.utils.torch_utils import unwrap_model
def on_pretrain_routine_end(trainer) -> None:
"""Set up model classes and text encoder at the end of the pretrain routine."""
if RANK in {-1, 0}:
# Set class names for evaluation
names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
class WorldTrainer(DetectionTrainer):
"""
A trainer class for fine-tuning YOLO World models on close-set datasets.
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
features for improved object detection and understanding. It handles text embedding generation and caching to
accelerate training with multi-modal data.
Attributes:
text_embeddings (dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
training.
model (WorldModel): The YOLO World model being trained.
data (dict[str, Any]): Dataset configuration containing class information.
args (Any): Training arguments and configuration.
Methods:
get_model: Return WorldModel initialized with specified config and weights.
build_dataset: Build YOLO Dataset for training or validation.
set_text_embeddings: Set text embeddings for datasets to accelerate training.
generate_text_embeddings: Generate text embeddings for a list of text samples.
preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
Examples:
Initialize and train a YOLO World model
>>> from ultralytics.models.yolo.world import WorldTrainer
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
>>> trainer = WorldTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a WorldTrainer object with given arguments.
Args:
cfg (dict[str, Any]): Configuration for the trainer.
overrides (dict[str, Any], optional): Configuration overrides.
_callbacks (list[Any], optional): List of callback functions.
"""
if overrides is None:
overrides = {}
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
super().__init__(cfg, overrides, _callbacks)
self.text_embeddings = None
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
"""
Return WorldModel initialized with specified config and weights.
Args:
cfg (dict[str, Any] | str, optional): Model configuration.
weights (str, optional): Path to pretrained weights.
verbose (bool): Whether to display model info.
Returns:
(WorldModel): Initialized WorldModel.
"""
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
# NOTE: Following the official config, nc hard-coded to 80 for now.
model = WorldModel(
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
ch=self.data["channels"],
nc=min(self.data["nc"], 80),
verbose=verbose and RANK == -1,
)
if weights:
model.load(weights)
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
return model
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
"""
Build YOLO Dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`.
Returns:
(Any): YOLO dataset configured for training or validation.
"""
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
dataset = build_yolo_dataset(
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
)
if mode == "train":
self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
return dataset
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
"""
Set text embeddings for datasets to accelerate training by caching category names.
This method collects unique category names from all datasets, then generates and caches text embeddings
for these categories to improve training efficiency.
Args:
datasets (list[Any]): List of datasets from which to extract category names.
batch (int | None): Batch size used for processing.
Notes:
This method collects category names from datasets that have the 'category_names' attribute,
then uses the first dataset's image path to determine where to cache the generated text embeddings.
"""
text_embeddings = {}
for dataset in datasets:
if not hasattr(dataset, "category_names"):
continue
text_embeddings.update(
self.generate_text_embeddings(
list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
)
)
self.text_embeddings = text_embeddings
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
"""
Generate text embeddings for a list of text samples.
Args:
texts (list[str]): List of text samples to encode.
batch (int): Batch size for processing.
cache_dir (Path): Directory to save/load cached embeddings.
Returns:
(dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
"""
model = "clip:ViT-B/32"
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
if cache_path.exists():
LOGGER.info(f"Reading existed cache from '{cache_path}'")
txt_map = torch.load(cache_path, map_location=self.device)
if sorted(txt_map.keys()) == sorted(texts):
return txt_map
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
assert self.model is not None
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
torch.save(txt_map, cache_path)
return txt_map
def preprocess_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess a batch of images and text for YOLOWorld training."""
batch = DetectionTrainer.preprocess_batch(self, batch)
# Add text features
texts = list(itertools.chain(*batch["texts"]))
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
self.device, non_blocking=self.device.type == "cuda"
)
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
return batch

Some files were not shown because too many files have changed in this diff Show More