init commit
This commit is contained in:
9
ultralytics/models/__init__.py
Normal file
9
ultralytics/models/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .fastsam import FastSAM
|
||||
from .nas import NAS
|
||||
from .rtdetr import RTDETR
|
||||
from .sam import SAM
|
||||
from .yolo import YOLO, YOLOE, YOLOWorld
|
||||
|
||||
__all__ = "YOLO", "RTDETR", "SAM", "FastSAM", "NAS", "YOLOWorld", "YOLOE" # allow simpler import
|
||||
BIN
ultralytics/models/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
7
ultralytics/models/fastsam/__init__.py
Normal file
7
ultralytics/models/fastsam/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
|
||||
BIN
ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/utils.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
81
ultralytics/models/fastsam/model.py
Normal file
81
ultralytics/models/fastsam/model.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
|
||||
from .predict import FastSAMPredictor
|
||||
from .val import FastSAMValidator
|
||||
|
||||
|
||||
class FastSAM(Model):
|
||||
"""
|
||||
FastSAM model interface for segment anything tasks.
|
||||
|
||||
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
||||
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
|
||||
|
||||
Attributes:
|
||||
model (str): Path to the pre-trained FastSAM model file.
|
||||
task (str): The task type, set to "segment" for FastSAM models.
|
||||
|
||||
Methods:
|
||||
predict: Perform segmentation prediction on image or video source with optional prompts.
|
||||
task_map: Returns mapping of segment task to predictor and validator classes.
|
||||
|
||||
Examples:
|
||||
Initialize FastSAM model and run prediction
|
||||
>>> from ultralytics import FastSAM
|
||||
>>> model = FastSAM("FastSAM-x.pt")
|
||||
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
||||
|
||||
Run prediction with bounding box prompts
|
||||
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "FastSAM-x.pt"):
|
||||
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
||||
if str(model) == "FastSAM.pt":
|
||||
model = "FastSAM-x.pt"
|
||||
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source,
|
||||
stream: bool = False,
|
||||
bboxes: list | None = None,
|
||||
points: list | None = None,
|
||||
labels: list | None = None,
|
||||
texts: list | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Perform segmentation prediction on image or video source.
|
||||
|
||||
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
||||
prompts and passes them to the parent class predict method for processing.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
||||
or numpy array.
|
||||
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
||||
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
|
||||
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
|
||||
labels (list, optional): Class labels for prompted segmentation.
|
||||
texts (list, optional): Text prompts for segmentation guidance.
|
||||
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the prediction results.
|
||||
"""
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
||||
181
ultralytics/models/fastsam/predict.py
Normal file
181
ultralytics/models/fastsam/predict.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, checks
|
||||
from ultralytics.utils.metrics import box_iou
|
||||
from ultralytics.utils.ops import scale_masks
|
||||
from ultralytics.utils.torch_utils import TORCH_1_10
|
||||
|
||||
from .utils import adjust_bboxes_to_image_border
|
||||
|
||||
|
||||
class FastSAMPredictor(SegmentationPredictor):
|
||||
"""
|
||||
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
||||
|
||||
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
||||
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
||||
single-class segmentation.
|
||||
|
||||
Attributes:
|
||||
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
|
||||
device (torch.device): Device on which model and tensors are processed.
|
||||
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
|
||||
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
||||
|
||||
Methods:
|
||||
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
prompt: Perform image segmentation inference based on various prompt types.
|
||||
set_prompts: Set prompts to be used during inference.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMPredictor with configuration and callbacks.
|
||||
|
||||
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
|
||||
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
|
||||
optimized for single-class segmentation.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides.
|
||||
_callbacks (list, optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.prompts = {}
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
img (torch.Tensor): Input image tensor that was fed to the model.
|
||||
orig_imgs (list[np.ndarray]): Original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): Processed results with prompts applied.
|
||||
"""
|
||||
bboxes = self.prompts.pop("bboxes", None)
|
||||
points = self.prompts.pop("points", None)
|
||||
labels = self.prompts.pop("labels", None)
|
||||
texts = self.prompts.pop("texts", None)
|
||||
results = super().postprocess(preds, img, orig_imgs)
|
||||
for result in results:
|
||||
full_box = torch.tensor(
|
||||
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
|
||||
)
|
||||
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
|
||||
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
||||
if idx.numel() != 0:
|
||||
result.boxes.xyxy[idx] = full_box
|
||||
|
||||
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
|
||||
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
||||
"""
|
||||
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
||||
|
||||
Args:
|
||||
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
|
||||
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
texts (str | list[str], optional): Textual prompts, a list containing string objects.
|
||||
|
||||
Returns:
|
||||
(list[Results]): Output results filtered and determined by the provided prompts.
|
||||
"""
|
||||
if bboxes is None and points is None and texts is None:
|
||||
return results
|
||||
prompt_results = []
|
||||
if not isinstance(results, list):
|
||||
results = [results]
|
||||
for result in results:
|
||||
if len(result) == 0:
|
||||
prompt_results.append(result)
|
||||
continue
|
||||
masks = result.masks.data
|
||||
if masks.shape[1:] != result.orig_shape:
|
||||
masks = scale_masks(masks[None], result.orig_shape)[0]
|
||||
# bboxes prompt
|
||||
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
||||
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
||||
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert len(labels) == len(points), (
|
||||
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
|
||||
)
|
||||
point_idx = (
|
||||
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
||||
if labels.sum() == 0 # all negative points
|
||||
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
)
|
||||
for point, label in zip(points, labels):
|
||||
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
|
||||
idx |= point_idx
|
||||
if texts is not None:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
crop_ims, filter_idx = [], []
|
||||
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
||||
x1, y1, x2, y2 = (int(x) for x in b)
|
||||
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
|
||||
filter_idx.append(i)
|
||||
continue
|
||||
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
||||
similarity = self._clip_inference(crop_ims, texts)
|
||||
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
||||
if len(filter_idx):
|
||||
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
||||
idx[text_idx] = True
|
||||
|
||||
prompt_results.append(result[idx])
|
||||
|
||||
return prompt_results
|
||||
|
||||
def _clip_inference(self, images, texts):
|
||||
"""
|
||||
Perform CLIP inference to calculate similarity between images and text prompts.
|
||||
|
||||
Args:
|
||||
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
||||
texts (list[str]): List of prompt texts, each should be a string object.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
|
||||
"""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
||||
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
||||
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
||||
tokenized_text = clip.tokenize(texts).to(self.device)
|
||||
image_features = self.clip_model.encode_image(images)
|
||||
text_features = self.clip_model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
||||
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts to be used during inference."""
|
||||
self.prompts = prompts
|
||||
24
ultralytics/models/fastsam/utils.py
Normal file
24
ultralytics/models/fastsam/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
image_shape (tuple): Image dimensions as (height, width).
|
||||
threshold (int): Pixel threshold for considering a box close to the border.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Adjusted bounding boxes with shape (N, 4).
|
||||
"""
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes that are close to image borders
|
||||
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
||||
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
||||
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
||||
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
||||
return boxes
|
||||
40
ultralytics/models/fastsam/val.py
Normal file
40
ultralytics/models/fastsam/val.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
|
||||
|
||||
class FastSAMValidator(SegmentationValidator):
|
||||
"""
|
||||
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
||||
|
||||
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
|
||||
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
||||
to avoid errors during validation.
|
||||
|
||||
Attributes:
|
||||
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
|
||||
save_dir (Path): The directory where validation results will be saved.
|
||||
args (SimpleNamespace): Additional arguments for customization of the validation process.
|
||||
_callbacks (list): List of callback functions to be invoked during validation.
|
||||
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (SimpleNamespace, optional): Configuration for the validator.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during validation.
|
||||
|
||||
Notes:
|
||||
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.args.task = "segment"
|
||||
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
||||
7
ultralytics/models/nas/__init__.py
Normal file
7
ultralytics/models/nas/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import NAS
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
__all__ = "NASPredictor", "NASValidator", "NAS"
|
||||
BIN
ultralytics/models/nas/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/nas/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/nas/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/nas/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/nas/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/nas/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/nas/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/nas/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
101
ultralytics/models/nas/model.py
Normal file
101
ultralytics/models/nas/model.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils import DEFAULT_CFG_DICT
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
from ultralytics.utils.patches import torch_load
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .predict import NASPredictor
|
||||
from .val import NASValidator
|
||||
|
||||
|
||||
class NAS(Model):
|
||||
"""
|
||||
YOLO-NAS model for object detection.
|
||||
|
||||
This class provides an interface for the YOLO-NAS models and extends the `Model` class from ultralytics engine.
|
||||
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The loaded YOLO-NAS model.
|
||||
task (str): The task type for the model, defaults to 'detect'.
|
||||
predictor (NASPredictor): The predictor instance for making predictions.
|
||||
validator (NASValidator): The validator instance for model validation.
|
||||
|
||||
Methods:
|
||||
info: Log model information and return model details.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import NAS
|
||||
>>> model = NAS("yolo_nas_s")
|
||||
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
||||
|
||||
Notes:
|
||||
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "yolo_nas_s.pt") -> None:
|
||||
"""Initialize the NAS model with the provided or default model."""
|
||||
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
|
||||
super().__init__(model, task="detect")
|
||||
|
||||
def _load(self, weights: str, task=None) -> None:
|
||||
"""
|
||||
Load an existing NAS model weights or create a new NAS model with pretrained weights.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the model weights file or model name.
|
||||
task (str, optional): Task type for the model.
|
||||
"""
|
||||
import super_gradients
|
||||
|
||||
suffix = Path(weights).suffix
|
||||
if suffix == ".pt":
|
||||
self.model = torch_load(attempt_download_asset(weights))
|
||||
elif suffix == "":
|
||||
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
|
||||
|
||||
# Override the forward method to ignore additional arguments
|
||||
def new_forward(x, *args, **kwargs):
|
||||
"""Ignore additional __call__ arguments."""
|
||||
return self.model._original_forward(x)
|
||||
|
||||
self.model._original_forward = self.model.forward
|
||||
self.model.forward = new_forward
|
||||
|
||||
# Standardize model attributes for compatibility
|
||||
self.model.fuse = lambda verbose=True: self.model
|
||||
self.model.stride = torch.tensor([32])
|
||||
self.model.names = dict(enumerate(self.model._class_names))
|
||||
self.model.is_fused = lambda: False # for info()
|
||||
self.model.yaml = {} # for info()
|
||||
self.model.pt_path = weights # for export()
|
||||
self.model.task = "detect" # for export()
|
||||
self.model.args = {**DEFAULT_CFG_DICT, **self.overrides} # for export()
|
||||
self.model.eval()
|
||||
|
||||
def info(self, detailed: bool = False, verbose: bool = True) -> dict[str, Any]:
|
||||
"""
|
||||
Log model information.
|
||||
|
||||
Args:
|
||||
detailed (bool): Show detailed information about model.
|
||||
verbose (bool): Controls verbosity.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Model information dictionary.
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Return a dictionary mapping tasks to respective predictor and validator classes."""
|
||||
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
|
||||
58
ultralytics/models/nas/predict.py
Normal file
58
ultralytics/models/nas/predict.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import ops
|
||||
|
||||
|
||||
class NASPredictor(DetectionPredictor):
|
||||
"""
|
||||
Ultralytics YOLO NAS Predictor for object detection.
|
||||
|
||||
This class extends the DetectionPredictor from ultralytics engine and is responsible for post-processing the
|
||||
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
|
||||
scaling the bounding boxes to fit the original image dimensions.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing including confidence
|
||||
threshold, IoU threshold, agnostic NMS flag, maximum detections, and class filtering options.
|
||||
model (torch.nn.Module): The YOLO NAS model used for inference.
|
||||
batch (list): Batch of inputs for processing.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import NAS
|
||||
>>> model = NAS("yolo_nas_s")
|
||||
>>> predictor = model.predictor
|
||||
|
||||
Assume that raw_preds, img, orig_imgs are available
|
||||
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
||||
|
||||
Notes:
|
||||
Typically, this class is not instantiated directly. It is used internally within the NAS class.
|
||||
"""
|
||||
|
||||
def postprocess(self, preds_in, img, orig_imgs):
|
||||
"""
|
||||
Postprocess NAS model predictions to generate final detection results.
|
||||
|
||||
This method takes raw predictions from a YOLO NAS model, converts bounding box formats, and applies
|
||||
post-processing operations to generate the final detection results compatible with Ultralytics
|
||||
result visualization and analysis tools.
|
||||
|
||||
Args:
|
||||
preds_in (list): Raw predictions from the NAS model, typically containing bounding boxes and class scores.
|
||||
img (torch.Tensor): Input image tensor that was fed to the model, with shape (B, C, H, W).
|
||||
orig_imgs (list | torch.Tensor | np.ndarray): Original images before preprocessing, used for scaling
|
||||
coordinates back to original dimensions.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the processed predictions for each image in the batch.
|
||||
|
||||
Examples:
|
||||
>>> predictor = NAS("yolo_nas_s").predictor
|
||||
>>> results = predictor.postprocess(raw_preds, img, orig_imgs)
|
||||
"""
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding boxes from xyxy to xywh format
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with class scores
|
||||
return super().postprocess(preds, img, orig_imgs)
|
||||
39
ultralytics/models/nas/val.py
Normal file
39
ultralytics/models/nas/val.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import ops
|
||||
|
||||
__all__ = ["NASValidator"]
|
||||
|
||||
|
||||
class NASValidator(DetectionValidator):
|
||||
"""
|
||||
Ultralytics YOLO NAS Validator for object detection.
|
||||
|
||||
Extends DetectionValidator from the Ultralytics models package and is designed to post-process the raw predictions
|
||||
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
|
||||
ultimately producing the final detections.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU
|
||||
thresholds.
|
||||
lb (torch.Tensor): Optional tensor for multilabel NMS.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import NAS
|
||||
>>> model = NAS("yolo_nas_s")
|
||||
>>> validator = model.validator
|
||||
>>> # Assumes that raw_preds are available
|
||||
>>> final_preds = validator.postprocess(raw_preds)
|
||||
|
||||
Notes:
|
||||
This class is generally not instantiated directly but is used internally within the NAS class.
|
||||
"""
|
||||
|
||||
def postprocess(self, preds_in):
|
||||
"""Apply Non-maximum suppression to prediction outputs."""
|
||||
boxes = ops.xyxy2xywh(preds_in[0][0]) # Convert bounding box format from xyxy to xywh
|
||||
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1) # Concatenate boxes with scores and permute
|
||||
return super().postprocess(preds)
|
||||
7
ultralytics/models/rtdetr/__init__.py
Normal file
7
ultralytics/models/rtdetr/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import RTDETR
|
||||
from .predict import RTDETRPredictor
|
||||
from .val import RTDETRValidator
|
||||
|
||||
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
|
||||
BIN
ultralytics/models/rtdetr/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/rtdetr/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/rtdetr/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/rtdetr/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/rtdetr/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/rtdetr/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/rtdetr/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/rtdetr/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/rtdetr/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/rtdetr/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
66
ultralytics/models/rtdetr/model.py
Normal file
66
ultralytics/models/rtdetr/model.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector.
|
||||
|
||||
RT-DETR offers real-time performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT.
|
||||
It features an efficient hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
|
||||
|
||||
References:
|
||||
https://arxiv.org/pdf/2304.08069.pdf
|
||||
"""
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.utils.torch_utils import TORCH_1_11
|
||||
|
||||
from .predict import RTDETRPredictor
|
||||
from .train import RTDETRTrainer
|
||||
from .val import RTDETRValidator
|
||||
|
||||
|
||||
class RTDETR(Model):
|
||||
"""
|
||||
Interface for Baidu's RT-DETR model, a Vision Transformer-based real-time object detector.
|
||||
|
||||
This model provides real-time performance with high accuracy. It supports efficient hybrid encoding, IoU-aware
|
||||
query selection, and adaptable inference speed.
|
||||
|
||||
Attributes:
|
||||
model (str): Path to the pre-trained model.
|
||||
|
||||
Methods:
|
||||
task_map: Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||
|
||||
Examples:
|
||||
Initialize RT-DETR with a pre-trained model
|
||||
>>> from ultralytics import RTDETR
|
||||
>>> model = RTDETR("rtdetr-l.pt")
|
||||
>>> results = model("image.jpg")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "rtdetr-l.pt") -> None:
|
||||
"""
|
||||
Initialize the RT-DETR model with the given pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained model. Supports .pt, .yaml, and .yml formats.
|
||||
"""
|
||||
assert TORCH_1_11, "RTDETR requires torch>=1.11"
|
||||
super().__init__(model=model, task="detect")
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict:
|
||||
"""
|
||||
Return a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
|
||||
"""
|
||||
return {
|
||||
"detect": {
|
||||
"predictor": RTDETRPredictor,
|
||||
"validator": RTDETRValidator,
|
||||
"trainer": RTDETRTrainer,
|
||||
"model": RTDETRDetectionModel,
|
||||
}
|
||||
}
|
||||
92
ultralytics/models/rtdetr/predict.py
Normal file
92
ultralytics/models/rtdetr/predict.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data.augment import LetterBox
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import ops
|
||||
|
||||
|
||||
class RTDETRPredictor(BasePredictor):
|
||||
"""
|
||||
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions.
|
||||
|
||||
This class leverages Vision Transformers to provide real-time object detection while maintaining high accuracy.
|
||||
It supports key features like efficient hybrid encoding and IoU-aware query selection.
|
||||
|
||||
Attributes:
|
||||
imgsz (int): Image size for inference (must be square and scale-filled).
|
||||
args (dict): Argument overrides for the predictor.
|
||||
model (torch.nn.Module): The loaded RT-DETR model.
|
||||
batch (list): Current batch of processed inputs.
|
||||
|
||||
Methods:
|
||||
postprocess: Postprocess raw model predictions to generate bounding boxes and confidence scores.
|
||||
pre_transform: Pre-transform input images before feeding them into the model for inference.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.rtdetr import RTDETRPredictor
|
||||
>>> args = dict(model="rtdetr-l.pt", source=ASSETS)
|
||||
>>> predictor = RTDETRPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
|
||||
|
||||
The method filters detections based on confidence and class if specified in `self.args`. It converts
|
||||
model predictions to Results objects containing properly scaled bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (list | tuple): List of [predictions, extra] from the model, where predictions contain
|
||||
bounding boxes and scores.
|
||||
img (torch.Tensor): Processed input images with shape (N, 3, H, W).
|
||||
orig_imgs (list | torch.Tensor): Original, unprocessed images.
|
||||
|
||||
Returns:
|
||||
results (list[Results]): A list of Results objects containing the post-processed bounding boxes,
|
||||
confidence scores, and class labels.
|
||||
"""
|
||||
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||
preds = [preds, None]
|
||||
|
||||
nd = preds[0].shape[-1]
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
results = []
|
||||
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
|
||||
bbox = ops.xywh2xyxy(bbox)
|
||||
max_score, cls = score.max(-1, keepdim=True) # (300, 1)
|
||||
idx = max_score.squeeze(-1) > self.args.conf # (300, )
|
||||
if self.args.classes is not None:
|
||||
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
|
||||
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
|
||||
pred = pred[pred[:, 4].argsort(descending=True)][: self.args.max_det]
|
||||
oh, ow = orig_img.shape[:2]
|
||||
pred[..., [0, 2]] *= ow # scale x coordinates to original width
|
||||
pred[..., [1, 3]] *= oh # scale y coordinates to original height
|
||||
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
|
||||
return results
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Pre-transform input images before feeding them into the model for inference.
|
||||
|
||||
The input images are letterboxed to ensure a square aspect ratio and scale-filled. The size must be square
|
||||
(640) and scale_filled.
|
||||
|
||||
Args:
|
||||
im (list[np.ndarray] | torch.Tensor): Input images of shape (N, 3, H, W) for tensor,
|
||||
[(H, W, 3) x N] for list.
|
||||
|
||||
Returns:
|
||||
(list): List of pre-transformed images ready for model inference.
|
||||
"""
|
||||
letterbox = LetterBox(self.imgsz, auto=False, scale_fill=True)
|
||||
return [letterbox(image=x) for x in im]
|
||||
92
ultralytics/models/rtdetr/train.py
Normal file
92
ultralytics/models/rtdetr/train.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import RTDETRDetectionModel
|
||||
from ultralytics.utils import RANK, colorstr
|
||||
|
||||
from .val import RTDETRDataset, RTDETRValidator
|
||||
|
||||
|
||||
class RTDETRTrainer(DetectionTrainer):
|
||||
"""
|
||||
Trainer class for the RT-DETR model developed by Baidu for real-time object detection.
|
||||
|
||||
This class extends the DetectionTrainer class for YOLO to adapt to the specific features and architecture of RT-DETR.
|
||||
The model leverages Vision Transformers and has capabilities like IoU-aware query selection and adaptable inference
|
||||
speed.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of the loss components used for training.
|
||||
data (dict): Dataset configuration containing class count and other parameters.
|
||||
args (dict): Training arguments and hyperparameters.
|
||||
save_dir (Path): Directory to save training results.
|
||||
test_loader (DataLoader): DataLoader for validation/testing data.
|
||||
|
||||
Methods:
|
||||
get_model: Initialize and return an RT-DETR model for object detection tasks.
|
||||
build_dataset: Build and return an RT-DETR dataset for training or validation.
|
||||
get_validator: Return a DetectionValidator suitable for RT-DETR model validation.
|
||||
|
||||
Notes:
|
||||
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
|
||||
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.rtdetr.train import RTDETRTrainer
|
||||
>>> args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
|
||||
>>> trainer = RTDETRTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def get_model(self, cfg: dict | None = None, weights: str | None = None, verbose: bool = True):
|
||||
"""
|
||||
Initialize and return an RT-DETR model for object detection tasks.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Model configuration.
|
||||
weights (str, optional): Path to pre-trained model weights.
|
||||
verbose (bool): Verbose logging if True.
|
||||
|
||||
Returns:
|
||||
(RTDETRDetectionModel): Initialized model.
|
||||
"""
|
||||
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None):
|
||||
"""
|
||||
Build and return an RT-DETR dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): Dataset mode, either 'train' or 'val'.
|
||||
batch (int, optional): Batch size for rectangle training.
|
||||
|
||||
Returns:
|
||||
(RTDETRDataset): Dataset object for the specific mode.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=mode == "train",
|
||||
hyp=self.args,
|
||||
rect=False,
|
||||
cache=self.args.cache or None,
|
||||
single_cls=self.args.single_cls or False,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
classes=self.args.classes,
|
||||
data=self.data,
|
||||
fraction=self.args.fraction if mode == "train" else 1.0,
|
||||
)
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a DetectionValidator suitable for RT-DETR model validation."""
|
||||
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
|
||||
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
|
||||
218
ultralytics/models/rtdetr/val.py
Normal file
218
ultralytics/models/rtdetr/val.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import YOLODataset
|
||||
from ultralytics.data.augment import Compose, Format, v8_transforms
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import colorstr, ops
|
||||
|
||||
__all__ = ("RTDETRValidator",) # tuple or list
|
||||
|
||||
|
||||
class RTDETRDataset(YOLODataset):
|
||||
"""
|
||||
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
|
||||
|
||||
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
|
||||
real-time detection and tracking tasks.
|
||||
|
||||
Attributes:
|
||||
augment (bool): Whether to apply data augmentation.
|
||||
rect (bool): Whether to use rectangular training.
|
||||
use_segments (bool): Whether to use segmentation masks.
|
||||
use_keypoints (bool): Whether to use keypoint annotations.
|
||||
imgsz (int): Target image size for training.
|
||||
|
||||
Methods:
|
||||
load_image: Load one image from dataset index.
|
||||
build_transforms: Build transformation pipeline for the dataset.
|
||||
|
||||
Examples:
|
||||
Initialize an RT-DETR dataset
|
||||
>>> dataset = RTDETRDataset(img_path="path/to/images", imgsz=640)
|
||||
>>> image, hw = dataset.load_image(0)
|
||||
"""
|
||||
|
||||
def __init__(self, *args, data=None, **kwargs):
|
||||
"""
|
||||
Initialize the RTDETRDataset class by inheriting from the YOLODataset class.
|
||||
|
||||
This constructor sets up a dataset specifically optimized for the RT-DETR (Real-Time DEtection and TRacking)
|
||||
model, building upon the base YOLODataset functionality.
|
||||
|
||||
Args:
|
||||
*args (Any): Variable length argument list passed to the parent YOLODataset class.
|
||||
data (dict | None): Dictionary containing dataset information. If None, default values will be used.
|
||||
**kwargs (Any): Additional keyword arguments passed to the parent YOLODataset class.
|
||||
"""
|
||||
super().__init__(*args, data=data, **kwargs)
|
||||
|
||||
def load_image(self, i, rect_mode=False):
|
||||
"""
|
||||
Load one image from dataset index 'i'.
|
||||
|
||||
Args:
|
||||
i (int): Index of the image to load.
|
||||
rect_mode (bool, optional): Whether to use rectangular mode for batch inference.
|
||||
|
||||
Returns:
|
||||
im (torch.Tensor): The loaded image.
|
||||
resized_hw (tuple): Height and width of the resized image with shape (2,).
|
||||
|
||||
Examples:
|
||||
Load an image from the dataset
|
||||
>>> dataset = RTDETRDataset(img_path="path/to/images")
|
||||
>>> image, hw = dataset.load_image(0)
|
||||
"""
|
||||
return super().load_image(i=i, rect_mode=rect_mode)
|
||||
|
||||
def build_transforms(self, hyp=None):
|
||||
"""
|
||||
Build transformation pipeline for the dataset.
|
||||
|
||||
Args:
|
||||
hyp (dict, optional): Hyperparameters for transformations.
|
||||
|
||||
Returns:
|
||||
(Compose): Composition of transformation functions.
|
||||
"""
|
||||
if self.augment:
|
||||
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
|
||||
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
|
||||
hyp.cutmix = hyp.cutmix if self.augment and not self.rect else 0.0
|
||||
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
|
||||
else:
|
||||
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scale_fill=True)])
|
||||
transforms = Compose([])
|
||||
transforms.append(
|
||||
Format(
|
||||
bbox_format="xywh",
|
||||
normalize=True,
|
||||
return_mask=self.use_segments,
|
||||
return_keypoint=self.use_keypoints,
|
||||
batch_idx=True,
|
||||
mask_ratio=hyp.mask_ratio,
|
||||
mask_overlap=hyp.overlap_mask,
|
||||
)
|
||||
)
|
||||
return transforms
|
||||
|
||||
|
||||
class RTDETRValidator(DetectionValidator):
|
||||
"""
|
||||
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
|
||||
the RT-DETR (Real-Time DETR) object detection model.
|
||||
|
||||
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
|
||||
post-processing, and updates evaluation metrics accordingly.
|
||||
|
||||
Attributes:
|
||||
args (Namespace): Configuration arguments for validation.
|
||||
data (dict): Dataset configuration dictionary.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build an RTDETR Dataset for validation.
|
||||
postprocess: Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Examples:
|
||||
Initialize and run RT-DETR validation
|
||||
>>> from ultralytics.models.rtdetr import RTDETRValidator
|
||||
>>> args = dict(model="rtdetr-l.pt", data="coco8.yaml")
|
||||
>>> validator = RTDETRValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
For further details on the attributes and methods, refer to the parent DetectionValidator class.
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path, mode="val", batch=None):
|
||||
"""
|
||||
Build an RTDETR Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str, optional): `train` mode or `val` mode, users are able to customize different augmentations for
|
||||
each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`.
|
||||
|
||||
Returns:
|
||||
(RTDETRDataset): Dataset configured for RT-DETR validation.
|
||||
"""
|
||||
return RTDETRDataset(
|
||||
img_path=img_path,
|
||||
imgsz=self.args.imgsz,
|
||||
batch_size=batch,
|
||||
augment=False, # no augmentation
|
||||
hyp=self.args,
|
||||
rect=False, # no rect
|
||||
cache=self.args.cache or None,
|
||||
prefix=colorstr(f"{mode}: "),
|
||||
data=self.data,
|
||||
)
|
||||
|
||||
def postprocess(
|
||||
self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]
|
||||
) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor | list | tuple): Raw predictions from the model. If tensor, should have shape
|
||||
(batch_size, num_predictions, num_classes + 4) where last dimension contains bbox coords and class scores.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): List of dictionaries for each image, each containing:
|
||||
- 'bboxes': Tensor of shape (N, 4) with bounding box coordinates
|
||||
- 'conf': Tensor of shape (N,) with confidence scores
|
||||
- 'cls': Tensor of shape (N,) with class indices
|
||||
"""
|
||||
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
|
||||
preds = [preds, None]
|
||||
|
||||
bs, _, nd = preds[0].shape
|
||||
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
|
||||
bboxes *= self.args.imgsz
|
||||
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
|
||||
for i, bbox in enumerate(bboxes): # (300, 4)
|
||||
bbox = ops.xywh2xyxy(bbox)
|
||||
score, cls = scores[i].max(-1) # (300, )
|
||||
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
|
||||
# Sort by confidence to correctly get internal metrics
|
||||
pred = pred[score.argsort(descending=True)]
|
||||
outputs[i] = pred[score > self.args.conf]
|
||||
|
||||
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5]} for x in outputs]
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Serialize YOLO predictions to COCO json format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = predn["bboxes"].clone()
|
||||
box[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
|
||||
box[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
|
||||
box = ops.xyxy2xywh(box) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"bbox": [round(x, 3) for x in b],
|
||||
"score": round(s, 5),
|
||||
}
|
||||
)
|
||||
12
ultralytics/models/sam/__init__.py
Normal file
12
ultralytics/models/sam/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import SAM
|
||||
from .predict import Predictor, SAM2DynamicInteractivePredictor, SAM2Predictor, SAM2VideoPredictor
|
||||
|
||||
__all__ = (
|
||||
"SAM",
|
||||
"Predictor",
|
||||
"SAM2Predictor",
|
||||
"SAM2VideoPredictor",
|
||||
"SAM2DynamicInteractivePredictor",
|
||||
) # tuple or list of exportable items
|
||||
BIN
ultralytics/models/sam/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/sam/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/sam/__pycache__/amg.cpython-310.pyc
Normal file
BIN
ultralytics/models/sam/__pycache__/amg.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/sam/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/sam/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/sam/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/sam/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
281
ultralytics/models/sam/amg.py
Normal file
281
ultralytics/models/sam/amg.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Generator
|
||||
from itertools import product
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def is_box_near_crop_edge(
|
||||
boxes: torch.Tensor, crop_box: list[int], orig_box: list[int], atol: float = 20.0
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Determine if bounding boxes are near the edge of a cropped image region using a specified tolerance.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes in XYXY format.
|
||||
crop_box (list[int]): Crop box coordinates in [x0, y0, x1, y1] format.
|
||||
orig_box (list[int]): Original image box coordinates in [x0, y0, x1, y1] format.
|
||||
atol (float, optional): Absolute tolerance for edge proximity detection.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Boolean tensor indicating which boxes are near crop edges.
|
||||
|
||||
Examples:
|
||||
>>> boxes = torch.tensor([[10, 10, 50, 50], [100, 100, 150, 150]])
|
||||
>>> crop_box = [0, 0, 200, 200]
|
||||
>>> orig_box = [0, 0, 300, 300]
|
||||
>>> near_edge = is_box_near_crop_edge(boxes, crop_box, orig_box, atol=20.0)
|
||||
"""
|
||||
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
||||
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
||||
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
||||
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
||||
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
||||
return torch.any(near_crop_edge, dim=1)
|
||||
|
||||
|
||||
def batch_iterator(batch_size: int, *args) -> Generator[list[Any]]:
|
||||
"""
|
||||
Yield batches of data from input arguments with specified batch size for efficient processing.
|
||||
|
||||
This function takes a batch size and any number of iterables, then yields batches of elements from those
|
||||
iterables. All input iterables must have the same length.
|
||||
|
||||
Args:
|
||||
batch_size (int): Size of each batch to yield.
|
||||
*args (Any): Variable length input iterables to batch. All iterables must have the same length.
|
||||
|
||||
Yields:
|
||||
(list[Any]): A list of batched elements from each input iterable.
|
||||
|
||||
Examples:
|
||||
>>> data = [1, 2, 3, 4, 5]
|
||||
>>> labels = ["a", "b", "c", "d", "e"]
|
||||
>>> for batch in batch_iterator(2, data, labels):
|
||||
... print(batch)
|
||||
[[1, 2], ['a', 'b']]
|
||||
[[3, 4], ['c', 'd']]
|
||||
[[5], ['e']]
|
||||
"""
|
||||
assert args and all(len(a) == len(args[0]) for a in args), "Batched iteration must have same-size inputs."
|
||||
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
||||
for b in range(n_batches):
|
||||
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
||||
|
||||
|
||||
def calculate_stability_score(masks: torch.Tensor, mask_threshold: float, threshold_offset: float) -> torch.Tensor:
|
||||
"""
|
||||
Compute the stability score for a batch of masks.
|
||||
|
||||
The stability score is the IoU between binary masks obtained by thresholding the predicted mask logits at
|
||||
high and low values.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Batch of predicted mask logits.
|
||||
mask_threshold (float): Threshold value for creating binary masks.
|
||||
threshold_offset (float): Offset applied to the threshold for creating high and low binary masks.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Stability scores for each mask in the batch.
|
||||
|
||||
Notes:
|
||||
- One mask is always contained inside the other.
|
||||
- Memory is saved by preventing unnecessary cast to torch.int64.
|
||||
|
||||
Examples:
|
||||
>>> masks = torch.rand(10, 256, 256) # Batch of 10 masks
|
||||
>>> mask_threshold = 0.5
|
||||
>>> threshold_offset = 0.1
|
||||
>>> stability_scores = calculate_stability_score(masks, mask_threshold, threshold_offset)
|
||||
"""
|
||||
intersections = (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
unions = (masks > (mask_threshold - threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32)
|
||||
return intersections / unions
|
||||
|
||||
|
||||
def build_point_grid(n_per_side: int) -> np.ndarray:
|
||||
"""Generate a 2D grid of evenly spaced points in the range [0,1]x[0,1] for image segmentation tasks."""
|
||||
offset = 1 / (2 * n_per_side)
|
||||
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
||||
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
||||
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
||||
return np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
||||
|
||||
|
||||
def build_all_layer_point_grids(n_per_side: int, n_layers: int, scale_per_layer: int) -> list[np.ndarray]:
|
||||
"""Generate point grids for multiple crop layers with varying scales and densities."""
|
||||
return [build_point_grid(int(n_per_side / (scale_per_layer**i))) for i in range(n_layers + 1)]
|
||||
|
||||
|
||||
def generate_crop_boxes(
|
||||
im_size: tuple[int, ...], n_layers: int, overlap_ratio: float
|
||||
) -> tuple[list[list[int]], list[int]]:
|
||||
"""
|
||||
Generate crop boxes of varying sizes for multiscale image processing, with layered overlapping regions.
|
||||
|
||||
Args:
|
||||
im_size (tuple[int, ...]): Height and width of the input image.
|
||||
n_layers (int): Number of layers to generate crop boxes for.
|
||||
overlap_ratio (float): Ratio of overlap between adjacent crop boxes.
|
||||
|
||||
Returns:
|
||||
crop_boxes (list[list[int]]): List of crop boxes in [x0, y0, x1, y1] format.
|
||||
layer_idxs (list[int]): List of layer indices corresponding to each crop box.
|
||||
|
||||
Examples:
|
||||
>>> im_size = (800, 1200) # Height, width
|
||||
>>> n_layers = 3
|
||||
>>> overlap_ratio = 0.25
|
||||
>>> crop_boxes, layer_idxs = generate_crop_boxes(im_size, n_layers, overlap_ratio)
|
||||
"""
|
||||
crop_boxes, layer_idxs = [], []
|
||||
im_h, im_w = im_size
|
||||
short_side = min(im_h, im_w)
|
||||
|
||||
# Original image
|
||||
crop_boxes.append([0, 0, im_w, im_h])
|
||||
layer_idxs.append(0)
|
||||
|
||||
def crop_len(orig_len, n_crops, overlap):
|
||||
"""Calculate the length of each crop given the original length, number of crops, and overlap."""
|
||||
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
||||
|
||||
for i_layer in range(n_layers):
|
||||
n_crops_per_side = 2 ** (i_layer + 1)
|
||||
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
||||
|
||||
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
||||
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
||||
|
||||
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
||||
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
||||
|
||||
# Crops in XYWH format
|
||||
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
||||
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
||||
crop_boxes.append(box)
|
||||
layer_idxs.append(i_layer + 1)
|
||||
|
||||
return crop_boxes, layer_idxs
|
||||
|
||||
|
||||
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
||||
"""Uncrop bounding boxes by adding the crop box offset to their coordinates."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
||||
# Check if boxes has a channel dimension
|
||||
if len(boxes.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return boxes + offset
|
||||
|
||||
|
||||
def uncrop_points(points: torch.Tensor, crop_box: list[int]) -> torch.Tensor:
|
||||
"""Uncrop points by adding the crop box offset to their coordinates."""
|
||||
x0, y0, _, _ = crop_box
|
||||
offset = torch.tensor([[x0, y0]], device=points.device)
|
||||
# Check if points has a channel dimension
|
||||
if len(points.shape) == 3:
|
||||
offset = offset.unsqueeze(1)
|
||||
return points + offset
|
||||
|
||||
|
||||
def uncrop_masks(masks: torch.Tensor, crop_box: list[int], orig_h: int, orig_w: int) -> torch.Tensor:
|
||||
"""Uncrop masks by padding them to the original image size, handling coordinate transformations."""
|
||||
x0, y0, x1, y1 = crop_box
|
||||
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
||||
return masks
|
||||
# Coordinate transform masks
|
||||
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
||||
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
||||
return torch.nn.functional.pad(masks, pad, value=0)
|
||||
|
||||
|
||||
def remove_small_regions(mask: np.ndarray, area_thresh: float, mode: str) -> tuple[np.ndarray, bool]:
|
||||
"""
|
||||
Remove small disconnected regions or holes in a mask based on area threshold and mode.
|
||||
|
||||
Args:
|
||||
mask (np.ndarray): Binary mask to process.
|
||||
area_thresh (float): Area threshold below which regions will be removed.
|
||||
mode (str): Processing mode, either 'holes' to fill small holes or 'islands' to remove small disconnected
|
||||
regions.
|
||||
|
||||
Returns:
|
||||
processed_mask (np.ndarray): Processed binary mask with small regions removed.
|
||||
modified (bool): Whether any regions were modified.
|
||||
|
||||
Examples:
|
||||
>>> mask = np.zeros((100, 100), dtype=np.bool_)
|
||||
>>> mask[40:60, 40:60] = True # Create a square
|
||||
>>> mask[45:55, 45:55] = False # Create a hole
|
||||
>>> processed_mask, modified = remove_small_regions(mask, 50, "holes")
|
||||
"""
|
||||
import cv2 # type: ignore
|
||||
|
||||
assert mode in {"holes", "islands"}, f"Provided mode {mode} is invalid"
|
||||
correct_holes = mode == "holes"
|
||||
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
||||
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
||||
sizes = stats[:, -1][1:] # Row 0 is background label
|
||||
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
||||
if not small_regions:
|
||||
return mask, False
|
||||
fill_labels = [0] + small_regions
|
||||
if not correct_holes:
|
||||
# If every region is below threshold, keep largest
|
||||
fill_labels = [i for i in range(n_labels) if i not in fill_labels] or [int(np.argmax(sizes)) + 1]
|
||||
mask = np.isin(regions, fill_labels)
|
||||
return mask, True
|
||||
|
||||
|
||||
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate bounding boxes in XYXY format around binary masks.
|
||||
|
||||
Args:
|
||||
masks (torch.Tensor): Binary masks with shape (B, H, W) or (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Bounding boxes in XYXY format with shape (B, 4) or (B, C, 4).
|
||||
|
||||
Notes:
|
||||
- Handles empty masks by returning zero boxes.
|
||||
- Preserves input tensor dimensions in the output.
|
||||
"""
|
||||
# torch.max below raises an error on empty inputs, just skip in this case
|
||||
if torch.numel(masks) == 0:
|
||||
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
||||
|
||||
# Normalize shape to CxHxW
|
||||
shape = masks.shape
|
||||
h, w = shape[-2:]
|
||||
masks = masks.flatten(0, -3) if len(shape) > 2 else masks.unsqueeze(0)
|
||||
# Get top and bottom edges
|
||||
in_height, _ = torch.max(masks, dim=-1)
|
||||
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
||||
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
||||
in_height_coords = in_height_coords + h * (~in_height)
|
||||
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
||||
|
||||
# Get left and right edges
|
||||
in_width, _ = torch.max(masks, dim=-2)
|
||||
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
||||
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
||||
in_width_coords = in_width_coords + w * (~in_width)
|
||||
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
||||
|
||||
# If the mask is empty the right edge will be to the left of the left edge.
|
||||
# Replace these boxes with [0, 0, 0, 0]
|
||||
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
||||
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
||||
out = out * (~empty_filter).unsqueeze(-1)
|
||||
|
||||
# Return to original shape
|
||||
return out.reshape(*shape[:-2], 4) if len(shape) > 2 else out[0]
|
||||
358
ultralytics/models/sam/build.py
Normal file
358
ultralytics/models/sam/build.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils.downloads import attempt_download_asset
|
||||
|
||||
from .modules.decoders import MaskDecoder
|
||||
from .modules.encoders import FpnNeck, Hiera, ImageEncoder, ImageEncoderViT, MemoryEncoder, PromptEncoder
|
||||
from .modules.memory_attention import MemoryAttention, MemoryAttentionLayer
|
||||
from .modules.sam import SAM2Model, SAMModel
|
||||
from .modules.tiny_encoder import TinyViT
|
||||
from .modules.transformer import TwoWayTransformer
|
||||
|
||||
|
||||
def build_sam_vit_h(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) h-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_depth=32,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[7, 15, 23, 31],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_l(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) l-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=1024,
|
||||
encoder_depth=24,
|
||||
encoder_num_heads=16,
|
||||
encoder_global_attn_indexes=[5, 11, 17, 23],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam_vit_b(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model (SAM) b-size model with specified encoder parameters."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=768,
|
||||
encoder_depth=12,
|
||||
encoder_num_heads=12,
|
||||
encoder_global_attn_indexes=[2, 5, 8, 11],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_mobile_sam(checkpoint=None):
|
||||
"""Build and return a Mobile Segment Anything Model (Mobile-SAM) for efficient image segmentation."""
|
||||
return _build_sam(
|
||||
encoder_embed_dim=[64, 128, 160, 320],
|
||||
encoder_depth=[2, 2, 6, 2],
|
||||
encoder_num_heads=[2, 4, 5, 10],
|
||||
encoder_global_attn_indexes=None,
|
||||
mobile_sam=True,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_t(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model 2 (SAM2) tiny-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 7, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[5, 7, 9],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_s(checkpoint=None):
|
||||
"""Build and return a small-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=96,
|
||||
encoder_stages=[1, 2, 11, 2],
|
||||
encoder_num_heads=1,
|
||||
encoder_global_att_blocks=[7, 10, 13],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_backbone_channel_list=[768, 384, 192, 96],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_b(checkpoint=None):
|
||||
"""Build and return a Segment Anything Model 2 (SAM2) base-size model with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=112,
|
||||
encoder_stages=[2, 3, 16, 3],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[12, 16, 20],
|
||||
encoder_window_spec=[8, 4, 14, 7],
|
||||
encoder_window_spatial_size=[14, 14],
|
||||
encoder_backbone_channel_list=[896, 448, 224, 112],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_sam2_l(checkpoint=None):
|
||||
"""Build and return a large-size Segment Anything Model 2 (SAM2) with specified architecture parameters."""
|
||||
return _build_sam2(
|
||||
encoder_embed_dim=144,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[23, 33, 43],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_sam(
|
||||
encoder_embed_dim,
|
||||
encoder_depth,
|
||||
encoder_num_heads,
|
||||
encoder_global_attn_indexes,
|
||||
checkpoint=None,
|
||||
mobile_sam=False,
|
||||
):
|
||||
"""
|
||||
Build a Segment Anything Model (SAM) with specified encoder parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int | list[int]): Embedding dimension for the encoder.
|
||||
encoder_depth (int | list[int]): Depth of the encoder.
|
||||
encoder_num_heads (int | list[int]): Number of attention heads in the encoder.
|
||||
encoder_global_attn_indexes (list[int] | None): Indexes for global attention in the encoder.
|
||||
checkpoint (str | None, optional): Path to the model checkpoint file.
|
||||
mobile_sam (bool, optional): Whether to build a Mobile-SAM model.
|
||||
|
||||
Returns:
|
||||
(SAMModel): A Segment Anything Model instance with the specified architecture.
|
||||
|
||||
Examples:
|
||||
>>> sam = _build_sam(768, 12, 12, [2, 5, 8, 11])
|
||||
>>> sam = _build_sam([64, 128, 160, 320], [2, 2, 6, 2], [2, 4, 5, 10], None, mobile_sam=True)
|
||||
"""
|
||||
prompt_embed_dim = 256
|
||||
image_size = 1024
|
||||
vit_patch_size = 16
|
||||
image_embedding_size = image_size // vit_patch_size
|
||||
image_encoder = (
|
||||
TinyViT(
|
||||
img_size=1024,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dims=encoder_embed_dim,
|
||||
depths=encoder_depth,
|
||||
num_heads=encoder_num_heads,
|
||||
window_sizes=[7, 7, 14, 7],
|
||||
mlp_ratio=4.0,
|
||||
drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
use_checkpoint=False,
|
||||
mbconv_expand_ratio=4.0,
|
||||
local_conv_size=3,
|
||||
layer_lr_decay=0.8,
|
||||
)
|
||||
if mobile_sam
|
||||
else ImageEncoderViT(
|
||||
depth=encoder_depth,
|
||||
embed_dim=encoder_embed_dim,
|
||||
img_size=image_size,
|
||||
mlp_ratio=4,
|
||||
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
||||
num_heads=encoder_num_heads,
|
||||
patch_size=vit_patch_size,
|
||||
qkv_bias=True,
|
||||
use_rel_pos=True,
|
||||
global_attn_indexes=encoder_global_attn_indexes,
|
||||
window_size=14,
|
||||
out_chans=prompt_embed_dim,
|
||||
)
|
||||
)
|
||||
sam = SAMModel(
|
||||
image_encoder=image_encoder,
|
||||
prompt_encoder=PromptEncoder(
|
||||
embed_dim=prompt_embed_dim,
|
||||
image_embedding_size=(image_embedding_size, image_embedding_size),
|
||||
input_image_size=(image_size, image_size),
|
||||
mask_in_chans=16,
|
||||
),
|
||||
mask_decoder=MaskDecoder(
|
||||
num_multimask_outputs=3,
|
||||
transformer=TwoWayTransformer(
|
||||
depth=2,
|
||||
embedding_dim=prompt_embed_dim,
|
||||
mlp_dim=2048,
|
||||
num_heads=8,
|
||||
),
|
||||
transformer_dim=prompt_embed_dim,
|
||||
iou_head_depth=3,
|
||||
iou_head_hidden_dim=256,
|
||||
),
|
||||
pixel_mean=[123.675, 116.28, 103.53],
|
||||
pixel_std=[58.395, 57.12, 57.375],
|
||||
)
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)
|
||||
sam.load_state_dict(state_dict)
|
||||
sam.eval()
|
||||
return sam
|
||||
|
||||
|
||||
def _build_sam2(
|
||||
encoder_embed_dim=1280,
|
||||
encoder_stages=[2, 6, 36, 4],
|
||||
encoder_num_heads=2,
|
||||
encoder_global_att_blocks=[7, 15, 23, 31],
|
||||
encoder_backbone_channel_list=[1152, 576, 288, 144],
|
||||
encoder_window_spatial_size=[7, 7],
|
||||
encoder_window_spec=[8, 4, 16, 8],
|
||||
checkpoint=None,
|
||||
):
|
||||
"""
|
||||
Build and return a Segment Anything Model 2 (SAM2) with specified architecture parameters.
|
||||
|
||||
Args:
|
||||
encoder_embed_dim (int, optional): Embedding dimension for the encoder.
|
||||
encoder_stages (list[int], optional): Number of blocks in each stage of the encoder.
|
||||
encoder_num_heads (int, optional): Number of attention heads in the encoder.
|
||||
encoder_global_att_blocks (list[int], optional): Indices of global attention blocks in the encoder.
|
||||
encoder_backbone_channel_list (list[int], optional): Channel dimensions for each level of the encoder backbone.
|
||||
encoder_window_spatial_size (list[int], optional): Spatial size of the window for position embeddings.
|
||||
encoder_window_spec (list[int], optional): Window specifications for each stage of the encoder.
|
||||
checkpoint (str | None, optional): Path to the checkpoint file for loading pre-trained weights.
|
||||
|
||||
Returns:
|
||||
(SAM2Model): A configured and initialized SAM2 model.
|
||||
|
||||
Examples:
|
||||
>>> sam2_model = _build_sam2(encoder_embed_dim=96, encoder_stages=[1, 2, 7, 2])
|
||||
>>> sam2_model.eval()
|
||||
"""
|
||||
image_encoder = ImageEncoder(
|
||||
trunk=Hiera(
|
||||
embed_dim=encoder_embed_dim,
|
||||
num_heads=encoder_num_heads,
|
||||
stages=encoder_stages,
|
||||
global_att_blocks=encoder_global_att_blocks,
|
||||
window_pos_embed_bkg_spatial_size=encoder_window_spatial_size,
|
||||
window_spec=encoder_window_spec,
|
||||
),
|
||||
neck=FpnNeck(
|
||||
d_model=256,
|
||||
backbone_channel_list=encoder_backbone_channel_list,
|
||||
fpn_top_down_levels=[2, 3],
|
||||
fpn_interp_model="nearest",
|
||||
),
|
||||
scalp=1,
|
||||
)
|
||||
memory_attention = MemoryAttention(d_model=256, pos_enc_at_input=True, num_layers=4, layer=MemoryAttentionLayer())
|
||||
memory_encoder = MemoryEncoder(out_dim=64)
|
||||
|
||||
is_sam2_1 = checkpoint is not None and "sam2.1" in checkpoint
|
||||
sam2 = SAM2Model(
|
||||
image_encoder=image_encoder,
|
||||
memory_attention=memory_attention,
|
||||
memory_encoder=memory_encoder,
|
||||
num_maskmem=7,
|
||||
image_size=1024,
|
||||
sigmoid_scale_for_mem_enc=20.0,
|
||||
sigmoid_bias_for_mem_enc=-10.0,
|
||||
use_mask_input_as_output_without_sam=True,
|
||||
directly_add_no_mem_embed=True,
|
||||
use_high_res_features_in_sam=True,
|
||||
multimask_output_in_sam=True,
|
||||
iou_prediction_use_sigmoid=True,
|
||||
use_obj_ptrs_in_encoder=True,
|
||||
add_tpos_enc_to_obj_ptrs=True,
|
||||
only_obj_ptrs_in_the_past_for_eval=True,
|
||||
pred_obj_scores=True,
|
||||
pred_obj_scores_mlp=True,
|
||||
fixed_no_obj_ptr=True,
|
||||
multimask_output_for_tracking=True,
|
||||
use_multimask_token_for_obj_ptr=True,
|
||||
multimask_min_pt_num=0,
|
||||
multimask_max_pt_num=1,
|
||||
use_mlp_for_obj_ptr_proj=True,
|
||||
compile_image_encoder=False,
|
||||
no_obj_embed_spatial=is_sam2_1,
|
||||
proj_tpos_enc_in_obj_ptrs=is_sam2_1,
|
||||
use_signed_tpos_enc_to_obj_ptrs=is_sam2_1,
|
||||
sam_mask_decoder_extra_args=dict(
|
||||
dynamic_multimask_via_stability=True,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
),
|
||||
)
|
||||
|
||||
if checkpoint is not None:
|
||||
checkpoint = attempt_download_asset(checkpoint)
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f)["model"]
|
||||
sam2.load_state_dict(state_dict)
|
||||
sam2.eval()
|
||||
return sam2
|
||||
|
||||
|
||||
sam_model_map = {
|
||||
"sam_h.pt": build_sam_vit_h,
|
||||
"sam_l.pt": build_sam_vit_l,
|
||||
"sam_b.pt": build_sam_vit_b,
|
||||
"mobile_sam.pt": build_mobile_sam,
|
||||
"sam2_t.pt": build_sam2_t,
|
||||
"sam2_s.pt": build_sam2_s,
|
||||
"sam2_b.pt": build_sam2_b,
|
||||
"sam2_l.pt": build_sam2_l,
|
||||
"sam2.1_t.pt": build_sam2_t,
|
||||
"sam2.1_s.pt": build_sam2_s,
|
||||
"sam2.1_b.pt": build_sam2_b,
|
||||
"sam2.1_l.pt": build_sam2_l,
|
||||
}
|
||||
|
||||
|
||||
def build_sam(ckpt="sam_b.pt"):
|
||||
"""
|
||||
Build and return a Segment Anything Model (SAM) based on the provided checkpoint.
|
||||
|
||||
Args:
|
||||
ckpt (str | Path, optional): Path to the checkpoint file or name of a pre-defined SAM model.
|
||||
|
||||
Returns:
|
||||
(SAMModel | SAM2Model): A configured and initialized SAM or SAM2 model instance.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the provided checkpoint is not a supported SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam_model = build_sam("sam_b.pt")
|
||||
>>> sam_model = build_sam("path/to/custom_checkpoint.pt")
|
||||
|
||||
Notes:
|
||||
Supported pre-defined models include:
|
||||
- SAM: 'sam_h.pt', 'sam_l.pt', 'sam_b.pt', 'mobile_sam.pt'
|
||||
- SAM2: 'sam2_t.pt', 'sam2_s.pt', 'sam2_b.pt', 'sam2_l.pt'
|
||||
"""
|
||||
model_builder = None
|
||||
ckpt = str(ckpt) # to allow Path ckpt types
|
||||
for k in sam_model_map.keys():
|
||||
if ckpt.endswith(k):
|
||||
model_builder = sam_model_map.get(k)
|
||||
|
||||
if not model_builder:
|
||||
raise FileNotFoundError(f"{ckpt} is not a supported SAM model. Available models are: \n {sam_model_map.keys()}")
|
||||
|
||||
return model_builder(ckpt)
|
||||
172
ultralytics/models/sam/model.py
Normal file
172
ultralytics/models/sam/model.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
"""
|
||||
SAM model interface.
|
||||
|
||||
This module provides an interface to the Segment Anything Model (SAM) from ultralytics, designed for real-time image
|
||||
segmentation tasks. The SAM model allows for promptable segmentation with unparalleled versatility in image analysis,
|
||||
and has been trained on the SA-1B dataset. It features zero-shot performance capabilities, enabling it to adapt to new
|
||||
image distributions and tasks without prior knowledge.
|
||||
|
||||
Key Features:
|
||||
- Promptable segmentation
|
||||
- Real-time performance
|
||||
- Zero-shot transfer capabilities
|
||||
- Trained on SA-1B dataset
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.utils.torch_utils import model_info
|
||||
|
||||
from .predict import Predictor, SAM2Predictor
|
||||
|
||||
|
||||
class SAM(Model):
|
||||
"""
|
||||
SAM (Segment Anything Model) interface class for real-time image segmentation tasks.
|
||||
|
||||
This class provides an interface to the Segment Anything Model (SAM) from ultralytics, designed for
|
||||
promptable segmentation with versatility in image analysis. It supports various prompts such as bounding
|
||||
boxes, points, or labels, and features zero-shot performance capabilities.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The loaded SAM model.
|
||||
is_sam2 (bool): Indicates whether the model is SAM2 variant.
|
||||
task (str): The task type, set to "segment" for SAM models.
|
||||
|
||||
Methods:
|
||||
predict: Perform segmentation prediction on the given image or video source.
|
||||
info: Log information about the SAM model.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
>>> print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "sam_b.pt") -> None:
|
||||
"""
|
||||
Initialize the SAM (Segment Anything Model) instance.
|
||||
|
||||
Args:
|
||||
model (str): Path to the pre-trained SAM model file. File should have a .pt or .pth extension.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the model file extension is not .pt or .pth.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> print(sam.is_sam2)
|
||||
"""
|
||||
if model and Path(model).suffix not in {".pt", ".pth"}:
|
||||
raise NotImplementedError("SAM prediction requires pre-trained *.pt or *.pth model.")
|
||||
self.is_sam2 = "sam2" in Path(model).stem
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def _load(self, weights: str, task=None):
|
||||
"""
|
||||
Load the specified weights into the SAM model.
|
||||
|
||||
Args:
|
||||
weights (str): Path to the weights file. Should be a .pt or .pth file containing the model parameters.
|
||||
task (str | None): Task name. If provided, it specifies the particular task the model is being loaded for.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> sam._load("path/to/custom_weights.pt")
|
||||
"""
|
||||
from .build import build_sam # slow import
|
||||
|
||||
self.model = build_sam(weights)
|
||||
|
||||
def predict(self, source, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Perform segmentation prediction on the given image or video source.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray): Path to the image or video file, or a PIL.Image object, or
|
||||
a np.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (list[list[float]] | None): List of points for prompted segmentation.
|
||||
labels (list[int] | None): List of labels for prompted segmentation.
|
||||
**kwargs (Any): Additional keyword arguments for prediction.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam.predict("image.jpg", points=[[500, 375]])
|
||||
>>> for r in results:
|
||||
... print(f"Detected {len(r.masks)} masks")
|
||||
"""
|
||||
overrides = dict(conf=0.25, task="segment", mode="predict", imgsz=1024)
|
||||
kwargs = {**overrides, **kwargs}
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
def __call__(self, source=None, stream: bool = False, bboxes=None, points=None, labels=None, **kwargs):
|
||||
"""
|
||||
Perform segmentation prediction on the given image or video source.
|
||||
|
||||
This method is an alias for the 'predict' method, providing a convenient way to call the SAM model
|
||||
for segmentation tasks.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray | None): Path to the image or video file, or a PIL.Image
|
||||
object, or a np.ndarray object.
|
||||
stream (bool): If True, enables real-time streaming.
|
||||
bboxes (list[list[float]] | None): List of bounding box coordinates for prompted segmentation.
|
||||
points (list[list[float]] | None): List of points for prompted segmentation.
|
||||
labels (list[int] | None): List of labels for prompted segmentation.
|
||||
**kwargs (Any): Additional keyword arguments to be passed to the predict method.
|
||||
|
||||
Returns:
|
||||
(list): The model predictions, typically containing segmentation masks and other relevant information.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> results = sam("image.jpg", points=[[500, 375]])
|
||||
>>> print(f"Detected {len(results[0].masks)} masks")
|
||||
"""
|
||||
return self.predict(source, stream, bboxes, points, labels, **kwargs)
|
||||
|
||||
def info(self, detailed: bool = False, verbose: bool = True):
|
||||
"""
|
||||
Log information about the SAM model.
|
||||
|
||||
Args:
|
||||
detailed (bool): If True, displays detailed information about the model layers and operations.
|
||||
verbose (bool): If True, prints the information to the console.
|
||||
|
||||
Returns:
|
||||
(tuple): A tuple containing the model's information (string representations of the model).
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> info = sam.info()
|
||||
>>> print(info[0]) # Print summary information
|
||||
"""
|
||||
return model_info(self.model, detailed=detailed, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, type[Predictor]]]:
|
||||
"""
|
||||
Provide a mapping from the 'segment' task to its corresponding 'Predictor'.
|
||||
|
||||
Returns:
|
||||
(dict[str, dict[str, Type[Predictor]]]): A dictionary mapping the 'segment' task to its corresponding
|
||||
Predictor class. For SAM2 models, it maps to SAM2Predictor, otherwise to the standard Predictor.
|
||||
|
||||
Examples:
|
||||
>>> sam = SAM("sam_b.pt")
|
||||
>>> task_map = sam.task_map
|
||||
>>> print(task_map)
|
||||
{'segment': {'predictor': <class 'ultralytics.models.sam.predict.Predictor'>}}
|
||||
"""
|
||||
return {"segment": {"predictor": SAM2Predictor if self.is_sam2 else Predictor}}
|
||||
1
ultralytics/models/sam/modules/__init__.py
Normal file
1
ultralytics/models/sam/modules/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
1128
ultralytics/models/sam/modules/blocks.py
Normal file
1128
ultralytics/models/sam/modules/blocks.py
Normal file
File diff suppressed because it is too large
Load Diff
513
ultralytics/models/sam/modules/decoders.py
Normal file
513
ultralytics/models/sam/modules/decoders.py
Normal file
@@ -0,0 +1,513 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ultralytics.nn.modules import MLP, LayerNorm2d
|
||||
|
||||
|
||||
class MaskDecoder(nn.Module):
|
||||
"""
|
||||
Decoder module for generating masks and their associated quality scores using a transformer architecture.
|
||||
|
||||
This class predicts masks given image and prompt embeddings, utilizing a transformer to process the inputs and
|
||||
generate mask predictions along with their quality scores.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): Transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for the IoU token.
|
||||
num_mask_tokens (int): Number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for the mask tokens.
|
||||
output_upscaling (nn.Sequential): Neural network sequence for upscaling the output.
|
||||
output_hypernetworks_mlps (nn.ModuleList): Hypernetwork MLPs for generating masks.
|
||||
iou_prediction_head (nn.Module): MLP for predicting mask quality.
|
||||
|
||||
Methods:
|
||||
forward: Predict masks given image and prompt embeddings.
|
||||
predict_masks: Internal method for mask prediction.
|
||||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> masks, iou_pred = decoder(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, multimask_output=True
|
||||
... )
|
||||
>>> print(f"Predicted masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the MaskDecoder module for generating masks and their associated quality scores.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): Channel dimension for the transformer module.
|
||||
transformer (nn.Module): Transformer module used for mask prediction.
|
||||
num_multimask_outputs (int): Number of masks to predict for disambiguating masks.
|
||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
||||
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
||||
|
||||
Examples:
|
||||
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer)
|
||||
>>> print(decoder)
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder.
|
||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings.
|
||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes.
|
||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs.
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): Batched predicted masks.
|
||||
iou_pred (torch.Tensor): Batched predictions of mask quality.
|
||||
|
||||
Examples:
|
||||
>>> decoder = MaskDecoder(transformer_dim=256, transformer=transformer_module)
|
||||
>>> image_emb = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_emb = torch.rand(1, 2, 256)
|
||||
>>> dense_emb = torch.rand(1, 256, 64, 64)
|
||||
>>> masks, iou_pred = decoder(image_emb, image_pe, sparse_emb, dense_emb, multimask_output=True)
|
||||
>>> print(f"Masks shape: {masks.shape}, IoU predictions shape: {iou_pred.shape}")
|
||||
"""
|
||||
masks, iou_pred = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
mask_slice = slice(1, None) if multimask_output else slice(0, 1)
|
||||
masks = masks[:, mask_slice, :, :]
|
||||
iou_pred = iou_pred[:, mask_slice]
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Predict masks and quality scores using image and prompt embeddings via transformer architecture."""
|
||||
# Concatenate output tokens
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
src = src + dense_prompt_embeddings
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, 0, :]
|
||||
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
hyper_in_list: list[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||
]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
|
||||
return masks, iou_pred
|
||||
|
||||
|
||||
class SAM2MaskDecoder(nn.Module):
|
||||
"""
|
||||
Transformer-based decoder for predicting instance segmentation masks from image and prompt embeddings.
|
||||
|
||||
This class extends the functionality of the MaskDecoder, incorporating additional features such as
|
||||
high-resolution feature processing, dynamic multimask output, and object score prediction.
|
||||
|
||||
Attributes:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
iou_token (nn.Embedding): Embedding for IOU token.
|
||||
num_mask_tokens (int): Total number of mask tokens.
|
||||
mask_tokens (nn.Embedding): Embedding for mask tokens.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
obj_score_token (nn.Embedding): Embedding for object score token.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
output_upscaling (nn.Sequential): Upscaling layers for output.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
conv_s0 (nn.Conv2d): Convolutional layer for high-resolution features (s0).
|
||||
conv_s1 (nn.Conv2d): Convolutional layer for high-resolution features (s1).
|
||||
output_hypernetworks_mlps (nn.ModuleList): List of MLPs for output hypernetworks.
|
||||
iou_prediction_head (MLP): MLP for IOU prediction.
|
||||
pred_obj_score_head (nn.Linear | MLP): Linear layer or MLP for object score prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
||||
|
||||
Methods:
|
||||
forward: Predict masks given image and prompt embeddings.
|
||||
predict_masks: Predict instance segmentation masks from image and prompt embeddings.
|
||||
_get_stability_scores: Compute mask stability scores based on IoU between thresholds.
|
||||
_dynamic_multimask_via_stability: Dynamically select the most stable mask output.
|
||||
|
||||
Examples:
|
||||
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
|
||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer_dim: int,
|
||||
transformer: nn.Module,
|
||||
num_multimask_outputs: int = 3,
|
||||
activation: type[nn.Module] = nn.GELU,
|
||||
iou_head_depth: int = 3,
|
||||
iou_head_hidden_dim: int = 256,
|
||||
use_high_res_features: bool = False,
|
||||
iou_prediction_use_sigmoid=False,
|
||||
dynamic_multimask_via_stability=False,
|
||||
dynamic_multimask_stability_delta=0.05,
|
||||
dynamic_multimask_stability_thresh=0.98,
|
||||
pred_obj_scores: bool = False,
|
||||
pred_obj_scores_mlp: bool = False,
|
||||
use_multimask_token_for_obj_ptr: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the SAM2MaskDecoder module for predicting instance segmentation masks.
|
||||
|
||||
This decoder extends the functionality of MaskDecoder, incorporating additional features such as
|
||||
high-resolution feature processing, dynamic multimask output, and object score prediction.
|
||||
|
||||
Args:
|
||||
transformer_dim (int): Channel dimension of the transformer.
|
||||
transformer (nn.Module): Transformer used to predict masks.
|
||||
num_multimask_outputs (int): Number of masks to predict when disambiguating masks.
|
||||
activation (Type[nn.Module]): Type of activation to use when upscaling masks.
|
||||
iou_head_depth (int): Depth of the MLP used to predict mask quality.
|
||||
iou_head_hidden_dim (int): Hidden dimension of the MLP used to predict mask quality.
|
||||
use_high_res_features (bool): Whether to use high-resolution features.
|
||||
iou_prediction_use_sigmoid (bool): Whether to use sigmoid for IOU prediction.
|
||||
dynamic_multimask_via_stability (bool): Whether to use dynamic multimask via stability.
|
||||
dynamic_multimask_stability_delta (float): Delta value for dynamic multimask stability.
|
||||
dynamic_multimask_stability_thresh (float): Threshold for dynamic multimask stability.
|
||||
pred_obj_scores (bool): Whether to predict object scores.
|
||||
pred_obj_scores_mlp (bool): Whether to use MLP for object score prediction.
|
||||
use_multimask_token_for_obj_ptr (bool): Whether to use multimask token for object pointer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=256, nhead=8), num_layers=6)
|
||||
>>> decoder = SAM2MaskDecoder(transformer_dim=256, transformer=transformer)
|
||||
>>> print(decoder)
|
||||
"""
|
||||
super().__init__()
|
||||
self.transformer_dim = transformer_dim
|
||||
self.transformer = transformer
|
||||
|
||||
self.num_multimask_outputs = num_multimask_outputs
|
||||
|
||||
self.iou_token = nn.Embedding(1, transformer_dim)
|
||||
self.num_mask_tokens = num_multimask_outputs + 1
|
||||
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
|
||||
|
||||
self.pred_obj_scores = pred_obj_scores
|
||||
if self.pred_obj_scores:
|
||||
self.obj_score_token = nn.Embedding(1, transformer_dim)
|
||||
self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
|
||||
|
||||
self.output_upscaling = nn.Sequential(
|
||||
nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(transformer_dim // 4),
|
||||
activation(),
|
||||
nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
|
||||
activation(),
|
||||
)
|
||||
self.use_high_res_features = use_high_res_features
|
||||
if use_high_res_features:
|
||||
self.conv_s0 = nn.Conv2d(transformer_dim, transformer_dim // 8, kernel_size=1, stride=1)
|
||||
self.conv_s1 = nn.Conv2d(transformer_dim, transformer_dim // 4, kernel_size=1, stride=1)
|
||||
|
||||
self.output_hypernetworks_mlps = nn.ModuleList(
|
||||
[MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) for _ in range(self.num_mask_tokens)]
|
||||
)
|
||||
|
||||
self.iou_prediction_head = MLP(
|
||||
transformer_dim,
|
||||
iou_head_hidden_dim,
|
||||
self.num_mask_tokens,
|
||||
iou_head_depth,
|
||||
sigmoid=iou_prediction_use_sigmoid,
|
||||
)
|
||||
if self.pred_obj_scores:
|
||||
self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
|
||||
if pred_obj_scores_mlp:
|
||||
self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
|
||||
|
||||
# When outputting a single mask, optionally we can dynamically fall back to the best
|
||||
# multimask output token if the single mask output token gives low stability scores.
|
||||
self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
|
||||
self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
|
||||
self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
multimask_output: bool,
|
||||
repeat_image: bool,
|
||||
high_res_features: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Predict masks given image and prompt embeddings.
|
||||
|
||||
Args:
|
||||
image_embeddings (torch.Tensor): Embeddings from the image encoder with shape (B, C, H, W).
|
||||
image_pe (torch.Tensor): Positional encoding with the shape of image_embeddings (B, C, H, W).
|
||||
sparse_prompt_embeddings (torch.Tensor): Embeddings of the points and boxes with shape (B, N, C).
|
||||
dense_prompt_embeddings (torch.Tensor): Embeddings of the mask inputs with shape (B, C, H, W).
|
||||
multimask_output (bool): Whether to return multiple masks or a single mask.
|
||||
repeat_image (bool): Flag to repeat the image embeddings.
|
||||
high_res_features (list[torch.Tensor] | None, optional): Optional high-resolution features.
|
||||
|
||||
Returns:
|
||||
masks (torch.Tensor): Batched predicted masks with shape (B, N, H, W).
|
||||
iou_pred (torch.Tensor): Batched predictions of mask quality with shape (B, N).
|
||||
sam_tokens_out (torch.Tensor): Batched SAM token for mask output with shape (B, N, C).
|
||||
object_score_logits (torch.Tensor): Batched object score logits with shape (B, 1).
|
||||
|
||||
Examples:
|
||||
>>> image_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> image_pe = torch.rand(1, 256, 64, 64)
|
||||
>>> sparse_prompt_embeddings = torch.rand(1, 2, 256)
|
||||
>>> dense_prompt_embeddings = torch.rand(1, 256, 64, 64)
|
||||
>>> decoder = SAM2MaskDecoder(256, transformer)
|
||||
>>> masks, iou_pred, sam_tokens_out, obj_score_logits = decoder.forward(
|
||||
... image_embeddings, image_pe, sparse_prompt_embeddings, dense_prompt_embeddings, True, False
|
||||
... )
|
||||
"""
|
||||
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
|
||||
image_embeddings=image_embeddings,
|
||||
image_pe=image_pe,
|
||||
sparse_prompt_embeddings=sparse_prompt_embeddings,
|
||||
dense_prompt_embeddings=dense_prompt_embeddings,
|
||||
repeat_image=repeat_image,
|
||||
high_res_features=high_res_features,
|
||||
)
|
||||
|
||||
# Select the correct mask or masks for output
|
||||
if multimask_output:
|
||||
masks = masks[:, 1:, :, :]
|
||||
iou_pred = iou_pred[:, 1:]
|
||||
elif self.dynamic_multimask_via_stability and not self.training:
|
||||
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
|
||||
else:
|
||||
masks = masks[:, 0:1, :, :]
|
||||
iou_pred = iou_pred[:, 0:1]
|
||||
|
||||
if multimask_output and self.use_multimask_token_for_obj_ptr:
|
||||
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
|
||||
else:
|
||||
# Take the mask output token. Here we *always* use the token for single mask output.
|
||||
# At test time, even if we track after 1-click (and using multimask_output=True),
|
||||
# we still take the single mask token here. The rationale is that we always track
|
||||
# after multiple clicks during training, so the past tokens seen during training
|
||||
# are always the single mask token (and we'll let it be the object-memory token).
|
||||
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
|
||||
|
||||
return masks, iou_pred, sam_tokens_out, object_score_logits
|
||||
|
||||
def predict_masks(
|
||||
self,
|
||||
image_embeddings: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
sparse_prompt_embeddings: torch.Tensor,
|
||||
dense_prompt_embeddings: torch.Tensor,
|
||||
repeat_image: bool,
|
||||
high_res_features: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Predict instance segmentation masks from image and prompt embeddings using a transformer."""
|
||||
# Concatenate output tokens
|
||||
s = 0
|
||||
if self.pred_obj_scores:
|
||||
output_tokens = torch.cat(
|
||||
[
|
||||
self.obj_score_token.weight,
|
||||
self.iou_token.weight,
|
||||
self.mask_tokens.weight,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
s = 1
|
||||
else:
|
||||
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
|
||||
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.shape[0], -1, -1)
|
||||
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
|
||||
|
||||
# Expand per-image data in batch direction to be per-mask
|
||||
if repeat_image:
|
||||
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
|
||||
else:
|
||||
assert image_embeddings.shape[0] == tokens.shape[0]
|
||||
src = image_embeddings
|
||||
src = src + dense_prompt_embeddings
|
||||
assert image_pe.shape[0] == 1, "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
|
||||
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
|
||||
b, c, h, w = src.shape
|
||||
|
||||
# Run the transformer
|
||||
hs, src = self.transformer(src, pos_src, tokens)
|
||||
iou_token_out = hs[:, s, :]
|
||||
mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
|
||||
|
||||
# Upscale mask embeddings and predict masks using the mask tokens
|
||||
src = src.transpose(1, 2).view(b, c, h, w)
|
||||
if not self.use_high_res_features or high_res_features is None:
|
||||
upscaled_embedding = self.output_upscaling(src)
|
||||
else:
|
||||
dc1, ln1, act1, dc2, act2 = self.output_upscaling
|
||||
feat_s0, feat_s1 = high_res_features
|
||||
upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
|
||||
upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
|
||||
|
||||
hyper_in_list: list[torch.Tensor] = [
|
||||
self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) for i in range(self.num_mask_tokens)
|
||||
]
|
||||
hyper_in = torch.stack(hyper_in_list, dim=1)
|
||||
b, c, h, w = upscaled_embedding.shape
|
||||
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
|
||||
|
||||
# Generate mask quality predictions
|
||||
iou_pred = self.iou_prediction_head(iou_token_out)
|
||||
if self.pred_obj_scores:
|
||||
assert s == 1
|
||||
object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
|
||||
else:
|
||||
# Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
|
||||
object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
|
||||
|
||||
return masks, iou_pred, mask_tokens_out, object_score_logits
|
||||
|
||||
def _get_stability_scores(self, mask_logits):
|
||||
"""Compute mask stability scores based on IoU between upper and lower thresholds."""
|
||||
mask_logits = mask_logits.flatten(-2)
|
||||
stability_delta = self.dynamic_multimask_stability_delta
|
||||
area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
|
||||
area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
|
||||
return torch.where(area_u > 0, area_i / area_u, 1.0)
|
||||
|
||||
def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
|
||||
"""
|
||||
Dynamically select the most stable mask output based on stability scores and IoU predictions.
|
||||
|
||||
This method is used when outputting a single mask. If the stability score from the current single-mask
|
||||
output (based on output token 0) falls below a threshold, it instead selects from multi-mask outputs
|
||||
(based on output tokens 1-3) the mask with the highest predicted IoU score. This ensures a valid mask
|
||||
for both clicking and tracking scenarios.
|
||||
|
||||
Args:
|
||||
all_mask_logits (torch.Tensor): Logits for all predicted masks, shape (B, N, H, W) where B is
|
||||
batch size, N is number of masks (typically 4), and H, W are mask dimensions.
|
||||
all_iou_scores (torch.Tensor): Predicted IoU scores for all masks, shape (B, N).
|
||||
|
||||
Returns:
|
||||
mask_logits_out (torch.Tensor): Selected mask logits, shape (B, 1, H, W).
|
||||
iou_scores_out (torch.Tensor): Selected IoU scores, shape (B, 1).
|
||||
|
||||
Examples:
|
||||
>>> decoder = SAM2MaskDecoder(...)
|
||||
>>> all_mask_logits = torch.rand(2, 4, 256, 256) # 2 images, 4 masks each
|
||||
>>> all_iou_scores = torch.rand(2, 4)
|
||||
>>> mask_logits, iou_scores = decoder._dynamic_multimask_via_stability(all_mask_logits, all_iou_scores)
|
||||
>>> print(mask_logits.shape, iou_scores.shape)
|
||||
torch.Size([2, 1, 256, 256]) torch.Size([2, 1])
|
||||
"""
|
||||
# The best mask from multimask output tokens (1~3)
|
||||
multimask_logits = all_mask_logits[:, 1:, :, :]
|
||||
multimask_iou_scores = all_iou_scores[:, 1:]
|
||||
best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
|
||||
batch_inds = torch.arange(multimask_iou_scores.shape[0], device=all_iou_scores.device)
|
||||
best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
|
||||
best_multimask_logits = best_multimask_logits.unsqueeze(1)
|
||||
best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
|
||||
best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
|
||||
|
||||
# The mask from singlemask output token 0 and its stability score
|
||||
singlemask_logits = all_mask_logits[:, 0:1, :, :]
|
||||
singlemask_iou_scores = all_iou_scores[:, 0:1]
|
||||
stability_scores = self._get_stability_scores(singlemask_logits)
|
||||
is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
|
||||
|
||||
# Dynamically fall back to best multimask output upon low stability scores.
|
||||
mask_logits_out = torch.where(
|
||||
is_stable[..., None, None].expand_as(singlemask_logits),
|
||||
singlemask_logits,
|
||||
best_multimask_logits,
|
||||
)
|
||||
iou_scores_out = torch.where(
|
||||
is_stable.expand_as(singlemask_iou_scores),
|
||||
singlemask_iou_scores,
|
||||
best_multimask_iou_scores,
|
||||
)
|
||||
return mask_logits_out, iou_scores_out
|
||||
851
ultralytics/models/sam/modules/encoders.py
Normal file
851
ultralytics/models/sam/modules/encoders.py
Normal file
@@ -0,0 +1,851 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
|
||||
from .blocks import (
|
||||
Block,
|
||||
CXBlock,
|
||||
Fuser,
|
||||
MaskDownSampler,
|
||||
MultiScaleBlock,
|
||||
PatchEmbed,
|
||||
PositionEmbeddingRandom,
|
||||
PositionEmbeddingSine,
|
||||
)
|
||||
|
||||
|
||||
class ImageEncoderViT(nn.Module):
|
||||
"""
|
||||
An image encoder using Vision Transformer (ViT) architecture for encoding images into a compact latent space.
|
||||
|
||||
This class processes images by splitting them into patches, applying transformer blocks, and generating a final
|
||||
encoded representation through a neck module.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Dimension of input images, assumed to be square.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
pos_embed (nn.Parameter | None): Absolute positional embedding for patches.
|
||||
blocks (nn.ModuleList): List of transformer blocks for processing patch embeddings.
|
||||
neck (nn.Sequential): Neck module to further process the output.
|
||||
|
||||
Methods:
|
||||
forward: Process input through patch embedding, positional embedding, blocks, and neck.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||
>>> input_image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(input_image)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 1024,
|
||||
patch_size: int = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
depth: int = 12,
|
||||
num_heads: int = 12,
|
||||
mlp_ratio: float = 4.0,
|
||||
out_chans: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
act_layer: type[nn.Module] = nn.GELU,
|
||||
use_abs_pos: bool = True,
|
||||
use_rel_pos: bool = False,
|
||||
rel_pos_zero_init: bool = True,
|
||||
window_size: int = 0,
|
||||
global_attn_indexes: tuple[int, ...] = (),
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an ImageEncoderViT instance for encoding images using Vision Transformer architecture.
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size, assumed to be square.
|
||||
patch_size (int): Size of image patches.
|
||||
in_chans (int): Number of input image channels.
|
||||
embed_dim (int): Dimension of patch embeddings.
|
||||
depth (int): Number of transformer blocks.
|
||||
num_heads (int): Number of attention heads in each block.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
out_chans (int): Number of output channels from the neck module.
|
||||
qkv_bias (bool): If True, adds learnable bias to query, key, value projections.
|
||||
norm_layer (Type[nn.Module]): Type of normalization layer to use.
|
||||
act_layer (Type[nn.Module]): Type of activation layer to use.
|
||||
use_abs_pos (bool): If True, uses absolute positional embeddings.
|
||||
use_rel_pos (bool): If True, adds relative positional embeddings to attention maps.
|
||||
rel_pos_zero_init (bool): If True, initializes relative positional parameters to zero.
|
||||
window_size (int): Size of attention window for windowed attention blocks.
|
||||
global_attn_indexes (tuple[int, ...]): Indices of blocks that use global attention.
|
||||
|
||||
Examples:
|
||||
>>> encoder = ImageEncoderViT(img_size=224, patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
||||
>>> input_image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(input_image)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
kernel_size=(patch_size, patch_size),
|
||||
stride=(patch_size, patch_size),
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
|
||||
self.pos_embed: nn.Parameter | None = None
|
||||
if use_abs_pos:
|
||||
# Initialize absolute positional embedding with pretrain image size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim))
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for i in range(depth):
|
||||
block = Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
use_rel_pos=use_rel_pos,
|
||||
rel_pos_zero_init=rel_pos_zero_init,
|
||||
window_size=window_size if i not in global_attn_indexes else 0,
|
||||
input_size=(img_size // patch_size, img_size // patch_size),
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dim,
|
||||
out_chans,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
nn.Conv2d(
|
||||
out_chans,
|
||||
out_chans,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(out_chans),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input through patch embedding, positional embedding, transformer blocks, and neck module."""
|
||||
x = self.patch_embed(x)
|
||||
if self.pos_embed is not None:
|
||||
pos_embed = (
|
||||
F.interpolate(self.pos_embed.permute(0, 3, 1, 2), scale_factor=self.img_size / 1024).permute(0, 2, 3, 1)
|
||||
if self.img_size != 1024
|
||||
else self.pos_embed
|
||||
)
|
||||
x = x + pos_embed
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
return self.neck(x.permute(0, 3, 1, 2))
|
||||
|
||||
|
||||
class PromptEncoder(nn.Module):
|
||||
"""
|
||||
Encode different types of prompts for input to SAM's mask decoder, producing sparse and dense embeddings.
|
||||
|
||||
Attributes:
|
||||
embed_dim (int): Dimension of the embeddings.
|
||||
input_image_size (tuple[int, int]): Size of the input image as (H, W).
|
||||
image_embedding_size (tuple[int, int]): Spatial size of the image embedding as (H, W).
|
||||
pe_layer (PositionEmbeddingRandom): Module for random position embedding.
|
||||
num_point_embeddings (int): Number of point embeddings for different types of points.
|
||||
point_embeddings (nn.ModuleList): List of point embeddings.
|
||||
not_a_point_embed (nn.Embedding): Embedding for points that are not part of any label.
|
||||
mask_input_size (tuple[int, int]): Size of the input mask.
|
||||
mask_downscaling (nn.Sequential): Neural network for downscaling the mask.
|
||||
no_mask_embed (nn.Embedding): Embedding for cases where no mask is provided.
|
||||
|
||||
Methods:
|
||||
get_dense_pe: Return the positional encoding used to encode point prompts.
|
||||
forward: Embed different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
|
||||
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
image_embedding_size: tuple[int, int],
|
||||
input_image_size: tuple[int, int],
|
||||
mask_in_chans: int,
|
||||
activation: type[nn.Module] = nn.GELU,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the PromptEncoder module for encoding various types of prompts.
|
||||
|
||||
Args:
|
||||
embed_dim (int): The dimension of the embeddings.
|
||||
image_embedding_size (tuple[int, int]): The spatial size of the image embedding as (H, W).
|
||||
input_image_size (tuple[int, int]): The padded size of the input image as (H, W).
|
||||
mask_in_chans (int): The number of hidden channels used for encoding input masks.
|
||||
activation (Type[nn.Module]): The activation function to use when encoding input masks.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_embeddings, dense_embeddings = prompt_encoder(points, boxes, masks)
|
||||
>>> print(sparse_embeddings.shape, dense_embeddings.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.input_image_size = input_image_size
|
||||
self.image_embedding_size = image_embedding_size
|
||||
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
|
||||
|
||||
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
|
||||
point_embeddings = [nn.Embedding(1, embed_dim) for _ in range(self.num_point_embeddings)]
|
||||
self.point_embeddings = nn.ModuleList(point_embeddings)
|
||||
self.not_a_point_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
|
||||
self.mask_downscaling = nn.Sequential(
|
||||
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans // 4),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
|
||||
LayerNorm2d(mask_in_chans),
|
||||
activation(),
|
||||
nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
|
||||
)
|
||||
self.no_mask_embed = nn.Embedding(1, embed_dim)
|
||||
|
||||
def get_dense_pe(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the dense positional encoding used for encoding point prompts.
|
||||
|
||||
Generate a positional encoding for a dense set of points matching the shape of the image
|
||||
encoding. The encoding is used to provide spatial information to the model when processing point prompts.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Positional encoding tensor with shape (1, embed_dim, H, W), where H and W are the
|
||||
height and width of the image embedding size, respectively.
|
||||
|
||||
Examples:
|
||||
>>> prompt_encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> dense_pe = prompt_encoder.get_dense_pe()
|
||||
>>> print(dense_pe.shape)
|
||||
torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
return self.pe_layer(self.image_embedding_size).unsqueeze(0)
|
||||
|
||||
def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
|
||||
"""Embed point prompts by applying positional encoding and label-specific embeddings."""
|
||||
points = points + 0.5 # Shift to center of pixel
|
||||
if pad:
|
||||
padding_point = torch.zeros((points.shape[0], 1, 2), dtype=points.dtype, device=points.device)
|
||||
padding_label = -torch.ones((labels.shape[0], 1), dtype=labels.dtype, device=labels.device)
|
||||
points = torch.cat([points, padding_point], dim=1)
|
||||
labels = torch.cat([labels, padding_label], dim=1)
|
||||
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
|
||||
point_embedding[labels == -1] = 0.0
|
||||
point_embedding[labels == -1] += self.not_a_point_embed.weight
|
||||
point_embedding[labels == 0] += self.point_embeddings[0].weight
|
||||
point_embedding[labels == 1] += self.point_embeddings[1].weight
|
||||
point_embedding[labels == 2] += self.point_embeddings[2].weight
|
||||
point_embedding[labels == 3] += self.point_embeddings[3].weight
|
||||
return point_embedding
|
||||
|
||||
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
|
||||
"""Embed box prompts by applying positional encoding and adding corner embeddings."""
|
||||
boxes = boxes + 0.5 # Shift to center of pixel
|
||||
coords = boxes.reshape(-1, 2, 2)
|
||||
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
|
||||
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
|
||||
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
|
||||
return corner_embedding
|
||||
|
||||
def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
|
||||
"""Embed mask inputs by downscaling and processing through convolutional layers."""
|
||||
return self.mask_downscaling(masks)
|
||||
|
||||
@staticmethod
|
||||
def _get_batch_size(
|
||||
points: tuple[torch.Tensor, torch.Tensor] | None,
|
||||
boxes: torch.Tensor | None,
|
||||
masks: torch.Tensor | None,
|
||||
) -> int:
|
||||
"""Get the batch size of the output given the batch size of the input prompts."""
|
||||
if points is not None:
|
||||
return points[0].shape[0]
|
||||
elif boxes is not None:
|
||||
return boxes.shape[0]
|
||||
elif masks is not None:
|
||||
return masks.shape[0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
points: tuple[torch.Tensor, torch.Tensor] | None,
|
||||
boxes: torch.Tensor | None,
|
||||
masks: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Embed different types of prompts, returning both sparse and dense embeddings.
|
||||
|
||||
Args:
|
||||
points (tuple[torch.Tensor, torch.Tensor] | None): Point coordinates and labels to embed. The first
|
||||
tensor contains coordinates with shape (B, N, 2), and the second tensor contains labels with
|
||||
shape (B, N).
|
||||
boxes (torch.Tensor | None): Boxes to embed with shape (B, M, 2, 2), where M is the number of boxes.
|
||||
masks (torch.Tensor | None): Masks to embed with shape (B, 1, H, W).
|
||||
|
||||
Returns:
|
||||
sparse_embeddings (torch.Tensor): Sparse embeddings for points and boxes with shape (B, N, embed_dim).
|
||||
dense_embeddings (torch.Tensor): Dense embeddings for masks of shape (B, embed_dim, embed_H, embed_W).
|
||||
|
||||
Examples:
|
||||
>>> encoder = PromptEncoder(256, (64, 64), (1024, 1024), 16)
|
||||
>>> points = (torch.rand(1, 5, 2), torch.randint(0, 4, (1, 5)))
|
||||
>>> boxes = torch.rand(1, 2, 2, 2)
|
||||
>>> masks = torch.rand(1, 1, 256, 256)
|
||||
>>> sparse_emb, dense_emb = encoder(points, boxes, masks)
|
||||
>>> print(sparse_emb.shape, dense_emb.shape)
|
||||
torch.Size([1, 7, 256]) torch.Size([1, 256, 64, 64])
|
||||
"""
|
||||
bs = self._get_batch_size(points, boxes, masks)
|
||||
sparse_embeddings = torch.empty(
|
||||
(bs, 0, self.embed_dim),
|
||||
dtype=self.point_embeddings[0].weight.dtype,
|
||||
device=self.point_embeddings[0].weight.device,
|
||||
)
|
||||
if points is not None:
|
||||
coords, labels = points
|
||||
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
|
||||
if boxes is not None:
|
||||
box_embeddings = self._embed_boxes(boxes)
|
||||
sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
|
||||
|
||||
if masks is not None:
|
||||
dense_embeddings = self._embed_masks(masks)
|
||||
else:
|
||||
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
|
||||
bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
|
||||
)
|
||||
|
||||
return sparse_embeddings, dense_embeddings
|
||||
|
||||
|
||||
class MemoryEncoder(nn.Module):
|
||||
"""
|
||||
Encode pixel features and masks into a memory representation for efficient image segmentation.
|
||||
|
||||
This class processes pixel-level features and masks, fusing them to generate encoded memory representations
|
||||
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
||||
|
||||
Attributes:
|
||||
mask_downsampler (MaskDownSampler): Module for downsampling input masks.
|
||||
pix_feat_proj (nn.Conv2d): Convolutional layer for projecting pixel features.
|
||||
fuser (Fuser): Module for fusing pixel features and masks.
|
||||
position_encoding (PositionEmbeddingSine): Module for adding positional encoding to features.
|
||||
out_proj (nn.Module): Output projection layer, either nn.Identity or nn.Conv2d.
|
||||
|
||||
Methods:
|
||||
forward: Process input pixel features and masks to generate encoded memory representations.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
|
||||
>>> pix_feat = torch.randn(1, 256, 64, 64)
|
||||
>>> masks = torch.randn(1, 1, 64, 64)
|
||||
>>> encoded_feat, pos = encoder(pix_feat, masks)
|
||||
>>> print(encoded_feat.shape, pos.shape)
|
||||
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_dim,
|
||||
in_dim=256, # in_dim of pix_feats
|
||||
):
|
||||
"""
|
||||
Initialize the MemoryEncoder for encoding pixel features and masks into memory representations.
|
||||
|
||||
This encoder processes pixel-level features and masks, fusing them to generate encoded memory representations
|
||||
suitable for downstream tasks in image segmentation models like SAM (Segment Anything Model).
|
||||
|
||||
Args:
|
||||
out_dim (int): Output dimension of the encoded features.
|
||||
in_dim (int): Input dimension of the pixel features.
|
||||
|
||||
Examples:
|
||||
>>> encoder = MemoryEncoder(out_dim=256, in_dim=256)
|
||||
>>> pix_feat = torch.randn(1, 256, 64, 64)
|
||||
>>> masks = torch.randn(1, 1, 64, 64)
|
||||
>>> encoded_feat, pos = encoder(pix_feat, masks)
|
||||
>>> print(encoded_feat.shape, pos.shape)
|
||||
torch.Size([1, 256, 64, 64]) torch.Size([1, 128, 64, 64])
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.mask_downsampler = MaskDownSampler(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
|
||||
self.fuser = Fuser(CXBlock(dim=256), num_layers=2)
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=64)
|
||||
self.out_proj = nn.Identity()
|
||||
if out_dim != in_dim:
|
||||
self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pix_feat: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
skip_mask_sigmoid: bool = False,
|
||||
) -> dict:
|
||||
"""Process pixel features and masks to generate encoded memory representations for segmentation."""
|
||||
if not skip_mask_sigmoid:
|
||||
masks = F.sigmoid(masks)
|
||||
masks = self.mask_downsampler(masks)
|
||||
|
||||
# Fuse pix_feats and downsampled masks, in case the visual features are on CPU, cast them to CUDA
|
||||
pix_feat = pix_feat.to(masks.device)
|
||||
|
||||
x = self.pix_feat_proj(pix_feat)
|
||||
x = x + masks
|
||||
x = self.fuser(x)
|
||||
x = self.out_proj(x)
|
||||
|
||||
pos = self.position_encoding(x).to(x.dtype)
|
||||
|
||||
return {"vision_features": x, "vision_pos_enc": [pos]}
|
||||
|
||||
|
||||
class ImageEncoder(nn.Module):
|
||||
"""
|
||||
Encode images using a trunk-neck architecture, producing multiscale features and positional encodings.
|
||||
|
||||
This class combines a trunk network for feature extraction with a neck network for feature refinement
|
||||
and positional encoding generation. It can optionally discard the lowest resolution features.
|
||||
|
||||
Attributes:
|
||||
trunk (nn.Module): The trunk network for initial feature extraction.
|
||||
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
|
||||
scalp (int): Number of lowest resolution feature levels to discard.
|
||||
|
||||
Methods:
|
||||
forward: Process the input image through the trunk and neck networks.
|
||||
|
||||
Examples:
|
||||
>>> trunk = SomeTrunkNetwork()
|
||||
>>> neck = SomeNeckNetwork()
|
||||
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
|
||||
>>> image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(image)
|
||||
>>> print(output.keys())
|
||||
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trunk: nn.Module,
|
||||
neck: nn.Module,
|
||||
scalp: int = 0,
|
||||
):
|
||||
"""
|
||||
Initialize the ImageEncoder with trunk and neck networks for feature extraction and refinement.
|
||||
|
||||
This encoder combines a trunk network for feature extraction with a neck network for feature refinement
|
||||
and positional encoding generation. It can optionally discard the lowest resolution features.
|
||||
|
||||
Args:
|
||||
trunk (nn.Module): The trunk network for initial feature extraction.
|
||||
neck (nn.Module): The neck network for feature refinement and positional encoding generation.
|
||||
scalp (int): Number of lowest resolution feature levels to discard.
|
||||
|
||||
Examples:
|
||||
>>> trunk = SomeTrunkNetwork()
|
||||
>>> neck = SomeNeckNetwork()
|
||||
>>> encoder = ImageEncoder(trunk, neck, scalp=1)
|
||||
>>> image = torch.randn(1, 3, 224, 224)
|
||||
>>> output = encoder(image)
|
||||
>>> print(output.keys())
|
||||
dict_keys(['vision_features', 'vision_pos_enc', 'backbone_fpn'])
|
||||
"""
|
||||
super().__init__()
|
||||
self.trunk = trunk
|
||||
self.neck = neck
|
||||
self.scalp = scalp
|
||||
assert self.trunk.channel_list == self.neck.backbone_channel_list, (
|
||||
f"Channel dims of trunk {self.trunk.channel_list} and neck {self.neck.backbone_channel_list} do not match."
|
||||
)
|
||||
|
||||
def forward(self, sample: torch.Tensor):
|
||||
"""Encode input through trunk and neck networks, returning multiscale features and positional encodings."""
|
||||
features, pos = self.neck(self.trunk(sample))
|
||||
if self.scalp > 0:
|
||||
# Discard the lowest resolution features
|
||||
features, pos = features[: -self.scalp], pos[: -self.scalp]
|
||||
|
||||
src = features[-1]
|
||||
return {
|
||||
"vision_features": src,
|
||||
"vision_pos_enc": pos,
|
||||
"backbone_fpn": features,
|
||||
}
|
||||
|
||||
|
||||
class FpnNeck(nn.Module):
|
||||
"""
|
||||
A Feature Pyramid Network (FPN) neck variant for multiscale feature fusion in object detection models.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
||||
similar to ViT positional embedding interpolation.
|
||||
|
||||
Attributes:
|
||||
position_encoding (PositionEmbeddingSine): Sinusoidal positional encoding module.
|
||||
convs (nn.ModuleList): List of convolutional layers for each backbone level.
|
||||
backbone_channel_list (list[int]): List of channel dimensions from the backbone.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
||||
fpn_top_down_levels (list[int]): Levels to have top-down features in outputs.
|
||||
|
||||
Methods:
|
||||
forward: Perform forward pass through the FPN neck.
|
||||
|
||||
Examples:
|
||||
>>> backbone_channels = [64, 128, 256, 512]
|
||||
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in backbone_channels]
|
||||
>>> outputs, positions = fpn_neck(inputs)
|
||||
>>> print(len(outputs), len(positions))
|
||||
4 4
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
backbone_channel_list: list[int],
|
||||
kernel_size: int = 1,
|
||||
stride: int = 1,
|
||||
padding: int = 0,
|
||||
fpn_interp_model: str = "bilinear",
|
||||
fuse_type: str = "sum",
|
||||
fpn_top_down_levels: list[int] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a modified Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This FPN variant removes the output convolution and uses bicubic interpolation for feature resizing,
|
||||
similar to ViT positional embedding interpolation.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimension of the model.
|
||||
backbone_channel_list (list[int]): List of channel dimensions from the backbone.
|
||||
kernel_size (int): Kernel size for the convolutional layers.
|
||||
stride (int): Stride for the convolutional layers.
|
||||
padding (int): Padding for the convolutional layers.
|
||||
fpn_interp_model (str): Interpolation mode for FPN feature resizing.
|
||||
fuse_type (str): Type of feature fusion, either 'sum' or 'avg'.
|
||||
fpn_top_down_levels (Optional[list[int]]): Levels to have top-down features in outputs.
|
||||
|
||||
Examples:
|
||||
>>> backbone_channels = [64, 128, 256, 512]
|
||||
>>> fpn_neck = FpnNeck(256, backbone_channels)
|
||||
>>> print(fpn_neck)
|
||||
"""
|
||||
super().__init__()
|
||||
self.position_encoding = PositionEmbeddingSine(num_pos_feats=256)
|
||||
self.convs = nn.ModuleList()
|
||||
self.backbone_channel_list = backbone_channel_list
|
||||
for dim in backbone_channel_list:
|
||||
current = nn.Sequential()
|
||||
current.add_module(
|
||||
"conv",
|
||||
nn.Conv2d(
|
||||
in_channels=dim,
|
||||
out_channels=d_model,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
),
|
||||
)
|
||||
|
||||
self.convs.append(current)
|
||||
self.fpn_interp_model = fpn_interp_model
|
||||
assert fuse_type in {"sum", "avg"}
|
||||
self.fuse_type = fuse_type
|
||||
|
||||
# Levels to have top-down features in its outputs
|
||||
# e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
|
||||
# have top-down propagation, while outputs of level 0 and level 1 have only
|
||||
# lateral features from the same backbone level
|
||||
if fpn_top_down_levels is None:
|
||||
# Default is to have top-down features on all levels
|
||||
fpn_top_down_levels = range(len(self.convs))
|
||||
self.fpn_top_down_levels = list(fpn_top_down_levels)
|
||||
|
||||
def forward(self, xs: list[torch.Tensor]):
|
||||
"""
|
||||
Perform forward pass through the Feature Pyramid Network (FPN) neck.
|
||||
|
||||
This method processes a list of input tensors from the backbone through the FPN, applying lateral connections
|
||||
and top-down feature fusion. It generates output feature maps and corresponding positional encodings.
|
||||
|
||||
Args:
|
||||
xs (list[torch.Tensor]): List of input tensors from the backbone, each with shape (B, C, H, W).
|
||||
|
||||
Returns:
|
||||
out (list[torch.Tensor]): List of output feature maps after FPN processing, each with shape
|
||||
(B, d_model, H, W).
|
||||
pos (list[torch.Tensor]): List of positional encodings corresponding to each output feature map.
|
||||
|
||||
Examples:
|
||||
>>> fpn_neck = FpnNeck(d_model=256, backbone_channel_list=[64, 128, 256, 512])
|
||||
>>> inputs = [torch.rand(1, c, 32, 32) for c in [64, 128, 256, 512]]
|
||||
>>> outputs, positions = fpn_neck(inputs)
|
||||
>>> print(len(outputs), len(positions))
|
||||
4 4
|
||||
"""
|
||||
out = [None] * len(self.convs)
|
||||
pos = [None] * len(self.convs)
|
||||
assert len(xs) == len(self.convs)
|
||||
# FPN forward pass
|
||||
# see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
|
||||
prev_features = None
|
||||
# Forward in top-down order (from low to high resolution)
|
||||
n = len(self.convs) - 1
|
||||
for i in range(n, -1, -1):
|
||||
x = xs[i]
|
||||
lateral_features = self.convs[n - i](x)
|
||||
if i in self.fpn_top_down_levels and prev_features is not None:
|
||||
top_down_features = F.interpolate(
|
||||
prev_features.to(dtype=x.dtype),
|
||||
scale_factor=2.0,
|
||||
mode=self.fpn_interp_model,
|
||||
align_corners=(None if self.fpn_interp_model == "nearest" else False),
|
||||
antialias=False,
|
||||
)
|
||||
prev_features = lateral_features + top_down_features
|
||||
if self.fuse_type == "avg":
|
||||
prev_features /= 2
|
||||
else:
|
||||
prev_features = lateral_features
|
||||
x_out = prev_features
|
||||
out[i] = x_out
|
||||
pos[i] = self.position_encoding(x_out).to(x_out.dtype)
|
||||
|
||||
return out, pos
|
||||
|
||||
|
||||
class Hiera(nn.Module):
|
||||
"""
|
||||
Hierarchical vision transformer for efficient multiscale feature extraction in image processing tasks.
|
||||
|
||||
This class implements a Hiera model, which is a hierarchical vision transformer architecture designed for
|
||||
efficient multiscale feature extraction. It uses a series of transformer blocks organized into stages,
|
||||
with optional pooling and global attention mechanisms.
|
||||
|
||||
Attributes:
|
||||
window_spec (tuple[int, ...]): Window sizes for each stage.
|
||||
q_stride (tuple[int, int]): Downsampling stride between stages.
|
||||
stage_ends (list[int]): Indices of the last block in each stage.
|
||||
q_pool_blocks (list[int]): Indices of blocks where pooling is applied.
|
||||
return_interm_layers (bool): Whether to return intermediate layer outputs.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
global_att_blocks (tuple[int, ...]): Indices of blocks with global attention.
|
||||
window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
|
||||
pos_embed (nn.Parameter): Positional embedding for the background.
|
||||
pos_embed_window (nn.Parameter): Positional embedding for the window.
|
||||
blocks (nn.ModuleList): List of MultiScaleBlock modules.
|
||||
channel_list (list[int]): List of output channel dimensions for each stage.
|
||||
|
||||
Methods:
|
||||
_get_pos_embed: Generate positional embeddings by interpolating and combining window and background embeddings.
|
||||
forward: Perform the forward pass through the Hiera model.
|
||||
|
||||
Examples:
|
||||
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output_features = model(input_tensor)
|
||||
>>> for feat in output_features:
|
||||
... print(feat.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int = 96, # initial embed dim
|
||||
num_heads: int = 1, # initial number of heads
|
||||
drop_path_rate: float = 0.0, # stochastic depth
|
||||
q_pool: int = 3, # number of q_pool stages
|
||||
q_stride: tuple[int, int] = (2, 2), # downsample stride bet. stages
|
||||
stages: tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
|
||||
dim_mul: float = 2.0, # dim_mul factor at stage shift
|
||||
head_mul: float = 2.0, # head_mul factor at stage shift
|
||||
window_pos_embed_bkg_spatial_size: tuple[int, int] = (14, 14),
|
||||
# window size per stage, when not using global att.
|
||||
window_spec: tuple[int, ...] = (
|
||||
8,
|
||||
4,
|
||||
14,
|
||||
7,
|
||||
),
|
||||
# global attn in these blocks
|
||||
global_att_blocks: tuple[int, ...] = (
|
||||
12,
|
||||
16,
|
||||
20,
|
||||
),
|
||||
return_interm_layers=True, # return feats from every stage
|
||||
):
|
||||
"""
|
||||
Initialize a Hiera model, a hierarchical vision transformer for efficient multiscale feature extraction.
|
||||
|
||||
Hiera is a hierarchical vision transformer architecture designed for efficient multiscale feature extraction
|
||||
in image processing tasks. It uses a series of transformer blocks organized into stages, with optional
|
||||
pooling and global attention mechanisms.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Initial embedding dimension for the model.
|
||||
num_heads (int): Initial number of attention heads.
|
||||
drop_path_rate (float): Stochastic depth rate.
|
||||
q_pool (int): Number of query pooling stages.
|
||||
q_stride (tuple[int, int]): Downsampling stride between stages.
|
||||
stages (tuple[int, ...]): Number of blocks per stage.
|
||||
dim_mul (float): Dimension multiplier factor at stage transitions.
|
||||
head_mul (float): Head multiplier factor at stage transitions.
|
||||
window_pos_embed_bkg_spatial_size (tuple[int, int]): Spatial size for window positional embedding background.
|
||||
window_spec (tuple[int, ...]): Window sizes for each stage when not using global attention.
|
||||
global_att_blocks (tuple[int, ...]): Indices of blocks that use global attention.
|
||||
return_interm_layers (bool): Whether to return intermediate layer outputs.
|
||||
|
||||
Examples:
|
||||
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output_features = model(input_tensor)
|
||||
>>> for feat in output_features:
|
||||
... print(feat.shape)
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(stages) == len(window_spec)
|
||||
self.window_spec = window_spec
|
||||
|
||||
depth = sum(stages)
|
||||
self.q_stride = q_stride
|
||||
self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
|
||||
assert 0 <= q_pool <= len(self.stage_ends[:-1])
|
||||
self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
|
||||
self.return_interm_layers = return_interm_layers
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
embed_dim=embed_dim,
|
||||
kernel_size=(7, 7),
|
||||
stride=(4, 4),
|
||||
padding=(3, 3),
|
||||
)
|
||||
# Which blocks have global attention?
|
||||
self.global_att_blocks = global_att_blocks
|
||||
|
||||
# Windowed positional embedding (https://arxiv.org/abs/2311.05613)
|
||||
self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size))
|
||||
self.pos_embed_window = nn.Parameter(torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]))
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
cur_stage = 1
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
dim_out = embed_dim
|
||||
# Lags by a block, so first block of next stage uses an initial window size
|
||||
# of previous stage and final window size of current stage
|
||||
window_size = self.window_spec[cur_stage - 1]
|
||||
|
||||
if self.global_att_blocks is not None:
|
||||
window_size = 0 if i in self.global_att_blocks else window_size
|
||||
|
||||
if i - 1 in self.stage_ends:
|
||||
dim_out = int(embed_dim * dim_mul)
|
||||
num_heads = int(num_heads * head_mul)
|
||||
cur_stage += 1
|
||||
|
||||
block = MultiScaleBlock(
|
||||
dim=embed_dim,
|
||||
dim_out=dim_out,
|
||||
num_heads=num_heads,
|
||||
drop_path=dpr[i],
|
||||
q_stride=self.q_stride if i in self.q_pool_blocks else None,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
embed_dim = dim_out
|
||||
self.blocks.append(block)
|
||||
|
||||
self.channel_list = (
|
||||
[self.blocks[i].dim_out for i in self.stage_ends[::-1]]
|
||||
if return_interm_layers
|
||||
else [self.blocks[-1].dim_out]
|
||||
)
|
||||
|
||||
def _get_pos_embed(self, hw: tuple[int, int]) -> torch.Tensor:
|
||||
"""Generate positional embeddings by interpolating and combining window and background embeddings."""
|
||||
h, w = hw
|
||||
window_embed = self.pos_embed_window
|
||||
pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
|
||||
pos_embed = pos_embed + window_embed.tile([x // y for x, y in zip(pos_embed.shape, window_embed.shape)])
|
||||
pos_embed = pos_embed.permute(0, 2, 3, 1)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
Perform forward pass through Hiera model, extracting multiscale features from input images.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (B, C, H, W) representing a batch of images.
|
||||
|
||||
Returns:
|
||||
(list[torch.Tensor]): List of feature maps at different scales, each with shape (B, C_i, H_i, W_i), where
|
||||
C_i is the channel dimension and H_i, W_i are the spatial dimensions at scale i. The list is ordered
|
||||
from highest resolution (fine features) to lowest resolution (coarse features) if return_interm_layers
|
||||
is True, otherwise contains only the final output.
|
||||
|
||||
Examples:
|
||||
>>> model = Hiera(embed_dim=96, num_heads=1, stages=(2, 3, 16, 3))
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output_features = model(input_tensor)
|
||||
>>> for feat in output_features:
|
||||
... print(feat.shape)
|
||||
"""
|
||||
x = self.patch_embed(x)
|
||||
# x: (B, H, W, C)
|
||||
|
||||
# Add positional embedding
|
||||
x = x + self._get_pos_embed(x.shape[1:3])
|
||||
|
||||
outputs = []
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if (i == self.stage_ends[-1]) or (i in self.stage_ends and self.return_interm_layers):
|
||||
feats = x.permute(0, 3, 1, 2)
|
||||
outputs.append(feats)
|
||||
|
||||
return outputs
|
||||
312
ultralytics/models/sam/modules/memory_attention.py
Normal file
312
ultralytics/models/sam/modules/memory_attention.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .blocks import RoPEAttention
|
||||
|
||||
|
||||
class MemoryAttentionLayer(nn.Module):
|
||||
"""
|
||||
Implements a memory attention layer with self-attention and cross-attention mechanisms for neural networks.
|
||||
|
||||
This class combines self-attention, cross-attention, and feedforward components to process input tensors and
|
||||
generate memory-based attention outputs.
|
||||
|
||||
Attributes:
|
||||
d_model (int): Dimensionality of the model.
|
||||
dim_feedforward (int): Dimensionality of the feedforward network.
|
||||
dropout_value (float): Dropout rate for regularization.
|
||||
self_attn (RoPEAttention): Self-attention mechanism using RoPE (Rotary Position Embedding).
|
||||
cross_attn_image (RoPEAttention): Cross-attention mechanism for image processing.
|
||||
linear1 (nn.Linear): First linear layer of the feedforward network.
|
||||
linear2 (nn.Linear): Second linear layer of the feedforward network.
|
||||
norm1 (nn.LayerNorm): Layer normalization for self-attention output.
|
||||
norm2 (nn.LayerNorm): Layer normalization for cross-attention output.
|
||||
norm3 (nn.LayerNorm): Layer normalization for feedforward network output.
|
||||
dropout1 (nn.Dropout): Dropout layer after self-attention.
|
||||
dropout2 (nn.Dropout): Dropout layer after cross-attention.
|
||||
dropout3 (nn.Dropout): Dropout layer after feedforward network.
|
||||
activation (nn.ReLU): Activation function for the feedforward network.
|
||||
pos_enc_at_attn (bool): Flag to add positional encoding at attention.
|
||||
pos_enc_at_cross_attn_queries (bool): Flag to add positional encoding to cross-attention queries.
|
||||
pos_enc_at_cross_attn_keys (bool): Flag to add positional encoding to cross-attention keys.
|
||||
|
||||
Methods:
|
||||
forward: Performs the full memory attention operation on input tensors.
|
||||
_forward_sa: Performs self-attention on input tensor.
|
||||
_forward_ca: Performs cross-attention between target and memory tensors.
|
||||
|
||||
Examples:
|
||||
>>> layer = MemoryAttentionLayer(d_model=256, dim_feedforward=2048, dropout=0.1)
|
||||
>>> tgt = torch.randn(1, 100, 256)
|
||||
>>> memory = torch.randn(1, 100, 64)
|
||||
>>> pos = torch.randn(1, 100, 256)
|
||||
>>> query_pos = torch.randn(1, 100, 256)
|
||||
>>> output = layer(tgt, memory, pos, query_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int = 256,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
pos_enc_at_attn: bool = False,
|
||||
pos_enc_at_cross_attn_keys: bool = True,
|
||||
pos_enc_at_cross_attn_queries: bool = False,
|
||||
):
|
||||
"""
|
||||
Initialize a memory attention layer with self-attention, cross-attention, and feedforward components.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimensionality of the model.
|
||||
dim_feedforward (int): Dimensionality of the feedforward network.
|
||||
dropout (float): Dropout rate for regularization.
|
||||
pos_enc_at_attn (bool): Whether to add positional encoding at attention.
|
||||
pos_enc_at_cross_attn_keys (bool): Whether to add positional encoding to cross-attention keys.
|
||||
pos_enc_at_cross_attn_queries (bool): Whether to add positional encoding to cross-attention queries.
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dim_feedforward = dim_feedforward
|
||||
self.dropout_value = dropout
|
||||
self.self_attn = RoPEAttention(embedding_dim=256, num_heads=1, downsample_rate=1)
|
||||
self.cross_attn_image = RoPEAttention(
|
||||
rope_k_repeat=True,
|
||||
embedding_dim=256,
|
||||
num_heads=1,
|
||||
downsample_rate=1,
|
||||
kv_in_dim=64,
|
||||
)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = nn.ReLU()
|
||||
|
||||
# Where to add pos enc
|
||||
self.pos_enc_at_attn = pos_enc_at_attn
|
||||
self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
|
||||
self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
|
||||
|
||||
def _forward_sa(self, tgt: torch.Tensor, query_pos: torch.Tensor | None) -> torch.Tensor:
|
||||
"""Perform self-attention on input tensor using positional encoding and RoPE attention mechanism."""
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
|
||||
tgt2 = self.self_attn(q, k, v=tgt2)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
return tgt
|
||||
|
||||
def _forward_ca(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
query_pos: torch.Tensor | None,
|
||||
pos: torch.Tensor | None,
|
||||
num_k_exclude_rope: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""Perform cross-attention between target and memory tensors using RoPEAttention mechanism."""
|
||||
kwds = {}
|
||||
if num_k_exclude_rope > 0:
|
||||
assert isinstance(self.cross_attn_image, RoPEAttention)
|
||||
kwds = {"num_k_exclude_rope": num_k_exclude_rope}
|
||||
|
||||
# Cross-Attention
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.cross_attn_image(
|
||||
q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
|
||||
k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
|
||||
v=memory,
|
||||
**kwds,
|
||||
)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
pos: torch.Tensor | None = None,
|
||||
query_pos: torch.Tensor | None = None,
|
||||
num_k_exclude_rope: int = 0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process input tensors through self-attention, cross-attention, and feedforward network layers.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Target tensor for self-attention with shape (N, L, D).
|
||||
memory (torch.Tensor): Memory tensor for cross-attention with shape (N, S, D).
|
||||
pos (Optional[torch.Tensor]): Positional encoding for memory tensor.
|
||||
query_pos (Optional[torch.Tensor]): Positional encoding for target tensor.
|
||||
num_k_exclude_rope (int): Number of keys to exclude from rotary position embedding.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Processed tensor after attention and feedforward layers with shape (N, L, D).
|
||||
"""
|
||||
tgt = self._forward_sa(tgt, query_pos)
|
||||
tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
|
||||
# MLP
|
||||
tgt2 = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
return tgt
|
||||
|
||||
|
||||
class MemoryAttention(nn.Module):
|
||||
"""
|
||||
Memory attention module for processing sequential data with self and cross-attention mechanisms.
|
||||
|
||||
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
||||
for processing sequential data, particularly useful in transformer-like architectures.
|
||||
|
||||
Attributes:
|
||||
d_model (int): The dimension of the model's hidden state.
|
||||
layers (nn.ModuleList): A list of MemoryAttentionLayer modules.
|
||||
num_layers (int): The number of attention layers.
|
||||
norm (nn.LayerNorm): Layer normalization applied to the output.
|
||||
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
||||
batch_first (bool): Whether the input tensors are in batch-first format.
|
||||
|
||||
Methods:
|
||||
forward: Processes input tensors through the attention layers.
|
||||
|
||||
Examples:
|
||||
>>> d_model = 256
|
||||
>>> layer = MemoryAttentionLayer(d_model)
|
||||
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
||||
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
||||
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
||||
>>> curr_pos = torch.randn(10, 32, d_model)
|
||||
>>> memory_pos = torch.randn(20, 32, d_model)
|
||||
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([10, 32, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
pos_enc_at_input: bool,
|
||||
layer: nn.Module,
|
||||
num_layers: int,
|
||||
batch_first: bool = True, # Do layers expect batch first input?
|
||||
):
|
||||
"""
|
||||
Initialize MemoryAttention with specified layers and normalization for sequential data processing.
|
||||
|
||||
This class implements a multi-layer attention mechanism that combines self-attention and cross-attention
|
||||
for processing sequential data, particularly useful in transformer-like architectures.
|
||||
|
||||
Args:
|
||||
d_model (int): The dimension of the model's hidden state.
|
||||
pos_enc_at_input (bool): Whether to apply positional encoding at the input.
|
||||
layer (nn.Module): The attention layer to be used in the module.
|
||||
num_layers (int): The number of attention layers.
|
||||
batch_first (bool): Whether the input tensors are in batch-first format.
|
||||
|
||||
Examples:
|
||||
>>> d_model = 256
|
||||
>>> layer = MemoryAttentionLayer(d_model)
|
||||
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
||||
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
||||
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
||||
>>> curr_pos = torch.randn(10, 32, d_model)
|
||||
>>> memory_pos = torch.randn(20, 32, d_model)
|
||||
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([10, 32, 256])
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.layers = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_layers)])
|
||||
self.num_layers = num_layers
|
||||
self.norm = nn.LayerNorm(d_model)
|
||||
self.pos_enc_at_input = pos_enc_at_input
|
||||
self.batch_first = batch_first
|
||||
|
||||
def forward(
|
||||
self,
|
||||
curr: torch.Tensor, # self-attention inputs
|
||||
memory: torch.Tensor, # cross-attention inputs
|
||||
curr_pos: torch.Tensor | None = None, # pos_enc for self-attention inputs
|
||||
memory_pos: torch.Tensor | None = None, # pos_enc for cross-attention inputs
|
||||
num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Process inputs through attention layers, applying self and cross-attention with positional encoding.
|
||||
|
||||
Args:
|
||||
curr (torch.Tensor): Self-attention input tensor, representing the current state.
|
||||
memory (torch.Tensor): Cross-attention input tensor, representing memory information.
|
||||
curr_pos (Optional[torch.Tensor]): Positional encoding for self-attention inputs.
|
||||
memory_pos (Optional[torch.Tensor]): Positional encoding for cross-attention inputs.
|
||||
num_obj_ptr_tokens (int): Number of object pointer tokens to exclude from rotary position embedding.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Processed output tensor after applying attention layers and normalization.
|
||||
|
||||
Examples:
|
||||
>>> d_model = 256
|
||||
>>> layer = MemoryAttentionLayer(d_model)
|
||||
>>> attention = MemoryAttention(d_model, pos_enc_at_input=True, layer=layer, num_layers=3)
|
||||
>>> curr = torch.randn(10, 32, d_model) # (seq_len, batch_size, d_model)
|
||||
>>> memory = torch.randn(20, 32, d_model) # (mem_len, batch_size, d_model)
|
||||
>>> curr_pos = torch.randn(10, 32, d_model)
|
||||
>>> memory_pos = torch.randn(20, 32, d_model)
|
||||
>>> output = attention(curr, memory, curr_pos, memory_pos)
|
||||
>>> print(output.shape)
|
||||
torch.Size([10, 32, 256])
|
||||
"""
|
||||
if isinstance(curr, list):
|
||||
assert isinstance(curr_pos, list)
|
||||
assert len(curr) == len(curr_pos) == 1
|
||||
curr, curr_pos = curr[0], curr_pos[0]
|
||||
|
||||
assert curr.shape[1] == memory.shape[1], "Batch size must be the same for curr and memory"
|
||||
|
||||
output = curr
|
||||
if self.pos_enc_at_input and curr_pos is not None:
|
||||
output = output + 0.1 * curr_pos
|
||||
|
||||
if self.batch_first:
|
||||
# Convert to batch first
|
||||
output = output.transpose(0, 1)
|
||||
curr_pos = curr_pos.transpose(0, 1)
|
||||
memory = memory.transpose(0, 1)
|
||||
memory_pos = memory_pos.transpose(0, 1)
|
||||
|
||||
for layer in self.layers:
|
||||
kwds = {}
|
||||
if isinstance(layer.cross_attn_image, RoPEAttention):
|
||||
kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
|
||||
|
||||
output = layer(
|
||||
tgt=output,
|
||||
memory=memory,
|
||||
pos=memory_pos,
|
||||
query_pos=curr_pos,
|
||||
**kwds,
|
||||
)
|
||||
normed_output = self.norm(output)
|
||||
|
||||
if self.batch_first:
|
||||
# Convert back to seq first
|
||||
normed_output = normed_output.transpose(0, 1)
|
||||
curr_pos = curr_pos.transpose(0, 1)
|
||||
|
||||
return normed_output
|
||||
1033
ultralytics/models/sam/modules/sam.py
Normal file
1033
ultralytics/models/sam/modules/sam.py
Normal file
File diff suppressed because it is too large
Load Diff
998
ultralytics/models/sam/modules/tiny_encoder.py
Normal file
998
ultralytics/models/sam/modules/tiny_encoder.py
Normal file
@@ -0,0 +1,998 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
# --------------------------------------------------------
|
||||
# TinyViT Model Architecture
|
||||
# Copyright (c) 2022 Microsoft
|
||||
# Adapted from LeViT and Swin Transformer
|
||||
# LeViT: (https://github.com/facebookresearch/levit)
|
||||
# Swin: (https://github.com/microsoft/swin-transformer)
|
||||
# Build the TinyViT Model
|
||||
# --------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.nn.modules import LayerNorm2d
|
||||
from ultralytics.utils.instance import to_2tuple
|
||||
|
||||
|
||||
class Conv2d_BN(torch.nn.Sequential):
|
||||
"""
|
||||
A sequential container that performs 2D convolution followed by batch normalization.
|
||||
|
||||
This module combines a 2D convolution layer with batch normalization, providing a common building block
|
||||
for convolutional neural networks. The batch normalization weights and biases are initialized to specific
|
||||
values for optimal training performance.
|
||||
|
||||
Attributes:
|
||||
c (torch.nn.Conv2d): 2D convolution layer.
|
||||
bn (torch.nn.BatchNorm2d): Batch normalization layer.
|
||||
|
||||
Examples:
|
||||
>>> conv_bn = Conv2d_BN(3, 64, ks=3, stride=1, pad=1)
|
||||
>>> input_tensor = torch.randn(1, 3, 224, 224)
|
||||
>>> output = conv_bn(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 224, 224])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
a: int,
|
||||
b: int,
|
||||
ks: int = 1,
|
||||
stride: int = 1,
|
||||
pad: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bn_weight_init: float = 1,
|
||||
):
|
||||
"""
|
||||
Initialize a sequential container with 2D convolution followed by batch normalization.
|
||||
|
||||
Args:
|
||||
a (int): Number of input channels.
|
||||
b (int): Number of output channels.
|
||||
ks (int, optional): Kernel size for the convolution.
|
||||
stride (int, optional): Stride for the convolution.
|
||||
pad (int, optional): Padding for the convolution.
|
||||
dilation (int, optional): Dilation factor for the convolution.
|
||||
groups (int, optional): Number of groups for the convolution.
|
||||
bn_weight_init (float, optional): Initial value for batch normalization weight.
|
||||
"""
|
||||
super().__init__()
|
||||
self.add_module("c", torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
|
||||
bn = torch.nn.BatchNorm2d(b)
|
||||
torch.nn.init.constant_(bn.weight, bn_weight_init)
|
||||
torch.nn.init.constant_(bn.bias, 0)
|
||||
self.add_module("bn", bn)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
Embed images into patches and project them into a specified embedding dimension.
|
||||
|
||||
This module converts input images into patch embeddings using a sequence of convolutional layers,
|
||||
effectively downsampling the spatial dimensions while increasing the channel dimension.
|
||||
|
||||
Attributes:
|
||||
patches_resolution (tuple[int, int]): Resolution of the patches after embedding.
|
||||
num_patches (int): Total number of patches.
|
||||
in_chans (int): Number of input channels.
|
||||
embed_dim (int): Dimension of the embedding.
|
||||
seq (nn.Sequential): Sequence of convolutional and activation layers for patch embedding.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> patch_embed = PatchEmbed(in_chans=3, embed_dim=96, resolution=224, activation=nn.GELU)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> output = patch_embed(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 96, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans: int, embed_dim: int, resolution: int, activation):
|
||||
"""
|
||||
Initialize patch embedding with convolutional layers for image-to-patch conversion and projection.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
embed_dim (int): Dimension of the embedding.
|
||||
resolution (int): Input image resolution.
|
||||
activation (nn.Module): Activation function to use between convolutions.
|
||||
"""
|
||||
super().__init__()
|
||||
img_size: tuple[int, int] = to_2tuple(resolution)
|
||||
self.patches_resolution = (img_size[0] // 4, img_size[1] // 4)
|
||||
self.num_patches = self.patches_resolution[0] * self.patches_resolution[1]
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
n = embed_dim
|
||||
self.seq = nn.Sequential(
|
||||
Conv2d_BN(in_chans, n // 2, 3, 2, 1),
|
||||
activation(),
|
||||
Conv2d_BN(n // 2, n, 3, 2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input tensor through patch embedding sequence, converting images to patch embeddings."""
|
||||
return self.seq(x)
|
||||
|
||||
|
||||
class MBConv(nn.Module):
|
||||
"""
|
||||
Mobile Inverted Bottleneck Conv (MBConv) layer, part of the EfficientNet architecture.
|
||||
|
||||
This module implements the mobile inverted bottleneck convolution with expansion, depthwise convolution,
|
||||
and projection phases, along with residual connections for improved gradient flow.
|
||||
|
||||
Attributes:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels after expansion.
|
||||
out_chans (int): Number of output channels.
|
||||
conv1 (Conv2d_BN): First convolutional layer for channel expansion.
|
||||
act1 (nn.Module): First activation function.
|
||||
conv2 (Conv2d_BN): Depthwise convolutional layer.
|
||||
act2 (nn.Module): Second activation function.
|
||||
conv3 (Conv2d_BN): Final convolutional layer for projection.
|
||||
act3 (nn.Module): Third activation function.
|
||||
drop_path (nn.Module): Drop path layer (Identity for inference).
|
||||
|
||||
Examples:
|
||||
>>> in_chans, out_chans = 32, 64
|
||||
>>> mbconv = MBConv(in_chans, out_chans, expand_ratio=4, activation=nn.ReLU, drop_path=0.1)
|
||||
>>> x = torch.randn(1, in_chans, 56, 56)
|
||||
>>> output = mbconv(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 64, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans: int, out_chans: int, expand_ratio: float, activation, drop_path: float):
|
||||
"""
|
||||
Initialize the MBConv layer with specified input/output channels, expansion ratio, and activation.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
out_chans (int): Number of output channels.
|
||||
expand_ratio (float): Channel expansion ratio for the hidden layer.
|
||||
activation (nn.Module): Activation function to use.
|
||||
drop_path (float): Drop path rate for stochastic depth.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_chans = in_chans
|
||||
self.hidden_chans = int(in_chans * expand_ratio)
|
||||
self.out_chans = out_chans
|
||||
|
||||
self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1)
|
||||
self.act1 = activation()
|
||||
|
||||
self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, ks=3, stride=1, pad=1, groups=self.hidden_chans)
|
||||
self.act2 = activation()
|
||||
|
||||
self.conv3 = Conv2d_BN(self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0)
|
||||
self.act3 = activation()
|
||||
|
||||
# NOTE: `DropPath` is needed only for training.
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Implement the forward pass of MBConv, applying convolutions and skip connection."""
|
||||
shortcut = x
|
||||
x = self.conv1(x)
|
||||
x = self.act1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.act2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.drop_path(x)
|
||||
x += shortcut
|
||||
return self.act3(x)
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
"""
|
||||
Merge neighboring patches in the feature map and project to a new dimension.
|
||||
|
||||
This class implements a patch merging operation that combines spatial information and adjusts the feature
|
||||
dimension using a series of convolutional layers with batch normalization. It effectively reduces spatial
|
||||
resolution while potentially increasing channel dimensions.
|
||||
|
||||
Attributes:
|
||||
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
|
||||
dim (int): The input dimension of the feature map.
|
||||
out_dim (int): The output dimension after merging and projection.
|
||||
act (nn.Module): The activation function used between convolutions.
|
||||
conv1 (Conv2d_BN): The first convolutional layer for dimension projection.
|
||||
conv2 (Conv2d_BN): The second convolutional layer for spatial merging.
|
||||
conv3 (Conv2d_BN): The third convolutional layer for final projection.
|
||||
|
||||
Examples:
|
||||
>>> input_resolution = (56, 56)
|
||||
>>> patch_merging = PatchMerging(input_resolution, dim=64, out_dim=128, activation=nn.ReLU)
|
||||
>>> x = torch.randn(4, 64, 56, 56)
|
||||
>>> output = patch_merging(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([4, 3136, 128])
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution: tuple[int, int], dim: int, out_dim: int, activation):
|
||||
"""
|
||||
Initialize the PatchMerging module for merging and projecting neighboring patches in feature maps.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int, int]): The input resolution (height, width) of the feature map.
|
||||
dim (int): The input dimension of the feature map.
|
||||
out_dim (int): The output dimension after merging and projection.
|
||||
activation (nn.Module): The activation function used between convolutions.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.out_dim = out_dim
|
||||
self.act = activation()
|
||||
self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0)
|
||||
stride_c = 1 if out_dim in {320, 448, 576} else 2
|
||||
self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim)
|
||||
self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply patch merging and dimension projection to the input feature map."""
|
||||
if x.ndim == 3:
|
||||
H, W = self.input_resolution
|
||||
B = len(x)
|
||||
# (B, C, H, W)
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.act(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
x = self.act(x)
|
||||
x = self.conv3(x)
|
||||
return x.flatten(2).transpose(1, 2)
|
||||
|
||||
|
||||
class ConvLayer(nn.Module):
|
||||
"""
|
||||
Convolutional Layer featuring multiple MobileNetV3-style inverted bottleneck convolutions (MBConv).
|
||||
|
||||
This layer optionally applies downsample operations to the output and supports gradient checkpointing
|
||||
for memory efficiency during training.
|
||||
|
||||
Attributes:
|
||||
dim (int): Dimensionality of the input and output.
|
||||
input_resolution (tuple[int, int]): Resolution of the input image.
|
||||
depth (int): Number of MBConv layers in the block.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of MBConv layers.
|
||||
downsample (Optional[nn.Module]): Function for downsampling the output.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 64, 56, 56)
|
||||
>>> conv_layer = ConvLayer(64, (56, 56), depth=3, activation=nn.ReLU)
|
||||
>>> output = conv_layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 3136, 128])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
input_resolution: tuple[int, int],
|
||||
depth: int,
|
||||
activation,
|
||||
drop_path: float | list[float] = 0.0,
|
||||
downsample: nn.Module | None = None,
|
||||
use_checkpoint: bool = False,
|
||||
out_dim: int | None = None,
|
||||
conv_expand_ratio: float = 4.0,
|
||||
):
|
||||
"""
|
||||
Initialize the ConvLayer with the given dimensions and settings.
|
||||
|
||||
This layer consists of multiple MobileNetV3-style inverted bottleneck convolutions (MBConv) and
|
||||
optionally applies downsampling to the output.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (tuple[int, int]): The resolution of the input image.
|
||||
depth (int): The number of MBConv layers in the block.
|
||||
activation (nn.Module): Activation function applied after each convolution.
|
||||
drop_path (float | list[float], optional): Drop path rate. Single float or a list of floats for each MBConv.
|
||||
downsample (Optional[nn.Module], optional): Function for downsampling the output. None to skip downsampling.
|
||||
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
||||
out_dim (Optional[int], optional): The dimensionality of the output. None means it will be the same as `dim`.
|
||||
conv_expand_ratio (float, optional): Expansion ratio for the MBConv layers.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
MBConv(
|
||||
dim,
|
||||
dim,
|
||||
conv_expand_ratio,
|
||||
activation,
|
||||
drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input through convolutional layers, applying MBConv blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
Multi-layer Perceptron (MLP) module for transformer architectures.
|
||||
|
||||
This module applies layer normalization, two fully-connected layers with an activation function in between,
|
||||
and dropout. It is commonly used in transformer-based architectures for processing token embeddings.
|
||||
|
||||
Attributes:
|
||||
norm (nn.LayerNorm): Layer normalization applied to the input.
|
||||
fc1 (nn.Linear): First fully-connected layer.
|
||||
fc2 (nn.Linear): Second fully-connected layer.
|
||||
act (nn.Module): Activation function applied after the first fully-connected layer.
|
||||
drop (nn.Dropout): Dropout layer applied after the activation function.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> from torch import nn
|
||||
>>> mlp = MLP(in_features=256, hidden_features=512, out_features=256, activation=nn.GELU, drop=0.1)
|
||||
>>> x = torch.randn(32, 100, 256)
|
||||
>>> output = mlp(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([32, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
activation=nn.GELU,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
"""
|
||||
Initialize a multi-layer perceptron with configurable input, hidden, and output dimensions.
|
||||
|
||||
Args:
|
||||
in_features (int): Number of input features.
|
||||
hidden_features (Optional[int], optional): Number of hidden features.
|
||||
out_features (Optional[int], optional): Number of output features.
|
||||
activation (nn.Module): Activation function applied after the first fully-connected layer.
|
||||
drop (float, optional): Dropout probability.
|
||||
"""
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.norm = nn.LayerNorm(in_features)
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.act = activation()
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply MLP operations: layer norm, FC layers, activation, and dropout to the input tensor."""
|
||||
x = self.norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
return self.drop(x)
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
"""
|
||||
Multi-head attention module with spatial awareness and trainable attention biases.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Attributes:
|
||||
num_heads (int): Number of attention heads.
|
||||
scale (float): Scaling factor for attention scores.
|
||||
key_dim (int): Dimensionality of the keys and queries.
|
||||
nh_kd (int): Product of num_heads and key_dim.
|
||||
d (int): Dimensionality of the value vectors.
|
||||
dh (int): Product of d and num_heads.
|
||||
attn_ratio (float): Attention ratio affecting the dimensions of the value vectors.
|
||||
norm (nn.LayerNorm): Layer normalization applied to input.
|
||||
qkv (nn.Linear): Linear layer for computing query, key, and value projections.
|
||||
proj (nn.Linear): Linear layer for final projection.
|
||||
attention_biases (nn.Parameter): Learnable attention biases.
|
||||
attention_bias_idxs (torch.Tensor): Indices for attention biases.
|
||||
ab (torch.Tensor): Cached attention biases for inference, deleted during training.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(dim=256, key_dim=64, num_heads=8, resolution=(14, 14))
|
||||
>>> x = torch.randn(1, 196, 256)
|
||||
>>> output = attn(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
key_dim: int,
|
||||
num_heads: int = 8,
|
||||
attn_ratio: float = 4,
|
||||
resolution: tuple[int, int] = (14, 14),
|
||||
):
|
||||
"""
|
||||
Initialize the Attention module for multi-head attention with spatial awareness.
|
||||
|
||||
This module implements a multi-head attention mechanism with support for spatial awareness, applying
|
||||
attention biases based on spatial resolution. It includes trainable attention biases for each unique
|
||||
offset between spatial positions in the resolution grid.
|
||||
|
||||
Args:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
key_dim (int): The dimensionality of the keys and queries.
|
||||
num_heads (int, optional): Number of attention heads.
|
||||
attn_ratio (float, optional): Attention ratio, affecting the dimensions of the value vectors.
|
||||
resolution (tuple[int, int], optional): Spatial resolution of the input feature map.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(resolution, tuple) and len(resolution) == 2, "'resolution' argument not tuple of length 2"
|
||||
self.num_heads = num_heads
|
||||
self.scale = key_dim**-0.5
|
||||
self.key_dim = key_dim
|
||||
self.nh_kd = nh_kd = key_dim * num_heads
|
||||
self.d = int(attn_ratio * key_dim)
|
||||
self.dh = int(attn_ratio * key_dim) * num_heads
|
||||
self.attn_ratio = attn_ratio
|
||||
h = self.dh + nh_kd * 2
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.qkv = nn.Linear(dim, h)
|
||||
self.proj = nn.Linear(self.dh, dim)
|
||||
|
||||
points = list(itertools.product(range(resolution[0]), range(resolution[1])))
|
||||
N = len(points)
|
||||
attention_offsets = {}
|
||||
idxs = []
|
||||
for p1 in points:
|
||||
for p2 in points:
|
||||
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
|
||||
if offset not in attention_offsets:
|
||||
attention_offsets[offset] = len(attention_offsets)
|
||||
idxs.append(attention_offsets[offset])
|
||||
self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
|
||||
self.register_buffer("attention_bias_idxs", torch.LongTensor(idxs).view(N, N), persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def train(self, mode: bool = True):
|
||||
"""Set the module in training mode and handle the 'ab' attribute for cached attention biases."""
|
||||
super().train(mode)
|
||||
if mode and hasattr(self, "ab"):
|
||||
del self.ab
|
||||
else:
|
||||
self.ab = self.attention_biases[:, self.attention_bias_idxs]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply multi-head attention with spatial awareness and trainable attention biases."""
|
||||
B, N, _ = x.shape # B, N, C
|
||||
|
||||
# Normalization
|
||||
x = self.norm(x)
|
||||
|
||||
qkv = self.qkv(x)
|
||||
# (B, N, num_heads, d)
|
||||
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
|
||||
# (B, num_heads, N, d)
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
self.ab = self.ab.to(self.attention_biases.device)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale + (
|
||||
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
|
||||
)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class TinyViTBlock(nn.Module):
|
||||
"""
|
||||
TinyViT Block that applies self-attention and a local convolution to the input.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently. It supports windowed attention for
|
||||
computational efficiency and includes residual connections.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output.
|
||||
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Size of the attention window.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop_path (nn.Module): Stochastic depth layer, identity function during inference.
|
||||
attn (Attention): Self-attention module.
|
||||
mlp (MLP): Multi-layer perceptron module.
|
||||
local_conv (Conv2d_BN): Depth-wise local convolution layer.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 196, 192)
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3)
|
||||
>>> output = block(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 196, 192])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
input_resolution: tuple[int, int],
|
||||
num_heads: int,
|
||||
window_size: int = 7,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop: float = 0.0,
|
||||
drop_path: float = 0.0,
|
||||
local_conv_size: int = 3,
|
||||
activation=nn.GELU,
|
||||
):
|
||||
"""
|
||||
Initialize a TinyViT block with self-attention and local convolution.
|
||||
|
||||
This block is a key component of the TinyViT architecture, combining self-attention mechanisms with
|
||||
local convolutions to process input features efficiently.
|
||||
|
||||
Args:
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int, optional): Size of the attention window. Must be greater than 0.
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float, optional): Dropout rate.
|
||||
drop_path (float, optional): Stochastic depth rate.
|
||||
local_conv_size (int, optional): Kernel size of the local convolution.
|
||||
activation (nn.Module): Activation function for MLP.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
assert window_size > 0, "window_size must be greater than 0"
|
||||
self.window_size = window_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# NOTE: `DropPath` is needed only for training.
|
||||
# self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.drop_path = nn.Identity()
|
||||
|
||||
assert dim % num_heads == 0, "dim must be divisible by num_heads"
|
||||
head_dim = dim // num_heads
|
||||
|
||||
window_resolution = (window_size, window_size)
|
||||
self.attn = Attention(dim, head_dim, num_heads, attn_ratio=1, resolution=window_resolution)
|
||||
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
mlp_activation = activation
|
||||
self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=mlp_activation, drop=drop)
|
||||
|
||||
pad = local_conv_size // 2
|
||||
self.local_conv = Conv2d_BN(dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply self-attention, local convolution, and MLP operations to the input tensor."""
|
||||
h, w = self.input_resolution
|
||||
b, hw, c = x.shape # batch, height*width, channels
|
||||
assert hw == h * w, "input feature has wrong size"
|
||||
res_x = x
|
||||
if h == self.window_size and w == self.window_size:
|
||||
x = self.attn(x)
|
||||
else:
|
||||
x = x.view(b, h, w, c)
|
||||
pad_b = (self.window_size - h % self.window_size) % self.window_size
|
||||
pad_r = (self.window_size - w % self.window_size) % self.window_size
|
||||
padding = pad_b > 0 or pad_r > 0
|
||||
if padding:
|
||||
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
||||
|
||||
pH, pW = h + pad_b, w + pad_r
|
||||
nH = pH // self.window_size
|
||||
nW = pW // self.window_size
|
||||
|
||||
# Window partition
|
||||
x = (
|
||||
x.view(b, nH, self.window_size, nW, self.window_size, c)
|
||||
.transpose(2, 3)
|
||||
.reshape(b * nH * nW, self.window_size * self.window_size, c)
|
||||
)
|
||||
x = self.attn(x)
|
||||
|
||||
# Window reverse
|
||||
x = x.view(b, nH, nW, self.window_size, self.window_size, c).transpose(2, 3).reshape(b, pH, pW, c)
|
||||
if padding:
|
||||
x = x[:, :h, :w].contiguous()
|
||||
|
||||
x = x.view(b, hw, c)
|
||||
|
||||
x = res_x + self.drop_path(x)
|
||||
x = x.transpose(1, 2).reshape(b, c, h, w)
|
||||
x = self.local_conv(x)
|
||||
x = x.view(b, c, hw).transpose(1, 2)
|
||||
|
||||
return x + self.drop_path(self.mlp(x))
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""
|
||||
Return a string representation of the TinyViTBlock's parameters.
|
||||
|
||||
This method provides a formatted string containing key information about the TinyViTBlock, including its
|
||||
dimension, input resolution, number of attention heads, window size, and MLP ratio.
|
||||
|
||||
Returns:
|
||||
(str): A formatted string containing the block's parameters.
|
||||
|
||||
Examples:
|
||||
>>> block = TinyViTBlock(dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0)
|
||||
>>> print(block.extra_repr())
|
||||
dim=192, input_resolution=(14, 14), num_heads=3, window_size=7, mlp_ratio=4.0
|
||||
"""
|
||||
return (
|
||||
f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
|
||||
f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
|
||||
)
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
"""
|
||||
A basic TinyViT layer for one stage in a TinyViT architecture.
|
||||
|
||||
This class represents a single layer in the TinyViT model, consisting of multiple TinyViT blocks
|
||||
and an optional downsampling operation. It processes features at a specific resolution and
|
||||
dimensionality within the overall architecture.
|
||||
|
||||
Attributes:
|
||||
dim (int): The dimensionality of the input and output features.
|
||||
input_resolution (tuple[int, int]): Spatial resolution of the input feature map.
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
use_checkpoint (bool): Whether to use gradient checkpointing to save memory.
|
||||
blocks (nn.ModuleList): List of TinyViT blocks that make up this layer.
|
||||
downsample (nn.Module | None): Downsample layer at the end of the layer, if specified.
|
||||
|
||||
Examples:
|
||||
>>> input_tensor = torch.randn(1, 3136, 192)
|
||||
>>> layer = BasicLayer(dim=192, input_resolution=(56, 56), depth=2, num_heads=3, window_size=7)
|
||||
>>> output = layer(input_tensor)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 784, 384])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
input_resolution: tuple[int, int],
|
||||
depth: int,
|
||||
num_heads: int,
|
||||
window_size: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
drop: float = 0.0,
|
||||
drop_path: float | list[float] = 0.0,
|
||||
downsample: nn.Module | None = None,
|
||||
use_checkpoint: bool = False,
|
||||
local_conv_size: int = 3,
|
||||
activation=nn.GELU,
|
||||
out_dim: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize a BasicLayer in the TinyViT architecture.
|
||||
|
||||
This layer consists of multiple TinyViT blocks and an optional downsampling operation. It is designed to
|
||||
process feature maps at a specific resolution and dimensionality within the TinyViT model.
|
||||
|
||||
Args:
|
||||
dim (int): Dimensionality of the input and output features.
|
||||
input_resolution (tuple[int, int]): Spatial resolution of the input feature map (height, width).
|
||||
depth (int): Number of TinyViT blocks in this layer.
|
||||
num_heads (int): Number of attention heads in each TinyViT block.
|
||||
window_size (int): Size of the local window for attention computation.
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dimension to embedding dimension.
|
||||
drop (float, optional): Dropout rate.
|
||||
drop_path (float | list[float], optional): Stochastic depth rate. Can be a float or a list of floats for each block.
|
||||
downsample (nn.Module | None, optional): Downsampling layer at the end of the layer. None to skip downsampling.
|
||||
use_checkpoint (bool, optional): Whether to use gradient checkpointing to save memory.
|
||||
local_conv_size (int, optional): Kernel size for the local convolution in each TinyViT block.
|
||||
activation (nn.Module): Activation function used in the MLP.
|
||||
out_dim (int | None, optional): Output dimension after downsampling. None means it will be the same as `dim`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# Build blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TinyViTBlock(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
drop=drop,
|
||||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
||||
local_conv_size=local_conv_size,
|
||||
activation=activation,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Patch merging layer
|
||||
self.downsample = (
|
||||
None
|
||||
if downsample is None
|
||||
else downsample(input_resolution, dim=dim, out_dim=out_dim, activation=activation)
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input through TinyViT blocks and optional downsampling."""
|
||||
for blk in self.blocks:
|
||||
x = torch.utils.checkpoint(blk, x) if self.use_checkpoint else blk(x) # warn: checkpoint is slow import
|
||||
return x if self.downsample is None else self.downsample(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
"""Return a string with the layer's parameters for printing."""
|
||||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
||||
|
||||
|
||||
class TinyViT(nn.Module):
|
||||
"""
|
||||
TinyViT: A compact vision transformer architecture for efficient image classification and feature extraction.
|
||||
|
||||
This class implements the TinyViT model, which combines elements of vision transformers and convolutional
|
||||
neural networks for improved efficiency and performance on vision tasks. It features hierarchical processing
|
||||
with patch embedding, multiple stages of attention and convolution blocks, and a feature refinement neck.
|
||||
|
||||
Attributes:
|
||||
img_size (int): Input image size.
|
||||
num_classes (int): Number of classification classes.
|
||||
depths (tuple[int, int, int, int]): Number of blocks in each stage.
|
||||
num_layers (int): Total number of layers in the network.
|
||||
mlp_ratio (float): Ratio of MLP hidden dimension to embedding dimension.
|
||||
patch_embed (PatchEmbed): Module for patch embedding.
|
||||
patches_resolution (tuple[int, int]): Resolution of embedded patches.
|
||||
layers (nn.ModuleList): List of network layers.
|
||||
norm_head (nn.LayerNorm): Layer normalization for the classifier head.
|
||||
head (nn.Linear): Linear layer for final classification.
|
||||
neck (nn.Sequential): Neck module for feature refinement.
|
||||
|
||||
Examples:
|
||||
>>> model = TinyViT(img_size=224, num_classes=1000)
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> features = model.forward_features(x)
|
||||
>>> print(features.shape)
|
||||
torch.Size([1, 256, 56, 56])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: int = 224,
|
||||
in_chans: int = 3,
|
||||
num_classes: int = 1000,
|
||||
embed_dims: tuple[int, int, int, int] = (96, 192, 384, 768),
|
||||
depths: tuple[int, int, int, int] = (2, 2, 6, 2),
|
||||
num_heads: tuple[int, int, int, int] = (3, 6, 12, 24),
|
||||
window_sizes: tuple[int, int, int, int] = (7, 7, 14, 7),
|
||||
mlp_ratio: float = 4.0,
|
||||
drop_rate: float = 0.0,
|
||||
drop_path_rate: float = 0.1,
|
||||
use_checkpoint: bool = False,
|
||||
mbconv_expand_ratio: float = 4.0,
|
||||
local_conv_size: int = 3,
|
||||
layer_lr_decay: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize the TinyViT model.
|
||||
|
||||
This constructor sets up the TinyViT architecture, including patch embedding, multiple layers of
|
||||
attention and convolution blocks, and a classification head.
|
||||
|
||||
Args:
|
||||
img_size (int, optional): Size of the input image.
|
||||
in_chans (int, optional): Number of input channels.
|
||||
num_classes (int, optional): Number of classes for classification.
|
||||
embed_dims (tuple[int, int, int, int], optional): Embedding dimensions for each stage.
|
||||
depths (tuple[int, int, int, int], optional): Number of blocks in each stage.
|
||||
num_heads (tuple[int, int, int, int], optional): Number of attention heads in each stage.
|
||||
window_sizes (tuple[int, int, int, int], optional): Window sizes for each stage.
|
||||
mlp_ratio (float, optional): Ratio of MLP hidden dim to embedding dim.
|
||||
drop_rate (float, optional): Dropout rate.
|
||||
drop_path_rate (float, optional): Stochastic depth rate.
|
||||
use_checkpoint (bool, optional): Whether to use checkpointing to save memory.
|
||||
mbconv_expand_ratio (float, optional): Expansion ratio for MBConv layer.
|
||||
local_conv_size (int, optional): Kernel size for local convolutions.
|
||||
layer_lr_decay (float, optional): Layer-wise learning rate decay factor.
|
||||
"""
|
||||
super().__init__()
|
||||
self.img_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.depths = depths
|
||||
self.num_layers = len(depths)
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
activation = nn.GELU
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
in_chans=in_chans, embed_dim=embed_dims[0], resolution=img_size, activation=activation
|
||||
)
|
||||
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# Stochastic depth
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
||||
|
||||
# Build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
kwargs = dict(
|
||||
dim=embed_dims[i_layer],
|
||||
input_resolution=(
|
||||
patches_resolution[0] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
patches_resolution[1] // (2 ** (i_layer - 1 if i_layer == 3 else i_layer)),
|
||||
),
|
||||
# input_resolution=(patches_resolution[0] // (2 ** i_layer),
|
||||
# patches_resolution[1] // (2 ** i_layer)),
|
||||
depth=depths[i_layer],
|
||||
drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
|
||||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
out_dim=embed_dims[min(i_layer + 1, len(embed_dims) - 1)],
|
||||
activation=activation,
|
||||
)
|
||||
if i_layer == 0:
|
||||
layer = ConvLayer(conv_expand_ratio=mbconv_expand_ratio, **kwargs)
|
||||
else:
|
||||
layer = BasicLayer(
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_sizes[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
drop=drop_rate,
|
||||
local_conv_size=local_conv_size,
|
||||
**kwargs,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
# Classifier head
|
||||
self.norm_head = nn.LayerNorm(embed_dims[-1])
|
||||
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity()
|
||||
|
||||
# Init weights
|
||||
self.apply(self._init_weights)
|
||||
self.set_layer_lr_decay(layer_lr_decay)
|
||||
self.neck = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
embed_dims[-1],
|
||||
256,
|
||||
kernel_size=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(256),
|
||||
nn.Conv2d(
|
||||
256,
|
||||
256,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
bias=False,
|
||||
),
|
||||
LayerNorm2d(256),
|
||||
)
|
||||
|
||||
def set_layer_lr_decay(self, layer_lr_decay: float):
|
||||
"""Set layer-wise learning rate decay for the TinyViT model based on depth."""
|
||||
decay_rate = layer_lr_decay
|
||||
|
||||
# Layers -> blocks (depth)
|
||||
depth = sum(self.depths)
|
||||
lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)]
|
||||
|
||||
def _set_lr_scale(m, scale):
|
||||
"""Set the learning rate scale for each layer in the model based on the layer's depth."""
|
||||
for p in m.parameters():
|
||||
p.lr_scale = scale
|
||||
|
||||
self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0]))
|
||||
i = 0
|
||||
for layer in self.layers:
|
||||
for block in layer.blocks:
|
||||
block.apply(lambda x: _set_lr_scale(x, lr_scales[i]))
|
||||
i += 1
|
||||
if layer.downsample is not None:
|
||||
layer.downsample.apply(lambda x: _set_lr_scale(x, lr_scales[i - 1]))
|
||||
assert i == depth
|
||||
for m in {self.norm_head, self.head}:
|
||||
m.apply(lambda x: _set_lr_scale(x, lr_scales[-1]))
|
||||
|
||||
for k, p in self.named_parameters():
|
||||
p.param_name = k
|
||||
|
||||
def _check_lr_scale(m):
|
||||
"""Check if the learning rate scale attribute is present in module's parameters."""
|
||||
for p in m.parameters():
|
||||
assert hasattr(p, "lr_scale"), p.param_name
|
||||
|
||||
self.apply(_check_lr_scale)
|
||||
|
||||
@staticmethod
|
||||
def _init_weights(m):
|
||||
"""Initialize weights for linear and normalization layers in the TinyViT model."""
|
||||
if isinstance(m, nn.Linear):
|
||||
# NOTE: This initialization is needed only for training.
|
||||
# trunc_normal_(m.weight, std=.02)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
"""Return a set of keywords for parameters that should not use weight decay."""
|
||||
return {"attention_biases"}
|
||||
|
||||
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input through feature extraction layers, returning spatial features."""
|
||||
x = self.patch_embed(x) # x input is (N, C, H, W)
|
||||
|
||||
x = self.layers[0](x)
|
||||
start_i = 1
|
||||
|
||||
for i in range(start_i, len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
x = layer(x)
|
||||
batch, _, channel = x.shape
|
||||
x = x.view(batch, self.patches_resolution[0] // 4, self.patches_resolution[1] // 4, channel)
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
return self.neck(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform the forward pass through the TinyViT model, extracting features from the input image."""
|
||||
return self.forward_features(x)
|
||||
|
||||
def set_imgsz(self, imgsz: list[int] = [1024, 1024]):
|
||||
"""Set image size to make model compatible with different image sizes."""
|
||||
imgsz = [s // 4 for s in imgsz]
|
||||
self.patches_resolution = imgsz
|
||||
for i, layer in enumerate(self.layers):
|
||||
input_resolution = (
|
||||
imgsz[0] // (2 ** (i - 1 if i == 3 else i)),
|
||||
imgsz[1] // (2 ** (i - 1 if i == 3 else i)),
|
||||
)
|
||||
layer.input_resolution = input_resolution
|
||||
if layer.downsample is not None:
|
||||
layer.downsample.input_resolution = input_resolution
|
||||
if isinstance(layer, BasicLayer):
|
||||
for b in layer.blocks:
|
||||
b.input_resolution = input_resolution
|
||||
354
ultralytics/models/sam/modules/transformer.py
Normal file
354
ultralytics/models/sam/modules/transformer.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from ultralytics.nn.modules import MLPBlock
|
||||
|
||||
|
||||
class TwoWayTransformer(nn.Module):
|
||||
"""
|
||||
A Two-Way Transformer module for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer decoder that attends to an input image using queries with
|
||||
supplied positional embeddings. It's useful for tasks like object detection, image segmentation, and point
|
||||
cloud processing.
|
||||
|
||||
Attributes:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
layers (nn.ModuleList): List of TwoWayAttentionBlock layers composing the transformer.
|
||||
final_attn_token_to_image (Attention): Final attention layer from queries to image.
|
||||
norm_final_attn (nn.LayerNorm): Layer normalization applied to final queries.
|
||||
|
||||
Methods:
|
||||
forward: Process image and point embeddings through the transformer.
|
||||
|
||||
Examples:
|
||||
>>> transformer = TwoWayTransformer(depth=6, embedding_dim=256, num_heads=8, mlp_dim=2048)
|
||||
>>> image_embedding = torch.randn(1, 256, 32, 32)
|
||||
>>> image_pe = torch.randn(1, 256, 32, 32)
|
||||
>>> point_embedding = torch.randn(1, 100, 256)
|
||||
>>> output_queries, output_image = transformer(image_embedding, image_pe, point_embedding)
|
||||
>>> print(output_queries.shape, output_image.shape)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
depth: int,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int,
|
||||
activation: type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a Two-Way Transformer for simultaneous attention to image and query points.
|
||||
|
||||
Args:
|
||||
depth (int): Number of layers in the transformer.
|
||||
embedding_dim (int): Channel dimension for input embeddings.
|
||||
num_heads (int): Number of heads for multihead attention. Must divide embedding_dim.
|
||||
mlp_dim (int): Internal channel dimension for the MLP block.
|
||||
activation (Type[nn.Module], optional): Activation function to use in the MLP block.
|
||||
attention_downsample_rate (int, optional): Downsampling rate for attention mechanism.
|
||||
"""
|
||||
super().__init__()
|
||||
self.depth = depth
|
||||
self.embedding_dim = embedding_dim
|
||||
self.num_heads = num_heads
|
||||
self.mlp_dim = mlp_dim
|
||||
self.layers = nn.ModuleList()
|
||||
|
||||
for i in range(depth):
|
||||
self.layers.append(
|
||||
TwoWayAttentionBlock(
|
||||
embedding_dim=embedding_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_dim=mlp_dim,
|
||||
activation=activation,
|
||||
attention_downsample_rate=attention_downsample_rate,
|
||||
skip_first_layer_pe=(i == 0),
|
||||
)
|
||||
)
|
||||
|
||||
self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm_final_attn = nn.LayerNorm(embedding_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image_embedding: torch.Tensor,
|
||||
image_pe: torch.Tensor,
|
||||
point_embedding: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Process image and point embeddings through the Two-Way Transformer.
|
||||
|
||||
Args:
|
||||
image_embedding (torch.Tensor): Image to attend to, with shape (B, embedding_dim, H, W).
|
||||
image_pe (torch.Tensor): Positional encoding to add to the image, with same shape as image_embedding.
|
||||
point_embedding (torch.Tensor): Embedding to add to query points, with shape (B, N_points, embedding_dim).
|
||||
|
||||
Returns:
|
||||
queries (torch.Tensor): Processed point embeddings with shape (B, N_points, embedding_dim).
|
||||
keys (torch.Tensor): Processed image embeddings with shape (B, H*W, embedding_dim).
|
||||
"""
|
||||
# BxCxHxW -> BxHWxC == B x N_image_tokens x C
|
||||
image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
|
||||
image_pe = image_pe.flatten(2).permute(0, 2, 1)
|
||||
|
||||
# Prepare queries
|
||||
queries = point_embedding
|
||||
keys = image_embedding
|
||||
|
||||
# Apply transformer blocks and final layernorm
|
||||
for layer in self.layers:
|
||||
queries, keys = layer(
|
||||
queries=queries,
|
||||
keys=keys,
|
||||
query_pe=point_embedding,
|
||||
key_pe=image_pe,
|
||||
)
|
||||
|
||||
# Apply the final attention layer from the points to the image
|
||||
q = queries + point_embedding
|
||||
k = keys + image_pe
|
||||
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm_final_attn(queries)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class TwoWayAttentionBlock(nn.Module):
|
||||
"""
|
||||
A two-way attention block for simultaneous attention to image and query points.
|
||||
|
||||
This class implements a specialized transformer block with four main layers: self-attention on sparse inputs,
|
||||
cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention of dense
|
||||
inputs to sparse inputs.
|
||||
|
||||
Attributes:
|
||||
self_attn (Attention): Self-attention layer for queries.
|
||||
norm1 (nn.LayerNorm): Layer normalization after self-attention.
|
||||
cross_attn_token_to_image (Attention): Cross-attention layer from queries to keys.
|
||||
norm2 (nn.LayerNorm): Layer normalization after token-to-image attention.
|
||||
mlp (MLPBlock): MLP block for transforming query embeddings.
|
||||
norm3 (nn.LayerNorm): Layer normalization after MLP block.
|
||||
norm4 (nn.LayerNorm): Layer normalization after image-to-token attention.
|
||||
cross_attn_image_to_token (Attention): Cross-attention layer from keys to queries.
|
||||
skip_first_layer_pe (bool): Whether to skip positional encoding in the first layer.
|
||||
|
||||
Methods:
|
||||
forward: Apply self-attention and cross-attention to queries and keys.
|
||||
|
||||
Examples:
|
||||
>>> embedding_dim, num_heads = 256, 8
|
||||
>>> block = TwoWayAttentionBlock(embedding_dim, num_heads)
|
||||
>>> queries = torch.randn(1, 100, embedding_dim)
|
||||
>>> keys = torch.randn(1, 1000, embedding_dim)
|
||||
>>> query_pe = torch.randn(1, 100, embedding_dim)
|
||||
>>> key_pe = torch.randn(1, 1000, embedding_dim)
|
||||
>>> processed_queries, processed_keys = block(queries, keys, query_pe, key_pe)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
mlp_dim: int = 2048,
|
||||
activation: type[nn.Module] = nn.ReLU,
|
||||
attention_downsample_rate: int = 2,
|
||||
skip_first_layer_pe: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a TwoWayAttentionBlock for simultaneous attention to image and query points.
|
||||
|
||||
This block implements a specialized transformer layer with four main components: self-attention on sparse
|
||||
inputs, cross-attention of sparse inputs to dense inputs, MLP block on sparse inputs, and cross-attention
|
||||
of dense inputs to sparse inputs.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): Channel dimension of the embeddings.
|
||||
num_heads (int): Number of attention heads in the attention layers.
|
||||
mlp_dim (int, optional): Hidden dimension of the MLP block.
|
||||
activation (Type[nn.Module], optional): Activation function for the MLP block.
|
||||
attention_downsample_rate (int, optional): Downsampling rate for the attention mechanism.
|
||||
skip_first_layer_pe (bool, optional): Whether to skip positional encoding in the first layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.self_attn = Attention(embedding_dim, num_heads)
|
||||
self.norm1 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
self.norm2 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.mlp = MLPBlock(embedding_dim, mlp_dim, activation)
|
||||
self.norm3 = nn.LayerNorm(embedding_dim)
|
||||
|
||||
self.norm4 = nn.LayerNorm(embedding_dim)
|
||||
self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
|
||||
|
||||
self.skip_first_layer_pe = skip_first_layer_pe
|
||||
|
||||
def forward(
|
||||
self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply two-way attention to process query and key embeddings in a transformer block.
|
||||
|
||||
Args:
|
||||
queries (torch.Tensor): Query embeddings with shape (B, N_queries, embedding_dim).
|
||||
keys (torch.Tensor): Key embeddings with shape (B, N_keys, embedding_dim).
|
||||
query_pe (torch.Tensor): Positional encodings for queries with same shape as queries.
|
||||
key_pe (torch.Tensor): Positional encodings for keys with same shape as keys.
|
||||
|
||||
Returns:
|
||||
queries (torch.Tensor): Processed query embeddings with shape (B, N_queries, embedding_dim).
|
||||
keys (torch.Tensor): Processed key embeddings with shape (B, N_keys, embedding_dim).
|
||||
"""
|
||||
# Self attention block
|
||||
if self.skip_first_layer_pe:
|
||||
queries = self.self_attn(q=queries, k=queries, v=queries)
|
||||
else:
|
||||
q = queries + query_pe
|
||||
attn_out = self.self_attn(q=q, k=q, v=queries)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm1(queries)
|
||||
|
||||
# Cross attention block, tokens attending to image embedding
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
|
||||
queries = queries + attn_out
|
||||
queries = self.norm2(queries)
|
||||
|
||||
# MLP block
|
||||
mlp_out = self.mlp(queries)
|
||||
queries = queries + mlp_out
|
||||
queries = self.norm3(queries)
|
||||
|
||||
# Cross attention block, image embedding attending to tokens
|
||||
q = queries + query_pe
|
||||
k = keys + key_pe
|
||||
attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
|
||||
keys = keys + attn_out
|
||||
keys = self.norm4(keys)
|
||||
|
||||
return queries, keys
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
An attention layer with downscaling capability for embedding size after projection.
|
||||
|
||||
This class implements a multi-head attention mechanism with the option to downsample the internal
|
||||
dimension of queries, keys, and values.
|
||||
|
||||
Attributes:
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
kv_in_dim (int): Dimensionality of key and value inputs.
|
||||
internal_dim (int): Internal dimension after downsampling.
|
||||
num_heads (int): Number of attention heads.
|
||||
q_proj (nn.Linear): Linear projection for queries.
|
||||
k_proj (nn.Linear): Linear projection for keys.
|
||||
v_proj (nn.Linear): Linear projection for values.
|
||||
out_proj (nn.Linear): Linear projection for output.
|
||||
|
||||
Methods:
|
||||
_separate_heads: Separate input tensor into attention heads.
|
||||
_recombine_heads: Recombine separated attention heads.
|
||||
forward: Compute attention output for given query, key, and value tensors.
|
||||
|
||||
Examples:
|
||||
>>> attn = Attention(embedding_dim=256, num_heads=8, downsample_rate=2)
|
||||
>>> q = torch.randn(1, 100, 256)
|
||||
>>> k = v = torch.randn(1, 50, 256)
|
||||
>>> output = attn(q, k, v)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 100, 256])
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
num_heads: int,
|
||||
downsample_rate: int = 1,
|
||||
kv_in_dim: int = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the Attention module with specified dimensions and settings.
|
||||
|
||||
Args:
|
||||
embedding_dim (int): Dimensionality of input embeddings.
|
||||
num_heads (int): Number of attention heads.
|
||||
downsample_rate (int, optional): Factor by which internal dimensions are downsampled.
|
||||
kv_in_dim (int | None, optional): Dimensionality of key and value inputs. If None, uses embedding_dim.
|
||||
|
||||
Raises:
|
||||
AssertionError: If num_heads does not evenly divide the internal dim (embedding_dim / downsample_rate).
|
||||
"""
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
|
||||
self.internal_dim = embedding_dim // downsample_rate
|
||||
self.num_heads = num_heads
|
||||
assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
|
||||
|
||||
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
|
||||
self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
|
||||
self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
|
||||
|
||||
@staticmethod
|
||||
def _separate_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
|
||||
"""Separate the input tensor into the specified number of attention heads."""
|
||||
b, n, c = x.shape
|
||||
x = x.reshape(b, n, num_heads, c // num_heads)
|
||||
return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
|
||||
|
||||
@staticmethod
|
||||
def _recombine_heads(x: Tensor) -> Tensor:
|
||||
"""Recombine separated attention heads into a single tensor."""
|
||||
b, n_heads, n_tokens, c_per_head = x.shape
|
||||
x = x.transpose(1, 2)
|
||||
return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply multi-head attention to query, key, and value tensors with optional downsampling.
|
||||
|
||||
Args:
|
||||
q (torch.Tensor): Query tensor with shape (B, N_q, embedding_dim).
|
||||
k (torch.Tensor): Key tensor with shape (B, N_k, embedding_dim).
|
||||
v (torch.Tensor): Value tensor with shape (B, N_k, embedding_dim).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Output tensor after attention with shape (B, N_q, embedding_dim).
|
||||
"""
|
||||
# Input projections
|
||||
q = self.q_proj(q)
|
||||
k = self.k_proj(k)
|
||||
v = self.v_proj(v)
|
||||
|
||||
# Separate into heads
|
||||
q = self._separate_heads(q, self.num_heads)
|
||||
k = self._separate_heads(k, self.num_heads)
|
||||
v = self._separate_heads(v, self.num_heads)
|
||||
|
||||
# Attention
|
||||
_, _, _, c_per_head = q.shape
|
||||
attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
|
||||
attn = attn / math.sqrt(c_per_head)
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
|
||||
# Get output
|
||||
out = attn @ v
|
||||
out = self._recombine_heads(out)
|
||||
return self.out_proj(out)
|
||||
388
ultralytics/models/sam/modules/utils.py
Normal file
388
ultralytics/models/sam/modules/utils.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def select_closest_cond_frames(frame_idx: int, cond_frame_outputs: dict[int, Any], max_cond_frame_num: int):
|
||||
"""
|
||||
Select the closest conditioning frames to a given frame index.
|
||||
|
||||
Args:
|
||||
frame_idx (int): Current frame index.
|
||||
cond_frame_outputs (dict[int, Any]): Dictionary of conditioning frame outputs keyed by frame indices.
|
||||
max_cond_frame_num (int): Maximum number of conditioning frames to select.
|
||||
|
||||
Returns:
|
||||
selected_outputs (dict[int, Any]): Selected items from cond_frame_outputs.
|
||||
unselected_outputs (dict[int, Any]): Items not selected from cond_frame_outputs.
|
||||
|
||||
Examples:
|
||||
>>> frame_idx = 5
|
||||
>>> cond_frame_outputs = {1: "a", 3: "b", 7: "c", 9: "d"}
|
||||
>>> max_cond_frame_num = 2
|
||||
>>> selected, unselected = select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num)
|
||||
>>> print(selected)
|
||||
{3: 'b', 7: 'c'}
|
||||
>>> print(unselected)
|
||||
{1: 'a', 9: 'd'}
|
||||
"""
|
||||
if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
|
||||
selected_outputs = cond_frame_outputs
|
||||
unselected_outputs = {}
|
||||
else:
|
||||
assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
|
||||
selected_outputs = {}
|
||||
|
||||
# The closest conditioning frame before `frame_idx` (if any)
|
||||
idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
|
||||
if idx_before is not None:
|
||||
selected_outputs[idx_before] = cond_frame_outputs[idx_before]
|
||||
|
||||
# The closest conditioning frame after `frame_idx` (if any)
|
||||
idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
|
||||
if idx_after is not None:
|
||||
selected_outputs[idx_after] = cond_frame_outputs[idx_after]
|
||||
|
||||
# Add other temporally closest conditioning frames until reaching a total
|
||||
# of `max_cond_frame_num` conditioning frames.
|
||||
num_remain = max_cond_frame_num - len(selected_outputs)
|
||||
inds_remain = sorted(
|
||||
(t for t in cond_frame_outputs if t not in selected_outputs),
|
||||
key=lambda x: abs(x - frame_idx),
|
||||
)[:num_remain]
|
||||
selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
|
||||
unselected_outputs = {t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs}
|
||||
|
||||
return selected_outputs, unselected_outputs
|
||||
|
||||
|
||||
def get_1d_sine_pe(pos_inds: torch.Tensor, dim: int, temperature: float = 10000):
|
||||
"""
|
||||
Generate 1D sinusoidal positional embeddings for given positions and dimensions.
|
||||
|
||||
Args:
|
||||
pos_inds (torch.Tensor): Position indices for which to generate embeddings.
|
||||
dim (int): Dimension of the positional embeddings. Should be an even number.
|
||||
temperature (float, optional): Scaling factor for the frequency of the sinusoidal functions.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Sinusoidal positional embeddings with shape (pos_inds.shape, dim).
|
||||
|
||||
Examples:
|
||||
>>> pos = torch.tensor([0, 1, 2, 3])
|
||||
>>> embeddings = get_1d_sine_pe(pos, 128)
|
||||
>>> embeddings.shape
|
||||
torch.Size([4, 128])
|
||||
"""
|
||||
pe_dim = dim // 2
|
||||
dim_t = torch.arange(pe_dim, dtype=pos_inds.dtype, device=pos_inds.device)
|
||||
dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
|
||||
|
||||
pos_embed = pos_inds.unsqueeze(-1) / dim_t
|
||||
pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def init_t_xy(end_x: int, end_y: int):
|
||||
"""
|
||||
Initialize 1D and 2D coordinate tensors for a grid of specified dimensions.
|
||||
|
||||
This function creates coordinate tensors for a grid with dimensions end_x × end_y. It generates a linear index tensor
|
||||
and corresponding x and y coordinate tensors.
|
||||
|
||||
Args:
|
||||
end_x (int): Width of the grid (number of columns).
|
||||
end_y (int): Height of the grid (number of rows).
|
||||
|
||||
Returns:
|
||||
t_x (torch.Tensor): X-coordinates for each position, with shape (end_x * end_y).
|
||||
t_y (torch.Tensor): Y-coordinates for each position, with shape (end_x * end_y).
|
||||
|
||||
Examples:
|
||||
>>> t_x, t_y = init_t_xy(3, 2)
|
||||
>>> print(t_x)
|
||||
tensor([0., 1., 2., 0., 1., 2.])
|
||||
>>> print(t_y)
|
||||
tensor([0., 0., 0., 1., 1., 1.])
|
||||
"""
|
||||
t = torch.arange(end_x * end_y, dtype=torch.float32)
|
||||
t_x = (t % end_x).float()
|
||||
t_y = torch.div(t, end_x, rounding_mode="floor").float()
|
||||
return t_x, t_y
|
||||
|
||||
|
||||
def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
|
||||
"""
|
||||
Compute axial complex exponential positional encodings for 2D spatial positions in a grid.
|
||||
|
||||
This function generates complex exponential positional encodings for a 2D grid of spatial positions,
|
||||
using separate frequency components for the x and y dimensions.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the positional encoding.
|
||||
end_x (int): Width of the 2D grid.
|
||||
end_y (int): Height of the 2D grid.
|
||||
theta (float, optional): Scaling factor for frequency computation.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Complex exponential positional encodings with shape (end_x*end_y, dim//2).
|
||||
|
||||
Examples:
|
||||
>>> dim, end_x, end_y = 128, 8, 8
|
||||
>>> freqs_cis = compute_axial_cis(dim, end_x, end_y)
|
||||
>>> freqs_cis.shape
|
||||
torch.Size([64, 64])
|
||||
"""
|
||||
freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
|
||||
|
||||
t_x, t_y = init_t_xy(end_x, end_y)
|
||||
freqs_x = torch.outer(t_x, freqs_x)
|
||||
freqs_y = torch.outer(t_y, freqs_y)
|
||||
freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
|
||||
freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
|
||||
return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
||||
"""
|
||||
Reshape frequency tensor for broadcasting with input tensor.
|
||||
|
||||
Reshapes a frequency tensor to ensure dimensional compatibility for broadcasting with an input tensor.
|
||||
This function is typically used in positional encoding operations.
|
||||
|
||||
Args:
|
||||
freqs_cis (torch.Tensor): Frequency tensor with shape matching the last two dimensions of x.
|
||||
x (torch.Tensor): Input tensor to broadcast with.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Reshaped frequency tensor ready for broadcasting with the input tensor.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the shape of freqs_cis doesn't match the last two dimensions of x.
|
||||
"""
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
|
||||
shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def apply_rotary_enc(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
repeat_freqs_k: bool = False,
|
||||
):
|
||||
"""
|
||||
Apply rotary positional encoding to query and key tensors.
|
||||
|
||||
This function applies rotary positional encoding (RoPE) to query and key tensors using complex-valued frequency
|
||||
components. RoPE is a technique that injects relative position information into self-attention mechanisms.
|
||||
|
||||
Args:
|
||||
xq (torch.Tensor): Query tensor to encode with positional information.
|
||||
xk (torch.Tensor): Key tensor to encode with positional information.
|
||||
freqs_cis (torch.Tensor): Complex-valued frequency components for rotary encoding with shape matching the
|
||||
last two dimensions of xq.
|
||||
repeat_freqs_k (bool, optional): Whether to repeat frequency components along sequence length dimension
|
||||
to match key sequence length.
|
||||
|
||||
Returns:
|
||||
xq_out (torch.Tensor): Query tensor with rotary positional encoding applied.
|
||||
xk_out (torch.Tensor): Key tensor with rotary positional encoding applied, or original xk if xk is empty.
|
||||
|
||||
Examples:
|
||||
>>> import torch
|
||||
>>> xq = torch.randn(2, 8, 16, 64) # [batch, heads, seq_len, dim]
|
||||
>>> xk = torch.randn(2, 8, 16, 64)
|
||||
>>> freqs_cis = compute_axial_cis(64, 4, 4) # For a 4x4 spatial grid with dim=64
|
||||
>>> q_encoded, k_encoded = apply_rotary_enc(xq, xk, freqs_cis)
|
||||
"""
|
||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) if xk.shape[-2] != 0 else None
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
||||
if xk_ is None:
|
||||
# No keys to rotate, due to dropout
|
||||
return xq_out.type_as(xq).to(xq.device), xk
|
||||
# Repeat freqs along seq_len dim to match k seq_len
|
||||
if repeat_freqs_k:
|
||||
r = xk_.shape[-2] // xq_.shape[-2]
|
||||
freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
||||
return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
|
||||
|
||||
|
||||
def window_partition(x: torch.Tensor, window_size: int):
|
||||
"""
|
||||
Partition input tensor into non-overlapping windows with padding if needed.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with shape (B, H, W, C).
|
||||
window_size (int): Size of each window.
|
||||
|
||||
Returns:
|
||||
windows (torch.Tensor): Partitioned windows with shape (B * num_windows, window_size, window_size, C).
|
||||
padded_h_w (tuple[int, int]): Padded height and width before partition.
|
||||
|
||||
Examples:
|
||||
>>> x = torch.randn(1, 16, 16, 3)
|
||||
>>> windows, (Hp, Wp) = window_partition(x, window_size=4)
|
||||
>>> print(windows.shape, Hp, Wp)
|
||||
torch.Size([16, 4, 4, 3]) 16 16
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
|
||||
pad_h = (window_size - H % window_size) % window_size
|
||||
pad_w = (window_size - W % window_size) % window_size
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
|
||||
Hp, Wp = H + pad_h, W + pad_w
|
||||
|
||||
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows, (Hp, Wp)
|
||||
|
||||
|
||||
def window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]):
|
||||
"""
|
||||
Unpartition windowed sequences into original sequences and remove padding.
|
||||
|
||||
This function reverses the windowing process, reconstructing the original input from windowed segments
|
||||
and removing any padding that was added during the windowing process.
|
||||
|
||||
Args:
|
||||
windows (torch.Tensor): Input tensor of windowed sequences with shape (B * num_windows, window_size,
|
||||
window_size, C), where B is the batch size, num_windows is the number of windows, window_size is
|
||||
the size of each window, and C is the number of channels.
|
||||
window_size (int): Size of each window.
|
||||
pad_hw (tuple[int, int]): Padded height and width (Hp, Wp) of the input before windowing.
|
||||
hw (tuple[int, int]): Original height and width (H, W) of the input before padding and windowing.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Unpartitioned sequences with shape (B, H, W, C), where B is the batch size, H and W
|
||||
are the original height and width, and C is the number of channels.
|
||||
|
||||
Examples:
|
||||
>>> windows = torch.rand(32, 8, 8, 64) # 32 windows of size 8x8 with 64 channels
|
||||
>>> pad_hw = (16, 16) # Padded height and width
|
||||
>>> hw = (15, 14) # Original height and width
|
||||
>>> x = window_unpartition(windows, window_size=8, pad_hw=pad_hw, hw=hw)
|
||||
>>> print(x.shape)
|
||||
torch.Size([1, 15, 14, 64])
|
||||
"""
|
||||
Hp, Wp = pad_hw
|
||||
H, W = hw
|
||||
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
|
||||
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
|
||||
|
||||
if Hp > H or Wp > W:
|
||||
x = x[:, :H, :W, :].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract relative positional embeddings based on query and key sizes.
|
||||
|
||||
Args:
|
||||
q_size (int): Size of the query.
|
||||
k_size (int): Size of the key.
|
||||
rel_pos (torch.Tensor): Relative position embeddings with shape (L, C), where L is the maximum relative
|
||||
distance and C is the embedding dimension.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Extracted positional embeddings according to relative positions, with shape (q_size,
|
||||
k_size, C).
|
||||
|
||||
Examples:
|
||||
>>> q_size, k_size = 8, 16
|
||||
>>> rel_pos = torch.randn(31, 64) # 31 = 2 * max(8, 16) - 1
|
||||
>>> extracted_pos = get_rel_pos(q_size, k_size, rel_pos)
|
||||
>>> print(extracted_pos.shape)
|
||||
torch.Size([8, 16, 64])
|
||||
"""
|
||||
max_rel_dist = int(2 * max(q_size, k_size) - 1)
|
||||
# Interpolate rel pos if needed.
|
||||
if rel_pos.shape[0] != max_rel_dist:
|
||||
# Interpolate rel pos.
|
||||
rel_pos_resized = F.interpolate(
|
||||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
|
||||
size=max_rel_dist,
|
||||
mode="linear",
|
||||
)
|
||||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
|
||||
else:
|
||||
rel_pos_resized = rel_pos
|
||||
|
||||
# Scale the coords with short length if shapes for q and k are different.
|
||||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
|
||||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
|
||||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
|
||||
|
||||
return rel_pos_resized[relative_coords.long()]
|
||||
|
||||
|
||||
def add_decomposed_rel_pos(
|
||||
attn: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
rel_pos_h: torch.Tensor,
|
||||
rel_pos_w: torch.Tensor,
|
||||
q_size: tuple[int, int],
|
||||
k_size: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Add decomposed Relative Positional Embeddings to the attention map.
|
||||
|
||||
This function calculates and applies decomposed Relative Positional Embeddings as described in the MVITv2
|
||||
paper. It enhances the attention mechanism by incorporating spatial relationships between query and key
|
||||
positions.
|
||||
|
||||
Args:
|
||||
attn (torch.Tensor): Attention map with shape (B, q_h * q_w, k_h * k_w).
|
||||
q (torch.Tensor): Query tensor in the attention layer with shape (B, q_h * q_w, C).
|
||||
rel_pos_h (torch.Tensor): Relative position embeddings for height axis with shape (Lh, C).
|
||||
rel_pos_w (torch.Tensor): Relative position embeddings for width axis with shape (Lw, C).
|
||||
q_size (tuple[int, int]): Spatial sequence size of query q as (q_h, q_w).
|
||||
k_size (tuple[int, int]): Spatial sequence size of key k as (k_h, k_w).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Updated attention map with added relative positional embeddings, shape
|
||||
(B, q_h * q_w, k_h * k_w).
|
||||
|
||||
Examples:
|
||||
>>> B, C, q_h, q_w, k_h, k_w = 1, 64, 8, 8, 8, 8
|
||||
>>> attn = torch.rand(B, q_h * q_w, k_h * k_w)
|
||||
>>> q = torch.rand(B, q_h * q_w, C)
|
||||
>>> rel_pos_h = torch.rand(2 * max(q_h, k_h) - 1, C)
|
||||
>>> rel_pos_w = torch.rand(2 * max(q_w, k_w) - 1, C)
|
||||
>>> q_size, k_size = (q_h, q_w), (k_h, k_w)
|
||||
>>> updated_attn = add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size)
|
||||
>>> print(updated_attn.shape)
|
||||
torch.Size([1, 64, 64])
|
||||
|
||||
References:
|
||||
https://github.com/facebookresearch/mvit/blob/main/mvit/models/attention.py
|
||||
"""
|
||||
q_h, q_w = q_size
|
||||
k_h, k_w = k_size
|
||||
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
|
||||
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
|
||||
|
||||
B, _, dim = q.shape
|
||||
r_q = q.reshape(B, q_h, q_w, dim)
|
||||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
|
||||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
|
||||
|
||||
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
|
||||
B, q_h * q_w, k_h * k_w
|
||||
)
|
||||
|
||||
return attn
|
||||
2042
ultralytics/models/sam/predict.py
Normal file
2042
ultralytics/models/sam/predict.py
Normal file
File diff suppressed because it is too large
Load Diff
1
ultralytics/models/utils/__init__.py
Normal file
1
ultralytics/models/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
478
ultralytics/models/utils/loss.py
Normal file
478
ultralytics/models/utils/loss.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
|
||||
from .ops import HungarianMatcher
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
"""
|
||||
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
||||
|
||||
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
||||
DETR object detection model.
|
||||
|
||||
Attributes:
|
||||
nc (int): Number of classes.
|
||||
loss_gain (dict[str, float]): Coefficients for different loss components.
|
||||
aux_loss (bool): Whether to compute auxiliary losses.
|
||||
use_fl (bool): Whether to use FocalLoss.
|
||||
use_vfl (bool): Whether to use VarifocalLoss.
|
||||
use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
|
||||
uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
|
||||
matcher (HungarianMatcher): Object to compute matching cost and indices.
|
||||
fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
|
||||
vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
|
||||
device (torch.device): Device on which tensors are stored.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nc: int = 80,
|
||||
loss_gain: dict[str, float] | None = None,
|
||||
aux_loss: bool = True,
|
||||
use_fl: bool = True,
|
||||
use_vfl: bool = False,
|
||||
use_uni_match: bool = False,
|
||||
uni_match_ind: int = 0,
|
||||
gamma: float = 1.5,
|
||||
alpha: float = 0.25,
|
||||
):
|
||||
"""
|
||||
Initialize DETR loss function with customizable components and gains.
|
||||
|
||||
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
||||
losses and various loss types.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
loss_gain (dict[str, float], optional): Coefficients for different loss components.
|
||||
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
|
||||
use_fl (bool): Whether to use FocalLoss.
|
||||
use_vfl (bool): Whether to use VarifocalLoss.
|
||||
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
|
||||
uni_match_ind (int): Index of fixed layer for uni_match.
|
||||
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
||||
alpha (float): The balancing factor used to address class imbalance.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if loss_gain is None:
|
||||
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
|
||||
self.nc = nc
|
||||
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||
self.loss_gain = loss_gain
|
||||
self.aux_loss = aux_loss
|
||||
self.fl = FocalLoss(gamma, alpha) if use_fl else None
|
||||
self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None
|
||||
|
||||
self.use_uni_match = use_uni_match
|
||||
self.uni_match_ind = uni_match_ind
|
||||
self.device = None
|
||||
|
||||
def _get_loss_class(
|
||||
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute classification loss based on predictions, target values, and ground truth scores.
|
||||
|
||||
Args:
|
||||
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
||||
targets (torch.Tensor): Target class indices with shape (B, N).
|
||||
gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
|
||||
num_gts (int): Number of ground truth objects.
|
||||
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing classification loss value.
|
||||
|
||||
Notes:
|
||||
The function supports different classification loss types:
|
||||
- Varifocal Loss (if self.vfl is True and num_gts > 0)
|
||||
- Focal Loss (if self.fl is True)
|
||||
- BCE Loss (default fallback)
|
||||
"""
|
||||
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f"loss_class{postfix}"
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
|
||||
one_hot = one_hot[..., :-1]
|
||||
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
|
||||
|
||||
if self.fl:
|
||||
if num_gts and self.vfl:
|
||||
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
|
||||
else:
|
||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||
loss_cls /= max(num_gts, 1) / nq
|
||||
else:
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
||||
|
||||
def _get_loss_bbox(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
|
||||
postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing:
|
||||
- loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
|
||||
- loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
|
||||
|
||||
Notes:
|
||||
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
|
||||
"""
|
||||
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f"loss_bbox{postfix}"
|
||||
name_giou = f"loss_giou{postfix}"
|
||||
|
||||
loss = {}
|
||||
if len(gt_bboxes) == 0:
|
||||
loss[name_bbox] = torch.tensor(0.0, device=self.device)
|
||||
loss[name_giou] = torch.tensor(0.0, device=self.device)
|
||||
return loss
|
||||
|
||||
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
|
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
# name_mask = f'loss_mask{postfix}'
|
||||
# name_dice = f'loss_dice{postfix}'
|
||||
#
|
||||
# loss = {}
|
||||
# if sum(len(a) for a in gt_mask) == 0:
|
||||
# loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
# loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
# return loss
|
||||
#
|
||||
# num_gts = len(gt_mask)
|
||||
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
# torch.tensor([num_gts], dtype=torch.float32))
|
||||
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
# return loss
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# @staticmethod
|
||||
# def _dice_loss(inputs, targets, num_gts):
|
||||
# inputs = F.sigmoid(inputs).flatten(1)
|
||||
# targets = targets.flatten(1)
|
||||
# numerator = 2 * (inputs * targets).sum(1)
|
||||
# denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
# return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
match_indices: list[tuple] | None = None,
|
||||
postfix: str = "",
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get auxiliary losses for intermediate decoder layers.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
||||
pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
gt_cls (torch.Tensor): Ground truth classes.
|
||||
gt_groups (list[int]): Number of ground truths per image.
|
||||
match_indices (list[tuple], optional): Pre-computed matching indices.
|
||||
postfix (str, optional): String to append to loss names.
|
||||
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
||||
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary of auxiliary losses.
|
||||
"""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask,
|
||||
)
|
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||
aux_masks = masks[i] if masks is not None else None
|
||||
loss_ = self._get_loss(
|
||||
aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices,
|
||||
)
|
||||
loss[0] += loss_[f"loss_class{postfix}"]
|
||||
loss[1] += loss_[f"loss_bbox{postfix}"]
|
||||
loss[2] += loss_[f"loss_giou{postfix}"]
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||
|
||||
loss = {
|
||||
f"loss_class_aux{postfix}": loss[0],
|
||||
f"loss_bbox_aux{postfix}": loss[1],
|
||||
f"loss_giou_aux{postfix}": loss[2],
|
||||
}
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Extract batch indices, source indices, and destination indices from match indices.
|
||||
|
||||
Args:
|
||||
match_indices (list[tuple]): List of tuples containing matched indices.
|
||||
|
||||
Returns:
|
||||
batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
|
||||
dst_idx (torch.Tensor): Destination indices.
|
||||
"""
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
match_indices (list[tuple]): List of tuples containing matched indices.
|
||||
|
||||
Returns:
|
||||
pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
|
||||
gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
|
||||
"""
|
||||
pred_assigned = torch.cat(
|
||||
[
|
||||
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (i, _) in zip(pred_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
gt_assigned = torch.cat(
|
||||
[
|
||||
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, j) in zip(gt_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
return pred_assigned, gt_assigned
|
||||
|
||||
def _get_loss(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: torch.Tensor | None = None,
|
||||
postfix: str = "",
|
||||
match_indices: list[tuple] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate losses for a single prediction layer.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
pred_scores (torch.Tensor): Predicted class scores.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
gt_cls (torch.Tensor): Ground truth classes.
|
||||
gt_groups (list[int]): Number of ground truths per image.
|
||||
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
||||
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
||||
postfix (str, optional): String to append to loss names.
|
||||
match_indices (list[tuple], optional): Pre-computed matching indices.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary of losses.
|
||||
"""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
||||
)
|
||||
|
||||
idx, gt_idx = self._get_index(match_indices)
|
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
|
||||
targets[idx] = gt_cls[gt_idx]
|
||||
|
||||
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
|
||||
if len(gt_bboxes):
|
||||
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
||||
|
||||
return {
|
||||
**self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
|
||||
**self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
|
||||
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
batch: dict[str, Any],
|
||||
postfix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate loss for predicted bounding boxes and scores.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
||||
pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
|
||||
batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
|
||||
postfix (str, optional): Postfix for loss names.
|
||||
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
|
||||
|
||||
Notes:
|
||||
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
||||
self.aux_loss is True.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get("match_indices", None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
|
||||
|
||||
total_loss = self._get_loss(
|
||||
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
|
||||
)
|
||||
|
||||
if self.aux_loss:
|
||||
total_loss.update(
|
||||
self._get_loss_aux(
|
||||
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
|
||||
)
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
"""
|
||||
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
|
||||
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
||||
an additional denoising training loss when provided with denoising metadata.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
preds: tuple[torch.Tensor, torch.Tensor],
|
||||
batch: dict[str, Any],
|
||||
dn_bboxes: torch.Tensor | None = None,
|
||||
dn_scores: torch.Tensor | None = None,
|
||||
dn_meta: dict[str, Any] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass to compute detection loss with optional denoising loss.
|
||||
|
||||
Args:
|
||||
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
||||
batch (dict[str, Any]): Batch data containing ground truth information.
|
||||
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
||||
dn_scores (torch.Tensor, optional): Denoising scores.
|
||||
dn_meta (dict[str, Any], optional): Metadata for denoising.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
|
||||
"""
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
# Check for denoising metadata to compute denoising training loss
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
|
||||
assert len(batch["gt_groups"]) == len(dn_pos_idx)
|
||||
|
||||
# Get the match indices for denoising
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
|
||||
|
||||
# Compute the denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
# If no denoising metadata is provided, set denoising loss to zero
|
||||
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
||||
@staticmethod
|
||||
def get_dn_match_indices(
|
||||
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Get match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
||||
dn_num_group (int): Number of denoising groups.
|
||||
gt_groups (list[int]): List of integers representing number of ground truths per image.
|
||||
|
||||
Returns:
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
for i, num_gt in enumerate(gt_groups):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), (
|
||||
f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
||||
)
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||
return dn_match_indices
|
||||
319
ultralytics/models/utils/ops.py
Normal file
319
ultralytics/models/utils/ops.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
||||
|
||||
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
||||
used in end-to-end object detection models like DETR.
|
||||
|
||||
Attributes:
|
||||
cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
|
||||
components.
|
||||
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
||||
with_mask (bool): Whether the model makes mask predictions.
|
||||
num_sample_points (int): Number of sample points used in mask cost calculation.
|
||||
alpha (float): Alpha factor in Focal Loss calculation.
|
||||
gamma (float): Gamma factor in Focal Loss calculation.
|
||||
|
||||
Methods:
|
||||
forward: Compute optimal assignment between predictions and ground truths for a batch.
|
||||
_cost_mask: Compute mask cost and dice cost if masks are predicted.
|
||||
|
||||
Examples:
|
||||
Initialize a HungarianMatcher with custom cost gains
|
||||
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||
|
||||
Perform matching between predictions and ground truth
|
||||
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
|
||||
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
|
||||
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
|
||||
>>> gt_classes = torch.randint(0, 80, (10,))
|
||||
>>> gt_groups = [5, 5] # 5 GT boxes per image
|
||||
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_gain: dict[str, float] | None = None,
|
||||
use_fl: bool = True,
|
||||
with_mask: bool = False,
|
||||
num_sample_points: int = 12544,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
||||
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
||||
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
||||
with_mask (bool): Whether the model makes mask predictions.
|
||||
num_sample_points (int): Number of sample points used in mask cost calculation.
|
||||
alpha (float): Alpha factor in Focal Loss calculation.
|
||||
gamma (float): Gamma factor in Focal Loss calculation.
|
||||
"""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
||||
self.cost_gain = cost_gain
|
||||
self.use_fl = use_fl
|
||||
self.with_mask = with_mask
|
||||
self.num_sample_points = num_sample_points
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: list[torch.Tensor] | None = None,
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
||||
|
||||
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
||||
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
||||
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
|
||||
num_classes).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
||||
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
|
||||
gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
|
||||
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
||||
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
||||
|
||||
Returns:
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
||||
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
|
||||
and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
||||
"""
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
||||
if sum(gt_groups) == 0:
|
||||
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
||||
|
||||
# Flatten to compute cost matrices in batch format
|
||||
pred_scores = pred_scores.detach().view(-1, nc)
|
||||
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||||
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||||
|
||||
# Compute classification cost
|
||||
pred_scores = pred_scores[:, gt_cls]
|
||||
if self.use_fl:
|
||||
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||
cost_class = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
cost_class = -pred_scores
|
||||
|
||||
# Compute L1 cost between boxes
|
||||
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||||
|
||||
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
|
||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||
|
||||
# Combine costs into final cost matrix
|
||||
C = (
|
||||
self.cost_gain["class"] * cost_class
|
||||
+ self.cost_gain["bbox"] * cost_bbox
|
||||
+ self.cost_gain["giou"] * cost_giou
|
||||
)
|
||||
|
||||
# Add mask costs if available
|
||||
if self.with_mask:
|
||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||
|
||||
# Set invalid values (NaNs and infinities) to 0
|
||||
C[C.isnan() | C.isinf()] = 0.0
|
||||
|
||||
C = C.view(bs, nq, -1).cpu()
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
||||
return [
|
||||
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
||||
for k, (i, j) in enumerate(indices)
|
||||
]
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||||
# # all masks share the same set of points for efficient matching
|
||||
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||||
# sample_points = 2.0 * sample_points - 1.0
|
||||
#
|
||||
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||||
# out_mask = out_mask.flatten(0, 1)
|
||||
#
|
||||
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||||
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||||
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||||
#
|
||||
# with torch.amp.autocast("cuda", enabled=False):
|
||||
# # binary cross entropy cost
|
||||
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||||
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||||
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||||
# cost_mask /= self.num_sample_points
|
||||
#
|
||||
# # dice cost
|
||||
# out_mask = F.sigmoid(out_mask)
|
||||
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||||
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||||
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||||
#
|
||||
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||||
# return C
|
||||
|
||||
|
||||
def get_cdn_group(
|
||||
batch: dict[str, Any],
|
||||
num_classes: int,
|
||||
num_queries: int,
|
||||
class_embed: torch.Tensor,
|
||||
num_dn: int = 100,
|
||||
cls_noise_ratio: float = 0.5,
|
||||
box_noise_scale: float = 1.0,
|
||||
training: bool = False,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
||||
|
||||
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
|
||||
bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
||||
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of
|
||||
ground truths per image.
|
||||
num_classes (int): Total number of object classes.
|
||||
num_queries (int): Number of object queries.
|
||||
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
||||
num_dn (int): Number of denoising queries to generate.
|
||||
cls_noise_ratio (float): Noise ratio for class labels.
|
||||
box_noise_scale (float): Noise scale for bounding box coordinates.
|
||||
training (bool): Whether model is in training mode.
|
||||
|
||||
Returns:
|
||||
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
|
||||
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
|
||||
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
|
||||
dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
|
||||
|
||||
Examples:
|
||||
Generate denoising group for training
|
||||
>>> batch = {
|
||||
... "cls": torch.tensor([0, 1, 2]),
|
||||
... "bboxes": torch.rand(3, 4),
|
||||
... "batch_idx": torch.tensor([0, 0, 1]),
|
||||
... "gt_groups": [2, 1],
|
||||
... }
|
||||
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
|
||||
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
|
||||
"""
|
||||
if (not training) or num_dn <= 0 or batch is None:
|
||||
return None, None, None, None
|
||||
gt_groups = batch["gt_groups"]
|
||||
total_num = sum(gt_groups)
|
||||
max_nums = max(gt_groups)
|
||||
if max_nums == 0:
|
||||
return None, None, None, None
|
||||
|
||||
num_group = num_dn // max_nums
|
||||
num_group = 1 if num_group == 0 else num_group
|
||||
# Pad gt to max_num of a batch
|
||||
bs = len(gt_groups)
|
||||
gt_cls = batch["cls"] # (bs*num, )
|
||||
gt_bbox = batch["bboxes"] # bs*num, 4
|
||||
b_idx = batch["batch_idx"]
|
||||
|
||||
# Each group has positive and negative queries
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# Positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
if cls_noise_ratio > 0:
|
||||
# Apply class label noise to half of the samples
|
||||
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||||
idx = torch.nonzero(mask).squeeze(-1)
|
||||
# Randomly assign new class labels
|
||||
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||||
dn_cls[idx] = new_label
|
||||
|
||||
if box_noise_scale > 0:
|
||||
known_bbox = xywh2xyxy(dn_bbox)
|
||||
|
||||
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||||
|
||||
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(dn_bbox)
|
||||
rand_part[neg_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
||||
|
||||
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||||
|
||||
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||||
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||||
|
||||
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||||
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||||
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||||
|
||||
tgt_size = num_dn + num_queries
|
||||
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||||
# Match query cannot see the reconstruct
|
||||
attn_mask[num_dn:, :num_dn] = True
|
||||
# Reconstruct cannot see each other
|
||||
for i in range(num_group):
|
||||
if i == 0:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||
if i == num_group - 1:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
||||
else:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
||||
dn_meta = {
|
||||
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
||||
"dn_num_group": num_group,
|
||||
"dn_num_split": [num_dn, num_queries],
|
||||
}
|
||||
|
||||
return (
|
||||
padding_cls.to(class_embed.device),
|
||||
padding_bbox.to(class_embed.device),
|
||||
attn_mask.to(class_embed.device),
|
||||
dn_meta,
|
||||
)
|
||||
7
ultralytics/models/yolo/__init__.py
Normal file
7
ultralytics/models/yolo/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo import classify, detect, obb, pose, segment, world, yoloe
|
||||
|
||||
from .model import YOLO, YOLOE, YOLOWorld
|
||||
|
||||
__all__ = "classify", "segment", "detect", "pose", "obb", "world", "yoloe", "YOLO", "YOLOWorld", "YOLOE"
|
||||
BIN
ultralytics/models/yolo/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
7
ultralytics/models/yolo/classify/__init__.py
Normal file
7
ultralytics/models/yolo/classify/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.classify.predict import ClassificationPredictor
|
||||
from ultralytics.models.yolo.classify.train import ClassificationTrainer
|
||||
from ultralytics.models.yolo.classify.val import ClassificationValidator
|
||||
|
||||
__all__ = "ClassificationPredictor", "ClassificationTrainer", "ClassificationValidator"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/classify/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
93
ultralytics/models/yolo/classify/predict.py
Normal file
93
ultralytics/models/yolo/classify/predict.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.data.augment import classify_transforms
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class ClassificationPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a classification model.
|
||||
|
||||
This predictor handles the specific requirements of classification models, including preprocessing images
|
||||
and postprocessing predictions to generate classification results.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the predictor.
|
||||
|
||||
Methods:
|
||||
preprocess: Convert input images to model-compatible format.
|
||||
postprocess: Process model predictions into Results objects.
|
||||
|
||||
Notes:
|
||||
- Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationPredictor
|
||||
>>> args = dict(model="yolo11n-cls.pt", source=ASSETS)
|
||||
>>> predictor = ClassificationPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the ClassificationPredictor with the specified configuration and set task to 'classify'.
|
||||
|
||||
This constructor initializes a ClassificationPredictor instance, which extends BasePredictor for classification
|
||||
tasks. It ensures the task is set to 'classify' regardless of input configuration.
|
||||
|
||||
Args:
|
||||
cfg (dict): Default configuration dictionary containing prediction settings.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be executed during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "classify"
|
||||
|
||||
def setup_source(self, source):
|
||||
"""Set up source and inference mode and classify transforms."""
|
||||
super().setup_source(source)
|
||||
updated = (
|
||||
self.model.model.transforms.transforms[0].size != max(self.imgsz)
|
||||
if hasattr(self.model.model, "transforms") and hasattr(self.model.model.transforms.transforms[0], "size")
|
||||
else False
|
||||
)
|
||||
self.transforms = (
|
||||
classify_transforms(self.imgsz) if updated or not self.model.pt else self.model.model.transforms
|
||||
)
|
||||
|
||||
def preprocess(self, img):
|
||||
"""Convert input images to model-compatible tensor format with appropriate normalization."""
|
||||
if not isinstance(img, torch.Tensor):
|
||||
img = torch.stack(
|
||||
[self.transforms(Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))) for im in img], dim=0
|
||||
)
|
||||
img = (img if isinstance(img, torch.Tensor) else torch.from_numpy(img)).to(self.model.device)
|
||||
return img.half() if self.model.fp16 else img.float() # Convert uint8 to fp16/32
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Process predictions to return Results objects with classification probabilities.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Input images after preprocessing.
|
||||
orig_imgs (list[np.ndarray] | torch.Tensor): Original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing classification results for each image.
|
||||
"""
|
||||
if not isinstance(orig_imgs, list): # Input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
preds = preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
return [
|
||||
Results(orig_img, path=img_path, names=self.model.names, probs=pred)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
223
ultralytics/models/yolo/classify/train.py
Normal file
223
ultralytics/models/yolo/classify/train.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import ClassificationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
from ultralytics.utils.torch_utils import is_parallel, strip_optimizer, torch_distributed_zero_first
|
||||
|
||||
|
||||
class ClassificationTrainer(BaseTrainer):
|
||||
"""
|
||||
A trainer class extending BaseTrainer for training image classification models.
|
||||
|
||||
This trainer handles the training process for image classification tasks, supporting both YOLO classification models
|
||||
and torchvision models with comprehensive dataset handling and validation.
|
||||
|
||||
Attributes:
|
||||
model (ClassificationModel): The classification model to be trained.
|
||||
data (dict[str, Any]): Dictionary containing dataset information including class names and number of classes.
|
||||
loss_names (list[str]): Names of the loss functions used during training.
|
||||
validator (ClassificationValidator): Validator instance for model evaluation.
|
||||
|
||||
Methods:
|
||||
set_model_attributes: Set the model's class names from the loaded dataset.
|
||||
get_model: Return a modified PyTorch model configured for training.
|
||||
setup_model: Load, create or download model for classification.
|
||||
build_dataset: Create a ClassificationDataset instance.
|
||||
get_dataloader: Return PyTorch DataLoader with transforms for image preprocessing.
|
||||
preprocess_batch: Preprocess a batch of images and classes.
|
||||
progress_string: Return a formatted string showing training progress.
|
||||
get_validator: Return an instance of ClassificationValidator.
|
||||
label_loss_items: Return a loss dict with labelled training loss items.
|
||||
final_eval: Evaluate trained model and save validation results.
|
||||
plot_training_samples: Plot training samples with their annotations.
|
||||
|
||||
Examples:
|
||||
Initialize and train a classification model
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationTrainer
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10", epochs=3)
|
||||
>>> trainer = ClassificationTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a ClassificationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any], optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict[str, Any], optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list[Any], optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "classify"
|
||||
if overrides.get("imgsz") is None:
|
||||
overrides["imgsz"] = 224
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set the YOLO model's class names from the loaded dataset."""
|
||||
self.model.names = self.data["names"]
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a modified PyTorch model configured for training YOLO classification.
|
||||
|
||||
Args:
|
||||
cfg (Any, optional): Model configuration.
|
||||
weights (Any, optional): Pre-trained model weights.
|
||||
verbose (bool, optional): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(ClassificationModel): Configured PyTorch model for classification.
|
||||
"""
|
||||
model = ClassificationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
for m in model.modules():
|
||||
if not self.args.pretrained and hasattr(m, "reset_parameters"):
|
||||
m.reset_parameters()
|
||||
if isinstance(m, torch.nn.Dropout) and self.args.dropout:
|
||||
m.p = self.args.dropout # set dropout
|
||||
for p in model.parameters():
|
||||
p.requires_grad = True # for training
|
||||
return model
|
||||
|
||||
def setup_model(self):
|
||||
"""
|
||||
Load, create or download model for classification tasks.
|
||||
|
||||
Returns:
|
||||
(Any): Model checkpoint if applicable, otherwise None.
|
||||
"""
|
||||
import torchvision # scope for faster 'import ultralytics'
|
||||
|
||||
if str(self.model) in torchvision.models.__dict__:
|
||||
self.model = torchvision.models.__dict__[self.model](
|
||||
weights="IMAGENET1K_V1" if self.args.pretrained else None
|
||||
)
|
||||
ckpt = None
|
||||
else:
|
||||
ckpt = super().setup_model()
|
||||
ClassificationModel.reshape_outputs(self.model, self.data["nc"])
|
||||
return ckpt
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch=None):
|
||||
"""
|
||||
Create a ClassificationDataset instance given an image path and mode.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the dataset images.
|
||||
mode (str, optional): Dataset mode ('train', 'val', or 'test').
|
||||
batch (Any, optional): Batch information (unused in this implementation).
|
||||
|
||||
Returns:
|
||||
(ClassificationDataset): Dataset for the specified mode.
|
||||
"""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=mode == "train", prefix=mode)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Return PyTorch DataLoader with transforms to preprocess images.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int, optional): Number of images per batch.
|
||||
rank (int, optional): Process rank for distributed training.
|
||||
mode (str, optional): 'train', 'val', or 'test' mode.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader for the specified dataset and mode.
|
||||
"""
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode)
|
||||
|
||||
loader = build_dataloader(dataset, batch_size, self.args.workers, rank=rank, drop_last=self.args.compile)
|
||||
# Attach inference transforms
|
||||
if mode != "train":
|
||||
if is_parallel(self.model):
|
||||
self.model.module.transforms = loader.dataset.torch_transforms
|
||||
else:
|
||||
self.model.transforms = loader.dataset.torch_transforms
|
||||
return loader
|
||||
|
||||
def preprocess_batch(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""Preprocess a batch of images and classes."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def progress_string(self) -> str:
|
||||
"""Return a formatted string showing training progress."""
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of ClassificationValidator for validation."""
|
||||
self.loss_names = ["loss"]
|
||||
return yolo.classify.ClassificationValidator(
|
||||
self.test_loader, self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items: torch.Tensor | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labelled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (torch.Tensor, optional): Loss tensor items.
|
||||
prefix (str, optional): Prefix to prepend to loss names.
|
||||
|
||||
Returns:
|
||||
keys (list[str]): List of loss keys if loss_items is None.
|
||||
loss_dict (dict[str, float]): Dictionary of loss items if loss_items is provided.
|
||||
"""
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is None:
|
||||
return keys
|
||||
loss_items = [round(float(loss_items), 5)]
|
||||
return dict(zip(keys, loss_items))
|
||||
|
||||
def final_eval(self):
|
||||
"""Evaluate trained model and save validation results."""
|
||||
for f in self.last, self.best:
|
||||
if f.exists():
|
||||
strip_optimizer(f) # strip optimizers
|
||||
if f is self.best:
|
||||
LOGGER.info(f"\nValidating {f}...")
|
||||
self.validator.args.data = self.args.data
|
||||
self.validator.args.plots = self.args.plots
|
||||
self.metrics = self.validator(model=f)
|
||||
self.metrics.pop("fitness", None)
|
||||
self.run_callbacks("on_fit_epoch_end")
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, torch.Tensor], ni: int):
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, torch.Tensor]): Batch containing images and class labels.
|
||||
ni (int): Number of iterations.
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
214
ultralytics/models/yolo/classify/val.py
Normal file
214
ultralytics/models/yolo/classify/val.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import ClassificationDataset, build_dataloader
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.metrics import ClassifyMetrics, ConfusionMatrix
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class ClassificationValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a classification model.
|
||||
|
||||
This validator handles the validation process for classification models, including metrics calculation,
|
||||
confusion matrix generation, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
targets (list[torch.Tensor]): Ground truth class labels.
|
||||
pred (list[torch.Tensor]): Model predictions.
|
||||
metrics (ClassifyMetrics): Object to calculate and store classification metrics.
|
||||
names (dict): Mapping of class indices to class names.
|
||||
nc (int): Number of classes.
|
||||
confusion_matrix (ConfusionMatrix): Matrix to evaluate model performance across classes.
|
||||
|
||||
Methods:
|
||||
get_desc: Return a formatted string summarizing classification metrics.
|
||||
init_metrics: Initialize confusion matrix, class names, and tracking containers.
|
||||
preprocess: Preprocess input batch by moving data to device.
|
||||
update_metrics: Update running metrics with model predictions and batch targets.
|
||||
finalize_metrics: Finalize metrics including confusion matrix and processing speed.
|
||||
postprocess: Extract the primary prediction from model output.
|
||||
get_stats: Calculate and return a dictionary of metrics.
|
||||
build_dataset: Create a ClassificationDataset instance for validation.
|
||||
get_dataloader: Build and return a data loader for classification validation.
|
||||
print_results: Print evaluation metrics for the classification model.
|
||||
plot_val_samples: Plot validation image samples with their ground truth labels.
|
||||
plot_predictions: Plot images with their predicted class labels.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
Torchvision classification models can also be passed to the 'model' argument, i.e. model='resnet18'.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize ClassificationValidator with dataloader, save directory, and other parameters.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict, optional): Arguments containing model and validation configuration.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.classify import ClassificationValidator
|
||||
>>> args = dict(model="yolo11n-cls.pt", data="imagenet10")
|
||||
>>> validator = ClassificationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.targets = None
|
||||
self.pred = None
|
||||
self.args.task = "classify"
|
||||
self.metrics = ClassifyMetrics()
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing classification metrics."""
|
||||
return ("%22s" + "%11s" * 2) % ("classes", "top1_acc", "top5_acc")
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""Initialize confusion matrix, class names, and tracking containers for predictions and targets."""
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.pred = []
|
||||
self.targets = []
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess input batch by moving data to device and converting to appropriate dtype."""
|
||||
batch["img"] = batch["img"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].half() if self.args.half else batch["img"].float()
|
||||
batch["cls"] = batch["cls"].to(self.device, non_blocking=self.device.type == "cuda")
|
||||
return batch
|
||||
|
||||
def update_metrics(self, preds: torch.Tensor, batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update running metrics with model predictions and batch targets.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Model predictions, typically logits or probabilities for each class.
|
||||
batch (dict): Batch data containing images and class labels.
|
||||
|
||||
Notes:
|
||||
This method appends the top-N predictions (sorted by confidence in descending order) to the
|
||||
prediction list for later evaluation. N is limited to the minimum of 5 and the number of classes.
|
||||
"""
|
||||
n5 = min(len(self.names), 5)
|
||||
self.pred.append(preds.argsort(1, descending=True)[:, :n5].type(torch.int32).cpu())
|
||||
self.targets.append(batch["cls"].type(torch.int32).cpu())
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""
|
||||
Finalize metrics including confusion matrix and processing speed.
|
||||
|
||||
Notes:
|
||||
This method processes the accumulated predictions and targets to generate the confusion matrix,
|
||||
optionally plots it, and updates the metrics object with speed information.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> validator.pred = [torch.tensor([[0, 1, 2]])] # Top-3 predictions for one sample
|
||||
>>> validator.targets = [torch.tensor([0])] # Ground truth class
|
||||
>>> validator.finalize_metrics()
|
||||
>>> print(validator.metrics.confusion_matrix) # Access the confusion matrix
|
||||
"""
|
||||
self.confusion_matrix.process_cls_preds(self.pred, self.targets)
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.save_dir = self.save_dir
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
|
||||
def postprocess(self, preds: torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor]) -> torch.Tensor:
|
||||
"""Extract the primary prediction from model output if it's in a list or tuple format."""
|
||||
return preds[0] if isinstance(preds, (list, tuple)) else preds
|
||||
|
||||
def get_stats(self) -> dict[str, float]:
|
||||
"""Calculate and return a dictionary of metrics by processing targets and predictions."""
|
||||
self.metrics.process(self.targets, self.pred)
|
||||
return self.metrics.results_dict
|
||||
|
||||
def build_dataset(self, img_path: str) -> ClassificationDataset:
|
||||
"""Create a ClassificationDataset instance for validation."""
|
||||
return ClassificationDataset(root=img_path, args=self.args, augment=False, prefix=self.args.split)
|
||||
|
||||
def get_dataloader(self, dataset_path: Path | str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Build and return a data loader for classification validation.
|
||||
|
||||
Args:
|
||||
dataset_path (str | Path): Path to the dataset directory.
|
||||
batch_size (int): Number of samples per batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): DataLoader object for the classification validation dataset.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path)
|
||||
return build_dataloader(dataset, batch_size, self.args.workers, rank=-1)
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print evaluation metrics for the classification model."""
|
||||
pf = "%22s" + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.metrics.top1, self.metrics.top5))
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples with their ground truth labels.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data with 'img' (images) and 'cls' (class labels).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224), "cls": torch.randint(0, 10, (16,))}
|
||||
>>> validator.plot_val_samples(batch, 0)
|
||||
"""
|
||||
batch["batch_idx"] = torch.arange(batch["img"].shape[0]) # add batch index for plotting
|
||||
plot_images(
|
||||
labels=batch,
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: torch.Tensor, ni: int) -> None:
|
||||
"""
|
||||
Plot images with their predicted class labels and save the visualization.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images and other information.
|
||||
preds (torch.Tensor): Model predictions with shape (batch_size, num_classes).
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = ClassificationValidator()
|
||||
>>> batch = {"img": torch.rand(16, 3, 224, 224)}
|
||||
>>> preds = torch.rand(16, 10) # 16 images, 10 classes
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
batched_preds = dict(
|
||||
img=batch["img"],
|
||||
batch_idx=torch.arange(batch["img"].shape[0]),
|
||||
cls=torch.argmax(preds, dim=1),
|
||||
)
|
||||
plot_images(
|
||||
batched_preds,
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
7
ultralytics/models/yolo/detect/__init__.py
Normal file
7
ultralytics/models/yolo/detect/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import DetectionPredictor
|
||||
from .train import DetectionTrainer
|
||||
from .val import DetectionValidator
|
||||
|
||||
__all__ = "DetectionPredictor", "DetectionTrainer", "DetectionValidator"
|
||||
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/detect/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/detect/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/detect/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/detect/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
125
ultralytics/models/yolo/detect/predict.py
Normal file
125
ultralytics/models/yolo/detect/predict.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.engine.predictor import BasePredictor
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.utils import nms, ops
|
||||
|
||||
|
||||
class DetectionPredictor(BasePredictor):
|
||||
"""
|
||||
A class extending the BasePredictor class for prediction based on a detection model.
|
||||
|
||||
This predictor specializes in object detection tasks, processing model outputs into meaningful detection results
|
||||
with bounding boxes and class predictions.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (nn.Module): The detection model used for inference.
|
||||
batch (list): Batch of images and metadata for processing.
|
||||
|
||||
Methods:
|
||||
postprocess: Process raw model predictions into detection results.
|
||||
construct_results: Build Results objects from processed predictions.
|
||||
construct_result: Create a single Result object from a prediction.
|
||||
get_obj_feats: Extract object features from the feature maps.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.detect import DetectionPredictor
|
||||
>>> args = dict(model="yolo11n.pt", source=ASSETS)
|
||||
>>> predictor = DetectionPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs, **kwargs):
|
||||
"""
|
||||
Post-process predictions and return a list of Results objects.
|
||||
|
||||
This method applies non-maximum suppression to raw model predictions and prepares them for visualization and
|
||||
further analysis.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
img (torch.Tensor): Processed input image tensor in model input format.
|
||||
orig_imgs (torch.Tensor | list): Original input images before preprocessing.
|
||||
**kwargs (Any): Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the post-processed predictions.
|
||||
|
||||
Examples:
|
||||
>>> predictor = DetectionPredictor(overrides=dict(model="yolo11n.pt"))
|
||||
>>> results = predictor.predict("path/to/image.jpg")
|
||||
>>> processed_results = predictor.postprocess(preds, img, orig_imgs)
|
||||
"""
|
||||
save_feats = getattr(self, "_feats", None) is not None
|
||||
preds = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
self.args.classes,
|
||||
self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
nc=0 if self.args.task == "detect" else len(self.model.names),
|
||||
end2end=getattr(self.model, "end2end", False),
|
||||
rotated=self.args.task == "obb",
|
||||
return_idxs=save_feats,
|
||||
)
|
||||
|
||||
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
|
||||
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
|
||||
|
||||
if save_feats:
|
||||
obj_feats = self.get_obj_feats(self._feats, preds[1])
|
||||
preds = preds[0]
|
||||
|
||||
results = self.construct_results(preds, img, orig_imgs, **kwargs)
|
||||
|
||||
if save_feats:
|
||||
for r, f in zip(results, obj_feats):
|
||||
r.feats = f # add object features to results
|
||||
|
||||
return results
|
||||
|
||||
def get_obj_feats(self, feat_maps, idxs):
|
||||
"""Extract object features from the feature maps."""
|
||||
import torch
|
||||
|
||||
s = min(x.shape[1] for x in feat_maps) # find shortest vector length
|
||||
obj_feats = torch.cat(
|
||||
[x.permute(0, 2, 3, 1).reshape(x.shape[0], -1, s, x.shape[1] // s).mean(dim=-1) for x in feat_maps], dim=1
|
||||
) # mean reduce all vectors to same length
|
||||
return [feats[idx] if idx.shape[0] else [] for feats, idx in zip(obj_feats, idxs)] # for each img in batch
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Construct a list of Results objects from model predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes and scores for each image.
|
||||
img (torch.Tensor): Batch of preprocessed images used for inference.
|
||||
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of Results objects containing detection information for each image.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path)
|
||||
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0])
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct a single Results object from one image prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): Predicted boxes and scores with shape (N, 6) where N is the number of detections.
|
||||
img (torch.Tensor): Preprocessed image tensor used for inference.
|
||||
orig_img (np.ndarray): Original image before preprocessing.
|
||||
img_path (str): Path to the original image file.
|
||||
|
||||
Returns:
|
||||
(Results): Results object containing the original image, image path, class names, and scaled bounding boxes.
|
||||
"""
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])
|
||||
236
ultralytics/models/yolo/detect/train.py
Normal file
236
ultralytics/models/yolo/detect/train.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import random
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset
|
||||
from ultralytics.engine.trainer import BaseTrainer
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import DetectionModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.patches import override_configs
|
||||
from ultralytics.utils.plotting import plot_images, plot_labels
|
||||
from ultralytics.utils.torch_utils import torch_distributed_zero_first, unwrap_model
|
||||
|
||||
|
||||
class DetectionTrainer(BaseTrainer):
|
||||
"""
|
||||
A class extending the BaseTrainer class for training based on a detection model.
|
||||
|
||||
This trainer specializes in object detection tasks, handling the specific requirements for training YOLO models
|
||||
for object detection including dataset building, data loading, preprocessing, and model configuration.
|
||||
|
||||
Attributes:
|
||||
model (DetectionModel): The YOLO detection model being trained.
|
||||
data (dict): Dictionary containing dataset information including class names and number of classes.
|
||||
loss_names (tuple): Names of the loss components used in training (box_loss, cls_loss, dfl_loss).
|
||||
|
||||
Methods:
|
||||
build_dataset: Build YOLO dataset for training or validation.
|
||||
get_dataloader: Construct and return dataloader for the specified mode.
|
||||
preprocess_batch: Preprocess a batch of images by scaling and converting to float.
|
||||
set_model_attributes: Set model attributes based on dataset information.
|
||||
get_model: Return a YOLO detection model.
|
||||
get_validator: Return a validator for model evaluation.
|
||||
label_loss_items: Return a loss dictionary with labeled training loss items.
|
||||
progress_string: Return a formatted string of training progress.
|
||||
plot_training_samples: Plot training samples with their annotations.
|
||||
plot_training_labels: Create a labeled training plot of the YOLO model.
|
||||
auto_batch: Calculate optimal batch size based on model memory requirements.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
>>> args = dict(model="yolo11n.pt", data="coco8.yaml", epochs=3)
|
||||
>>> trainer = DetectionTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a DetectionTrainer object for training YOLO object detection model training.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for 'rect' mode.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset object configured for the specified mode.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int = 16, rank: int = 0, mode: str = "train"):
|
||||
"""
|
||||
Construct and return dataloader for the specified mode.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int): Number of images per batch.
|
||||
rank (int): Process rank for distributed training.
|
||||
mode (str): 'train' for training dataloader, 'val' for validation dataloader.
|
||||
|
||||
Returns:
|
||||
(DataLoader): PyTorch dataloader object.
|
||||
"""
|
||||
assert mode in {"train", "val"}, f"Mode must be 'train' or 'val', not {mode}."
|
||||
with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
|
||||
dataset = self.build_dataset(dataset_path, mode, batch_size)
|
||||
shuffle = mode == "train"
|
||||
if getattr(dataset, "rect", False) and shuffle:
|
||||
LOGGER.warning("'rect=True' is incompatible with DataLoader shuffle, setting shuffle=False")
|
||||
shuffle = False
|
||||
return build_dataloader(
|
||||
dataset,
|
||||
batch=batch_size,
|
||||
workers=self.args.workers if mode == "train" else self.args.workers * 2,
|
||||
shuffle=shuffle,
|
||||
rank=rank,
|
||||
drop_last=self.args.compile and mode == "train",
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch: dict) -> dict:
|
||||
"""
|
||||
Preprocess a batch of images by scaling and converting to float.
|
||||
|
||||
Args:
|
||||
batch (dict): Dictionary containing batch data with 'img' tensor.
|
||||
|
||||
Returns:
|
||||
(dict): Preprocessed batch with normalized images.
|
||||
"""
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = batch["img"].float() / 255
|
||||
if self.args.multi_scale:
|
||||
imgs = batch["img"]
|
||||
sz = (
|
||||
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
||||
// self.stride
|
||||
* self.stride
|
||||
) # size
|
||||
sf = sz / max(imgs.shape[2:]) # scale factor
|
||||
if sf != 1:
|
||||
ns = [
|
||||
math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]
|
||||
] # new shape (stretched to gs-multiple)
|
||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
batch["img"] = imgs
|
||||
return batch
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set model attributes based on dataset information."""
|
||||
# Nl = de_parallel(self.model).model[-1].nl # number of detection layers (to scale hyps)
|
||||
# self.args.box *= 3 / nl # scale to layers
|
||||
# self.args.cls *= self.data["nc"] / 80 * 3 / nl # scale to classes and layers
|
||||
# self.args.cls *= (self.args.imgsz / 640) ** 2 * 3 / nl # scale to image size and layers
|
||||
self.model.nc = self.data["nc"] # attach number of classes to model
|
||||
self.model.names = self.data["names"] # attach class names to model
|
||||
self.model.args = self.args # attach hyperparameters to model
|
||||
# TODO: self.model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc
|
||||
|
||||
def get_model(self, cfg: str | None = None, weights: str | None = None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLO detection model.
|
||||
|
||||
Args:
|
||||
cfg (str, optional): Path to model configuration file.
|
||||
weights (str, optional): Path to model weights.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(DetectionModel): YOLO detection model.
|
||||
"""
|
||||
model = DetectionModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.detect.DetectionValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def label_loss_items(self, loss_items: list[float] | None = None, prefix: str = "train"):
|
||||
"""
|
||||
Return a loss dict with labeled training loss items tensor.
|
||||
|
||||
Args:
|
||||
loss_items (list[float], optional): List of loss values.
|
||||
prefix (str): Prefix for keys in the returned dictionary.
|
||||
|
||||
Returns:
|
||||
(dict | list): Dictionary of labeled loss items if loss_items is provided, otherwise list of keys.
|
||||
"""
|
||||
keys = [f"{prefix}/{x}" for x in self.loss_names]
|
||||
if loss_items is not None:
|
||||
loss_items = [round(float(x), 5) for x in loss_items] # convert tensors to 5 decimal place floats
|
||||
return dict(zip(keys, loss_items))
|
||||
else:
|
||||
return keys
|
||||
|
||||
def progress_string(self):
|
||||
"""Return a formatted string of training progress with epoch, GPU memory, loss, instances and size."""
|
||||
return ("\n" + "%11s" * (4 + len(self.loss_names))) % (
|
||||
"Epoch",
|
||||
"GPU_mem",
|
||||
*self.loss_names,
|
||||
"Instances",
|
||||
"Size",
|
||||
)
|
||||
|
||||
def plot_training_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot training samples with their annotations.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Dictionary containing batch data.
|
||||
ni (int): Number of iterations.
|
||||
"""
|
||||
plot_images(
|
||||
labels=batch,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"train_batch{ni}.jpg",
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_training_labels(self):
|
||||
"""Create a labeled training plot of the YOLO model."""
|
||||
boxes = np.concatenate([lb["bboxes"] for lb in self.train_loader.dataset.labels], 0)
|
||||
cls = np.concatenate([lb["cls"] for lb in self.train_loader.dataset.labels], 0)
|
||||
plot_labels(boxes, cls.squeeze(), names=self.data["names"], save_dir=self.save_dir, on_plot=self.on_plot)
|
||||
|
||||
def auto_batch(self):
|
||||
"""
|
||||
Get optimal batch size by calculating memory occupation of model.
|
||||
|
||||
Returns:
|
||||
(int): Optimal batch size.
|
||||
"""
|
||||
with override_configs(self.args, overrides={"cache": False}) as self.args:
|
||||
train_dataset = self.build_dataset(self.data["train"], mode="train", batch=16)
|
||||
max_num_obj = max(len(label["cls"]) for label in train_dataset.labels) * 4 # 4 for mosaic augmentation
|
||||
del train_dataset # free memory
|
||||
return super().auto_batch(max_num_obj)
|
||||
495
ultralytics/models/yolo/detect/val.py
Normal file
495
ultralytics/models/yolo/detect/val.py
Normal file
@@ -0,0 +1,495 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.data import build_dataloader, build_yolo_dataset, converter
|
||||
from ultralytics.engine.validator import BaseValidator
|
||||
from ultralytics.utils import LOGGER, nms, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import ConfusionMatrix, DetMetrics, box_iou
|
||||
from ultralytics.utils.plotting import plot_images
|
||||
|
||||
|
||||
class DetectionValidator(BaseValidator):
|
||||
"""
|
||||
A class extending the BaseValidator class for validation based on a detection model.
|
||||
|
||||
This class implements validation functionality specific to object detection tasks, including metrics calculation,
|
||||
prediction processing, and visualization of results.
|
||||
|
||||
Attributes:
|
||||
is_coco (bool): Whether the dataset is COCO.
|
||||
is_lvis (bool): Whether the dataset is LVIS.
|
||||
class_map (list[int]): Mapping from model class indices to dataset class indices.
|
||||
metrics (DetMetrics): Object detection metrics calculator.
|
||||
iouv (torch.Tensor): IoU thresholds for mAP calculation.
|
||||
niou (int): Number of IoU thresholds.
|
||||
lb (list[Any]): List for storing ground truth labels for hybrid saving.
|
||||
jdict (list[dict[str, Any]]): List for storing JSON detection results.
|
||||
stats (dict[str, list[torch.Tensor]]): Dictionary for storing statistics during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.detect import DetectionValidator
|
||||
>>> args = dict(model="yolo11n.pt", data="coco8.yaml")
|
||||
>>> validator = DetectionValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize detection validator with necessary variables and settings.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (dict[str, Any], optional): Arguments for the validator.
|
||||
_callbacks (list[Any], optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.is_coco = False
|
||||
self.is_lvis = False
|
||||
self.class_map = None
|
||||
self.args.task = "detect"
|
||||
self.iouv = torch.linspace(0.5, 0.95, 10) # IoU vector for mAP@0.5:0.95
|
||||
self.niou = self.iouv.numel()
|
||||
self.metrics = DetMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Preprocessed batch.
|
||||
"""
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
batch["img"] = (batch["img"].half() if self.args.half else batch["img"].float()) / 255
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO detection validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_coco = (
|
||||
isinstance(val, str)
|
||||
and "coco" in val
|
||||
and (val.endswith(f"{os.sep}val2017.txt") or val.endswith(f"{os.sep}test-dev2017.txt"))
|
||||
) # is COCO
|
||||
self.is_lvis = isinstance(val, str) and "lvis" in val and not self.is_coco # is LVIS
|
||||
self.class_map = converter.coco80_to_coco91_class() if self.is_coco else list(range(1, len(model.names) + 1))
|
||||
self.args.save_json |= self.args.val and (self.is_coco or self.is_lvis) and not self.training # run final val
|
||||
self.names = model.names
|
||||
self.nc = len(model.names)
|
||||
self.end2end = getattr(model, "end2end", False)
|
||||
self.seen = 0
|
||||
self.jdict = []
|
||||
self.metrics.names = model.names
|
||||
self.confusion_matrix = ConfusionMatrix(names=model.names, save_matches=self.args.plots and self.args.visualize)
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted string summarizing class metrics of YOLO model."""
|
||||
return ("%22s" + "%11s" * 6) % ("Class", "Images", "Instances", "Box(P", "R", "mAP50", "mAP50-95)")
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Apply Non-maximum suppression to prediction outputs.
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions after NMS, where each dict contains
|
||||
'bboxes', 'conf', 'cls', and 'extra' tensors.
|
||||
"""
|
||||
outputs = nms.non_max_suppression(
|
||||
preds,
|
||||
self.args.conf,
|
||||
self.args.iou,
|
||||
nc=0 if self.args.task == "detect" else self.nc,
|
||||
multi_label=True,
|
||||
agnostic=self.args.single_cls or self.args.agnostic_nms,
|
||||
max_det=self.args.max_det,
|
||||
end2end=self.end2end,
|
||||
rotated=self.args.task == "obb",
|
||||
)
|
||||
return [{"bboxes": x[:, :4], "conf": x[:, 4], "cls": x[:, 5], "extra": x[:, 6:]} for x in outputs]
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch of images and annotations for validation.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Batch data containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with processed annotations.
|
||||
"""
|
||||
idx = batch["batch_idx"] == si
|
||||
cls = batch["cls"][idx].squeeze(-1)
|
||||
bbox = batch["bboxes"][idx]
|
||||
ori_shape = batch["ori_shape"][si]
|
||||
imgsz = batch["img"].shape[2:]
|
||||
ratio_pad = batch["ratio_pad"][si]
|
||||
if cls.shape[0]:
|
||||
bbox = ops.xywh2xyxy(bbox) * torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]] # target boxes
|
||||
return {
|
||||
"cls": cls,
|
||||
"bboxes": bbox,
|
||||
"ori_shape": ori_shape,
|
||||
"imgsz": imgsz,
|
||||
"ratio_pad": ratio_pad,
|
||||
"im_file": batch["im_file"][si],
|
||||
}
|
||||
|
||||
def _prepare_pred(self, pred: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Prepare predictions for evaluation against ground truth.
|
||||
|
||||
Args:
|
||||
pred (dict[str, torch.Tensor]): Post-processed predictions from the model.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Prepared predictions in native space.
|
||||
"""
|
||||
if self.args.single_cls:
|
||||
pred["cls"] *= 0
|
||||
return pred
|
||||
|
||||
def update_metrics(self, preds: list[dict[str, torch.Tensor]], batch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Update metrics with new predictions and ground truth.
|
||||
|
||||
Args:
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
batch (dict[str, Any]): Batch data containing ground truth.
|
||||
"""
|
||||
for si, pred in enumerate(preds):
|
||||
self.seen += 1
|
||||
pbatch = self._prepare_batch(si, batch)
|
||||
predn = self._prepare_pred(pred)
|
||||
|
||||
cls = pbatch["cls"].cpu().numpy()
|
||||
no_pred = predn["cls"].shape[0] == 0
|
||||
self.metrics.update_stats(
|
||||
{
|
||||
**self._process_batch(predn, pbatch),
|
||||
"target_cls": cls,
|
||||
"target_img": np.unique(cls),
|
||||
"conf": np.zeros(0) if no_pred else predn["conf"].cpu().numpy(),
|
||||
"pred_cls": np.zeros(0) if no_pred else predn["cls"].cpu().numpy(),
|
||||
}
|
||||
)
|
||||
# Evaluate
|
||||
if self.args.plots:
|
||||
self.confusion_matrix.process_batch(predn, pbatch, conf=self.args.conf)
|
||||
if self.args.visualize:
|
||||
self.confusion_matrix.plot_matches(batch["img"][si], pbatch["im_file"], self.save_dir)
|
||||
|
||||
if no_pred:
|
||||
continue
|
||||
|
||||
# Save
|
||||
if self.args.save_json or self.args.save_txt:
|
||||
predn_scaled = self.scale_preds(predn, pbatch)
|
||||
if self.args.save_json:
|
||||
self.pred_to_json(predn_scaled, pbatch)
|
||||
if self.args.save_txt:
|
||||
self.save_one_txt(
|
||||
predn_scaled,
|
||||
self.args.save_conf,
|
||||
pbatch["ori_shape"],
|
||||
self.save_dir / "labels" / f"{Path(pbatch['im_file']).stem}.txt",
|
||||
)
|
||||
|
||||
def finalize_metrics(self) -> None:
|
||||
"""Set final values for metrics speed and confusion matrix."""
|
||||
if self.args.plots:
|
||||
for normalize in True, False:
|
||||
self.confusion_matrix.plot(save_dir=self.save_dir, normalize=normalize, on_plot=self.on_plot)
|
||||
self.metrics.speed = self.speed
|
||||
self.metrics.confusion_matrix = self.confusion_matrix
|
||||
self.metrics.save_dir = self.save_dir
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate and return metrics statistics.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Dictionary containing metrics results.
|
||||
"""
|
||||
self.metrics.process(save_dir=self.save_dir, plot=self.args.plots, on_plot=self.on_plot)
|
||||
self.metrics.clear_stats()
|
||||
return self.metrics.results_dict
|
||||
|
||||
def print_results(self) -> None:
|
||||
"""Print training/validation set metrics per class."""
|
||||
pf = "%22s" + "%11i" * 2 + "%11.3g" * len(self.metrics.keys) # print format
|
||||
LOGGER.info(pf % ("all", self.seen, self.metrics.nt_per_class.sum(), *self.metrics.mean_results()))
|
||||
if self.metrics.nt_per_class.sum() == 0:
|
||||
LOGGER.warning(f"no labels found in {self.args.task} set, can not compute metrics without labels")
|
||||
|
||||
# Print results per class
|
||||
if self.args.verbose and not self.training and self.nc > 1 and len(self.metrics.stats):
|
||||
for i, c in enumerate(self.metrics.ap_class_index):
|
||||
LOGGER.info(
|
||||
pf
|
||||
% (
|
||||
self.names[c],
|
||||
self.metrics.nt_per_image[c],
|
||||
self.metrics.nt_per_class[c],
|
||||
*self.metrics.class_result(i),
|
||||
)
|
||||
)
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with 'bboxes' and 'cls' keys.
|
||||
batch (dict[str, Any]): Batch dictionary containing ground truth data with 'bboxes' and 'cls' keys.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with correct prediction matrix of shape (N, 10) for 10 IoU levels.
|
||||
"""
|
||||
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
||||
iou = box_iou(batch["bboxes"], preds["bboxes"])
|
||||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "val", batch: int | None = None) -> torch.utils.data.Dataset:
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset.
|
||||
"""
|
||||
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, stride=self.stride)
|
||||
|
||||
def get_dataloader(self, dataset_path: str, batch_size: int) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Construct and return dataloader.
|
||||
|
||||
Args:
|
||||
dataset_path (str): Path to the dataset.
|
||||
batch_size (int): Size of each batch.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): Dataloader for validation.
|
||||
"""
|
||||
dataset = self.build_dataset(dataset_path, batch=batch_size, mode="val")
|
||||
return build_dataloader(
|
||||
dataset, batch_size, self.args.workers, shuffle=False, rank=-1, drop_last=self.args.compile
|
||||
)
|
||||
|
||||
def plot_val_samples(self, batch: dict[str, Any], ni: int) -> None:
|
||||
"""
|
||||
Plot validation image samples.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
ni (int): Batch index.
|
||||
"""
|
||||
plot_images(
|
||||
labels=batch,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"val_batch{ni}_labels.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
)
|
||||
|
||||
def plot_predictions(
|
||||
self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int, max_det: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
ni (int): Batch index.
|
||||
max_det (Optional[int]): Maximum number of detections to plot.
|
||||
"""
|
||||
# TODO: optimize this
|
||||
for i, pred in enumerate(preds):
|
||||
pred["batch_idx"] = torch.ones_like(pred["conf"]) * i # add batch index to predictions
|
||||
keys = preds[0].keys()
|
||||
max_det = max_det or self.args.max_det
|
||||
batched_preds = {k: torch.cat([x[k][:max_det] for x in preds], dim=0) for k in keys}
|
||||
# TODO: fix this
|
||||
batched_preds["bboxes"][:, :4] = ops.xyxy2xywh(batched_preds["bboxes"][:, :4]) # convert to xywh format
|
||||
plot_images(
|
||||
images=batch["img"],
|
||||
labels=batched_preds,
|
||||
paths=batch["im_file"],
|
||||
fname=self.save_dir / f"val_batch{ni}_pred.jpg",
|
||||
names=self.names,
|
||||
on_plot=self.on_plot,
|
||||
) # pred
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', and 'cls'.
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image (height, width).
|
||||
file (Path): File path to save the detections.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Serialize YOLO predictions to COCO json format.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Examples:
|
||||
>>> result = {
|
||||
... "image_id": 42,
|
||||
... "file_name": "42.jpg",
|
||||
... "category_id": 18,
|
||||
... "bbox": [258.15, 41.29, 348.26, 243.78],
|
||||
... "score": 0.236,
|
||||
... }
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
box = ops.xyxy2xywh(predn["bboxes"]) # xywh
|
||||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner
|
||||
for b, s, c in zip(box.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"bbox": [round(x, 3) for x in b],
|
||||
"score": round(s, 5),
|
||||
}
|
||||
)
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**predn,
|
||||
"bboxes": ops.scale_boxes(
|
||||
pbatch["imgsz"],
|
||||
predn["bboxes"].clone(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Current statistics dictionary.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated statistics dictionary with COCO/LVIS evaluation results.
|
||||
"""
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
return self.coco_evaluate(stats, pred_json, anno_json)
|
||||
|
||||
def coco_evaluate(
|
||||
self,
|
||||
stats: dict[str, Any],
|
||||
pred_json: str,
|
||||
anno_json: str,
|
||||
iou_types: str | list[str] = "bbox",
|
||||
suffix: str | list[str] = "Box",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate COCO/LVIS metrics using faster-coco-eval library.
|
||||
|
||||
Performs evaluation using the faster-coco-eval library to compute mAP metrics
|
||||
for object detection. Updates the provided stats dictionary with computed metrics
|
||||
including mAP50, mAP50-95, and LVIS-specific metrics if applicable.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Dictionary to store computed metrics and statistics.
|
||||
pred_json (str | Path]): Path to JSON file containing predictions in COCO format.
|
||||
anno_json (str | Path]): Path to JSON file containing ground truth annotations in COCO format.
|
||||
iou_types (str | list[str]]): IoU type(s) for evaluation. Can be single string or list of strings.
|
||||
Common values include "bbox", "segm", "keypoints". Defaults to "bbox".
|
||||
suffix (str | list[str]]): Suffix to append to metric names in stats dictionary. Should correspond
|
||||
to iou_types if multiple types provided. Defaults to "Box".
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated stats dictionary containing the computed COCO/LVIS evaluation metrics.
|
||||
"""
|
||||
if self.args.save_json and (self.is_coco or self.is_lvis) and len(self.jdict):
|
||||
LOGGER.info(f"\nEvaluating faster-coco-eval mAP using {pred_json} and {anno_json}...")
|
||||
try:
|
||||
for x in pred_json, anno_json:
|
||||
assert x.is_file(), f"{x} file not found"
|
||||
iou_types = [iou_types] if isinstance(iou_types, str) else iou_types
|
||||
suffix = [suffix] if isinstance(suffix, str) else suffix
|
||||
check_requirements("faster-coco-eval>=1.6.7")
|
||||
from faster_coco_eval import COCO, COCOeval_faster
|
||||
|
||||
anno = COCO(anno_json)
|
||||
pred = anno.loadRes(pred_json)
|
||||
for i, iou_type in enumerate(iou_types):
|
||||
val = COCOeval_faster(
|
||||
anno, pred, iouType=iou_type, lvis_style=self.is_lvis, print_function=LOGGER.info
|
||||
)
|
||||
val.params.imgIds = [int(Path(x).stem) for x in self.dataloader.dataset.im_files] # images to eval
|
||||
val.evaluate()
|
||||
val.accumulate()
|
||||
val.summarize()
|
||||
|
||||
# update mAP50-95 and mAP50
|
||||
stats[f"metrics/mAP50({suffix[i][0]})"] = val.stats_as_dict["AP_50"]
|
||||
stats[f"metrics/mAP50-95({suffix[i][0]})"] = val.stats_as_dict["AP_all"]
|
||||
|
||||
if self.is_lvis:
|
||||
stats[f"metrics/APr({suffix[i][0]})"] = val.stats_as_dict["APr"]
|
||||
stats[f"metrics/APc({suffix[i][0]})"] = val.stats_as_dict["APc"]
|
||||
stats[f"metrics/APf({suffix[i][0]})"] = val.stats_as_dict["APf"]
|
||||
|
||||
if self.is_lvis:
|
||||
stats["fitness"] = stats["metrics/mAP50-95(B)"] # always use box mAP50-95 for fitness
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"faster-coco-eval unable to run: {e}")
|
||||
return stats
|
||||
447
ultralytics/models/yolo/model.py
Normal file
447
ultralytics/models/yolo/model.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data.build import load_inference_source
|
||||
from ultralytics.engine.model import Model
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import (
|
||||
ClassificationModel,
|
||||
DetectionModel,
|
||||
OBBModel,
|
||||
PoseModel,
|
||||
SegmentationModel,
|
||||
WorldModel,
|
||||
YOLOEModel,
|
||||
YOLOESegModel,
|
||||
)
|
||||
from ultralytics.utils import ROOT, YAML
|
||||
|
||||
|
||||
class YOLO(Model):
|
||||
"""
|
||||
YOLO (You Only Look Once) object detection model.
|
||||
|
||||
This class provides a unified interface for YOLO models, automatically switching to specialized model types
|
||||
(YOLOWorld or YOLOE) based on the model filename. It supports various computer vision tasks including object
|
||||
detection, segmentation, classification, pose estimation, and oriented bounding box detection.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLO model instance.
|
||||
task: The task type (detect, segment, classify, pose, obb).
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize a YOLO model with automatic type detection.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
|
||||
Examples:
|
||||
Load a pretrained YOLOv11n detection model
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
|
||||
Load a pretrained YOLO11n segmentation model
|
||||
>>> model = YOLO("yolo11n-seg.pt")
|
||||
|
||||
Initialize from a YAML configuration
|
||||
>>> model = YOLO("yolo11n.yaml")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolo11n.pt", task: str | None = None, verbose: bool = False):
|
||||
"""
|
||||
Initialize a YOLO model.
|
||||
|
||||
This constructor initializes a YOLO model, automatically switching to specialized model types
|
||||
(YOLOWorld or YOLOE) based on the model filename.
|
||||
|
||||
Args:
|
||||
model (str | Path): Model name or path to model file, i.e. 'yolo11n.pt', 'yolo11n.yaml'.
|
||||
task (str, optional): YOLO task specification, i.e. 'detect', 'segment', 'classify', 'pose', 'obb'.
|
||||
Defaults to auto-detection based on model.
|
||||
verbose (bool): Display model info on load.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt") # load a pretrained YOLOv11n detection model
|
||||
>>> model = YOLO("yolo11n-seg.pt") # load a pretrained YOLO11n segmentation model
|
||||
"""
|
||||
path = Path(model if isinstance(model, (str, Path)) else "")
|
||||
if "-world" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOWorld PyTorch model
|
||||
new_instance = YOLOWorld(path, verbose=verbose)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
elif "yoloe" in path.stem and path.suffix in {".pt", ".yaml", ".yml"}: # if YOLOE PyTorch model
|
||||
new_instance = YOLOE(path, task=task, verbose=verbose)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
else:
|
||||
# Continue with default YOLO initialization
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
if hasattr(self.model, "model") and "RTDETR" in self.model.model[-1]._get_name(): # if RTDETR head
|
||||
from ultralytics import RTDETR
|
||||
|
||||
new_instance = RTDETR(self)
|
||||
self.__class__ = type(new_instance)
|
||||
self.__dict__ = new_instance.__dict__
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, trainer, validator, and predictor classes."""
|
||||
return {
|
||||
"classify": {
|
||||
"model": ClassificationModel,
|
||||
"trainer": yolo.classify.ClassificationTrainer,
|
||||
"validator": yolo.classify.ClassificationValidator,
|
||||
"predictor": yolo.classify.ClassificationPredictor,
|
||||
},
|
||||
"detect": {
|
||||
"model": DetectionModel,
|
||||
"trainer": yolo.detect.DetectionTrainer,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
},
|
||||
"segment": {
|
||||
"model": SegmentationModel,
|
||||
"trainer": yolo.segment.SegmentationTrainer,
|
||||
"validator": yolo.segment.SegmentationValidator,
|
||||
"predictor": yolo.segment.SegmentationPredictor,
|
||||
},
|
||||
"pose": {
|
||||
"model": PoseModel,
|
||||
"trainer": yolo.pose.PoseTrainer,
|
||||
"validator": yolo.pose.PoseValidator,
|
||||
"predictor": yolo.pose.PosePredictor,
|
||||
},
|
||||
"obb": {
|
||||
"model": OBBModel,
|
||||
"trainer": yolo.obb.OBBTrainer,
|
||||
"validator": yolo.obb.OBBValidator,
|
||||
"predictor": yolo.obb.OBBPredictor,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class YOLOWorld(Model):
|
||||
"""
|
||||
YOLO-World object detection model.
|
||||
|
||||
YOLO-World is an open-vocabulary object detection model that can detect objects based on text descriptions
|
||||
without requiring training on specific classes. It extends the YOLO architecture to support real-time
|
||||
open-vocabulary detection.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLO-World model instance.
|
||||
task: Always set to 'detect' for object detection.
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize YOLOv8-World model with a pre-trained model file.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
set_classes: Set the model's class names for detection.
|
||||
|
||||
Examples:
|
||||
Load a YOLOv8-World model
|
||||
>>> model = YOLOWorld("yolov8s-world.pt")
|
||||
|
||||
Set custom classes for detection
|
||||
>>> model.set_classes(["person", "car", "bicycle"])
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yolov8s-world.pt", verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOv8-World model with a pre-trained model file.
|
||||
|
||||
Loads a YOLOv8-World model for object detection. If no custom class names are provided, it assigns default
|
||||
COCO class names.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
verbose (bool): If True, prints additional information during initialization.
|
||||
"""
|
||||
super().__init__(model=model, task="detect", verbose=verbose)
|
||||
|
||||
# Assign default COCO class names when there are no custom names
|
||||
if not hasattr(self.model, "names"):
|
||||
self.model.names = YAML.load(ROOT / "cfg/datasets/coco8.yaml").get("names")
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": WorldModel,
|
||||
"validator": yolo.detect.DetectionValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.world.WorldTrainer,
|
||||
}
|
||||
}
|
||||
|
||||
def set_classes(self, classes: list[str]) -> None:
|
||||
"""
|
||||
Set the model's class names for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
"""
|
||||
self.model.set_classes(classes)
|
||||
# Remove background if it's given
|
||||
background = " "
|
||||
if background in classes:
|
||||
classes.remove(background)
|
||||
self.model.names = classes
|
||||
|
||||
# Reset method class names
|
||||
if self.predictor:
|
||||
self.predictor.model.names = classes
|
||||
|
||||
|
||||
class YOLOE(Model):
|
||||
"""
|
||||
YOLOE object detection and segmentation model.
|
||||
|
||||
YOLOE is an enhanced YOLO model that supports both object detection and instance segmentation tasks with
|
||||
improved performance and additional features like visual and text positional embeddings.
|
||||
|
||||
Attributes:
|
||||
model: The loaded YOLOE model instance.
|
||||
task: The task type (detect or segment).
|
||||
overrides: Configuration overrides for the model.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize YOLOE model with a pre-trained model file.
|
||||
task_map: Map tasks to their corresponding model, trainer, validator, and predictor classes.
|
||||
get_text_pe: Get text positional embeddings for the given texts.
|
||||
get_visual_pe: Get visual positional embeddings for the given image and visual features.
|
||||
set_vocab: Set vocabulary and class names for the YOLOE model.
|
||||
get_vocab: Get vocabulary for the given class names.
|
||||
set_classes: Set the model's class names and embeddings for detection.
|
||||
val: Validate the model using text or visual prompts.
|
||||
predict: Run prediction on images, videos, directories, streams, etc.
|
||||
|
||||
Examples:
|
||||
Load a YOLOE detection model
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
|
||||
Set vocabulary and class names
|
||||
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
||||
|
||||
Predict with visual prompts
|
||||
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
||||
>>> results = model.predict("image.jpg", visual_prompts=prompts)
|
||||
"""
|
||||
|
||||
def __init__(self, model: str | Path = "yoloe-11s-seg.pt", task: str | None = None, verbose: bool = False) -> None:
|
||||
"""
|
||||
Initialize YOLOE model with a pre-trained model file.
|
||||
|
||||
Args:
|
||||
model (str | Path): Path to the pre-trained model file. Supports *.pt and *.yaml formats.
|
||||
task (str, optional): Task type for the model. Auto-detected if None.
|
||||
verbose (bool): If True, prints additional information during initialization.
|
||||
"""
|
||||
super().__init__(model=model, task=task, verbose=verbose)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Map head to model, validator, and predictor classes."""
|
||||
return {
|
||||
"detect": {
|
||||
"model": YOLOEModel,
|
||||
"validator": yolo.yoloe.YOLOEDetectValidator,
|
||||
"predictor": yolo.detect.DetectionPredictor,
|
||||
"trainer": yolo.yoloe.YOLOETrainer,
|
||||
},
|
||||
"segment": {
|
||||
"model": YOLOESegModel,
|
||||
"validator": yolo.yoloe.YOLOESegValidator,
|
||||
"predictor": yolo.segment.SegmentationPredictor,
|
||||
"trainer": yolo.yoloe.YOLOESegTrainer,
|
||||
},
|
||||
}
|
||||
|
||||
def get_text_pe(self, texts):
|
||||
"""Get text positional embeddings for the given texts."""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_text_pe(texts)
|
||||
|
||||
def get_visual_pe(self, img, visual):
|
||||
"""
|
||||
Get visual positional embeddings for the given image and visual features.
|
||||
|
||||
This method extracts positional embeddings from visual features based on the input image. It requires
|
||||
that the model is an instance of YOLOEModel.
|
||||
|
||||
Args:
|
||||
img (torch.Tensor): Input image tensor.
|
||||
visual (torch.Tensor): Visual features extracted from the image.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Visual positional embeddings.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> img = torch.rand(1, 3, 640, 640)
|
||||
>>> visual_features = torch.rand(1, 1, 80, 80)
|
||||
>>> pe = model.get_visual_pe(img, visual_features)
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_visual_pe(img, visual)
|
||||
|
||||
def set_vocab(self, vocab: list[str], names: list[str]) -> None:
|
||||
"""
|
||||
Set vocabulary and class names for the YOLOE model.
|
||||
|
||||
This method configures the vocabulary and class names used by the model for text processing and
|
||||
classification tasks. The model must be an instance of YOLOEModel.
|
||||
|
||||
Args:
|
||||
vocab (list[str]): Vocabulary list containing tokens or words used by the model for text processing.
|
||||
names (list[str]): List of class names that the model can detect or classify.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the model is not an instance of YOLOEModel.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> model.set_vocab(["person", "car", "dog"], ["person", "car", "dog"])
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
self.model.set_vocab(vocab, names=names)
|
||||
|
||||
def get_vocab(self, names):
|
||||
"""Get vocabulary for the given class names."""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
return self.model.get_vocab(names)
|
||||
|
||||
def set_classes(self, classes: list[str], embeddings: torch.Tensor | None = None) -> None:
|
||||
"""
|
||||
Set the model's class names and embeddings for detection.
|
||||
|
||||
Args:
|
||||
classes (list[str]): A list of categories i.e. ["person"].
|
||||
embeddings (torch.Tensor): Embeddings corresponding to the classes.
|
||||
"""
|
||||
assert isinstance(self.model, YOLOEModel)
|
||||
if embeddings is None:
|
||||
embeddings = self.get_text_pe(classes) # generate text embeddings if not provided
|
||||
self.model.set_classes(classes, embeddings)
|
||||
# Verify no background class is present
|
||||
assert " " not in classes
|
||||
self.model.names = classes
|
||||
|
||||
# Reset method class names
|
||||
if self.predictor:
|
||||
self.predictor.model.names = classes
|
||||
|
||||
def val(
|
||||
self,
|
||||
validator=None,
|
||||
load_vp: bool = False,
|
||||
refer_data: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Validate the model using text or visual prompts.
|
||||
|
||||
Args:
|
||||
validator (callable, optional): A callable validator function. If None, a default validator is loaded.
|
||||
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
||||
refer_data (str, optional): Path to the reference data for visual prompts.
|
||||
**kwargs (Any): Additional keyword arguments to override default settings.
|
||||
|
||||
Returns:
|
||||
(dict): Validation statistics containing metrics computed during validation.
|
||||
"""
|
||||
custom = {"rect": not load_vp} # method defaults
|
||||
args = {**self.overrides, **custom, **kwargs, "mode": "val"} # highest priority args on the right
|
||||
|
||||
validator = (validator or self._smart_load("validator"))(args=args, _callbacks=self.callbacks)
|
||||
validator(model=self.model, load_vp=load_vp, refer_data=refer_data)
|
||||
self.metrics = validator.metrics
|
||||
return validator.metrics
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source=None,
|
||||
stream: bool = False,
|
||||
visual_prompts: dict[str, list] = {},
|
||||
refer_image=None,
|
||||
predictor=yolo.yoloe.YOLOEVPDetectPredictor,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Run prediction on images, videos, directories, streams, etc.
|
||||
|
||||
Args:
|
||||
source (str | int | PIL.Image | np.ndarray, optional): Source for prediction. Accepts image paths,
|
||||
directory paths, URL/YouTube streams, PIL images, numpy arrays, or webcam indices.
|
||||
stream (bool): Whether to stream the prediction results. If True, results are yielded as a
|
||||
generator as they are computed.
|
||||
visual_prompts (dict[str, list]): Dictionary containing visual prompts for the model. Must include
|
||||
'bboxes' and 'cls' keys when non-empty.
|
||||
refer_image (str | PIL.Image | np.ndarray, optional): Reference image for visual prompts.
|
||||
predictor (callable, optional): Custom predictor function. If None, a predictor is automatically
|
||||
loaded based on the task.
|
||||
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
||||
|
||||
Returns:
|
||||
(list | generator): List of Results objects or generator of Results objects if stream=True.
|
||||
|
||||
Examples:
|
||||
>>> model = YOLOE("yoloe-11s-seg.pt")
|
||||
>>> results = model.predict("path/to/image.jpg")
|
||||
>>> # With visual prompts
|
||||
>>> prompts = {"bboxes": [[10, 20, 100, 200]], "cls": ["person"]}
|
||||
>>> results = model.predict("path/to/image.jpg", visual_prompts=prompts)
|
||||
"""
|
||||
if len(visual_prompts):
|
||||
assert "bboxes" in visual_prompts and "cls" in visual_prompts, (
|
||||
f"Expected 'bboxes' and 'cls' in visual prompts, but got {visual_prompts.keys()}"
|
||||
)
|
||||
assert len(visual_prompts["bboxes"]) == len(visual_prompts["cls"]), (
|
||||
f"Expected equal number of bounding boxes and classes, but got {len(visual_prompts['bboxes'])} and "
|
||||
f"{len(visual_prompts['cls'])} respectively"
|
||||
)
|
||||
if type(self.predictor) is not predictor:
|
||||
self.predictor = predictor(
|
||||
overrides={
|
||||
"task": self.model.task,
|
||||
"mode": "predict",
|
||||
"save": False,
|
||||
"verbose": refer_image is None,
|
||||
"batch": 1,
|
||||
"device": kwargs.get("device", None),
|
||||
"half": kwargs.get("half", False),
|
||||
"imgsz": kwargs.get("imgsz", self.overrides["imgsz"]),
|
||||
},
|
||||
_callbacks=self.callbacks,
|
||||
)
|
||||
|
||||
num_cls = (
|
||||
max(len(set(c)) for c in visual_prompts["cls"])
|
||||
if isinstance(source, list) and refer_image is None # means multiple images
|
||||
else len(set(visual_prompts["cls"]))
|
||||
)
|
||||
self.model.model[-1].nc = num_cls
|
||||
self.model.names = [f"object{i}" for i in range(num_cls)]
|
||||
self.predictor.set_prompts(visual_prompts.copy())
|
||||
self.predictor.setup_model(model=self.model)
|
||||
|
||||
if refer_image is None and source is not None:
|
||||
dataset = load_inference_source(source)
|
||||
if dataset.mode in {"video", "stream"}:
|
||||
# NOTE: set the first frame as refer image for videos/streams inference
|
||||
refer_image = next(iter(dataset))[1][0]
|
||||
if refer_image is not None:
|
||||
vpe = self.predictor.get_vpe(refer_image)
|
||||
self.model.set_classes(self.model.names, vpe)
|
||||
self.task = "segment" if isinstance(self.predictor, yolo.segment.SegmentationPredictor) else "detect"
|
||||
self.predictor = None # reset predictor
|
||||
elif isinstance(self.predictor, yolo.yoloe.YOLOEVPDetectPredictor):
|
||||
self.predictor = None # reset predictor if no visual prompts
|
||||
|
||||
return super().predict(source, stream, **kwargs)
|
||||
7
ultralytics/models/yolo/obb/__init__.py
Normal file
7
ultralytics/models/yolo/obb/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import OBBPredictor
|
||||
from .train import OBBTrainer
|
||||
from .val import OBBValidator
|
||||
|
||||
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
|
||||
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
65
ultralytics/models/yolo/obb/predict.py
Normal file
65
ultralytics/models/yolo/obb/predict.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class OBBPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
||||
bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO OBB model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize OBBPredictor with optional model and data configuration overrides.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "obb"
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
||||
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
||||
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
||||
boxes.
|
||||
"""
|
||||
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
||||
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
||||
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
||||
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
|
||||
82
ultralytics/models/yolo/obb/train.py
Normal file
82
ultralytics/models/yolo/obb/train.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import OBBModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
|
||||
|
||||
class OBBTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
||||
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
||||
and dfl_loss.
|
||||
|
||||
Methods:
|
||||
get_model: Return OBBModel initialized with specified config and weights.
|
||||
get_validator: Return an instance of OBBValidator for validation of YOLO model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
||||
>>> trainer = OBBTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
||||
"""
|
||||
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
||||
model configuration.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
||||
will take precedence over those in cfg.
|
||||
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "obb"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(
|
||||
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
||||
) -> OBBModel:
|
||||
"""
|
||||
Return OBBModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
||||
containing configuration parameters, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(OBBModel): Initialized OBBModel with the specified configuration and weights.
|
||||
|
||||
Examples:
|
||||
>>> trainer = OBBTrainer()
|
||||
>>> model = trainer.get_model(cfg="yolo11n-obb.yaml", weights="yolo11n-obb.pt")
|
||||
"""
|
||||
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of OBBValidator for validation of YOLO model."""
|
||||
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.obb.OBBValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
299
ultralytics/models/yolo/obb/val.py
Normal file
299
ultralytics/models/yolo/obb/val.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
||||
from ultralytics.utils.nms import TorchNMS
|
||||
|
||||
|
||||
class OBBValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
||||
satellite imagery where objects can appear at various orientations.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the validator.
|
||||
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
|
||||
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
|
||||
|
||||
Methods:
|
||||
init_metrics: Initialize evaluation metrics for YOLO.
|
||||
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
|
||||
_prepare_batch: Prepare batch data for OBB validation.
|
||||
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
|
||||
plot_predictions: Plot predicted bounding boxes on input images.
|
||||
pred_to_json: Serialize YOLO predictions to COCO json format.
|
||||
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
|
||||
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBValidator
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
|
||||
>>> validator = OBBValidator(args=args)
|
||||
>>> validator(model=args["model"])
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
||||
|
||||
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
||||
It extends the DetectionValidator class and configures it specifically for the OBB task.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.args.task = "obb"
|
||||
self.metrics = OBBMetrics()
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO obb validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
||||
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
||||
class labels and bounding boxes.
|
||||
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
||||
class labels and bounding boxes.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
||||
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
|
||||
of predictions compared to the ground truth.
|
||||
|
||||
Examples:
|
||||
>>> detections = torch.rand(100, 7) # 100 sample detections
|
||||
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
||||
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
||||
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
||||
"""
|
||||
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
||||
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
|
||||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare batch data for OBB validation with proper scaling and formatting.
|
||||
|
||||
Args:
|
||||
si (int): Batch index to process.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys:
|
||||
- batch_idx: Tensor of batch indices
|
||||
- cls: Tensor of class labels
|
||||
- bboxes: Tensor of bounding boxes
|
||||
- ori_shape: Original image shapes
|
||||
- img: Batch of images
|
||||
- ratio_pad: Ratio and padding information
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
||||
"""
|
||||
idx = batch["batch_idx"] == si
|
||||
cls = batch["cls"][idx].squeeze(-1)
|
||||
bbox = batch["bboxes"][idx]
|
||||
ori_shape = batch["ori_shape"][si]
|
||||
imgsz = batch["img"].shape[2:]
|
||||
ratio_pad = batch["ratio_pad"][si]
|
||||
if cls.shape[0]:
|
||||
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
||||
return {
|
||||
"cls": cls,
|
||||
"bboxes": bbox,
|
||||
"ori_shape": ori_shape,
|
||||
"imgsz": imgsz,
|
||||
"ratio_pad": ratio_pad,
|
||||
"im_file": batch["im_file"][si],
|
||||
}
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
||||
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> batch = {"img": images, "im_file": paths}
|
||||
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
for p in preds:
|
||||
# TODO: fix this duplicated `xywh2xyxy`
|
||||
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
|
||||
super().plot_predictions(batch, preds, ni) # plot bboxes
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
This method processes rotated bounding box predictions and converts them to both rbox format
|
||||
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
||||
to the JSON dictionary.
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
rbox = predn["bboxes"]
|
||||
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
||||
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"score": round(s, 5),
|
||||
"rbox": [round(x, 3) for x in r],
|
||||
"poly": [round(x, 3) for x in b],
|
||||
}
|
||||
)
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO OBB detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
||||
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
||||
save_conf (bool): Whether to save confidence scores in the text file.
|
||||
shape (tuple[int, int]): Original image shape in format (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
|
||||
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**predn,
|
||||
"bboxes": ops.scale_boxes(
|
||||
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Performance statistics dictionary.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated performance statistics.
|
||||
"""
|
||||
if self.args.save_json and self.is_dota and len(self.jdict):
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
pred_txt = self.save_dir / "predictions_txt" # predictions
|
||||
pred_txt.mkdir(parents=True, exist_ok=True)
|
||||
data = json.load(open(pred_json))
|
||||
# Save split results
|
||||
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"]
|
||||
score = d["score"]
|
||||
classname = self.names[d["category_id"] - 1].replace(" ", "-")
|
||||
p = d["poly"]
|
||||
|
||||
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
# Save merged results, this could result slightly lower map than using official merging script,
|
||||
# because of the probiou calculation.
|
||||
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
||||
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
||||
merged_results = defaultdict(list)
|
||||
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"].split("__", 1)[0]
|
||||
pattern = re.compile(r"\d+___\d+")
|
||||
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
||||
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
|
||||
bbox[0] += x
|
||||
bbox[1] += y
|
||||
bbox.extend([score, cls])
|
||||
merged_results[image_id].append(bbox)
|
||||
for image_id, bbox in merged_results.items():
|
||||
bbox = torch.tensor(bbox)
|
||||
max_wh = torch.max(bbox[:, :2]).item() * 2
|
||||
c = bbox[:, 6:7] * max_wh # classes
|
||||
scores = bbox[:, 5] # scores
|
||||
b = bbox[:, :5].clone()
|
||||
b[:, :2] += c
|
||||
# 0.3 could get results close to the ones from official merging script, even slightly better.
|
||||
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
|
||||
bbox = bbox[i]
|
||||
|
||||
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
||||
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
||||
classname = self.names[int(x[-1])].replace(" ", "-")
|
||||
p = [round(i, 3) for i in x[:-2]] # poly
|
||||
score = round(x[-2], 3)
|
||||
|
||||
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
|
||||
return stats
|
||||
7
ultralytics/models/yolo/pose/__init__.py
Normal file
7
ultralytics/models/yolo/pose/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import PosePredictor
|
||||
from .train import PoseTrainer
|
||||
from .val import PoseValidator
|
||||
|
||||
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
|
||||
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
80
ultralytics/models/yolo/pose/predict.py
Normal file
80
ultralytics/models/yolo/pose/predict.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
||||
|
||||
|
||||
class PosePredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a pose model.
|
||||
|
||||
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
||||
capabilities inherited from DetectionPredictor.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
||||
|
||||
Methods:
|
||||
construct_result: Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize PosePredictor for pose estimation tasks.
|
||||
|
||||
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
||||
warnings for Apple MPS.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "pose"
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
||||
result object.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
||||
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
|
||||
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original unprocessed image as a numpy array.
|
||||
img_path (str): The path to the original image file.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
||||
keypoints.
|
||||
"""
|
||||
result = super().construct_result(pred, img, orig_img, img_path)
|
||||
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
||||
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
|
||||
# Scale keypoints coordinates to match the original image dimensions
|
||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||
result.update(keypoints=pred_kpts)
|
||||
return result
|
||||
115
ultralytics/models/yolo/pose/train.py
Normal file
115
ultralytics/models/yolo/pose/train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import PoseModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
|
||||
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
|
||||
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
||||
of pose keypoints alongside bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for training.
|
||||
model (PoseModel): The pose estimation model being trained.
|
||||
data (dict): Dataset configuration including keypoint shape information.
|
||||
loss_names (tuple): Names of the loss components used in training.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieve a pose estimation model with specified configuration.
|
||||
set_model_attributes: Set keypoints shape attribute on the model.
|
||||
get_validator: Create a validator instance for model evaluation.
|
||||
plot_training_samples: Visualize training samples with keypoints.
|
||||
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
||||
>>> trainer = PoseTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
|
||||
Notes:
|
||||
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
||||
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "pose"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
cfg: str | Path | dict[str, Any] | None = None,
|
||||
weights: str | Path | None = None,
|
||||
verbose: bool = True,
|
||||
) -> PoseModel:
|
||||
"""
|
||||
Get pose estimation model with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
||||
weights (str | Path, optional): Path to the model weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(PoseModel): Initialized pose estimation model.
|
||||
"""
|
||||
model = PoseModel(
|
||||
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data["kpt_shape"]
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.pose.PoseValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def get_dataset(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `kpt_shape` key is not present in the dataset.
|
||||
"""
|
||||
data = super().get_dataset()
|
||||
if "kpt_shape" not in data:
|
||||
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
|
||||
return data
|
||||
267
ultralytics/models/yolo/pose/val.py
Normal file
267
ultralytics/models/yolo/pose/val.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
||||
|
||||
|
||||
class PoseValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a pose model.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Attributes:
|
||||
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
||||
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
||||
args (dict): Arguments for the validator including task set to "pose".
|
||||
metrics (PoseMetrics): Metrics object for pose evaluation.
|
||||
|
||||
Methods:
|
||||
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
|
||||
get_desc: Return description of evaluation metrics in string format.
|
||||
init_metrics: Initialize pose estimation metrics for YOLO model.
|
||||
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
||||
dimensions.
|
||||
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
||||
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
||||
detections and ground truth.
|
||||
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
||||
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
||||
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
||||
pred_to_json: Convert YOLO predictions to COCO JSON format.
|
||||
eval_json: Evaluate object detection model using COCO JSON format.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize a PoseValidator object for pose estimation validation.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path | str, optional): Directory to save results.
|
||||
args (dict, optional): Arguments for the validator including task set to "pose".
|
||||
_callbacks (list, optional): List of callback functions to be executed during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
||||
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
||||
due to a known bug with pose models.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.sigma = None
|
||||
self.kpt_shape = None
|
||||
self.args.task = "pose"
|
||||
self.metrics = PoseMetrics()
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch["keypoints"] = batch["keypoints"].float()
|
||||
return batch
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return description of evaluation metrics in string format."""
|
||||
return ("%22s" + "%11s" * 10) % (
|
||||
"Class",
|
||||
"Images",
|
||||
"Instances",
|
||||
"Box(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
"Pose(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO pose validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data["kpt_shape"]
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0]
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
||||
|
||||
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
||||
field of predictions and reshaping them according to the keypoint shape configuration.
|
||||
The keypoints are reshaped from a flattened format to the proper dimensional structure
|
||||
(typically [N, 17, 3] for COCO pose format).
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
||||
bounding boxes, confidence scores, class predictions, and keypoint data.
|
||||
|
||||
Returns:
|
||||
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
||||
- 'bboxes': Bounding box coordinates
|
||||
- 'conf': Confidence scores
|
||||
- 'cls': Class predictions
|
||||
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
||||
|
||||
Note:
|
||||
If no keypoints are present in a prediction (empty keypoints), that prediction
|
||||
is skipped and continues to the next one. The keypoints are extracted from the
|
||||
'extra' field which contains additional task-specific data beyond basic detection.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
||||
|
||||
Notes:
|
||||
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
||||
Keypoints are scaled from normalized coordinates to original image dimensions.
|
||||
"""
|
||||
pbatch = super()._prepare_batch(si, batch)
|
||||
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
||||
h, w = pbatch["imgsz"]
|
||||
kpts = kpts.clone()
|
||||
kpts[..., 0] *= w
|
||||
kpts[..., 1] *= h
|
||||
pbatch["keypoints"] = kpts
|
||||
return pbatch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
||||
and 'keypoints' for keypoint predictions.
|
||||
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
||||
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
||||
true positives across 10 IoU levels.
|
||||
|
||||
Notes:
|
||||
`0.53` scale factor used in area computation is referenced from
|
||||
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
||||
"""
|
||||
tp = super()._process_batch(preds, batch)
|
||||
gt_cls = batch["cls"]
|
||||
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
||||
else:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
||||
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
||||
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
||||
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
||||
return tp
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Notes:
|
||||
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
||||
normalized (x, y, visibility) values for each point.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
keypoints=predn["keypoints"],
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format.
|
||||
|
||||
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
||||
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
||||
and 'keypoints' tensors.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
||||
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
||||
before saving to the JSON dictionary.
|
||||
"""
|
||||
super().pred_to_json(predn, pbatch)
|
||||
kpts = predn["kpts"]
|
||||
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
||||
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**super().scale_preds(predn, pbatch),
|
||||
"kpts": ops.scale_coords(
|
||||
pbatch["imgsz"],
|
||||
predn["keypoints"].clone(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Evaluate object detection model using COCO JSON format."""
|
||||
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
|
||||
7
ultralytics/models/yolo/segment/__init__.py
Normal file
7
ultralytics/models/yolo/segment/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import SegmentationPredictor
|
||||
from .train import SegmentationTrainer
|
||||
from .val import SegmentationValidator
|
||||
|
||||
__all__ = "SegmentationPredictor", "SegmentationTrainer", "SegmentationValidator"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/segment/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/segment/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
113
ultralytics/models/yolo/segment/predict.py
Normal file
113
ultralytics/models/yolo/segment/predict.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class SegmentationPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a segmentation model.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO segmentation model.
|
||||
batch (list): Current batch of images being processed.
|
||||
|
||||
Methods:
|
||||
postprocess: Apply non-max suppression and process segmentation detections.
|
||||
construct_results: Construct a list of result objects from predictions.
|
||||
construct_result: Construct a single result object from a prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
>>> args = dict(model="yolo11n-seg.pt", source=ASSETS)
|
||||
>>> predictor = SegmentationPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the SegmentationPredictor with configuration, overrides, and callbacks.
|
||||
|
||||
This class specializes in processing segmentation model outputs, handling both bounding boxes and masks in the
|
||||
prediction results.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "segment"
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply non-max suppression and process segmentation detections for each image in the input batch.
|
||||
|
||||
Args:
|
||||
preds (tuple): Model predictions, containing bounding boxes, scores, classes, and mask coefficients.
|
||||
img (torch.Tensor): Input image tensor in model format, with shape (B, C, H, W).
|
||||
orig_imgs (list | torch.Tensor | np.ndarray): Original image or batch of images.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the segmentation predictions for each image in the batch.
|
||||
Each Results object includes both bounding boxes and segmentation masks.
|
||||
|
||||
Examples:
|
||||
>>> predictor = SegmentationPredictor(overrides=dict(model="yolo11n-seg.pt"))
|
||||
>>> results = predictor.postprocess(preds, img, orig_img)
|
||||
"""
|
||||
# Extract protos - tuple if PyTorch model or array if exported
|
||||
protos = preds[1][-1] if isinstance(preds[1], tuple) else preds[1]
|
||||
return super().postprocess(preds[0], img, orig_imgs, protos=protos)
|
||||
|
||||
def construct_results(self, preds, img, orig_imgs, protos):
|
||||
"""
|
||||
Construct a list of result objects from the predictions.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): List of predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_imgs (list[np.ndarray]): List of original images before preprocessing.
|
||||
protos (list[torch.Tensor]): List of prototype masks.
|
||||
|
||||
Returns:
|
||||
(list[Results]): List of result objects containing the original images, image paths, class names,
|
||||
bounding boxes, and masks.
|
||||
"""
|
||||
return [
|
||||
self.construct_result(pred, img, orig_img, img_path, proto)
|
||||
for pred, orig_img, img_path, proto in zip(preds, orig_imgs, self.batch[0], protos)
|
||||
]
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path, proto):
|
||||
"""
|
||||
Construct a single result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and masks.
|
||||
img (torch.Tensor): The image after preprocessing.
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
proto (torch.Tensor): The prototype masks.
|
||||
|
||||
Returns:
|
||||
(Results): Result object containing the original image, image path, class names, bounding boxes, and masks.
|
||||
"""
|
||||
if pred.shape[0] == 0: # save empty boxes
|
||||
masks = None
|
||||
elif self.args.retina_masks:
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
masks = ops.process_mask_native(proto, pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC
|
||||
else:
|
||||
masks = ops.process_mask(proto, pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC
|
||||
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
|
||||
if masks is not None:
|
||||
keep = masks.sum((-2, -1)) > 0 # only keep predictions with masks
|
||||
pred, masks = pred[keep], masks[keep]
|
||||
return Results(orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)
|
||||
72
ultralytics/models/yolo/segment/train.py
Normal file
72
ultralytics/models/yolo/segment/train.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import SegmentationModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
|
||||
|
||||
class SegmentationTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on a segmentation model.
|
||||
|
||||
This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific
|
||||
functionality including model initialization, validation, and visualization.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple[str]): Names of the loss components used during training.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3)
|
||||
>>> trainer = SegmentationTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a SegmentationTrainer object.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "segment"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True):
|
||||
"""
|
||||
Initialize and return a SegmentationModel with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None.
|
||||
weights (str | Path, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(SegmentationModel): Initialized segmentation model with loaded weights if specified.
|
||||
|
||||
Examples:
|
||||
>>> trainer = SegmentationTrainer()
|
||||
>>> model = trainer.get_model(cfg="yolo11n-seg.yaml")
|
||||
>>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False)
|
||||
"""
|
||||
model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of SegmentationValidator for validation of YOLO model."""
|
||||
self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.segment.SegmentationValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
259
ultralytics/models/yolo/segment/val.py
Normal file
259
ultralytics/models/yolo/segment/val.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, NUM_THREADS, ops
|
||||
from ultralytics.utils.checks import check_requirements
|
||||
from ultralytics.utils.metrics import SegmentMetrics, mask_iou
|
||||
|
||||
|
||||
class SegmentationValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a segmentation model.
|
||||
|
||||
This validator handles the evaluation of segmentation models, processing both bounding box and mask predictions
|
||||
to compute metrics such as mAP for both detection and segmentation tasks.
|
||||
|
||||
Attributes:
|
||||
plot_masks (list): List to store masks for plotting.
|
||||
process (callable): Function to process masks based on save_json and save_txt flags.
|
||||
args (namespace): Arguments for the validator.
|
||||
metrics (SegmentMetrics): Metrics calculator for segmentation tasks.
|
||||
stats (dict): Dictionary to store statistics during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
>>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml")
|
||||
>>> validator = SegmentationValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize SegmentationValidator and set task to 'segment', metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to use for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (namespace, optional): Arguments for the validator.
|
||||
_callbacks (list, optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.process = None
|
||||
self.args.task = "segment"
|
||||
self.metrics = SegmentMetrics()
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Preprocess batch of images for YOLO segmentation validation.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Preprocessed batch.
|
||||
"""
|
||||
batch = super().preprocess(batch)
|
||||
batch["masks"] = batch["masks"].float()
|
||||
return batch
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize metrics and select mask processing function based on save_json flag.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
if self.args.save_json:
|
||||
check_requirements("faster-coco-eval>=1.6.7")
|
||||
# More accurate vs faster
|
||||
self.process = ops.process_mask_native if self.args.save_json or self.args.save_txt else ops.process_mask
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return a formatted description of evaluation metrics."""
|
||||
return ("%22s" + "%11s" * 10) % (
|
||||
"Class",
|
||||
"Images",
|
||||
"Instances",
|
||||
"Box(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
"Mask(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def postprocess(self, preds: list[torch.Tensor]) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Post-process YOLO predictions and return output detections with proto.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
list[dict[str, torch.Tensor]]: Processed detection predictions with masks.
|
||||
"""
|
||||
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported
|
||||
preds = super().postprocess(preds[0])
|
||||
imgsz = [4 * x for x in proto.shape[2:]] # get image size from proto
|
||||
for i, pred in enumerate(preds):
|
||||
coefficient = pred.pop("extra")
|
||||
pred["masks"] = (
|
||||
self.process(proto[i], coefficient, pred["bboxes"], shape=imgsz)
|
||||
if coefficient.shape[0]
|
||||
else torch.zeros(
|
||||
(0, *(imgsz if self.process is ops.process_mask_native else proto.shape[2:])),
|
||||
dtype=torch.uint8,
|
||||
device=pred["bboxes"].device,
|
||||
)
|
||||
)
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for training or inference by processing images and targets.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Batch data containing images and annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with processed annotations.
|
||||
"""
|
||||
prepared_batch = super()._prepare_batch(si, batch)
|
||||
nl = prepared_batch["cls"].shape[0]
|
||||
if self.args.overlap_mask:
|
||||
masks = batch["masks"][si]
|
||||
index = torch.arange(1, nl + 1, device=masks.device).view(nl, 1, 1)
|
||||
masks = (masks == index).float()
|
||||
else:
|
||||
masks = batch["masks"][batch["batch_idx"] == si]
|
||||
if nl:
|
||||
mask_size = [s if self.process is ops.process_mask_native else s // 4 for s in prepared_batch["imgsz"]]
|
||||
if masks.shape[1:] != mask_size:
|
||||
masks = F.interpolate(masks[None], mask_size, mode="bilinear", align_corners=False)[0]
|
||||
masks = masks.gt_(0.5)
|
||||
prepared_batch["masks"] = masks
|
||||
return prepared_batch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute correct prediction matrix for a batch based on bounding boxes and optional masks.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing predictions with keys like 'cls' and 'masks'.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys like 'cls' and 'masks'.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): A dictionary containing correct prediction matrices including 'tp_m' for mask IoU.
|
||||
|
||||
Notes:
|
||||
- If `masks` is True, the function computes IoU between predicted and ground truth masks.
|
||||
- If `overlap` is True and `masks` is True, overlapping masks are taken into account when computing IoU.
|
||||
|
||||
Examples:
|
||||
>>> preds = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
||||
>>> batch = {"cls": torch.tensor([1, 0]), "masks": torch.rand(2, 640, 640), "bboxes": torch.rand(2, 4)}
|
||||
>>> correct_preds = validator._process_batch(preds, batch)
|
||||
"""
|
||||
tp = super()._process_batch(preds, batch)
|
||||
gt_cls = batch["cls"]
|
||||
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
tp_m = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
||||
else:
|
||||
iou = mask_iou(batch["masks"].flatten(1), preds["masks"].flatten(1))
|
||||
tp_m = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
||||
tp.update({"tp_m": tp_m}) # update tp with mask IoU
|
||||
return tp
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[dict[str, torch.Tensor]], ni: int) -> None:
|
||||
"""
|
||||
Plot batch predictions with masks and bounding boxes.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch containing images and annotations.
|
||||
preds (list[dict[str, torch.Tensor]]): List of predictions from the model.
|
||||
ni (int): Batch index.
|
||||
"""
|
||||
for p in preds:
|
||||
masks = p["masks"]
|
||||
if masks.shape[0] > self.args.max_det:
|
||||
LOGGER.warning(f"Limiting validation plots to 'max_det={self.args.max_det}' items.")
|
||||
p["masks"] = torch.as_tensor(masks[: self.args.max_det], dtype=torch.uint8).cpu()
|
||||
super().plot_predictions(batch, preds, ni, max_det=self.args.max_det) # plot bboxes
|
||||
|
||||
def save_one_txt(self, predn: torch.Tensor, save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO detections to a txt file in normalized coordinates in a specific format.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predictions in the format (x1, y1, x2, y2, conf, class).
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image.
|
||||
file (Path): File path to save the detections.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
masks=torch.as_tensor(predn["masks"], dtype=torch.uint8),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Save one JSON result for COCO evaluation.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Predictions containing bboxes, masks, confidence scores, and classes.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
"""
|
||||
from faster_coco_eval.core.mask import encode # noqa
|
||||
|
||||
def single_encode(x):
|
||||
"""Encode predicted masks as RLE and append results to jdict."""
|
||||
rle = encode(np.asarray(x[:, :, None], order="F", dtype="uint8"))[0]
|
||||
rle["counts"] = rle["counts"].decode("utf-8")
|
||||
return rle
|
||||
|
||||
pred_masks = np.transpose(predn["masks"], (2, 0, 1))
|
||||
with ThreadPool(NUM_THREADS) as pool:
|
||||
rles = pool.map(single_encode, pred_masks)
|
||||
super().pred_to_json(predn, pbatch)
|
||||
for i, r in enumerate(rles):
|
||||
self.jdict[-len(rles) + i]["segmentation"] = r # segmentation
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**super().scale_preds(predn, pbatch),
|
||||
"masks": ops.scale_image(
|
||||
torch.as_tensor(predn["masks"], dtype=torch.uint8).permute(1, 2, 0).contiguous().cpu().numpy(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return COCO-style instance segmentation evaluation metrics."""
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
anno_json = (
|
||||
self.data["path"]
|
||||
/ "annotations"
|
||||
/ ("instances_val2017.json" if self.is_coco else f"lvis_v1_{self.args.split}.json")
|
||||
) # annotations
|
||||
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "segm"], suffix=["Box", "Mask"])
|
||||
5
ultralytics/models/yolo/world/__init__.py
Normal file
5
ultralytics/models/yolo/world/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .train import WorldTrainer
|
||||
|
||||
__all__ = ["WorldTrainer"]
|
||||
Binary file not shown.
BIN
ultralytics/models/yolo/world/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/world/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
179
ultralytics/models/yolo/world/train.py
Normal file
179
ultralytics/models/yolo/world/train.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import build_yolo_dataset
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer
|
||||
from ultralytics.nn.tasks import WorldModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
|
||||
def on_pretrain_routine_end(trainer) -> None:
|
||||
"""Set up model classes and text encoder at the end of the pretrain routine."""
|
||||
if RANK in {-1, 0}:
|
||||
# Set class names for evaluation
|
||||
names = [name.split("/", 1)[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
|
||||
unwrap_model(trainer.ema.ema).set_classes(names, cache_clip_model=False)
|
||||
|
||||
|
||||
class WorldTrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for fine-tuning YOLO World models on close-set datasets.
|
||||
|
||||
This trainer extends the DetectionTrainer to support training YOLO World models, which combine visual and textual
|
||||
features for improved object detection and understanding. It handles text embedding generation and caching to
|
||||
accelerate training with multi-modal data.
|
||||
|
||||
Attributes:
|
||||
text_embeddings (dict[str, torch.Tensor] | None): Cached text embeddings for category names to accelerate
|
||||
training.
|
||||
model (WorldModel): The YOLO World model being trained.
|
||||
data (dict[str, Any]): Dataset configuration containing class information.
|
||||
args (Any): Training arguments and configuration.
|
||||
|
||||
Methods:
|
||||
get_model: Return WorldModel initialized with specified config and weights.
|
||||
build_dataset: Build YOLO Dataset for training or validation.
|
||||
set_text_embeddings: Set text embeddings for datasets to accelerate training.
|
||||
generate_text_embeddings: Generate text embeddings for a list of text samples.
|
||||
preprocess_batch: Preprocess a batch of images and text for YOLOWorld training.
|
||||
|
||||
Examples:
|
||||
Initialize and train a YOLO World model
|
||||
>>> from ultralytics.models.yolo.world import WorldTrainer
|
||||
>>> args = dict(model="yolov8s-world.pt", data="coco8.yaml", epochs=3)
|
||||
>>> trainer = WorldTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a WorldTrainer object with given arguments.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any]): Configuration for the trainer.
|
||||
overrides (dict[str, Any], optional): Configuration overrides.
|
||||
_callbacks (list[Any], optional): List of callback functions.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.text_embeddings = None
|
||||
|
||||
def get_model(self, cfg=None, weights: str | None = None, verbose: bool = True) -> WorldModel:
|
||||
"""
|
||||
Return WorldModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict[str, Any] | str, optional): Model configuration.
|
||||
weights (str, optional): Path to pretrained weights.
|
||||
verbose (bool): Whether to display model info.
|
||||
|
||||
Returns:
|
||||
(WorldModel): Initialized WorldModel.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = WorldModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
self.add_callback("on_pretrain_routine_end", on_pretrain_routine_end)
|
||||
|
||||
return model
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for `rect`.
|
||||
|
||||
Returns:
|
||||
(Any): YOLO dataset configured for training or validation.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
dataset = build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
if mode == "train":
|
||||
self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
|
||||
return dataset
|
||||
|
||||
def set_text_embeddings(self, datasets: list[Any], batch: int | None) -> None:
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, then generates and caches text embeddings
|
||||
for these categories to improve training efficiency.
|
||||
|
||||
Args:
|
||||
datasets (list[Any]): List of datasets from which to extract category names.
|
||||
batch (int | None): Batch size used for processing.
|
||||
|
||||
Notes:
|
||||
This method collects category names from datasets that have the 'category_names' attribute,
|
||||
then uses the first dataset's image path to determine where to cache the generated text embeddings.
|
||||
"""
|
||||
text_embeddings = {}
|
||||
for dataset in datasets:
|
||||
if not hasattr(dataset, "category_names"):
|
||||
continue
|
||||
text_embeddings.update(
|
||||
self.generate_text_embeddings(
|
||||
list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
|
||||
)
|
||||
)
|
||||
self.text_embeddings = text_embeddings
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
batch (int): Batch size for processing.
|
||||
cache_dir (Path): Directory to save/load cached embeddings.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary mapping text samples to their embeddings.
|
||||
"""
|
||||
model = "clip:ViT-B/32"
|
||||
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
||||
if cache_path.exists():
|
||||
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
||||
txt_map = torch.load(cache_path, map_location=self.device)
|
||||
if sorted(txt_map.keys()) == sorted(texts):
|
||||
return txt_map
|
||||
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
||||
assert self.model is not None
|
||||
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, cache_clip_model=False)
|
||||
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
||||
torch.save(txt_map, cache_path)
|
||||
return txt_map
|
||||
|
||||
def preprocess_batch(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess a batch of images and text for YOLOWorld training."""
|
||||
batch = DetectionTrainer.preprocess_batch(self, batch)
|
||||
|
||||
# Add text features
|
||||
texts = list(itertools.chain(*batch["texts"]))
|
||||
txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(
|
||||
self.device, non_blocking=self.device.type == "cuda"
|
||||
)
|
||||
batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
|
||||
return batch
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user