init commit
This commit is contained in:
7
ultralytics/models/fastsam/__init__.py
Normal file
7
ultralytics/models/fastsam/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .model import FastSAM
|
||||
from .predict import FastSAMPredictor
|
||||
from .val import FastSAMValidator
|
||||
|
||||
__all__ = "FastSAMPredictor", "FastSAM", "FastSAMValidator"
|
||||
BIN
ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/model.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/model.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/utils.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/fastsam/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/fastsam/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
81
ultralytics/models/fastsam/model.py
Normal file
81
ultralytics/models/fastsam/model.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.engine.model import Model
|
||||
|
||||
from .predict import FastSAMPredictor
|
||||
from .val import FastSAMValidator
|
||||
|
||||
|
||||
class FastSAM(Model):
|
||||
"""
|
||||
FastSAM model interface for segment anything tasks.
|
||||
|
||||
This class extends the base Model class to provide specific functionality for the FastSAM (Fast Segment Anything
|
||||
Model) implementation, allowing for efficient and accurate image segmentation with optional prompting support.
|
||||
|
||||
Attributes:
|
||||
model (str): Path to the pre-trained FastSAM model file.
|
||||
task (str): The task type, set to "segment" for FastSAM models.
|
||||
|
||||
Methods:
|
||||
predict: Perform segmentation prediction on image or video source with optional prompts.
|
||||
task_map: Returns mapping of segment task to predictor and validator classes.
|
||||
|
||||
Examples:
|
||||
Initialize FastSAM model and run prediction
|
||||
>>> from ultralytics import FastSAM
|
||||
>>> model = FastSAM("FastSAM-x.pt")
|
||||
>>> results = model.predict("ultralytics/assets/bus.jpg")
|
||||
|
||||
Run prediction with bounding box prompts
|
||||
>>> results = model.predict("image.jpg", bboxes=[[100, 100, 200, 200]])
|
||||
"""
|
||||
|
||||
def __init__(self, model: str = "FastSAM-x.pt"):
|
||||
"""Initialize the FastSAM model with the specified pre-trained weights."""
|
||||
if str(model) == "FastSAM.pt":
|
||||
model = "FastSAM-x.pt"
|
||||
assert Path(model).suffix not in {".yaml", ".yml"}, "FastSAM models only support pre-trained models."
|
||||
super().__init__(model=model, task="segment")
|
||||
|
||||
def predict(
|
||||
self,
|
||||
source,
|
||||
stream: bool = False,
|
||||
bboxes: list | None = None,
|
||||
points: list | None = None,
|
||||
labels: list | None = None,
|
||||
texts: list | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""
|
||||
Perform segmentation prediction on image or video source.
|
||||
|
||||
Supports prompted segmentation with bounding boxes, points, labels, and texts. The method packages these
|
||||
prompts and passes them to the parent class predict method for processing.
|
||||
|
||||
Args:
|
||||
source (str | PIL.Image | np.ndarray): Input source for prediction, can be a file path, URL, PIL image,
|
||||
or numpy array.
|
||||
stream (bool): Whether to enable real-time streaming mode for video inputs.
|
||||
bboxes (list, optional): Bounding box coordinates for prompted segmentation in format [[x1, y1, x2, y2]].
|
||||
points (list, optional): Point coordinates for prompted segmentation in format [[x, y]].
|
||||
labels (list, optional): Class labels for prompted segmentation.
|
||||
texts (list, optional): Text prompts for segmentation guidance.
|
||||
**kwargs (Any): Additional keyword arguments passed to the predictor.
|
||||
|
||||
Returns:
|
||||
(list): List of Results objects containing the prediction results.
|
||||
"""
|
||||
prompts = dict(bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
return super().predict(source, stream, prompts=prompts, **kwargs)
|
||||
|
||||
@property
|
||||
def task_map(self) -> dict[str, dict[str, Any]]:
|
||||
"""Returns a dictionary mapping segment task to corresponding predictor and validator classes."""
|
||||
return {"segment": {"predictor": FastSAMPredictor, "validator": FastSAMValidator}}
|
||||
181
ultralytics/models/fastsam/predict.py
Normal file
181
ultralytics/models/fastsam/predict.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, checks
|
||||
from ultralytics.utils.metrics import box_iou
|
||||
from ultralytics.utils.ops import scale_masks
|
||||
from ultralytics.utils.torch_utils import TORCH_1_10
|
||||
|
||||
from .utils import adjust_bboxes_to_image_border
|
||||
|
||||
|
||||
class FastSAMPredictor(SegmentationPredictor):
|
||||
"""
|
||||
FastSAMPredictor is specialized for fast SAM (Segment Anything Model) segmentation prediction tasks.
|
||||
|
||||
This class extends the SegmentationPredictor, customizing the prediction pipeline specifically for fast SAM. It
|
||||
adjusts post-processing steps to incorporate mask prediction and non-maximum suppression while optimizing for
|
||||
single-class segmentation.
|
||||
|
||||
Attributes:
|
||||
prompts (dict): Dictionary containing prompt information for segmentation (bboxes, points, labels, texts).
|
||||
device (torch.device): Device on which model and tensors are processed.
|
||||
clip_model (Any, optional): CLIP model for text-based prompting, loaded on demand.
|
||||
clip_preprocess (Any, optional): CLIP preprocessing function for images, loaded on demand.
|
||||
|
||||
Methods:
|
||||
postprocess: Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
prompt: Perform image segmentation inference based on various prompt types.
|
||||
set_prompts: Set prompts to be used during inference.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMPredictor with configuration and callbacks.
|
||||
|
||||
This initializes a predictor specialized for Fast SAM (Segment Anything Model) segmentation tasks. The predictor
|
||||
extends SegmentationPredictor with custom post-processing for mask prediction and non-maximum suppression
|
||||
optimized for single-class segmentation.
|
||||
|
||||
Args:
|
||||
cfg (dict): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides.
|
||||
_callbacks (list, optional): List of callback functions.
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.prompts = {}
|
||||
|
||||
def postprocess(self, preds, img, orig_imgs):
|
||||
"""
|
||||
Apply postprocessing to FastSAM predictions and handle prompts.
|
||||
|
||||
Args:
|
||||
preds (list[torch.Tensor]): Raw predictions from the model.
|
||||
img (torch.Tensor): Input image tensor that was fed to the model.
|
||||
orig_imgs (list[np.ndarray]): Original images before preprocessing.
|
||||
|
||||
Returns:
|
||||
(list[Results]): Processed results with prompts applied.
|
||||
"""
|
||||
bboxes = self.prompts.pop("bboxes", None)
|
||||
points = self.prompts.pop("points", None)
|
||||
labels = self.prompts.pop("labels", None)
|
||||
texts = self.prompts.pop("texts", None)
|
||||
results = super().postprocess(preds, img, orig_imgs)
|
||||
for result in results:
|
||||
full_box = torch.tensor(
|
||||
[0, 0, result.orig_shape[1], result.orig_shape[0]], device=preds[0].device, dtype=torch.float32
|
||||
)
|
||||
boxes = adjust_bboxes_to_image_border(result.boxes.xyxy, result.orig_shape)
|
||||
idx = torch.nonzero(box_iou(full_box[None], boxes) > 0.9).flatten()
|
||||
if idx.numel() != 0:
|
||||
result.boxes.xyxy[idx] = full_box
|
||||
|
||||
return self.prompt(results, bboxes=bboxes, points=points, labels=labels, texts=texts)
|
||||
|
||||
def prompt(self, results, bboxes=None, points=None, labels=None, texts=None):
|
||||
"""
|
||||
Perform image segmentation inference based on cues like bounding boxes, points, and text prompts.
|
||||
|
||||
Args:
|
||||
results (Results | list[Results]): Original inference results from FastSAM models without any prompts.
|
||||
bboxes (np.ndarray | list, optional): Bounding boxes with shape (N, 4), in XYXY format.
|
||||
points (np.ndarray | list, optional): Points indicating object locations with shape (N, 2), in pixels.
|
||||
labels (np.ndarray | list, optional): Labels for point prompts, shape (N, ). 1 = foreground, 0 = background.
|
||||
texts (str | list[str], optional): Textual prompts, a list containing string objects.
|
||||
|
||||
Returns:
|
||||
(list[Results]): Output results filtered and determined by the provided prompts.
|
||||
"""
|
||||
if bboxes is None and points is None and texts is None:
|
||||
return results
|
||||
prompt_results = []
|
||||
if not isinstance(results, list):
|
||||
results = [results]
|
||||
for result in results:
|
||||
if len(result) == 0:
|
||||
prompt_results.append(result)
|
||||
continue
|
||||
masks = result.masks.data
|
||||
if masks.shape[1:] != result.orig_shape:
|
||||
masks = scale_masks(masks[None], result.orig_shape)[0]
|
||||
# bboxes prompt
|
||||
idx = torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
if bboxes is not None:
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int32, device=self.device)
|
||||
bboxes = bboxes[None] if bboxes.ndim == 1 else bboxes
|
||||
bbox_areas = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
mask_areas = torch.stack([masks[:, b[1] : b[3], b[0] : b[2]].sum(dim=(1, 2)) for b in bboxes])
|
||||
full_mask_areas = torch.sum(masks, dim=(1, 2))
|
||||
|
||||
union = bbox_areas[:, None] + full_mask_areas - mask_areas
|
||||
idx[torch.argmax(mask_areas / union, dim=1)] = True
|
||||
if points is not None:
|
||||
points = torch.as_tensor(points, dtype=torch.int32, device=self.device)
|
||||
points = points[None] if points.ndim == 1 else points
|
||||
if labels is None:
|
||||
labels = torch.ones(points.shape[0])
|
||||
labels = torch.as_tensor(labels, dtype=torch.int32, device=self.device)
|
||||
assert len(labels) == len(points), (
|
||||
f"Expected `labels` with same size as `point`, but got {len(labels)} and {len(points)}"
|
||||
)
|
||||
point_idx = (
|
||||
torch.ones(len(result), dtype=torch.bool, device=self.device)
|
||||
if labels.sum() == 0 # all negative points
|
||||
else torch.zeros(len(result), dtype=torch.bool, device=self.device)
|
||||
)
|
||||
for point, label in zip(points, labels):
|
||||
point_idx[torch.nonzero(masks[:, point[1], point[0]], as_tuple=True)[0]] = bool(label)
|
||||
idx |= point_idx
|
||||
if texts is not None:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
crop_ims, filter_idx = [], []
|
||||
for i, b in enumerate(result.boxes.xyxy.tolist()):
|
||||
x1, y1, x2, y2 = (int(x) for x in b)
|
||||
if (masks[i].sum() if TORCH_1_10 else masks[i].sum(0).sum()) <= 100: # torch 1.9 bug workaround
|
||||
filter_idx.append(i)
|
||||
continue
|
||||
crop_ims.append(Image.fromarray(result.orig_img[y1:y2, x1:x2, ::-1]))
|
||||
similarity = self._clip_inference(crop_ims, texts)
|
||||
text_idx = torch.argmax(similarity, dim=-1) # (M, )
|
||||
if len(filter_idx):
|
||||
text_idx += (torch.tensor(filter_idx, device=self.device)[None] <= int(text_idx)).sum(0)
|
||||
idx[text_idx] = True
|
||||
|
||||
prompt_results.append(result[idx])
|
||||
|
||||
return prompt_results
|
||||
|
||||
def _clip_inference(self, images, texts):
|
||||
"""
|
||||
Perform CLIP inference to calculate similarity between images and text prompts.
|
||||
|
||||
Args:
|
||||
images (list[PIL.Image]): List of source images, each should be PIL.Image with RGB channel order.
|
||||
texts (list[str]): List of prompt texts, each should be a string object.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Similarity matrix between given images and texts with shape (M, N).
|
||||
"""
|
||||
try:
|
||||
import clip
|
||||
except ImportError:
|
||||
checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
|
||||
import clip
|
||||
if (not hasattr(self, "clip_model")) or (not hasattr(self, "clip_preprocess")):
|
||||
self.clip_model, self.clip_preprocess = clip.load("ViT-B/32", device=self.device)
|
||||
images = torch.stack([self.clip_preprocess(image).to(self.device) for image in images])
|
||||
tokenized_text = clip.tokenize(texts).to(self.device)
|
||||
image_features = self.clip_model.encode_image(images)
|
||||
text_features = self.clip_model.encode_text(tokenized_text)
|
||||
image_features /= image_features.norm(dim=-1, keepdim=True) # (N, 512)
|
||||
text_features /= text_features.norm(dim=-1, keepdim=True) # (M, 512)
|
||||
return (image_features * text_features[:, None]).sum(-1) # (M, N)
|
||||
|
||||
def set_prompts(self, prompts):
|
||||
"""Set prompts to be used during inference."""
|
||||
self.prompts = prompts
|
||||
24
ultralytics/models/fastsam/utils.py
Normal file
24
ultralytics/models/fastsam/utils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
|
||||
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
|
||||
"""
|
||||
Adjust bounding boxes to stick to image border if they are within a certain threshold.
|
||||
|
||||
Args:
|
||||
boxes (torch.Tensor): Bounding boxes with shape (N, 4) in xyxy format.
|
||||
image_shape (tuple): Image dimensions as (height, width).
|
||||
threshold (int): Pixel threshold for considering a box close to the border.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): Adjusted bounding boxes with shape (N, 4).
|
||||
"""
|
||||
# Image dimensions
|
||||
h, w = image_shape
|
||||
|
||||
# Adjust boxes that are close to image borders
|
||||
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
|
||||
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
|
||||
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
|
||||
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
|
||||
return boxes
|
||||
40
ultralytics/models/fastsam/val.py
Normal file
40
ultralytics/models/fastsam/val.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.segment import SegmentationValidator
|
||||
|
||||
|
||||
class FastSAMValidator(SegmentationValidator):
|
||||
"""
|
||||
Custom validation class for Fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
|
||||
|
||||
Extends the SegmentationValidator class, customizing the validation process specifically for Fast SAM. This class
|
||||
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
|
||||
to avoid errors during validation.
|
||||
|
||||
Attributes:
|
||||
dataloader (torch.utils.data.DataLoader): The data loader object used for validation.
|
||||
save_dir (Path): The directory where validation results will be saved.
|
||||
args (SimpleNamespace): Additional arguments for customization of the validation process.
|
||||
_callbacks (list): List of callback functions to be invoked during validation.
|
||||
metrics (SegmentMetrics): Segmentation metrics calculator for evaluation.
|
||||
|
||||
Methods:
|
||||
__init__: Initialize the FastSAMValidator with custom settings for Fast SAM.
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None):
|
||||
"""
|
||||
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path, optional): Directory to save results.
|
||||
args (SimpleNamespace, optional): Configuration for the validator.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during validation.
|
||||
|
||||
Notes:
|
||||
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.args.task = "segment"
|
||||
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
|
||||
Reference in New Issue
Block a user