init commit

This commit is contained in:
2025-11-08 19:15:39 +01:00
parent ecffcb08e8
commit c7adacf53b
470 changed files with 73751 additions and 0 deletions

View File

@@ -0,0 +1,7 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .predict import PosePredictor
from .train import PoseTrainer
from .val import PoseValidator
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"

View File

@@ -0,0 +1,80 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
class PosePredictor(DetectionPredictor):
"""
A class extending the DetectionPredictor class for prediction based on a pose model.
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
capabilities inherited from DetectionPredictor.
Attributes:
args (namespace): Configuration arguments for the predictor.
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
Methods:
construct_result: Construct the result object from the prediction, including keypoints.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
"""
Initialize PosePredictor for pose estimation tasks.
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
warnings for Apple MPS.
Args:
cfg (Any): Configuration for the predictor.
overrides (dict, optional): Configuration overrides that take precedence over cfg.
_callbacks (list, optional): List of callback functions to be invoked during prediction.
Examples:
>>> from ultralytics.utils import ASSETS
>>> from ultralytics.models.yolo.pose import PosePredictor
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
>>> predictor = PosePredictor(overrides=args)
>>> predictor.predict_cli()
"""
super().__init__(cfg, overrides, _callbacks)
self.args.task = "pose"
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def construct_result(self, pred, img, orig_img, img_path):
"""
Construct the result object from the prediction, including keypoints.
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
result object.
Args:
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
orig_img (np.ndarray): The original unprocessed image as a numpy array.
img_path (str): The path to the original image file.
Returns:
(Results): The result object containing the original image, image path, class names, bounding boxes, and
keypoints.
"""
result = super().construct_result(pred, img, orig_img, img_path)
# Extract keypoints from prediction and reshape according to model's keypoint shape
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
# Scale keypoints coordinates to match the original image dimensions
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
result.update(keypoints=pred_kpts)
return result

View File

@@ -0,0 +1,115 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from copy import copy
from pathlib import Path
from typing import Any
from ultralytics.models import yolo
from ultralytics.nn.tasks import PoseModel
from ultralytics.utils import DEFAULT_CFG, LOGGER
class PoseTrainer(yolo.detect.DetectionTrainer):
"""
A class extending the DetectionTrainer class for training YOLO pose estimation models.
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
of pose keypoints alongside bounding boxes.
Attributes:
args (dict): Configuration arguments for training.
model (PoseModel): The pose estimation model being trained.
data (dict): Dataset configuration including keypoint shape information.
loss_names (tuple): Names of the loss components used in training.
Methods:
get_model: Retrieve a pose estimation model with specified configuration.
set_model_attributes: Set keypoints shape attribute on the model.
get_validator: Create a validator instance for model evaluation.
plot_training_samples: Visualize training samples with keypoints.
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
Examples:
>>> from ultralytics.models.yolo.pose import PoseTrainer
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
>>> trainer = PoseTrainer(overrides=args)
>>> trainer.train()
"""
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
"""
Initialize a PoseTrainer object for training YOLO pose estimation models.
Args:
cfg (dict, optional): Default configuration dictionary containing training parameters.
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
_callbacks (list, optional): List of callback functions to be executed during training.
Notes:
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
A warning is issued when using Apple MPS device due to known bugs with pose models.
"""
if overrides is None:
overrides = {}
overrides["task"] = "pose"
super().__init__(cfg, overrides, _callbacks)
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def get_model(
self,
cfg: str | Path | dict[str, Any] | None = None,
weights: str | Path | None = None,
verbose: bool = True,
) -> PoseModel:
"""
Get pose estimation model with specified configuration and weights.
Args:
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
weights (str | Path, optional): Path to the model weights file.
verbose (bool): Whether to display model information.
Returns:
(PoseModel): Initialized pose estimation model.
"""
model = PoseModel(
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
)
if weights:
model.load(weights)
return model
def set_model_attributes(self):
"""Set keypoints shape attribute of PoseModel."""
super().set_model_attributes()
self.model.kpt_shape = self.data["kpt_shape"]
def get_validator(self):
"""Return an instance of the PoseValidator class for validation."""
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
return yolo.pose.PoseValidator(
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
)
def get_dataset(self) -> dict[str, Any]:
"""
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
Returns:
(dict): A dictionary containing the training/validation/test dataset and category names.
Raises:
KeyError: If the `kpt_shape` key is not present in the dataset.
"""
data = super().get_dataset()
if "kpt_shape" not in data:
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
return data

View File

@@ -0,0 +1,267 @@
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import LOGGER, ops
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
class PoseValidator(DetectionValidator):
"""
A class extending the DetectionValidator class for validation based on a pose model.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Attributes:
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
args (dict): Arguments for the validator including task set to "pose".
metrics (PoseMetrics): Metrics object for pose evaluation.
Methods:
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
get_desc: Return description of evaluation metrics in string format.
init_metrics: Initialize pose estimation metrics for YOLO model.
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
dimensions.
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
detections and ground truth.
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
pred_to_json: Convert YOLO predictions to COCO JSON format.
eval_json: Evaluate object detection model using COCO JSON format.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
"""
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
"""
Initialize a PoseValidator object for pose estimation validation.
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
specialized metrics for pose evaluation.
Args:
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
save_dir (Path | str, optional): Directory to save results.
args (dict, optional): Arguments for the validator including task set to "pose".
_callbacks (list, optional): List of callback functions to be executed during validation.
Examples:
>>> from ultralytics.models.yolo.pose import PoseValidator
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
>>> validator = PoseValidator(args=args)
>>> validator()
Notes:
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
due to a known bug with pose models.
"""
super().__init__(dataloader, save_dir, args, _callbacks)
self.sigma = None
self.kpt_shape = None
self.args.task = "pose"
self.metrics = PoseMetrics()
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
LOGGER.warning(
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
"See https://github.com/ultralytics/ultralytics/issues/4031."
)
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
batch = super().preprocess(batch)
batch["keypoints"] = batch["keypoints"].float()
return batch
def get_desc(self) -> str:
"""Return description of evaluation metrics in string format."""
return ("%22s" + "%11s" * 10) % (
"Class",
"Images",
"Instances",
"Box(P",
"R",
"mAP50",
"mAP50-95)",
"Pose(P",
"R",
"mAP50",
"mAP50-95)",
)
def init_metrics(self, model: torch.nn.Module) -> None:
"""
Initialize evaluation metrics for YOLO pose validation.
Args:
model (torch.nn.Module): Model to validate.
"""
super().init_metrics(model)
self.kpt_shape = self.data["kpt_shape"]
is_pose = self.kpt_shape == [17, 3]
nkpt = self.kpt_shape[0]
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
field of predictions and reshaping them according to the keypoint shape configuration.
The keypoints are reshaped from a flattened format to the proper dimensional structure
(typically [N, 17, 3] for COCO pose format).
Args:
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
bounding boxes, confidence scores, class predictions, and keypoint data.
Returns:
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
- 'bboxes': Bounding box coordinates
- 'conf': Confidence scores
- 'cls': Class predictions
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
Note:
If no keypoints are present in a prediction (empty keypoints), that prediction
is skipped and continues to the next one. The keypoints are extracted from the
'extra' field which contains additional task-specific data beyond basic detection.
"""
preds = super().postprocess(preds)
for pred in preds:
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
return preds
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
"""
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
Args:
si (int): Batch index.
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
Returns:
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
Notes:
This method extends the parent class's _prepare_batch method by adding keypoint processing.
Keypoints are scaled from normalized coordinates to original image dimensions.
"""
pbatch = super()._prepare_batch(si, batch)
kpts = batch["keypoints"][batch["batch_idx"] == si]
h, w = pbatch["imgsz"]
kpts = kpts.clone()
kpts[..., 0] *= w
kpts[..., 1] *= h
pbatch["keypoints"] = kpts
return pbatch
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
"""
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
Args:
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
and 'keypoints' for keypoint predictions.
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
Returns:
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
true positives across 10 IoU levels.
Notes:
`0.53` scale factor used in area computation is referenced from
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
"""
tp = super()._process_batch(preds, batch)
gt_cls = batch["cls"]
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
else:
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
return tp
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
"""
Save YOLO pose detections to a text file in normalized coordinates.
Args:
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
save_conf (bool): Whether to save confidence scores.
shape (tuple[int, int]): Shape of the original image (height, width).
file (Path): Output file path to save detections.
Notes:
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
normalized (x, y, visibility) values for each point.
"""
from ultralytics.engine.results import Results
Results(
np.zeros((shape[0], shape[1]), dtype=np.uint8),
path=None,
names=self.names,
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
keypoints=predn["keypoints"],
).save_txt(file, save_conf=save_conf)
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
"""
Convert YOLO predictions to COCO JSON format.
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
Args:
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
and 'keypoints' tensors.
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
Notes:
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
before saving to the JSON dictionary.
"""
super().pred_to_json(predn, pbatch)
kpts = predn["kpts"]
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
"""Scales predictions to the original image size."""
return {
**super().scale_preds(predn, pbatch),
"kpts": ops.scale_coords(
pbatch["imgsz"],
predn["keypoints"].clone(),
pbatch["ori_shape"],
ratio_pad=pbatch["ratio_pad"],
),
}
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
"""Evaluate object detection model using COCO JSON format."""
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
pred_json = self.save_dir / "predictions.json" # predictions
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])