init commit
This commit is contained in:
115
ultralytics/models/yolo/pose/train.py
Normal file
115
ultralytics/models/yolo/pose/train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import PoseModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
|
||||
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
|
||||
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
||||
of pose keypoints alongside bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for training.
|
||||
model (PoseModel): The pose estimation model being trained.
|
||||
data (dict): Dataset configuration including keypoint shape information.
|
||||
loss_names (tuple): Names of the loss components used in training.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieve a pose estimation model with specified configuration.
|
||||
set_model_attributes: Set keypoints shape attribute on the model.
|
||||
get_validator: Create a validator instance for model evaluation.
|
||||
plot_training_samples: Visualize training samples with keypoints.
|
||||
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
||||
>>> trainer = PoseTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
|
||||
Notes:
|
||||
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
||||
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "pose"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
cfg: str | Path | dict[str, Any] | None = None,
|
||||
weights: str | Path | None = None,
|
||||
verbose: bool = True,
|
||||
) -> PoseModel:
|
||||
"""
|
||||
Get pose estimation model with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
||||
weights (str | Path, optional): Path to the model weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(PoseModel): Initialized pose estimation model.
|
||||
"""
|
||||
model = PoseModel(
|
||||
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data["kpt_shape"]
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.pose.PoseValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def get_dataset(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `kpt_shape` key is not present in the dataset.
|
||||
"""
|
||||
data = super().get_dataset()
|
||||
if "kpt_shape" not in data:
|
||||
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
|
||||
return data
|
||||
Reference in New Issue
Block a user