init commit
This commit is contained in:
22
ultralytics/models/yolo/yoloe/__init__.py
Normal file
22
ultralytics/models/yolo/yoloe/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import YOLOEVPDetectPredictor, YOLOEVPSegPredictor
|
||||
from .train import YOLOEPEFreeTrainer, YOLOEPETrainer, YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
||||
from .train_seg import YOLOEPESegTrainer, YOLOESegTrainer, YOLOESegTrainerFromScratch, YOLOESegVPTrainer
|
||||
from .val import YOLOEDetectValidator, YOLOESegValidator
|
||||
|
||||
__all__ = [
|
||||
"YOLOETrainer",
|
||||
"YOLOEPETrainer",
|
||||
"YOLOESegTrainer",
|
||||
"YOLOEDetectValidator",
|
||||
"YOLOESegValidator",
|
||||
"YOLOEPESegTrainer",
|
||||
"YOLOESegTrainerFromScratch",
|
||||
"YOLOESegVPTrainer",
|
||||
"YOLOEVPTrainer",
|
||||
"YOLOEPEFreeTrainer",
|
||||
"YOLOEVPDetectPredictor",
|
||||
"YOLOEVPSegPredictor",
|
||||
"YOLOETrainerFromScratch",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/yoloe/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/yoloe/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
ultralytics/models/yolo/yoloe/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/yoloe/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
169
ultralytics/models/yolo/yoloe/predict.py
Normal file
169
ultralytics/models/yolo/yoloe/predict.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.models.yolo.detect import DetectionPredictor
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
|
||||
|
||||
class YOLOEVPDetectPredictor(DetectionPredictor):
|
||||
"""
|
||||
A mixin class for YOLO-EVP (Enhanced Visual Prompting) predictors.
|
||||
|
||||
This mixin provides common functionality for YOLO models that use visual prompting, including
|
||||
model setup, prompt handling, and preprocessing transformations.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): The YOLO model for inference.
|
||||
device (torch.device): Device to run the model on (CPU or CUDA).
|
||||
prompts (dict | torch.Tensor): Visual prompts containing class indices and bounding boxes or masks.
|
||||
|
||||
Methods:
|
||||
setup_model: Initialize the YOLO model and set it to evaluation mode.
|
||||
set_prompts: Set the visual prompts for the model.
|
||||
pre_transform: Preprocess images and prompts before inference.
|
||||
inference: Run inference with visual prompts.
|
||||
get_vpe: Process source to get visual prompt embeddings.
|
||||
"""
|
||||
|
||||
def setup_model(self, model, verbose: bool = True):
|
||||
"""
|
||||
Set up the model for prediction.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to load or use.
|
||||
verbose (bool, optional): If True, provides detailed logging.
|
||||
"""
|
||||
super().setup_model(model, verbose=verbose)
|
||||
self.done_warmup = True
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""
|
||||
Set the visual prompts for the model.
|
||||
|
||||
Args:
|
||||
prompts (dict): Dictionary containing class indices and bounding boxes or masks.
|
||||
Must include a 'cls' key with class indices.
|
||||
"""
|
||||
self.prompts = prompts
|
||||
|
||||
def pre_transform(self, im):
|
||||
"""
|
||||
Preprocess images and prompts before inference.
|
||||
|
||||
This method applies letterboxing to the input image and transforms the visual prompts
|
||||
(bounding boxes or masks) accordingly.
|
||||
|
||||
Args:
|
||||
im (list): List containing a single input image.
|
||||
|
||||
Returns:
|
||||
(list): Preprocessed image ready for model inference.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither valid bounding boxes nor masks are provided in the prompts.
|
||||
"""
|
||||
img = super().pre_transform(im)
|
||||
bboxes = self.prompts.pop("bboxes", None)
|
||||
masks = self.prompts.pop("masks", None)
|
||||
category = self.prompts["cls"]
|
||||
if len(img) == 1:
|
||||
visuals = self._process_single_image(img[0].shape[:2], im[0].shape[:2], category, bboxes, masks)
|
||||
prompts = visuals.unsqueeze(0).to(self.device) # (1, N, H, W)
|
||||
else:
|
||||
# NOTE: only supports bboxes as prompts for now
|
||||
assert bboxes is not None, f"Expected bboxes, but got {bboxes}!"
|
||||
# NOTE: needs list[np.ndarray]
|
||||
assert isinstance(bboxes, list) and all(isinstance(b, np.ndarray) for b in bboxes), (
|
||||
f"Expected list[np.ndarray], but got {bboxes}!"
|
||||
)
|
||||
assert isinstance(category, list) and all(isinstance(b, np.ndarray) for b in category), (
|
||||
f"Expected list[np.ndarray], but got {category}!"
|
||||
)
|
||||
assert len(im) == len(category) == len(bboxes), (
|
||||
f"Expected same length for all inputs, but got {len(im)}vs{len(category)}vs{len(bboxes)}!"
|
||||
)
|
||||
visuals = [
|
||||
self._process_single_image(img[i].shape[:2], im[i].shape[:2], category[i], bboxes[i])
|
||||
for i in range(len(img))
|
||||
]
|
||||
prompts = torch.nn.utils.rnn.pad_sequence(visuals, batch_first=True).to(self.device) # (B, N, H, W)
|
||||
self.prompts = prompts.half() if self.model.fp16 else prompts.float()
|
||||
return img
|
||||
|
||||
def _process_single_image(self, dst_shape, src_shape, category, bboxes=None, masks=None):
|
||||
"""
|
||||
Process a single image by resizing bounding boxes or masks and generating visuals.
|
||||
|
||||
Args:
|
||||
dst_shape (tuple): The target shape (height, width) of the image.
|
||||
src_shape (tuple): The original shape (height, width) of the image.
|
||||
category (str): The category of the image for visual prompts.
|
||||
bboxes (list | np.ndarray, optional): A list of bounding boxes in the format [x1, y1, x2, y2].
|
||||
masks (np.ndarray, optional): A list of masks corresponding to the image.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The processed visuals for the image.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither `bboxes` nor `masks` are provided.
|
||||
"""
|
||||
if bboxes is not None and len(bboxes):
|
||||
bboxes = np.array(bboxes, dtype=np.float32)
|
||||
if bboxes.ndim == 1:
|
||||
bboxes = bboxes[None, :]
|
||||
# Calculate scaling factor and adjust bounding boxes
|
||||
gain = min(dst_shape[0] / src_shape[0], dst_shape[1] / src_shape[1]) # gain = old / new
|
||||
bboxes *= gain
|
||||
bboxes[..., 0::2] += round((dst_shape[1] - src_shape[1] * gain) / 2 - 0.1)
|
||||
bboxes[..., 1::2] += round((dst_shape[0] - src_shape[0] * gain) / 2 - 0.1)
|
||||
elif masks is not None:
|
||||
# Resize and process masks
|
||||
resized_masks = super().pre_transform(masks)
|
||||
masks = np.stack(resized_masks) # (N, H, W)
|
||||
masks[masks == 114] = 0 # Reset padding values to 0
|
||||
else:
|
||||
raise ValueError("Please provide valid bboxes or masks")
|
||||
|
||||
# Generate visuals using the visual prompt loader
|
||||
return LoadVisualPrompt().get_visuals(category, dst_shape, bboxes, masks)
|
||||
|
||||
def inference(self, im, *args, **kwargs):
|
||||
"""
|
||||
Run inference with visual prompts.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input image tensor.
|
||||
*args (Any): Variable length argument list.
|
||||
**kwargs (Any): Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Model prediction results.
|
||||
"""
|
||||
return super().inference(im, vpe=self.prompts, *args, **kwargs)
|
||||
|
||||
def get_vpe(self, source):
|
||||
"""
|
||||
Process the source to get the visual prompt embeddings (VPE).
|
||||
|
||||
Args:
|
||||
source (str | Path | int | PIL.Image | np.ndarray | torch.Tensor | list | tuple): The source
|
||||
of the image to make predictions on. Accepts various types including file paths, URLs, PIL
|
||||
images, numpy arrays, and torch tensors.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The visual prompt embeddings (VPE) from the model.
|
||||
"""
|
||||
self.setup_source(source)
|
||||
assert len(self.dataset) == 1, "get_vpe only supports one image!"
|
||||
for _, im0s, _ in self.dataset:
|
||||
im = self.preprocess(im0s)
|
||||
return self.model(im, vpe=self.prompts, return_vpe=True)
|
||||
|
||||
|
||||
class YOLOEVPSegPredictor(YOLOEVPDetectPredictor, SegmentationPredictor):
|
||||
"""Predictor for YOLO-EVP segmentation tasks combining detection and segmentation capabilities."""
|
||||
|
||||
pass
|
||||
300
ultralytics/models/yolo/yoloe/train.py
Normal file
300
ultralytics/models/yolo/yoloe/train.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy, deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.data import YOLOConcatDataset, build_yolo_dataset
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.models.yolo.detect import DetectionTrainer, DetectionValidator
|
||||
from ultralytics.nn.tasks import YOLOEModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
|
||||
from ultralytics.utils.torch_utils import unwrap_model
|
||||
|
||||
from ..world.train_world import WorldTrainerFromScratch
|
||||
from .val import YOLOEDetectValidator
|
||||
|
||||
|
||||
class YOLOETrainer(DetectionTrainer):
|
||||
"""
|
||||
A trainer class for YOLOE object detection models.
|
||||
|
||||
This class extends DetectionTrainer to provide specialized training functionality for YOLOE models,
|
||||
including custom model initialization, validation, and dataset building with multi-modal support.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of loss components used during training.
|
||||
|
||||
Methods:
|
||||
get_model: Initialize and return a YOLOEModel with specified configuration.
|
||||
get_validator: Return a YOLOEDetectValidator for model validation.
|
||||
build_dataset: Build YOLO dataset with multi-modal support for training.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize the YOLOE Trainer with specified configurations.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration dictionary with default training settings from DEFAULT_CFG.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be applied during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
assert not overrides.get("compile"), f"Training with 'model={overrides['model']}' requires 'compile=False'"
|
||||
overrides["overlap_mask"] = False
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return a YOLOEModel initialized with the specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration. Can be a dictionary containing a 'yaml_file' key,
|
||||
a direct path to a YAML file, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file to load into the model.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(YOLOEModel): The initialized YOLOE model.
|
||||
|
||||
Notes:
|
||||
- The number of classes (nc) is hard-coded to a maximum of 80 following the official configuration.
|
||||
- The nc parameter here represents the maximum number of different text samples in one image,
|
||||
rather than the actual number of classes.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOEModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a YOLOEDetectValidator for YOLOE model validation."""
|
||||
self.loss_names = "box", "cls", "dfl"
|
||||
return YOLOEDetectValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def build_dataset(self, img_path: str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the folder containing images.
|
||||
mode (str): 'train' mode or 'val' mode, users are able to customize different augmentations for each mode.
|
||||
batch (int, optional): Size of batches, this is for rectangular training.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset configured for training or validation.
|
||||
"""
|
||||
gs = max(int(unwrap_model(self.model).stride.max() if self.model else 0), 32)
|
||||
return build_yolo_dataset(
|
||||
self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
|
||||
)
|
||||
|
||||
|
||||
class YOLOEPETrainer(DetectionTrainer):
|
||||
"""
|
||||
Fine-tune YOLOE model using linear probing approach.
|
||||
|
||||
This trainer freezes most model layers and only trains specific projection layers for efficient
|
||||
fine-tuning on new datasets while preserving pretrained features.
|
||||
|
||||
Methods:
|
||||
get_model: Initialize YOLOEModel with frozen layers except projection layers.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose: bool = True):
|
||||
"""
|
||||
Return YOLOEModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration.
|
||||
weights (str, optional): Path to pretrained weights.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOEModel): Initialized model with frozen layers except for specific projection layers.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOEModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=self.data["nc"],
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
|
||||
del model.model[-1].savpe
|
||||
|
||||
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
model.eval()
|
||||
names = list(self.data["names"].values())
|
||||
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
||||
# it'd get correct results as long as loading proper pretrained weights.
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
model.model[-1].fuse(model.pe) # fuse text embeddings to classify head
|
||||
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
||||
del model.pe
|
||||
model.train()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE models from scratch with text embedding support.
|
||||
|
||||
This trainer combines YOLOE training capabilities with world training features, enabling
|
||||
training from scratch with text embeddings and grounding datasets.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build datasets for training with grounding support.
|
||||
generate_text_embeddings: Generate and cache text embeddings for training.
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation.
|
||||
|
||||
This method constructs appropriate datasets based on the mode and input paths, handling both
|
||||
standard YOLO datasets and grounding datasets with different formats.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
||||
batch (int, optional): Size of batches, used for rectangular training/validation.
|
||||
|
||||
Returns:
|
||||
(YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
|
||||
"""
|
||||
return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
|
||||
|
||||
def generate_text_embeddings(self, texts: list[str], batch: int, cache_dir: Path):
|
||||
"""
|
||||
Generate text embeddings for a list of text samples.
|
||||
|
||||
Args:
|
||||
texts (list[str]): List of text samples to encode.
|
||||
batch (int): Batch size for processing.
|
||||
cache_dir (Path): Directory to save/load cached embeddings.
|
||||
|
||||
Returns:
|
||||
(dict): Dictionary mapping text samples to their embeddings.
|
||||
"""
|
||||
model = "mobileclip:blt"
|
||||
cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
|
||||
if cache_path.exists():
|
||||
LOGGER.info(f"Reading existed cache from '{cache_path}'")
|
||||
txt_map = torch.load(cache_path, map_location=self.device)
|
||||
if sorted(txt_map.keys()) == sorted(texts):
|
||||
return txt_map
|
||||
LOGGER.info(f"Caching text embeddings to '{cache_path}'")
|
||||
assert self.model is not None
|
||||
txt_feats = unwrap_model(self.model).get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
|
||||
txt_map = dict(zip(texts, txt_feats.squeeze(0)))
|
||||
torch.save(txt_map, cache_path)
|
||||
return txt_map
|
||||
|
||||
|
||||
class YOLOEPEFreeTrainer(YOLOEPETrainer, YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train prompt-free YOLOE model.
|
||||
|
||||
This trainer combines linear probing capabilities with from-scratch training for prompt-free
|
||||
YOLOE models that don't require text prompts during inference.
|
||||
|
||||
Methods:
|
||||
get_validator: Return standard DetectionValidator for validation.
|
||||
preprocess_batch: Preprocess batches without text features.
|
||||
set_text_embeddings: Set text embeddings for datasets (no-op for prompt-free).
|
||||
"""
|
||||
|
||||
def get_validator(self):
|
||||
"""Return a DetectionValidator for YOLO model validation."""
|
||||
self.loss_names = "box", "cls", "dfl"
|
||||
return DetectionValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def preprocess_batch(self, batch):
|
||||
"""Preprocess a batch of images for YOLOE training, adjusting formatting and dimensions as needed."""
|
||||
return DetectionTrainer.preprocess_batch(self, batch)
|
||||
|
||||
def set_text_embeddings(self, datasets, batch: int):
|
||||
"""
|
||||
Set text embeddings for datasets to accelerate training by caching category names.
|
||||
|
||||
This method collects unique category names from all datasets, generates text embeddings for them,
|
||||
and caches these embeddings to improve training efficiency. The embeddings are stored in a file
|
||||
in the parent directory of the first dataset's image path.
|
||||
|
||||
Args:
|
||||
datasets (list[Dataset]): List of datasets containing category names to process.
|
||||
batch (int): Batch size for processing text embeddings.
|
||||
|
||||
Notes:
|
||||
The method creates a dictionary mapping text samples to their embeddings and stores it
|
||||
at the path specified by 'cache_path'. If the cache file already exists, it will be loaded
|
||||
instead of regenerating the embeddings.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class YOLOEVPTrainer(YOLOETrainerFromScratch):
|
||||
"""
|
||||
Train YOLOE model with visual prompts.
|
||||
|
||||
This trainer extends YOLOETrainerFromScratch to support visual prompt-based training,
|
||||
where visual cues are provided alongside images to guide the detection process.
|
||||
|
||||
Methods:
|
||||
build_dataset: Build dataset with visual prompt loading transforms.
|
||||
"""
|
||||
|
||||
def build_dataset(self, img_path: list[str] | str, mode: str = "train", batch: int | None = None):
|
||||
"""
|
||||
Build YOLO Dataset for training or validation with visual prompts.
|
||||
|
||||
Args:
|
||||
img_path (list[str] | str): Path to the folder containing images or list of paths.
|
||||
mode (str): 'train' mode or 'val' mode, allowing customized augmentations for each mode.
|
||||
batch (int, optional): Size of batches, used for rectangular training/validation.
|
||||
|
||||
Returns:
|
||||
(Dataset): YOLO dataset configured for training or validation, with visual prompts for training mode.
|
||||
"""
|
||||
dataset = super().build_dataset(img_path, mode, batch)
|
||||
if isinstance(dataset, YOLOConcatDataset):
|
||||
for d in dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
dataset.transforms.append(LoadVisualPrompt())
|
||||
return dataset
|
||||
|
||||
def _close_dataloader_mosaic(self):
|
||||
"""Close mosaic augmentation and add visual prompt loading to the training dataset."""
|
||||
super()._close_dataloader_mosaic()
|
||||
if isinstance(self.train_loader.dataset, YOLOConcatDataset):
|
||||
for d in self.train_loader.dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
self.train_loader.dataset.transforms.append(LoadVisualPrompt())
|
||||
127
ultralytics/models/yolo/yoloe/train_seg.py
Normal file
127
ultralytics/models/yolo/yoloe/train_seg.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from copy import copy, deepcopy
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationTrainer
|
||||
from ultralytics.nn.tasks import YOLOESegModel
|
||||
from ultralytics.utils import RANK
|
||||
|
||||
from .train import YOLOETrainer, YOLOETrainerFromScratch, YOLOEVPTrainer
|
||||
from .val import YOLOESegValidator
|
||||
|
||||
|
||||
class YOLOESegTrainer(YOLOETrainer, SegmentationTrainer):
|
||||
"""
|
||||
Trainer class for YOLOE segmentation models.
|
||||
|
||||
This class combines YOLOETrainer and SegmentationTrainer to provide training functionality specifically for YOLOE
|
||||
segmentation models, enabling both object detection and instance segmentation capabilities.
|
||||
|
||||
Attributes:
|
||||
cfg (dict): Configuration dictionary with training parameters.
|
||||
overrides (dict): Dictionary with parameter overrides.
|
||||
_callbacks (list): List of callback functions for training events.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
weights (str, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOESegModel): Initialized YOLOE segmentation model.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOESegModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=min(self.data["nc"], 80),
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""
|
||||
Create and return a validator for YOLOE segmentation model evaluation.
|
||||
|
||||
Returns:
|
||||
(YOLOESegValidator): Validator for YOLOE segmentation models.
|
||||
"""
|
||||
self.loss_names = "box", "seg", "cls", "dfl"
|
||||
return YOLOESegValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
|
||||
class YOLOEPESegTrainer(SegmentationTrainer):
|
||||
"""
|
||||
Fine-tune YOLOESeg model in linear probing way.
|
||||
|
||||
This trainer specializes in fine-tuning YOLOESeg models using a linear probing approach, which involves freezing
|
||||
most of the model and only training specific layers for efficient adaptation to new tasks.
|
||||
|
||||
Attributes:
|
||||
data (dict): Dataset configuration containing channels, class names, and number of classes.
|
||||
"""
|
||||
|
||||
def get_model(self, cfg=None, weights=None, verbose=True):
|
||||
"""
|
||||
Return YOLOESegModel initialized with specified config and weights for linear probing.
|
||||
|
||||
Args:
|
||||
cfg (dict | str, optional): Model configuration dictionary or YAML file path.
|
||||
weights (str, optional): Path to pretrained weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(YOLOESegModel): Initialized YOLOE segmentation model configured for linear probing.
|
||||
"""
|
||||
# NOTE: This `nc` here is the max number of different text samples in one image, rather than the actual `nc`.
|
||||
# NOTE: Following the official config, nc hard-coded to 80 for now.
|
||||
model = YOLOESegModel(
|
||||
cfg["yaml_file"] if isinstance(cfg, dict) else cfg,
|
||||
ch=self.data["channels"],
|
||||
nc=self.data["nc"],
|
||||
verbose=verbose and RANK == -1,
|
||||
)
|
||||
|
||||
del model.model[-1].savpe
|
||||
|
||||
assert weights is not None, "Pretrained weights must be provided for linear probing."
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
model.eval()
|
||||
names = list(self.data["names"].values())
|
||||
# NOTE: `get_text_pe` related to text model and YOLOEDetect.reprta,
|
||||
# it'd get correct results as long as loading proper pretrained weights.
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
model.model[-1].fuse(model.pe)
|
||||
model.model[-1].cv3[0][2] = deepcopy(model.model[-1].cv3[0][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[1][2] = deepcopy(model.model[-1].cv3[1][2]).requires_grad_(True)
|
||||
model.model[-1].cv3[2][2] = deepcopy(model.model[-1].cv3[2][2]).requires_grad_(True)
|
||||
del model.pe
|
||||
model.train()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class YOLOESegTrainerFromScratch(YOLOETrainerFromScratch, YOLOESegTrainer):
|
||||
"""Trainer for YOLOE segmentation models trained from scratch without pretrained weights."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class YOLOESegVPTrainer(YOLOEVPTrainer, YOLOESegTrainerFromScratch):
|
||||
"""Trainer for YOLOE segmentation models with Vision Prompt (VP) capabilities."""
|
||||
|
||||
pass
|
||||
211
ultralytics/models/yolo/yoloe/val.py
Normal file
211
ultralytics/models/yolo/yoloe/val.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ultralytics.data import YOLOConcatDataset, build_dataloader, build_yolo_dataset
|
||||
from ultralytics.data.augment import LoadVisualPrompt
|
||||
from ultralytics.data.utils import check_det_dataset
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
from ultralytics.nn.modules.head import YOLOEDetect
|
||||
from ultralytics.nn.tasks import YOLOEModel
|
||||
from ultralytics.utils import LOGGER, TQDM
|
||||
from ultralytics.utils.torch_utils import select_device, smart_inference_mode
|
||||
|
||||
|
||||
class YOLOEDetectValidator(DetectionValidator):
|
||||
"""
|
||||
A validator class for YOLOE detection models that handles both text and visual prompt embeddings.
|
||||
|
||||
This class extends DetectionValidator to provide specialized validation functionality for YOLOE models.
|
||||
It supports validation using either text prompts or visual prompt embeddings extracted from training samples,
|
||||
enabling flexible evaluation strategies for prompt-based object detection.
|
||||
|
||||
Attributes:
|
||||
device (torch.device): The device on which validation is performed.
|
||||
args (namespace): Configuration arguments for validation.
|
||||
dataloader (DataLoader): DataLoader for validation data.
|
||||
|
||||
Methods:
|
||||
get_visual_pe: Extract visual prompt embeddings from training samples.
|
||||
preprocess: Preprocess batch data ensuring visuals are on the same device as images.
|
||||
get_vpe_dataloader: Create a dataloader for LVIS training visual prompt samples.
|
||||
__call__: Run validation using either text or visual prompt embeddings.
|
||||
|
||||
Examples:
|
||||
Validate with text prompts
|
||||
>>> validator = YOLOEDetectValidator()
|
||||
>>> stats = validator(model=model, load_vp=False)
|
||||
|
||||
Validate with visual prompts
|
||||
>>> stats = validator(model=model, refer_data="path/to/data.yaml", load_vp=True)
|
||||
"""
|
||||
|
||||
@smart_inference_mode()
|
||||
def get_visual_pe(self, dataloader: torch.utils.data.DataLoader, model: YOLOEModel) -> torch.Tensor:
|
||||
"""
|
||||
Extract visual prompt embeddings from training samples.
|
||||
|
||||
This method processes a dataloader to compute visual prompt embeddings for each class using a YOLOE model.
|
||||
It normalizes the embeddings and handles cases where no samples exist for a class by setting their
|
||||
embeddings to zero.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): The dataloader providing training samples.
|
||||
model (YOLOEModel): The YOLOE model from which to extract visual prompt embeddings.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Visual prompt embeddings with shape (1, num_classes, embed_dim).
|
||||
"""
|
||||
assert isinstance(model, YOLOEModel)
|
||||
names = [name.split("/", 1)[0] for name in list(dataloader.dataset.data["names"].values())]
|
||||
visual_pe = torch.zeros(len(names), model.model[-1].embed, device=self.device)
|
||||
cls_visual_num = torch.zeros(len(names))
|
||||
|
||||
desc = "Get visual prompt embeddings from samples"
|
||||
|
||||
# Count samples per class
|
||||
for batch in dataloader:
|
||||
cls = batch["cls"].squeeze(-1).to(torch.int).unique()
|
||||
count = torch.bincount(cls, minlength=len(names))
|
||||
cls_visual_num += count
|
||||
|
||||
cls_visual_num = cls_visual_num.to(self.device)
|
||||
|
||||
# Extract visual prompt embeddings
|
||||
pbar = TQDM(dataloader, total=len(dataloader), desc=desc)
|
||||
for batch in pbar:
|
||||
batch = self.preprocess(batch)
|
||||
preds = model.get_visual_pe(batch["img"], visual=batch["visuals"]) # (B, max_n, embed_dim)
|
||||
|
||||
batch_idx = batch["batch_idx"]
|
||||
for i in range(preds.shape[0]):
|
||||
cls = batch["cls"][batch_idx == i].squeeze(-1).to(torch.int).unique(sorted=True)
|
||||
pad_cls = torch.ones(preds.shape[1], device=self.device) * -1
|
||||
pad_cls[: cls.shape[0]] = cls
|
||||
for c in cls:
|
||||
visual_pe[c] += preds[i][pad_cls == c].sum(0) / cls_visual_num[c]
|
||||
|
||||
# Normalize embeddings for classes with samples, set others to zero
|
||||
visual_pe[cls_visual_num != 0] = F.normalize(visual_pe[cls_visual_num != 0], dim=-1, p=2)
|
||||
visual_pe[cls_visual_num == 0] = 0
|
||||
return visual_pe.unsqueeze(0)
|
||||
|
||||
def get_vpe_dataloader(self, data: dict[str, Any]) -> torch.utils.data.DataLoader:
|
||||
"""
|
||||
Create a dataloader for LVIS training visual prompt samples.
|
||||
|
||||
This method prepares a dataloader for visual prompt embeddings (VPE) using the specified dataset.
|
||||
It applies necessary transformations including LoadVisualPrompt and configurations to the dataset
|
||||
for validation purposes.
|
||||
|
||||
Args:
|
||||
data (dict): Dataset configuration dictionary containing paths and settings.
|
||||
|
||||
Returns:
|
||||
(torch.utils.data.DataLoader): The dataloader for visual prompt samples.
|
||||
"""
|
||||
dataset = build_yolo_dataset(
|
||||
self.args,
|
||||
data.get(self.args.split, data.get("val")),
|
||||
self.args.batch,
|
||||
data,
|
||||
mode="val",
|
||||
rect=False,
|
||||
)
|
||||
if isinstance(dataset, YOLOConcatDataset):
|
||||
for d in dataset.datasets:
|
||||
d.transforms.append(LoadVisualPrompt())
|
||||
else:
|
||||
dataset.transforms.append(LoadVisualPrompt())
|
||||
return build_dataloader(
|
||||
dataset,
|
||||
self.args.batch,
|
||||
self.args.workers,
|
||||
shuffle=False,
|
||||
rank=-1,
|
||||
)
|
||||
|
||||
@smart_inference_mode()
|
||||
def __call__(
|
||||
self,
|
||||
trainer: Any | None = None,
|
||||
model: YOLOEModel | str | None = None,
|
||||
refer_data: str | None = None,
|
||||
load_vp: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run validation on the model using either text or visual prompt embeddings.
|
||||
|
||||
This method validates the model using either text prompts or visual prompts, depending on the load_vp flag.
|
||||
It supports validation during training (using a trainer object) or standalone validation with a provided
|
||||
model. For visual prompts, reference data can be specified to extract embeddings from a different dataset.
|
||||
|
||||
Args:
|
||||
trainer (object, optional): Trainer object containing the model and device.
|
||||
model (YOLOEModel | str, optional): Model to validate. Required if trainer is not provided.
|
||||
refer_data (str, optional): Path to reference data for visual prompts.
|
||||
load_vp (bool): Whether to load visual prompts. If False, text prompts are used.
|
||||
|
||||
Returns:
|
||||
(dict): Validation statistics containing metrics computed during validation.
|
||||
"""
|
||||
if trainer is not None:
|
||||
self.device = trainer.device
|
||||
model = trainer.ema.ema
|
||||
names = [name.split("/", 1)[0] for name in list(self.dataloader.dataset.data["names"].values())]
|
||||
|
||||
if load_vp:
|
||||
LOGGER.info("Validate using the visual prompt.")
|
||||
self.args.half = False
|
||||
# Directly use the same dataloader for visual embeddings extracted during training
|
||||
vpe = self.get_visual_pe(self.dataloader, model)
|
||||
model.set_classes(names, vpe)
|
||||
else:
|
||||
LOGGER.info("Validate using the text prompt.")
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
stats = super().__call__(trainer, model)
|
||||
else:
|
||||
if refer_data is not None:
|
||||
assert load_vp, "Refer data is only used for visual prompt validation."
|
||||
self.device = select_device(self.args.device, verbose=False)
|
||||
|
||||
if isinstance(model, (str, Path)):
|
||||
from ultralytics.nn.tasks import load_checkpoint
|
||||
|
||||
model, _ = load_checkpoint(model, device=self.device) # model, ckpt
|
||||
model.eval().to(self.device)
|
||||
data = check_det_dataset(refer_data or self.args.data)
|
||||
names = [name.split("/", 1)[0] for name in list(data["names"].values())]
|
||||
|
||||
if load_vp:
|
||||
LOGGER.info("Validate using the visual prompt.")
|
||||
self.args.half = False
|
||||
# TODO: need to check if the names from refer data is consistent with the evaluated dataset
|
||||
# could use same dataset or refer to extract visual prompt embeddings
|
||||
dataloader = self.get_vpe_dataloader(data)
|
||||
vpe = self.get_visual_pe(dataloader, model)
|
||||
model.set_classes(names, vpe)
|
||||
stats = super().__call__(model=deepcopy(model))
|
||||
elif isinstance(model.model[-1], YOLOEDetect) and hasattr(model.model[-1], "lrpc"): # prompt-free
|
||||
return super().__call__(trainer, model)
|
||||
else:
|
||||
LOGGER.info("Validate using the text prompt.")
|
||||
tpe = model.get_text_pe(names)
|
||||
model.set_classes(names, tpe)
|
||||
stats = super().__call__(model=deepcopy(model))
|
||||
return stats
|
||||
|
||||
|
||||
class YOLOESegValidator(YOLOEDetectValidator, SegmentationValidator):
|
||||
"""YOLOE segmentation validator that supports both text and visual prompt embeddings."""
|
||||
|
||||
pass
|
||||
Reference in New Issue
Block a user