init commit
This commit is contained in:
164
ultralytics/nn/modules/utils.py
Normal file
164
ultralytics/nn/modules/utils.py
Normal file
@@ -0,0 +1,164 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import copy
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.init import uniform_
|
||||
|
||||
__all__ = "multi_scale_deformable_attn_pytorch", "inverse_sigmoid"
|
||||
|
||||
|
||||
def _get_clones(module, n):
|
||||
"""
|
||||
Create a list of cloned modules from the given module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to be cloned.
|
||||
n (int): Number of clones to create.
|
||||
|
||||
Returns:
|
||||
(nn.ModuleList): A ModuleList containing n clones of the input module.
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> layer = nn.Linear(10, 10)
|
||||
>>> clones = _get_clones(layer, 3)
|
||||
>>> len(clones)
|
||||
3
|
||||
"""
|
||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
|
||||
|
||||
|
||||
def bias_init_with_prob(prior_prob=0.01):
|
||||
"""
|
||||
Initialize conv/fc bias value according to a given probability value.
|
||||
|
||||
This function calculates the bias initialization value based on a prior probability using the inverse error function.
|
||||
It's commonly used in object detection models to initialize classification layers with a specific positive prediction
|
||||
probability.
|
||||
|
||||
Args:
|
||||
prior_prob (float, optional): Prior probability for bias initialization.
|
||||
|
||||
Returns:
|
||||
(float): Bias initialization value calculated from the prior probability.
|
||||
|
||||
Examples:
|
||||
>>> bias = bias_init_with_prob(0.01)
|
||||
>>> print(f"Bias initialization value: {bias:.4f}")
|
||||
Bias initialization value: -4.5951
|
||||
"""
|
||||
return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
|
||||
|
||||
|
||||
def linear_init(module):
|
||||
"""
|
||||
Initialize the weights and biases of a linear module.
|
||||
|
||||
This function initializes the weights of a linear module using a uniform distribution within bounds calculated
|
||||
from the input dimension. If the module has a bias, it is also initialized.
|
||||
|
||||
Args:
|
||||
module (nn.Module): Linear module to initialize.
|
||||
|
||||
Returns:
|
||||
(nn.Module): The initialized module.
|
||||
|
||||
Examples:
|
||||
>>> import torch.nn as nn
|
||||
>>> linear = nn.Linear(10, 5)
|
||||
>>> initialized_linear = linear_init(linear)
|
||||
"""
|
||||
bound = 1 / math.sqrt(module.weight.shape[0])
|
||||
uniform_(module.weight, -bound, bound)
|
||||
if hasattr(module, "bias") and module.bias is not None:
|
||||
uniform_(module.bias, -bound, bound)
|
||||
|
||||
|
||||
def inverse_sigmoid(x, eps=1e-5):
|
||||
"""
|
||||
Calculate the inverse sigmoid function for a tensor.
|
||||
|
||||
This function applies the inverse of the sigmoid function to a tensor, which is useful in various neural network
|
||||
operations, particularly in attention mechanisms and coordinate transformations.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor with values in range [0, 1].
|
||||
eps (float, optional): Small epsilon value to prevent numerical instability.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Tensor after applying the inverse sigmoid function.
|
||||
|
||||
Examples:
|
||||
>>> x = torch.tensor([0.2, 0.5, 0.8])
|
||||
>>> inverse_sigmoid(x)
|
||||
tensor([-1.3863, 0.0000, 1.3863])
|
||||
"""
|
||||
x = x.clamp(min=0, max=1)
|
||||
x1 = x.clamp(min=eps)
|
||||
x2 = (1 - x).clamp(min=eps)
|
||||
return torch.log(x1 / x2)
|
||||
|
||||
|
||||
def multi_scale_deformable_attn_pytorch(
|
||||
value: torch.Tensor,
|
||||
value_spatial_shapes: torch.Tensor,
|
||||
sampling_locations: torch.Tensor,
|
||||
attention_weights: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Implement multi-scale deformable attention in PyTorch.
|
||||
|
||||
This function performs deformable attention across multiple feature map scales, allowing the model to attend to
|
||||
different spatial locations with learned offsets.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): The value tensor with shape (bs, num_keys, num_heads, embed_dims).
|
||||
value_spatial_shapes (torch.Tensor): Spatial shapes of the value tensor with shape (num_levels, 2).
|
||||
sampling_locations (torch.Tensor): The sampling locations with shape
|
||||
(bs, num_queries, num_heads, num_levels, num_points, 2).
|
||||
attention_weights (torch.Tensor): The attention weights with shape
|
||||
(bs, num_queries, num_heads, num_levels, num_points).
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output tensor with shape (bs, num_queries, embed_dims).
|
||||
|
||||
References:
|
||||
https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
|
||||
"""
|
||||
bs, _, num_heads, embed_dims = value.shape
|
||||
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
|
||||
value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
|
||||
sampling_grids = 2 * sampling_locations - 1
|
||||
sampling_value_list = []
|
||||
for level, (H_, W_) in enumerate(value_spatial_shapes):
|
||||
# bs, H_*W_, num_heads, embed_dims ->
|
||||
# bs, H_*W_, num_heads*embed_dims ->
|
||||
# bs, num_heads*embed_dims, H_*W_ ->
|
||||
# bs*num_heads, embed_dims, H_, W_
|
||||
value_l_ = value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
|
||||
# bs, num_queries, num_heads, num_points, 2 ->
|
||||
# bs, num_heads, num_queries, num_points, 2 ->
|
||||
# bs*num_heads, num_queries, num_points, 2
|
||||
sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
|
||||
# bs*num_heads, embed_dims, num_queries, num_points
|
||||
sampling_value_l_ = F.grid_sample(
|
||||
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
|
||||
)
|
||||
sampling_value_list.append(sampling_value_l_)
|
||||
# (bs, num_queries, num_heads, num_levels, num_points) ->
|
||||
# (bs, num_heads, num_queries, num_levels, num_points) ->
|
||||
# (bs, num_heads, 1, num_queries, num_levels*num_points)
|
||||
attention_weights = attention_weights.transpose(1, 2).reshape(
|
||||
bs * num_heads, 1, num_queries, num_levels * num_points
|
||||
)
|
||||
output = (
|
||||
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
|
||||
.sum(-1)
|
||||
.view(bs, num_heads * embed_dims, num_queries)
|
||||
)
|
||||
return output.transpose(1, 2).contiguous()
|
||||
Reference in New Issue
Block a user