init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import FastSAM
from .predict import FastSAMPredictor
from .val import FastSAMValidator
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"

View File

@@ -0,0 +1,81 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
from ultralytics.engine.model import Model
from .predict import FastSAMPredictor
from .val import FastSAMValidator
class FastSAM(Model):
"""
FastSAM model interface for segment anything tasks.
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
Attributes:
model (str): Path to the pre-trained FastSAM model file.
task (str): The task type, set to "segment" for FastSAM models.
Methods:
predict: Perform segmentation prediction on image or video source with optional prompts.
task_map: Returns mapping of segment task to predictor and validator classes.
Examples:
Initialize FastSAM model and run prediction
>>> from ultralytics import FastSAM
>>> model = FastSAM("FastSAM-x.pt")
>>> results = model.predict("ultralytics/assets/bus.jpg")
Run prediction with bounding box prompts
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
"""
def __init__(self, model: str = "FastSAM-x.pt"):
"""Initialize the FastSAM model with the specified pre-trained weights."""
if str(model) == "FastSAM.pt":
model = "FastSAM-x.pt"
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
super().__init__(model=model, task="segment")
def predict(
self,
source,
stream: bool = False,
bboxes: list | None = None,
points: list | None = None,
labels: list | None = None,
texts: list | None = None,
**kwargs: Any,
):
"""
Perform segmentation prediction on image or video source.
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
prompts and passes them to the parent class predict method for processing.
Args:
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
or numpy array.
stream (bool): Whether to enable real-time streaming mode for video inputs.
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
labels (list, optional): Class labels for prompted segmentation.
texts (list, optional): Text prompts for segmentation guidance.
**kwargs (Any): Additional keyword arguments passed to the predictor.
Returns:
(list): List of Results objects containing the prediction results.
"""
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
return super().predict(source, stream, prompts=prompts, **kwargs)
@property
def task_map(self) -> dict[str, dict[str, Any]]:
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}

View File

@@ -0,0 +1,181 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from PIL import Image
from ultralytics.models.yolo.segment import SegmentationPredictor
from ultralytics.utils import DEFAULT_CFG, checks
from ultralytics.utils.metrics import box_iou
from ultralytics.utils.ops import scale_masks
from ultralytics.utils.torch_utils import TORCH_1_10
from .utils import adjust_bboxes_to_image_border
class FastSAMPredictor(SegmentationPredictor):
"""
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
single-class segmentation.
Attributes:
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
device (torch.device): Device on which model and tensors are processed.
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
Methods:
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
prompt: Perform image segmentation inference based on various prompt types.
set_prompts: Set prompts to be used during inference.
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize the FastSAMPredictor with configuration and callbacks.
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
optimized for single-class segmentation.
Args:
cfg (dict): Configuration for the predictor.
overrides (dict, optional): Configuration overrides.
_callbacks (list, optional): List of callback functions.
"""
super().__init__(cfg, overrides, _callbacks)
self.prompts = {}
def postprocess(self, preds, img, orig_imgs):
"""
Apply postprocessing to FastSAM predictions and handle prompts.
Args:
preds (list[torch.Tensor]): Raw predictions from the model.
img (torch.Tensor): Input image tensor that was fed to the model.
orig_imgs (list[np.ndarray]): Original images before preprocessing.
Returns:
(list[Results]): Processed results with prompts applied.
"""
bboxes = self.prompts.pop("bboxes", None)
points = self.prompts.pop("points", None)
labels = self.prompts.pop("labels", None)
texts = self.prompts.pop("texts", None)
results = super().postprocess(preds, img, orig_imgs)
for result in results:
full_box = torch.tensor(
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
)
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
if idx.numel() != 0:
result.boxes.xyxy[idx] = full_box
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
"""
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
Args:
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
texts (str | list[str], optional): Textual prompts, a list containing string objects.
Returns:
(list[Results]): Output results filtered and determined by the provided prompts.
"""
if bboxes is None and points is None and texts is None:
return results
prompt_results = []
if not isinstance(results, list):
results = [results]
for result in results:
if len(result) == 0:
prompt_results.append(result)
continue
masks = result.masks.data
if masks.shape[1:] != result.orig_shape:
masks = scale_masks(masks[None], result.orig_shape)[0]
# bboxes prompt
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
if bboxes is not None:
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
full_mask_areas = torch.sum(masks, dim=(1, 2))
union = bbox_areas[:, None] + full_mask_areas - mask_areas
idx[torch.argmax(mask_areas / union, dim=1)] = True
if points is not None:
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
points = points[None] if points.ndim == 1 else points
if labels is None:
labels = torch.ones(points.shape[0])
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
assert len(labels) == len(points), (
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
)
point_idx = (
torch.ones(len(result), dtype=torch.bool, device=self.device)
if labels.sum() == 0 # all negative points
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
)
for point, label in zip(points, labels):
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
idx |= point_idx
if texts is not None:
if isinstance(texts, str):
texts = [texts]
crop_ims, filter_idx = [], []
for i, b in enumerate(result.boxes.xyxy.tolist()):
x1, y1, x2, y2 = (int(x) for x in b)
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
filter_idx.append(i)
continue
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
similarity = self._clip_inference(crop_ims, texts)
text_idx = torch.argmax(similarity, dim=-1) # (M, )
if len(filter_idx):
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
idx[text_idx] = True
prompt_results.append(result[idx])
return prompt_results
def _clip_inference(self, images, texts):
"""
Perform CLIP inference to calculate similarity between images and text prompts.
Args:
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
texts (list[str]): List of prompt texts, each should be a string object.
Returns:
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
"""
try:
import clip
except ImportError:
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
import clip
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
tokenized_text = clip.tokenize(texts).to(self.device)
image_features = self.clip_model.encode_image(images)
text_features = self.clip_model.encode_text(tokenized_text)
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
return (image_features * text_features[:, None]).sum(-1) # (M, N)
def set_prompts(self, prompts):
"""Set prompts to be used during inference."""
self.prompts = prompts

View File

@@ -0,0 +1,24 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
image_shape (tuple): Image dimensions as (height, width).
threshold (int): Pixel threshold for considering a box close to the border.
Returns:
(torch.Tensor): Adjusted bounding boxes with shape (N, 4).
"""
# Image dimensions
h, w = image_shape
# Adjust boxes that are close to image borders
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes

View File

@@ -0,0 +1,40 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.segment import SegmentationValidator
class FastSAMValidator(SegmentationValidator):
"""
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
to avoid errors during validation.
Attributes:
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
save_dir (Path): The directory where validation results will be saved.
args (SimpleNamespace): Additional arguments for customization of the validation process.
_callbacks (list): List of callback functions to be invoked during validation.
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
Methods:
__init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
"""
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
args (SimpleNamespace, optional): Configuration for the validator.
_callbacks (list, optional): List of callback functions to be invoked during validation.
Notes:
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors