init commit
This commit is contained in:
7
ultralytics/models/yolo/pose/__init__.py
Normal file
7
ultralytics/models/yolo/pose/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import PosePredictor
|
||||
from .train import PoseTrainer
|
||||
from .val import PoseValidator
|
||||
|
||||
__all__ = "PoseTrainer", "PoseValidator", "PosePredictor"
|
||||
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/pose/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
80
ultralytics/models/yolo/pose/predict.py
Normal file
80
ultralytics/models/yolo/pose/predict.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
|
||||
|
||||
|
||||
class PosePredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on a pose model.
|
||||
|
||||
This class specializes in pose estimation, handling keypoints detection alongside standard object detection
|
||||
capabilities inherited from DetectionPredictor.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO pose model with keypoint detection capabilities.
|
||||
|
||||
Methods:
|
||||
construct_result: Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize PosePredictor for pose estimation tasks.
|
||||
|
||||
Sets up a PosePredictor instance, configuring it for pose detection tasks and handling device-specific
|
||||
warnings for Apple MPS.
|
||||
|
||||
Args:
|
||||
cfg (Any): Configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over cfg.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.pose import PosePredictor
|
||||
>>> args = dict(model="yolo11n-pose.pt", source=ASSETS)
|
||||
>>> predictor = PosePredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "pose"
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction, including keypoints.
|
||||
|
||||
Extends the parent class implementation by extracting keypoint data from predictions and adding them to the
|
||||
result object.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and keypoints with shape (N, 6+K*D) where N is
|
||||
the number of detections, K is the number of keypoints, and D is the keypoint dimension.
|
||||
img (torch.Tensor): The processed input image tensor with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original unprocessed image as a numpy array.
|
||||
img_path (str): The path to the original image file.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, bounding boxes, and
|
||||
keypoints.
|
||||
"""
|
||||
result = super().construct_result(pred, img, orig_img, img_path)
|
||||
# Extract keypoints from prediction and reshape according to model's keypoint shape
|
||||
pred_kpts = pred[:, 6:].view(pred.shape[0], *self.model.kpt_shape)
|
||||
# Scale keypoints coordinates to match the original image dimensions
|
||||
pred_kpts = ops.scale_coords(img.shape[2:], pred_kpts, orig_img.shape)
|
||||
result.update(keypoints=pred_kpts)
|
||||
return result
|
||||
115
ultralytics/models/yolo/pose/train.py
Normal file
115
ultralytics/models/yolo/pose/train.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import PoseModel
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER
|
||||
|
||||
|
||||
class PoseTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training YOLO pose estimation models.
|
||||
|
||||
This trainer specializes in handling pose estimation tasks, managing model training, validation, and visualization
|
||||
of pose keypoints alongside bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for training.
|
||||
model (PoseModel): The pose estimation model being trained.
|
||||
data (dict): Dataset configuration including keypoint shape information.
|
||||
loss_names (tuple): Names of the loss components used in training.
|
||||
|
||||
Methods:
|
||||
get_model: Retrieve a pose estimation model with specified configuration.
|
||||
set_model_attributes: Set keypoints shape attribute on the model.
|
||||
get_validator: Create a validator instance for model evaluation.
|
||||
plot_training_samples: Visualize training samples with keypoints.
|
||||
get_dataset: Retrieve the dataset and ensure it contains required kpt_shape key.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseTrainer
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml", epochs=3)
|
||||
>>> trainer = PoseTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict[str, Any] | None = None, _callbacks=None):
|
||||
"""
|
||||
Initialize a PoseTrainer object for training YOLO pose estimation models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration dictionary containing training parameters.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the default configuration.
|
||||
_callbacks (list, optional): List of callback functions to be executed during training.
|
||||
|
||||
Notes:
|
||||
This trainer will automatically set the task to 'pose' regardless of what is provided in overrides.
|
||||
A warning is issued when using Apple MPS device due to known bugs with pose models.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "pose"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
cfg: str | Path | dict[str, Any] | None = None,
|
||||
weights: str | Path | None = None,
|
||||
verbose: bool = True,
|
||||
) -> PoseModel:
|
||||
"""
|
||||
Get pose estimation model with specified configuration and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | Path | dict, optional): Model configuration file path or dictionary.
|
||||
weights (str | Path, optional): Path to the model weights file.
|
||||
verbose (bool): Whether to display model information.
|
||||
|
||||
Returns:
|
||||
(PoseModel): Initialized pose estimation model.
|
||||
"""
|
||||
model = PoseModel(
|
||||
cfg, nc=self.data["nc"], ch=self.data["channels"], data_kpt_shape=self.data["kpt_shape"], verbose=verbose
|
||||
)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def set_model_attributes(self):
|
||||
"""Set keypoints shape attribute of PoseModel."""
|
||||
super().set_model_attributes()
|
||||
self.model.kpt_shape = self.data["kpt_shape"]
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of the PoseValidator class for validation."""
|
||||
self.loss_names = "box_loss", "pose_loss", "kobj_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.pose.PoseValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
|
||||
def get_dataset(self) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve the dataset and ensure it contains the required `kpt_shape` key.
|
||||
|
||||
Returns:
|
||||
(dict): A dictionary containing the training/validation/test dataset and category names.
|
||||
|
||||
Raises:
|
||||
KeyError: If the `kpt_shape` key is not present in the dataset.
|
||||
"""
|
||||
data = super().get_dataset()
|
||||
if "kpt_shape" not in data:
|
||||
raise KeyError(f"No `kpt_shape` in the {self.args.data}. See https://docs.ultralytics.com/datasets/pose/")
|
||||
return data
|
||||
267
ultralytics/models/yolo/pose/val.py
Normal file
267
ultralytics/models/yolo/pose/val.py
Normal file
@@ -0,0 +1,267 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, kpt_iou
|
||||
|
||||
|
||||
class PoseValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on a pose model.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Attributes:
|
||||
sigma (np.ndarray): Sigma values for OKS calculation, either OKS_SIGMA or ones divided by number of keypoints.
|
||||
kpt_shape (list[int]): Shape of the keypoints, typically [17, 3] for COCO format.
|
||||
args (dict): Arguments for the validator including task set to "pose".
|
||||
metrics (PoseMetrics): Metrics object for pose evaluation.
|
||||
|
||||
Methods:
|
||||
preprocess: Preprocess batch by converting keypoints data to float and moving it to the device.
|
||||
get_desc: Return description of evaluation metrics in string format.
|
||||
init_metrics: Initialize pose estimation metrics for YOLO model.
|
||||
_prepare_batch: Prepare a batch for processing by converting keypoints to float and scaling to original
|
||||
dimensions.
|
||||
_prepare_pred: Prepare and scale keypoints in predictions for pose processing.
|
||||
_process_batch: Return correct prediction matrix by computing Intersection over Union (IoU) between
|
||||
detections and ground truth.
|
||||
plot_val_samples: Plot and save validation set samples with ground truth bounding boxes and keypoints.
|
||||
plot_predictions: Plot and save model predictions with bounding boxes and keypoints.
|
||||
save_one_txt: Save YOLO pose detections to a text file in normalized coordinates.
|
||||
pred_to_json: Convert YOLO predictions to COCO JSON format.
|
||||
eval_json: Evaluate object detection model using COCO JSON format.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize a PoseValidator object for pose estimation validation.
|
||||
|
||||
This validator is specifically designed for pose estimation tasks, handling keypoints and implementing
|
||||
specialized metrics for pose evaluation.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (Path | str, optional): Directory to save results.
|
||||
args (dict, optional): Arguments for the validator including task set to "pose".
|
||||
_callbacks (list, optional): List of callback functions to be executed during validation.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.pose import PoseValidator
|
||||
>>> args = dict(model="yolo11n-pose.pt", data="coco8-pose.yaml")
|
||||
>>> validator = PoseValidator(args=args)
|
||||
>>> validator()
|
||||
|
||||
Notes:
|
||||
This class extends DetectionValidator with pose-specific functionality. It initializes with sigma values
|
||||
for OKS calculation and sets up PoseMetrics for evaluation. A warning is displayed when using Apple MPS
|
||||
due to a known bug with pose models.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.sigma = None
|
||||
self.kpt_shape = None
|
||||
self.args.task = "pose"
|
||||
self.metrics = PoseMetrics()
|
||||
if isinstance(self.args.device, str) and self.args.device.lower() == "mps":
|
||||
LOGGER.warning(
|
||||
"Apple MPS known Pose bug. Recommend 'device=cpu' for Pose models. "
|
||||
"See https://github.com/ultralytics/ultralytics/issues/4031."
|
||||
)
|
||||
|
||||
def preprocess(self, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Preprocess batch by converting keypoints data to float and moving it to the device."""
|
||||
batch = super().preprocess(batch)
|
||||
batch["keypoints"] = batch["keypoints"].float()
|
||||
return batch
|
||||
|
||||
def get_desc(self) -> str:
|
||||
"""Return description of evaluation metrics in string format."""
|
||||
return ("%22s" + "%11s" * 10) % (
|
||||
"Class",
|
||||
"Images",
|
||||
"Instances",
|
||||
"Box(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
"Pose(P",
|
||||
"R",
|
||||
"mAP50",
|
||||
"mAP50-95)",
|
||||
)
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO pose validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
self.kpt_shape = self.data["kpt_shape"]
|
||||
is_pose = self.kpt_shape == [17, 3]
|
||||
nkpt = self.kpt_shape[0]
|
||||
self.sigma = OKS_SIGMA if is_pose else np.ones(nkpt) / nkpt
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Postprocess YOLO predictions to extract and reshape keypoints for pose estimation.
|
||||
|
||||
This method extends the parent class postprocessing by extracting keypoints from the 'extra'
|
||||
field of predictions and reshaping them according to the keypoint shape configuration.
|
||||
The keypoints are reshaped from a flattened format to the proper dimensional structure
|
||||
(typically [N, 17, 3] for COCO pose format).
|
||||
|
||||
Args:
|
||||
preds (torch.Tensor): Raw prediction tensor from the YOLO pose model containing
|
||||
bounding boxes, confidence scores, class predictions, and keypoint data.
|
||||
|
||||
Returns:
|
||||
(dict[torch.Tensor]): Dict of processed prediction dictionaries, each containing:
|
||||
- 'bboxes': Bounding box coordinates
|
||||
- 'conf': Confidence scores
|
||||
- 'cls': Class predictions
|
||||
- 'keypoints': Reshaped keypoint coordinates with shape (-1, *self.kpt_shape)
|
||||
|
||||
Note:
|
||||
If no keypoints are present in a prediction (empty keypoints), that prediction
|
||||
is skipped and continues to the next one. The keypoints are extracted from the
|
||||
'extra' field which contains additional task-specific data beyond basic detection.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["keypoints"] = pred.pop("extra").view(-1, *self.kpt_shape) # remove extra if exists
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare a batch for processing by converting keypoints to float and scaling to original dimensions.
|
||||
|
||||
Args:
|
||||
si (int): Batch index.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys like 'keypoints', 'batch_idx', etc.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch with keypoints scaled to original image dimensions.
|
||||
|
||||
Notes:
|
||||
This method extends the parent class's _prepare_batch method by adding keypoint processing.
|
||||
Keypoints are scaled from normalized coordinates to original image dimensions.
|
||||
"""
|
||||
pbatch = super()._prepare_batch(si, batch)
|
||||
kpts = batch["keypoints"][batch["batch_idx"] == si]
|
||||
h, w = pbatch["imgsz"]
|
||||
kpts = kpts.clone()
|
||||
kpts[..., 0] *= w
|
||||
kpts[..., 1] *= h
|
||||
pbatch["keypoints"] = kpts
|
||||
return pbatch
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, Any]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Return correct prediction matrix by computing Intersection over Union (IoU) between detections and ground truth.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Dictionary containing prediction data with keys 'cls' for class predictions
|
||||
and 'keypoints' for keypoint predictions.
|
||||
batch (dict[str, Any]): Dictionary containing ground truth data with keys 'cls' for class labels,
|
||||
'bboxes' for bounding boxes, and 'keypoints' for keypoint annotations.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing the correct prediction matrix including 'tp_p' for pose
|
||||
true positives across 10 IoU levels.
|
||||
|
||||
Notes:
|
||||
`0.53` scale factor used in area computation is referenced from
|
||||
https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384.
|
||||
"""
|
||||
tp = super()._process_batch(preds, batch)
|
||||
gt_cls = batch["cls"]
|
||||
if gt_cls.shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
tp_p = np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)
|
||||
else:
|
||||
# `0.53` is from https://github.com/jin-s13/xtcocoapi/blob/master/xtcocotools/cocoeval.py#L384
|
||||
area = ops.xyxy2xywh(batch["bboxes"])[:, 2:].prod(1) * 0.53
|
||||
iou = kpt_iou(batch["keypoints"], preds["keypoints"], sigma=self.sigma, area=area)
|
||||
tp_p = self.match_predictions(preds["cls"], gt_cls, iou).cpu().numpy()
|
||||
tp.update({"tp_p": tp_p}) # update tp with kpts IoU
|
||||
return tp
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO pose detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Dictionary containing predictions with keys 'bboxes', 'conf', 'cls' and 'keypoints.
|
||||
save_conf (bool): Whether to save confidence scores.
|
||||
shape (tuple[int, int]): Shape of the original image (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Notes:
|
||||
The output format is: class_id x_center y_center width height confidence keypoints where keypoints are
|
||||
normalized (x, y, visibility) values for each point.
|
||||
"""
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
boxes=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
keypoints=predn["keypoints"],
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format.
|
||||
|
||||
This method takes prediction tensors and a filename, converts the bounding boxes from YOLO format
|
||||
to COCO format, and appends the results to the internal JSON dictionary (self.jdict).
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', 'cls',
|
||||
and 'keypoints' tensors.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
The method extracts the image ID from the filename stem (either as an integer if numeric, or as a string),
|
||||
converts bounding boxes from xyxy to xywh format, and adjusts coordinates from center to top-left corner
|
||||
before saving to the JSON dictionary.
|
||||
"""
|
||||
super().pred_to_json(predn, pbatch)
|
||||
kpts = predn["kpts"]
|
||||
for i, k in enumerate(kpts.flatten(1, 2).tolist()):
|
||||
self.jdict[-len(kpts) + i]["keypoints"] = k # keypoints
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**super().scale_preds(predn, pbatch),
|
||||
"kpts": ops.scale_coords(
|
||||
pbatch["imgsz"],
|
||||
predn["keypoints"].clone(),
|
||||
pbatch["ori_shape"],
|
||||
ratio_pad=pbatch["ratio_pad"],
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Evaluate object detection model using COCO JSON format."""
|
||||
anno_json = self.data["path"] / "annotations/person_keypoints_val2017.json" # annotations
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
return super().coco_evaluate(stats, pred_json, anno_json, ["bbox", "keypoints"], suffix=["Box", "Pose"])
|
||||
Reference in New Issue
Block a user