# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license from __future__ import annotations from copy import copy from pathlib import Path from ultralytics.models import yolo from ultralytics.nn.tasks import SegmentationModel from ultralytics.utils import DEFAULT_CFG, RANK class SegmentationTrainer(yolo.detect.DetectionTrainer): """ A class extending the DetectionTrainer class for training based on a segmentation model. This trainer specializes in handling segmentation tasks, extending the detection trainer with segmentation-specific functionality including model initialization, validation, and visualization. Attributes: loss_names (tuple[str]): Names of the loss components used during training. Examples: >>> from ultralytics.models.yolo.segment import SegmentationTrainer >>> args = dict(model="yolo11n-seg.pt", data="coco8-seg.yaml", epochs=3) >>> trainer = SegmentationTrainer(overrides=args) >>> trainer.train() """ def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks=None): """ Initialize a SegmentationTrainer object. Args: cfg (dict): Configuration dictionary with default training settings. overrides (dict, optional): Dictionary of parameter overrides for the default configuration. _callbacks (list, optional): List of callback functions to be executed during training. """ if overrides is None: overrides = {} overrides["task"] = "segment" super().__init__(cfg, overrides, _callbacks) def get_model(self, cfg: dict | str | None = None, weights: str | Path | None = None, verbose: bool = True): """ Initialize and return a SegmentationModel with specified configuration and weights. Args: cfg (dict | str, optional): Model configuration. Can be a dictionary, a path to a YAML file, or None. weights (str | Path, optional): Path to pretrained weights file. verbose (bool): Whether to display model information during initialization. Returns: (SegmentationModel): Initialized segmentation model with loaded weights if specified. Examples: >>> trainer = SegmentationTrainer() >>> model = trainer.get_model(cfg="yolo11n-seg.yaml") >>> model = trainer.get_model(weights="yolo11n-seg.pt", verbose=False) """ model = SegmentationModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1) if weights: model.load(weights) return model def get_validator(self): """Return an instance of SegmentationValidator for validation of YOLO model.""" self.loss_names = "box_loss", "seg_loss", "cls_loss", "dfl_loss" return yolo.segment.SegmentationValidator( self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks )