init commit
This commit is contained in:
1
ultralytics/models/utils/__init__.py
Normal file
1
ultralytics/models/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
478
ultralytics/models/utils/loss.py
Normal file
478
ultralytics/models/utils/loss.py
Normal file
@@ -0,0 +1,478 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ultralytics.utils.loss import FocalLoss, VarifocalLoss
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
|
||||
from .ops import HungarianMatcher
|
||||
|
||||
|
||||
class DETRLoss(nn.Module):
|
||||
"""
|
||||
DETR (DEtection TRansformer) Loss class for calculating various loss components.
|
||||
|
||||
This class computes classification loss, bounding box loss, GIoU loss, and optionally auxiliary losses for the
|
||||
DETR object detection model.
|
||||
|
||||
Attributes:
|
||||
nc (int): Number of classes.
|
||||
loss_gain (dict[str, float]): Coefficients for different loss components.
|
||||
aux_loss (bool): Whether to compute auxiliary losses.
|
||||
use_fl (bool): Whether to use FocalLoss.
|
||||
use_vfl (bool): Whether to use VarifocalLoss.
|
||||
use_uni_match (bool): Whether to use a fixed layer for auxiliary branch label assignment.
|
||||
uni_match_ind (int): Index of fixed layer to use if use_uni_match is True.
|
||||
matcher (HungarianMatcher): Object to compute matching cost and indices.
|
||||
fl (FocalLoss | None): Focal Loss object if use_fl is True, otherwise None.
|
||||
vfl (VarifocalLoss | None): Varifocal Loss object if use_vfl is True, otherwise None.
|
||||
device (torch.device): Device on which tensors are stored.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nc: int = 80,
|
||||
loss_gain: dict[str, float] | None = None,
|
||||
aux_loss: bool = True,
|
||||
use_fl: bool = True,
|
||||
use_vfl: bool = False,
|
||||
use_uni_match: bool = False,
|
||||
uni_match_ind: int = 0,
|
||||
gamma: float = 1.5,
|
||||
alpha: float = 0.25,
|
||||
):
|
||||
"""
|
||||
Initialize DETR loss function with customizable components and gains.
|
||||
|
||||
Uses default loss_gain if not provided. Initializes HungarianMatcher with preset cost gains. Supports auxiliary
|
||||
losses and various loss types.
|
||||
|
||||
Args:
|
||||
nc (int): Number of classes.
|
||||
loss_gain (dict[str, float], optional): Coefficients for different loss components.
|
||||
aux_loss (bool): Whether to use auxiliary losses from each decoder layer.
|
||||
use_fl (bool): Whether to use FocalLoss.
|
||||
use_vfl (bool): Whether to use VarifocalLoss.
|
||||
use_uni_match (bool): Whether to use fixed layer for auxiliary branch label assignment.
|
||||
uni_match_ind (int): Index of fixed layer for uni_match.
|
||||
gamma (float): The focusing parameter that controls how much the loss focuses on hard-to-classify examples.
|
||||
alpha (float): The balancing factor used to address class imbalance.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if loss_gain is None:
|
||||
loss_gain = {"class": 1, "bbox": 5, "giou": 2, "no_object": 0.1, "mask": 1, "dice": 1}
|
||||
self.nc = nc
|
||||
self.matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||
self.loss_gain = loss_gain
|
||||
self.aux_loss = aux_loss
|
||||
self.fl = FocalLoss(gamma, alpha) if use_fl else None
|
||||
self.vfl = VarifocalLoss(gamma, alpha) if use_vfl else None
|
||||
|
||||
self.use_uni_match = use_uni_match
|
||||
self.uni_match_ind = uni_match_ind
|
||||
self.device = None
|
||||
|
||||
def _get_loss_class(
|
||||
self, pred_scores: torch.Tensor, targets: torch.Tensor, gt_scores: torch.Tensor, num_gts: int, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute classification loss based on predictions, target values, and ground truth scores.
|
||||
|
||||
Args:
|
||||
pred_scores (torch.Tensor): Predicted class scores with shape (B, N, C).
|
||||
targets (torch.Tensor): Target class indices with shape (B, N).
|
||||
gt_scores (torch.Tensor): Ground truth confidence scores with shape (B, N).
|
||||
num_gts (int): Number of ground truth objects.
|
||||
postfix (str, optional): String to append to the loss name for identification in multi-loss scenarios.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing classification loss value.
|
||||
|
||||
Notes:
|
||||
The function supports different classification loss types:
|
||||
- Varifocal Loss (if self.vfl is True and num_gts > 0)
|
||||
- Focal Loss (if self.fl is True)
|
||||
- BCE Loss (default fallback)
|
||||
"""
|
||||
# Logits: [b, query, num_classes], gt_class: list[[n, 1]]
|
||||
name_class = f"loss_class{postfix}"
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
# one_hot = F.one_hot(targets, self.nc + 1)[..., :-1] # (bs, num_queries, num_classes)
|
||||
one_hot = torch.zeros((bs, nq, self.nc + 1), dtype=torch.int64, device=targets.device)
|
||||
one_hot.scatter_(2, targets.unsqueeze(-1), 1)
|
||||
one_hot = one_hot[..., :-1]
|
||||
gt_scores = gt_scores.view(bs, nq, 1) * one_hot
|
||||
|
||||
if self.fl:
|
||||
if num_gts and self.vfl:
|
||||
loss_cls = self.vfl(pred_scores, gt_scores, one_hot)
|
||||
else:
|
||||
loss_cls = self.fl(pred_scores, one_hot.float())
|
||||
loss_cls /= max(num_gts, 1) / nq
|
||||
else:
|
||||
loss_cls = nn.BCEWithLogitsLoss(reduction="none")(pred_scores, gt_scores).mean(1).sum() # YOLO CLS loss
|
||||
|
||||
return {name_class: loss_cls.squeeze() * self.loss_gain["class"]}
|
||||
|
||||
def _get_loss_bbox(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, postfix: str = ""
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute bounding box and GIoU losses for predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (N, 4).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (N, 4).
|
||||
postfix (str, optional): String to append to the loss names for identification in multi-loss scenarios.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing:
|
||||
- loss_bbox{postfix}: L1 loss between predicted and ground truth boxes, scaled by the bbox loss gain.
|
||||
- loss_giou{postfix}: GIoU loss between predicted and ground truth boxes, scaled by the giou loss gain.
|
||||
|
||||
Notes:
|
||||
If no ground truth boxes are provided (empty list), zero-valued tensors are returned for both losses.
|
||||
"""
|
||||
# Boxes: [b, query, 4], gt_bbox: list[[n, 4]]
|
||||
name_bbox = f"loss_bbox{postfix}"
|
||||
name_giou = f"loss_giou{postfix}"
|
||||
|
||||
loss = {}
|
||||
if len(gt_bboxes) == 0:
|
||||
loss[name_bbox] = torch.tensor(0.0, device=self.device)
|
||||
loss[name_giou] = torch.tensor(0.0, device=self.device)
|
||||
return loss
|
||||
|
||||
loss[name_bbox] = self.loss_gain["bbox"] * F.l1_loss(pred_bboxes, gt_bboxes, reduction="sum") / len(gt_bboxes)
|
||||
loss[name_giou] = 1.0 - bbox_iou(pred_bboxes, gt_bboxes, xywh=True, GIoU=True)
|
||||
loss[name_giou] = loss[name_giou].sum() / len(gt_bboxes)
|
||||
loss[name_giou] = self.loss_gain["giou"] * loss[name_giou]
|
||||
return {k: v.squeeze() for k, v in loss.items()}
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _get_loss_mask(self, masks, gt_mask, match_indices, postfix=''):
|
||||
# # masks: [b, query, h, w], gt_mask: list[[n, H, W]]
|
||||
# name_mask = f'loss_mask{postfix}'
|
||||
# name_dice = f'loss_dice{postfix}'
|
||||
#
|
||||
# loss = {}
|
||||
# if sum(len(a) for a in gt_mask) == 0:
|
||||
# loss[name_mask] = torch.tensor(0., device=self.device)
|
||||
# loss[name_dice] = torch.tensor(0., device=self.device)
|
||||
# return loss
|
||||
#
|
||||
# num_gts = len(gt_mask)
|
||||
# src_masks, target_masks = self._get_assigned_bboxes(masks, gt_mask, match_indices)
|
||||
# src_masks = F.interpolate(src_masks.unsqueeze(0), size=target_masks.shape[-2:], mode='bilinear')[0]
|
||||
# # TODO: torch does not have `sigmoid_focal_loss`, but it's not urgent since we don't use mask branch for now.
|
||||
# loss[name_mask] = self.loss_gain['mask'] * F.sigmoid_focal_loss(src_masks, target_masks,
|
||||
# torch.tensor([num_gts], dtype=torch.float32))
|
||||
# loss[name_dice] = self.loss_gain['dice'] * self._dice_loss(src_masks, target_masks, num_gts)
|
||||
# return loss
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# @staticmethod
|
||||
# def _dice_loss(inputs, targets, num_gts):
|
||||
# inputs = F.sigmoid(inputs).flatten(1)
|
||||
# targets = targets.flatten(1)
|
||||
# numerator = 2 * (inputs * targets).sum(1)
|
||||
# denominator = inputs.sum(-1) + targets.sum(-1)
|
||||
# loss = 1 - (numerator + 1) / (denominator + 1)
|
||||
# return loss.sum() / num_gts
|
||||
|
||||
def _get_loss_aux(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
match_indices: list[tuple] | None = None,
|
||||
postfix: str = "",
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get auxiliary losses for intermediate decoder layers.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes from auxiliary layers.
|
||||
pred_scores (torch.Tensor): Predicted scores from auxiliary layers.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
gt_cls (torch.Tensor): Ground truth classes.
|
||||
gt_groups (list[int]): Number of ground truths per image.
|
||||
match_indices (list[tuple], optional): Pre-computed matching indices.
|
||||
postfix (str, optional): String to append to loss names.
|
||||
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
||||
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary of auxiliary losses.
|
||||
"""
|
||||
# NOTE: loss class, bbox, giou, mask, dice
|
||||
loss = torch.zeros(5 if masks is not None else 3, device=pred_bboxes.device)
|
||||
if match_indices is None and self.use_uni_match:
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes[self.uni_match_ind],
|
||||
pred_scores[self.uni_match_ind],
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=masks[self.uni_match_ind] if masks is not None else None,
|
||||
gt_mask=gt_mask,
|
||||
)
|
||||
for i, (aux_bboxes, aux_scores) in enumerate(zip(pred_bboxes, pred_scores)):
|
||||
aux_masks = masks[i] if masks is not None else None
|
||||
loss_ = self._get_loss(
|
||||
aux_bboxes,
|
||||
aux_scores,
|
||||
gt_bboxes,
|
||||
gt_cls,
|
||||
gt_groups,
|
||||
masks=aux_masks,
|
||||
gt_mask=gt_mask,
|
||||
postfix=postfix,
|
||||
match_indices=match_indices,
|
||||
)
|
||||
loss[0] += loss_[f"loss_class{postfix}"]
|
||||
loss[1] += loss_[f"loss_bbox{postfix}"]
|
||||
loss[2] += loss_[f"loss_giou{postfix}"]
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss_ = self._get_loss_mask(aux_masks, gt_mask, match_indices, postfix)
|
||||
# loss[3] += loss_[f'loss_mask{postfix}']
|
||||
# loss[4] += loss_[f'loss_dice{postfix}']
|
||||
|
||||
loss = {
|
||||
f"loss_class_aux{postfix}": loss[0],
|
||||
f"loss_bbox_aux{postfix}": loss[1],
|
||||
f"loss_giou_aux{postfix}": loss[2],
|
||||
}
|
||||
# if masks is not None and gt_mask is not None:
|
||||
# loss[f'loss_mask_aux{postfix}'] = loss[3]
|
||||
# loss[f'loss_dice_aux{postfix}'] = loss[4]
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def _get_index(match_indices: list[tuple]) -> tuple[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
|
||||
"""
|
||||
Extract batch indices, source indices, and destination indices from match indices.
|
||||
|
||||
Args:
|
||||
match_indices (list[tuple]): List of tuples containing matched indices.
|
||||
|
||||
Returns:
|
||||
batch_idx (tuple[torch.Tensor, torch.Tensor]): Tuple containing (batch_idx, src_idx).
|
||||
dst_idx (torch.Tensor): Destination indices.
|
||||
"""
|
||||
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)])
|
||||
src_idx = torch.cat([src for (src, _) in match_indices])
|
||||
dst_idx = torch.cat([dst for (_, dst) in match_indices])
|
||||
return (batch_idx, src_idx), dst_idx
|
||||
|
||||
def _get_assigned_bboxes(
|
||||
self, pred_bboxes: torch.Tensor, gt_bboxes: torch.Tensor, match_indices: list[tuple]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Assign predicted bounding boxes to ground truth bounding boxes based on match indices.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
match_indices (list[tuple]): List of tuples containing matched indices.
|
||||
|
||||
Returns:
|
||||
pred_assigned (torch.Tensor): Assigned predicted bounding boxes.
|
||||
gt_assigned (torch.Tensor): Assigned ground truth bounding boxes.
|
||||
"""
|
||||
pred_assigned = torch.cat(
|
||||
[
|
||||
t[i] if len(i) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (i, _) in zip(pred_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
gt_assigned = torch.cat(
|
||||
[
|
||||
t[j] if len(j) > 0 else torch.zeros(0, t.shape[-1], device=self.device)
|
||||
for t, (_, j) in zip(gt_bboxes, match_indices)
|
||||
]
|
||||
)
|
||||
return pred_assigned, gt_assigned
|
||||
|
||||
def _get_loss(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: torch.Tensor | None = None,
|
||||
postfix: str = "",
|
||||
match_indices: list[tuple] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate losses for a single prediction layer.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes.
|
||||
pred_scores (torch.Tensor): Predicted class scores.
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes.
|
||||
gt_cls (torch.Tensor): Ground truth classes.
|
||||
gt_groups (list[int]): Number of ground truths per image.
|
||||
masks (torch.Tensor, optional): Predicted masks if using segmentation.
|
||||
gt_mask (torch.Tensor, optional): Ground truth masks if using segmentation.
|
||||
postfix (str, optional): String to append to loss names.
|
||||
match_indices (list[tuple], optional): Pre-computed matching indices.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary of losses.
|
||||
"""
|
||||
if match_indices is None:
|
||||
match_indices = self.matcher(
|
||||
pred_bboxes, pred_scores, gt_bboxes, gt_cls, gt_groups, masks=masks, gt_mask=gt_mask
|
||||
)
|
||||
|
||||
idx, gt_idx = self._get_index(match_indices)
|
||||
pred_bboxes, gt_bboxes = pred_bboxes[idx], gt_bboxes[gt_idx]
|
||||
|
||||
bs, nq = pred_scores.shape[:2]
|
||||
targets = torch.full((bs, nq), self.nc, device=pred_scores.device, dtype=gt_cls.dtype)
|
||||
targets[idx] = gt_cls[gt_idx]
|
||||
|
||||
gt_scores = torch.zeros([bs, nq], device=pred_scores.device)
|
||||
if len(gt_bboxes):
|
||||
gt_scores[idx] = bbox_iou(pred_bboxes.detach(), gt_bboxes, xywh=True).squeeze(-1)
|
||||
|
||||
return {
|
||||
**self._get_loss_class(pred_scores, targets, gt_scores, len(gt_bboxes), postfix),
|
||||
**self._get_loss_bbox(pred_bboxes, gt_bboxes, postfix),
|
||||
# **(self._get_loss_mask(masks, gt_mask, match_indices, postfix) if masks is not None and gt_mask is not None else {})
|
||||
}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
batch: dict[str, Any],
|
||||
postfix: str = "",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Calculate loss for predicted bounding boxes and scores.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes, shape (L, B, N, 4).
|
||||
pred_scores (torch.Tensor): Predicted class scores, shape (L, B, N, C).
|
||||
batch (dict[str, Any]): Batch information containing cls, bboxes, and gt_groups.
|
||||
postfix (str, optional): Postfix for loss names.
|
||||
**kwargs (Any): Additional arguments, may include 'match_indices'.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Computed losses, including main and auxiliary (if enabled).
|
||||
|
||||
Notes:
|
||||
Uses last elements of pred_bboxes and pred_scores for main loss, and the rest for auxiliary losses if
|
||||
self.aux_loss is True.
|
||||
"""
|
||||
self.device = pred_bboxes.device
|
||||
match_indices = kwargs.get("match_indices", None)
|
||||
gt_cls, gt_bboxes, gt_groups = batch["cls"], batch["bboxes"], batch["gt_groups"]
|
||||
|
||||
total_loss = self._get_loss(
|
||||
pred_bboxes[-1], pred_scores[-1], gt_bboxes, gt_cls, gt_groups, postfix=postfix, match_indices=match_indices
|
||||
)
|
||||
|
||||
if self.aux_loss:
|
||||
total_loss.update(
|
||||
self._get_loss_aux(
|
||||
pred_bboxes[:-1], pred_scores[:-1], gt_bboxes, gt_cls, gt_groups, match_indices, postfix
|
||||
)
|
||||
)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class RTDETRDetectionLoss(DETRLoss):
|
||||
"""
|
||||
Real-Time DeepTracker (RT-DETR) Detection Loss class that extends the DETRLoss.
|
||||
|
||||
This class computes the detection loss for the RT-DETR model, which includes the standard detection loss as well as
|
||||
an additional denoising training loss when provided with denoising metadata.
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
preds: tuple[torch.Tensor, torch.Tensor],
|
||||
batch: dict[str, Any],
|
||||
dn_bboxes: torch.Tensor | None = None,
|
||||
dn_scores: torch.Tensor | None = None,
|
||||
dn_meta: dict[str, Any] | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass to compute detection loss with optional denoising loss.
|
||||
|
||||
Args:
|
||||
preds (tuple[torch.Tensor, torch.Tensor]): Tuple containing predicted bounding boxes and scores.
|
||||
batch (dict[str, Any]): Batch data containing ground truth information.
|
||||
dn_bboxes (torch.Tensor, optional): Denoising bounding boxes.
|
||||
dn_scores (torch.Tensor, optional): Denoising scores.
|
||||
dn_meta (dict[str, Any], optional): Metadata for denoising.
|
||||
|
||||
Returns:
|
||||
(dict[str, torch.Tensor]): Dictionary containing total loss and denoising loss if applicable.
|
||||
"""
|
||||
pred_bboxes, pred_scores = preds
|
||||
total_loss = super().forward(pred_bboxes, pred_scores, batch)
|
||||
|
||||
# Check for denoising metadata to compute denoising training loss
|
||||
if dn_meta is not None:
|
||||
dn_pos_idx, dn_num_group = dn_meta["dn_pos_idx"], dn_meta["dn_num_group"]
|
||||
assert len(batch["gt_groups"]) == len(dn_pos_idx)
|
||||
|
||||
# Get the match indices for denoising
|
||||
match_indices = self.get_dn_match_indices(dn_pos_idx, dn_num_group, batch["gt_groups"])
|
||||
|
||||
# Compute the denoising training loss
|
||||
dn_loss = super().forward(dn_bboxes, dn_scores, batch, postfix="_dn", match_indices=match_indices)
|
||||
total_loss.update(dn_loss)
|
||||
else:
|
||||
# If no denoising metadata is provided, set denoising loss to zero
|
||||
total_loss.update({f"{k}_dn": torch.tensor(0.0, device=self.device) for k in total_loss.keys()})
|
||||
|
||||
return total_loss
|
||||
|
||||
@staticmethod
|
||||
def get_dn_match_indices(
|
||||
dn_pos_idx: list[torch.Tensor], dn_num_group: int, gt_groups: list[int]
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Get match indices for denoising.
|
||||
|
||||
Args:
|
||||
dn_pos_idx (list[torch.Tensor]): List of tensors containing positive indices for denoising.
|
||||
dn_num_group (int): Number of denoising groups.
|
||||
gt_groups (list[int]): List of integers representing number of ground truths per image.
|
||||
|
||||
Returns:
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): List of tuples containing matched indices for denoising.
|
||||
"""
|
||||
dn_match_indices = []
|
||||
idx_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0)
|
||||
for i, num_gt in enumerate(gt_groups):
|
||||
if num_gt > 0:
|
||||
gt_idx = torch.arange(end=num_gt, dtype=torch.long) + idx_groups[i]
|
||||
gt_idx = gt_idx.repeat(dn_num_group)
|
||||
assert len(dn_pos_idx[i]) == len(gt_idx), (
|
||||
f"Expected the same length, but got {len(dn_pos_idx[i])} and {len(gt_idx)} respectively."
|
||||
)
|
||||
dn_match_indices.append((dn_pos_idx[i], gt_idx))
|
||||
else:
|
||||
dn_match_indices.append((torch.zeros([0], dtype=torch.long), torch.zeros([0], dtype=torch.long)))
|
||||
return dn_match_indices
|
||||
319
ultralytics/models/utils/ops.py
Normal file
319
ultralytics/models/utils/ops.py
Normal file
@@ -0,0 +1,319 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
from ultralytics.utils.metrics import bbox_iou
|
||||
from ultralytics.utils.ops import xywh2xyxy, xyxy2xywh
|
||||
|
||||
|
||||
class HungarianMatcher(nn.Module):
|
||||
"""
|
||||
A module implementing the HungarianMatcher for optimal assignment between predictions and ground truth.
|
||||
|
||||
HungarianMatcher performs optimal bipartite assignment over predicted and ground truth bounding boxes using a cost
|
||||
function that considers classification scores, bounding box coordinates, and optionally mask predictions. This is
|
||||
used in end-to-end object detection models like DETR.
|
||||
|
||||
Attributes:
|
||||
cost_gain (dict[str, float]): Dictionary of cost coefficients for 'class', 'bbox', 'giou', 'mask', and 'dice'
|
||||
components.
|
||||
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
||||
with_mask (bool): Whether the model makes mask predictions.
|
||||
num_sample_points (int): Number of sample points used in mask cost calculation.
|
||||
alpha (float): Alpha factor in Focal Loss calculation.
|
||||
gamma (float): Gamma factor in Focal Loss calculation.
|
||||
|
||||
Methods:
|
||||
forward: Compute optimal assignment between predictions and ground truths for a batch.
|
||||
_cost_mask: Compute mask cost and dice cost if masks are predicted.
|
||||
|
||||
Examples:
|
||||
Initialize a HungarianMatcher with custom cost gains
|
||||
>>> matcher = HungarianMatcher(cost_gain={"class": 2, "bbox": 5, "giou": 2})
|
||||
|
||||
Perform matching between predictions and ground truth
|
||||
>>> pred_boxes = torch.rand(2, 100, 4) # batch_size=2, num_queries=100
|
||||
>>> pred_scores = torch.rand(2, 100, 80) # 80 classes
|
||||
>>> gt_boxes = torch.rand(10, 4) # 10 ground truth boxes
|
||||
>>> gt_classes = torch.randint(0, 80, (10,))
|
||||
>>> gt_groups = [5, 5] # 5 GT boxes per image
|
||||
>>> indices = matcher(pred_boxes, pred_scores, gt_boxes, gt_classes, gt_groups)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cost_gain: dict[str, float] | None = None,
|
||||
use_fl: bool = True,
|
||||
with_mask: bool = False,
|
||||
num_sample_points: int = 12544,
|
||||
alpha: float = 0.25,
|
||||
gamma: float = 2.0,
|
||||
):
|
||||
"""
|
||||
Initialize HungarianMatcher for optimal assignment of predicted and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
cost_gain (dict[str, float], optional): Dictionary of cost coefficients for different matching cost
|
||||
components. Should contain keys 'class', 'bbox', 'giou', 'mask', and 'dice'.
|
||||
use_fl (bool): Whether to use Focal Loss for classification cost calculation.
|
||||
with_mask (bool): Whether the model makes mask predictions.
|
||||
num_sample_points (int): Number of sample points used in mask cost calculation.
|
||||
alpha (float): Alpha factor in Focal Loss calculation.
|
||||
gamma (float): Gamma factor in Focal Loss calculation.
|
||||
"""
|
||||
super().__init__()
|
||||
if cost_gain is None:
|
||||
cost_gain = {"class": 1, "bbox": 5, "giou": 2, "mask": 1, "dice": 1}
|
||||
self.cost_gain = cost_gain
|
||||
self.use_fl = use_fl
|
||||
self.with_mask = with_mask
|
||||
self.num_sample_points = num_sample_points
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_bboxes: torch.Tensor,
|
||||
pred_scores: torch.Tensor,
|
||||
gt_bboxes: torch.Tensor,
|
||||
gt_cls: torch.Tensor,
|
||||
gt_groups: list[int],
|
||||
masks: torch.Tensor | None = None,
|
||||
gt_mask: list[torch.Tensor] | None = None,
|
||||
) -> list[tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Compute optimal assignment between predictions and ground truth using Hungarian algorithm.
|
||||
|
||||
This method calculates matching costs based on classification scores, bounding box coordinates, and optionally
|
||||
mask predictions, then finds the optimal bipartite assignment between predictions and ground truth.
|
||||
|
||||
Args:
|
||||
pred_bboxes (torch.Tensor): Predicted bounding boxes with shape (batch_size, num_queries, 4).
|
||||
pred_scores (torch.Tensor): Predicted classification scores with shape (batch_size, num_queries,
|
||||
num_classes).
|
||||
gt_bboxes (torch.Tensor): Ground truth bounding boxes with shape (num_gts, 4).
|
||||
gt_cls (torch.Tensor): Ground truth class labels with shape (num_gts,).
|
||||
gt_groups (list[int]): Number of ground truth boxes for each image in the batch.
|
||||
masks (torch.Tensor, optional): Predicted masks with shape (batch_size, num_queries, height, width).
|
||||
gt_mask (list[torch.Tensor], optional): Ground truth masks, each with shape (num_masks, Height, Width).
|
||||
|
||||
Returns:
|
||||
(list[tuple[torch.Tensor, torch.Tensor]]): A list of size batch_size, each element is a tuple
|
||||
(index_i, index_j), where index_i is the tensor of indices of the selected predictions (in order)
|
||||
and index_j is the tensor of indices of the corresponding selected ground truth targets (in order).
|
||||
For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
|
||||
"""
|
||||
bs, nq, nc = pred_scores.shape
|
||||
|
||||
if sum(gt_groups) == 0:
|
||||
return [(torch.tensor([], dtype=torch.long), torch.tensor([], dtype=torch.long)) for _ in range(bs)]
|
||||
|
||||
# Flatten to compute cost matrices in batch format
|
||||
pred_scores = pred_scores.detach().view(-1, nc)
|
||||
pred_scores = F.sigmoid(pred_scores) if self.use_fl else F.softmax(pred_scores, dim=-1)
|
||||
pred_bboxes = pred_bboxes.detach().view(-1, 4)
|
||||
|
||||
# Compute classification cost
|
||||
pred_scores = pred_scores[:, gt_cls]
|
||||
if self.use_fl:
|
||||
neg_cost_class = (1 - self.alpha) * (pred_scores**self.gamma) * (-(1 - pred_scores + 1e-8).log())
|
||||
pos_cost_class = self.alpha * ((1 - pred_scores) ** self.gamma) * (-(pred_scores + 1e-8).log())
|
||||
cost_class = pos_cost_class - neg_cost_class
|
||||
else:
|
||||
cost_class = -pred_scores
|
||||
|
||||
# Compute L1 cost between boxes
|
||||
cost_bbox = (pred_bboxes.unsqueeze(1) - gt_bboxes.unsqueeze(0)).abs().sum(-1) # (bs*num_queries, num_gt)
|
||||
|
||||
# Compute GIoU cost between boxes, (bs*num_queries, num_gt)
|
||||
cost_giou = 1.0 - bbox_iou(pred_bboxes.unsqueeze(1), gt_bboxes.unsqueeze(0), xywh=True, GIoU=True).squeeze(-1)
|
||||
|
||||
# Combine costs into final cost matrix
|
||||
C = (
|
||||
self.cost_gain["class"] * cost_class
|
||||
+ self.cost_gain["bbox"] * cost_bbox
|
||||
+ self.cost_gain["giou"] * cost_giou
|
||||
)
|
||||
|
||||
# Add mask costs if available
|
||||
if self.with_mask:
|
||||
C += self._cost_mask(bs, gt_groups, masks, gt_mask)
|
||||
|
||||
# Set invalid values (NaNs and infinities) to 0
|
||||
C[C.isnan() | C.isinf()] = 0.0
|
||||
|
||||
C = C.view(bs, nq, -1).cpu()
|
||||
indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(gt_groups, -1))]
|
||||
gt_groups = torch.as_tensor([0, *gt_groups[:-1]]).cumsum_(0) # (idx for queries, idx for gt)
|
||||
return [
|
||||
(torch.tensor(i, dtype=torch.long), torch.tensor(j, dtype=torch.long) + gt_groups[k])
|
||||
for k, (i, j) in enumerate(indices)
|
||||
]
|
||||
|
||||
# This function is for future RT-DETR Segment models
|
||||
# def _cost_mask(self, bs, num_gts, masks=None, gt_mask=None):
|
||||
# assert masks is not None and gt_mask is not None, 'Make sure the input has `mask` and `gt_mask`'
|
||||
# # all masks share the same set of points for efficient matching
|
||||
# sample_points = torch.rand([bs, 1, self.num_sample_points, 2])
|
||||
# sample_points = 2.0 * sample_points - 1.0
|
||||
#
|
||||
# out_mask = F.grid_sample(masks.detach(), sample_points, align_corners=False).squeeze(-2)
|
||||
# out_mask = out_mask.flatten(0, 1)
|
||||
#
|
||||
# tgt_mask = torch.cat(gt_mask).unsqueeze(1)
|
||||
# sample_points = torch.cat([a.repeat(b, 1, 1, 1) for a, b in zip(sample_points, num_gts) if b > 0])
|
||||
# tgt_mask = F.grid_sample(tgt_mask, sample_points, align_corners=False).squeeze([1, 2])
|
||||
#
|
||||
# with torch.amp.autocast("cuda", enabled=False):
|
||||
# # binary cross entropy cost
|
||||
# pos_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.ones_like(out_mask), reduction='none')
|
||||
# neg_cost_mask = F.binary_cross_entropy_with_logits(out_mask, torch.zeros_like(out_mask), reduction='none')
|
||||
# cost_mask = torch.matmul(pos_cost_mask, tgt_mask.T) + torch.matmul(neg_cost_mask, 1 - tgt_mask.T)
|
||||
# cost_mask /= self.num_sample_points
|
||||
#
|
||||
# # dice cost
|
||||
# out_mask = F.sigmoid(out_mask)
|
||||
# numerator = 2 * torch.matmul(out_mask, tgt_mask.T)
|
||||
# denominator = out_mask.sum(-1, keepdim=True) + tgt_mask.sum(-1).unsqueeze(0)
|
||||
# cost_dice = 1 - (numerator + 1) / (denominator + 1)
|
||||
#
|
||||
# C = self.cost_gain['mask'] * cost_mask + self.cost_gain['dice'] * cost_dice
|
||||
# return C
|
||||
|
||||
|
||||
def get_cdn_group(
|
||||
batch: dict[str, Any],
|
||||
num_classes: int,
|
||||
num_queries: int,
|
||||
class_embed: torch.Tensor,
|
||||
num_dn: int = 100,
|
||||
cls_noise_ratio: float = 0.5,
|
||||
box_noise_scale: float = 1.0,
|
||||
training: bool = False,
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, dict[str, Any] | None]:
|
||||
"""
|
||||
Generate contrastive denoising training group with positive and negative samples from ground truths.
|
||||
|
||||
This function creates denoising queries for contrastive denoising training by adding noise to ground truth
|
||||
bounding boxes and class labels. It generates both positive and negative samples to improve model robustness.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch dictionary containing 'gt_cls' (torch.Tensor with shape (num_gts,)),
|
||||
'gt_bboxes' (torch.Tensor with shape (num_gts, 4)), and 'gt_groups' (list[int]) indicating number of
|
||||
ground truths per image.
|
||||
num_classes (int): Total number of object classes.
|
||||
num_queries (int): Number of object queries.
|
||||
class_embed (torch.Tensor): Class embedding weights to map labels to embedding space.
|
||||
num_dn (int): Number of denoising queries to generate.
|
||||
cls_noise_ratio (float): Noise ratio for class labels.
|
||||
box_noise_scale (float): Noise scale for bounding box coordinates.
|
||||
training (bool): Whether model is in training mode.
|
||||
|
||||
Returns:
|
||||
padding_cls (torch.Tensor | None): Modified class embeddings for denoising with shape (bs, num_dn, embed_dim).
|
||||
padding_bbox (torch.Tensor | None): Modified bounding boxes for denoising with shape (bs, num_dn, 4).
|
||||
attn_mask (torch.Tensor | None): Attention mask for denoising with shape (tgt_size, tgt_size).
|
||||
dn_meta (dict[str, Any] | None): Meta information dictionary containing denoising parameters.
|
||||
|
||||
Examples:
|
||||
Generate denoising group for training
|
||||
>>> batch = {
|
||||
... "cls": torch.tensor([0, 1, 2]),
|
||||
... "bboxes": torch.rand(3, 4),
|
||||
... "batch_idx": torch.tensor([0, 0, 1]),
|
||||
... "gt_groups": [2, 1],
|
||||
... }
|
||||
>>> class_embed = torch.rand(80, 256) # 80 classes, 256 embedding dim
|
||||
>>> cdn_outputs = get_cdn_group(batch, 80, 100, class_embed, training=True)
|
||||
"""
|
||||
if (not training) or num_dn <= 0 or batch is None:
|
||||
return None, None, None, None
|
||||
gt_groups = batch["gt_groups"]
|
||||
total_num = sum(gt_groups)
|
||||
max_nums = max(gt_groups)
|
||||
if max_nums == 0:
|
||||
return None, None, None, None
|
||||
|
||||
num_group = num_dn // max_nums
|
||||
num_group = 1 if num_group == 0 else num_group
|
||||
# Pad gt to max_num of a batch
|
||||
bs = len(gt_groups)
|
||||
gt_cls = batch["cls"] # (bs*num, )
|
||||
gt_bbox = batch["bboxes"] # bs*num, 4
|
||||
b_idx = batch["batch_idx"]
|
||||
|
||||
# Each group has positive and negative queries
|
||||
dn_cls = gt_cls.repeat(2 * num_group) # (2*num_group*bs*num, )
|
||||
dn_bbox = gt_bbox.repeat(2 * num_group, 1) # 2*num_group*bs*num, 4
|
||||
dn_b_idx = b_idx.repeat(2 * num_group).view(-1) # (2*num_group*bs*num, )
|
||||
|
||||
# Positive and negative mask
|
||||
# (bs*num*num_group, ), the second total_num*num_group part as negative samples
|
||||
neg_idx = torch.arange(total_num * num_group, dtype=torch.long, device=gt_bbox.device) + num_group * total_num
|
||||
|
||||
if cls_noise_ratio > 0:
|
||||
# Apply class label noise to half of the samples
|
||||
mask = torch.rand(dn_cls.shape) < (cls_noise_ratio * 0.5)
|
||||
idx = torch.nonzero(mask).squeeze(-1)
|
||||
# Randomly assign new class labels
|
||||
new_label = torch.randint_like(idx, 0, num_classes, dtype=dn_cls.dtype, device=dn_cls.device)
|
||||
dn_cls[idx] = new_label
|
||||
|
||||
if box_noise_scale > 0:
|
||||
known_bbox = xywh2xyxy(dn_bbox)
|
||||
|
||||
diff = (dn_bbox[..., 2:] * 0.5).repeat(1, 2) * box_noise_scale # 2*num_group*bs*num, 4
|
||||
|
||||
rand_sign = torch.randint_like(dn_bbox, 0, 2) * 2.0 - 1.0
|
||||
rand_part = torch.rand_like(dn_bbox)
|
||||
rand_part[neg_idx] += 1.0
|
||||
rand_part *= rand_sign
|
||||
known_bbox += rand_part * diff
|
||||
known_bbox.clip_(min=0.0, max=1.0)
|
||||
dn_bbox = xyxy2xywh(known_bbox)
|
||||
dn_bbox = torch.logit(dn_bbox, eps=1e-6) # inverse sigmoid
|
||||
|
||||
num_dn = int(max_nums * 2 * num_group) # total denoising queries
|
||||
dn_cls_embed = class_embed[dn_cls] # bs*num * 2 * num_group, 256
|
||||
padding_cls = torch.zeros(bs, num_dn, dn_cls_embed.shape[-1], device=gt_cls.device)
|
||||
padding_bbox = torch.zeros(bs, num_dn, 4, device=gt_bbox.device)
|
||||
|
||||
map_indices = torch.cat([torch.tensor(range(num), dtype=torch.long) for num in gt_groups])
|
||||
pos_idx = torch.stack([map_indices + max_nums * i for i in range(num_group)], dim=0)
|
||||
|
||||
map_indices = torch.cat([map_indices + max_nums * i for i in range(2 * num_group)])
|
||||
padding_cls[(dn_b_idx, map_indices)] = dn_cls_embed
|
||||
padding_bbox[(dn_b_idx, map_indices)] = dn_bbox
|
||||
|
||||
tgt_size = num_dn + num_queries
|
||||
attn_mask = torch.zeros([tgt_size, tgt_size], dtype=torch.bool)
|
||||
# Match query cannot see the reconstruct
|
||||
attn_mask[num_dn:, :num_dn] = True
|
||||
# Reconstruct cannot see each other
|
||||
for i in range(num_group):
|
||||
if i == 0:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||
if i == num_group - 1:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * i * 2] = True
|
||||
else:
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), max_nums * 2 * (i + 1) : num_dn] = True
|
||||
attn_mask[max_nums * 2 * i : max_nums * 2 * (i + 1), : max_nums * 2 * i] = True
|
||||
dn_meta = {
|
||||
"dn_pos_idx": [p.reshape(-1) for p in pos_idx.cpu().split(list(gt_groups), dim=1)],
|
||||
"dn_num_group": num_group,
|
||||
"dn_num_split": [num_dn, num_queries],
|
||||
}
|
||||
|
||||
return (
|
||||
padding_cls.to(class_embed.device),
|
||||
padding_bbox.to(class_embed.device),
|
||||
attn_mask.to(class_embed.device),
|
||||
dn_meta,
|
||||
)
|
||||
Reference in New Issue
Block a user