init commit
This commit is contained in:
7
ultralytics/models/yolo/obb/__init__.py
Normal file
7
ultralytics/models/yolo/obb/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from .predict import OBBPredictor
|
||||
from .train import OBBTrainer
|
||||
from .val import OBBValidator
|
||||
|
||||
__all__ = "OBBPredictor", "OBBTrainer", "OBBValidator"
|
||||
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/predict.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/train.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
BIN
ultralytics/models/yolo/obb/__pycache__/val.cpython-310.pyc
Normal file
Binary file not shown.
65
ultralytics/models/yolo/obb/predict.py
Normal file
65
ultralytics/models/yolo/obb/predict.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
from ultralytics.models.yolo.detect.predict import DetectionPredictor
|
||||
from ultralytics.utils import DEFAULT_CFG, ops
|
||||
|
||||
|
||||
class OBBPredictor(DetectionPredictor):
|
||||
"""
|
||||
A class extending the DetectionPredictor class for prediction based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This predictor handles oriented bounding box detection tasks, processing images and returning results with rotated
|
||||
bounding boxes.
|
||||
|
||||
Attributes:
|
||||
args (namespace): Configuration arguments for the predictor.
|
||||
model (torch.nn.Module): The loaded YOLO OBB model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
>>> predictor.predict_cli()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
|
||||
"""
|
||||
Initialize OBBPredictor with optional model and data configuration overrides.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Default configuration for the predictor.
|
||||
overrides (dict, optional): Configuration overrides that take precedence over the default config.
|
||||
_callbacks (list, optional): List of callback functions to be invoked during prediction.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.utils import ASSETS
|
||||
>>> from ultralytics.models.yolo.obb import OBBPredictor
|
||||
>>> args = dict(model="yolo11n-obb.pt", source=ASSETS)
|
||||
>>> predictor = OBBPredictor(overrides=args)
|
||||
"""
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
self.args.task = "obb"
|
||||
|
||||
def construct_result(self, pred, img, orig_img, img_path):
|
||||
"""
|
||||
Construct the result object from the prediction.
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): The predicted bounding boxes, scores, and rotation angles with shape (N, 7) where
|
||||
the last dimension contains [x, y, w, h, confidence, class_id, angle].
|
||||
img (torch.Tensor): The image after preprocessing with shape (B, C, H, W).
|
||||
orig_img (np.ndarray): The original image before preprocessing.
|
||||
img_path (str): The path to the original image.
|
||||
|
||||
Returns:
|
||||
(Results): The result object containing the original image, image path, class names, and oriented bounding
|
||||
boxes.
|
||||
"""
|
||||
rboxes = ops.regularize_rboxes(torch.cat([pred[:, :4], pred[:, -1:]], dim=-1))
|
||||
rboxes[:, :4] = ops.scale_boxes(img.shape[2:], rboxes[:, :4], orig_img.shape, xywh=True)
|
||||
obb = torch.cat([rboxes, pred[:, 4:6]], dim=-1)
|
||||
return Results(orig_img, path=img_path, names=self.model.names, obb=obb)
|
||||
82
ultralytics/models/yolo/obb/train.py
Normal file
82
ultralytics/models/yolo/obb/train.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from copy import copy
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ultralytics.models import yolo
|
||||
from ultralytics.nn.tasks import OBBModel
|
||||
from ultralytics.utils import DEFAULT_CFG, RANK
|
||||
|
||||
|
||||
class OBBTrainer(yolo.detect.DetectionTrainer):
|
||||
"""
|
||||
A class extending the DetectionTrainer class for training based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This trainer specializes in training YOLO models that detect oriented bounding boxes, which are useful for
|
||||
detecting objects at arbitrary angles rather than just axis-aligned rectangles.
|
||||
|
||||
Attributes:
|
||||
loss_names (tuple): Names of the loss components used during training including box_loss, cls_loss,
|
||||
and dfl_loss.
|
||||
|
||||
Methods:
|
||||
get_model: Return OBBModel initialized with specified config and weights.
|
||||
get_validator: Return an instance of OBBValidator for validation of YOLO model.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBTrainer
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml", epochs=3)
|
||||
>>> trainer = OBBTrainer(overrides=args)
|
||||
>>> trainer.train()
|
||||
"""
|
||||
|
||||
def __init__(self, cfg=DEFAULT_CFG, overrides: dict | None = None, _callbacks: list[Any] | None = None):
|
||||
"""
|
||||
Initialize an OBBTrainer object for training Oriented Bounding Box (OBB) models.
|
||||
|
||||
Args:
|
||||
cfg (dict, optional): Configuration dictionary for the trainer. Contains training parameters and
|
||||
model configuration.
|
||||
overrides (dict, optional): Dictionary of parameter overrides for the configuration. Any values here
|
||||
will take precedence over those in cfg.
|
||||
_callbacks (list[Any], optional): List of callback functions to be invoked during training.
|
||||
"""
|
||||
if overrides is None:
|
||||
overrides = {}
|
||||
overrides["task"] = "obb"
|
||||
super().__init__(cfg, overrides, _callbacks)
|
||||
|
||||
def get_model(
|
||||
self, cfg: str | dict | None = None, weights: str | Path | None = None, verbose: bool = True
|
||||
) -> OBBModel:
|
||||
"""
|
||||
Return OBBModel initialized with specified config and weights.
|
||||
|
||||
Args:
|
||||
cfg (str | dict, optional): Model configuration. Can be a path to a YAML config file, a dictionary
|
||||
containing configuration parameters, or None to use default configuration.
|
||||
weights (str | Path, optional): Path to pretrained weights file. If None, random initialization is used.
|
||||
verbose (bool): Whether to display model information during initialization.
|
||||
|
||||
Returns:
|
||||
(OBBModel): Initialized OBBModel with the specified configuration and weights.
|
||||
|
||||
Examples:
|
||||
>>> trainer = OBBTrainer()
|
||||
>>> model = trainer.get_model(cfg="yolo11n-obb.yaml", weights="yolo11n-obb.pt")
|
||||
"""
|
||||
model = OBBModel(cfg, nc=self.data["nc"], ch=self.data["channels"], verbose=verbose and RANK == -1)
|
||||
if weights:
|
||||
model.load(weights)
|
||||
|
||||
return model
|
||||
|
||||
def get_validator(self):
|
||||
"""Return an instance of OBBValidator for validation of YOLO model."""
|
||||
self.loss_names = "box_loss", "cls_loss", "dfl_loss"
|
||||
return yolo.obb.OBBValidator(
|
||||
self.test_loader, save_dir=self.save_dir, args=copy(self.args), _callbacks=self.callbacks
|
||||
)
|
||||
299
ultralytics/models/yolo/obb/val.py
Normal file
299
ultralytics/models/yolo/obb/val.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.models.yolo.detect import DetectionValidator
|
||||
from ultralytics.utils import LOGGER, ops
|
||||
from ultralytics.utils.metrics import OBBMetrics, batch_probiou
|
||||
from ultralytics.utils.nms import TorchNMS
|
||||
|
||||
|
||||
class OBBValidator(DetectionValidator):
|
||||
"""
|
||||
A class extending the DetectionValidator class for validation based on an Oriented Bounding Box (OBB) model.
|
||||
|
||||
This validator specializes in evaluating models that predict rotated bounding boxes, commonly used for aerial and
|
||||
satellite imagery where objects can appear at various orientations.
|
||||
|
||||
Attributes:
|
||||
args (dict): Configuration arguments for the validator.
|
||||
metrics (OBBMetrics): Metrics object for evaluating OBB model performance.
|
||||
is_dota (bool): Flag indicating whether the validation dataset is in DOTA format.
|
||||
|
||||
Methods:
|
||||
init_metrics: Initialize evaluation metrics for YOLO.
|
||||
_process_batch: Process batch of detections and ground truth boxes to compute IoU matrix.
|
||||
_prepare_batch: Prepare batch data for OBB validation.
|
||||
_prepare_pred: Prepare predictions with scaled and padded bounding boxes.
|
||||
plot_predictions: Plot predicted bounding boxes on input images.
|
||||
pred_to_json: Serialize YOLO predictions to COCO json format.
|
||||
save_one_txt: Save YOLO detections to a txt file in normalized coordinates.
|
||||
eval_json: Evaluate YOLO output in JSON format and return performance statistics.
|
||||
|
||||
Examples:
|
||||
>>> from ultralytics.models.yolo.obb import OBBValidator
|
||||
>>> args = dict(model="yolo11n-obb.pt", data="dota8.yaml")
|
||||
>>> validator = OBBValidator(args=args)
|
||||
>>> validator(model=args["model"])
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader=None, save_dir=None, args=None, _callbacks=None) -> None:
|
||||
"""
|
||||
Initialize OBBValidator and set task to 'obb', metrics to OBBMetrics.
|
||||
|
||||
This constructor initializes an OBBValidator instance for validating Oriented Bounding Box (OBB) models.
|
||||
It extends the DetectionValidator class and configures it specifically for the OBB task.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader, optional): Dataloader to be used for validation.
|
||||
save_dir (str | Path, optional): Directory to save results.
|
||||
args (dict | SimpleNamespace, optional): Arguments containing validation parameters.
|
||||
_callbacks (list, optional): List of callback functions to be called during validation.
|
||||
"""
|
||||
super().__init__(dataloader, save_dir, args, _callbacks)
|
||||
self.args.task = "obb"
|
||||
self.metrics = OBBMetrics()
|
||||
|
||||
def init_metrics(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Initialize evaluation metrics for YOLO obb validation.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model to validate.
|
||||
"""
|
||||
super().init_metrics(model)
|
||||
val = self.data.get(self.args.split, "") # validation path
|
||||
self.is_dota = isinstance(val, str) and "DOTA" in val # check if dataset is DOTA format
|
||||
self.confusion_matrix.task = "obb" # set confusion matrix task to 'obb'
|
||||
|
||||
def _process_batch(self, preds: dict[str, torch.Tensor], batch: dict[str, torch.Tensor]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
Compute the correct prediction matrix for a batch of detections and ground truth bounding boxes.
|
||||
|
||||
Args:
|
||||
preds (dict[str, torch.Tensor]): Prediction dictionary containing 'cls' and 'bboxes' keys with detected
|
||||
class labels and bounding boxes.
|
||||
batch (dict[str, torch.Tensor]): Batch dictionary containing 'cls' and 'bboxes' keys with ground truth
|
||||
class labels and bounding boxes.
|
||||
|
||||
Returns:
|
||||
(dict[str, np.ndarray]): Dictionary containing 'tp' key with the correct prediction matrix as a numpy
|
||||
array with shape (N, 10), which includes 10 IoU levels for each detection, indicating the accuracy
|
||||
of predictions compared to the ground truth.
|
||||
|
||||
Examples:
|
||||
>>> detections = torch.rand(100, 7) # 100 sample detections
|
||||
>>> gt_bboxes = torch.rand(50, 5) # 50 sample ground truth boxes
|
||||
>>> gt_cls = torch.randint(0, 5, (50,)) # 50 ground truth class labels
|
||||
>>> correct_matrix = validator._process_batch(detections, gt_bboxes, gt_cls)
|
||||
"""
|
||||
if batch["cls"].shape[0] == 0 or preds["cls"].shape[0] == 0:
|
||||
return {"tp": np.zeros((preds["cls"].shape[0], self.niou), dtype=bool)}
|
||||
iou = batch_probiou(batch["bboxes"], preds["bboxes"])
|
||||
return {"tp": self.match_predictions(preds["cls"], batch["cls"], iou).cpu().numpy()}
|
||||
|
||||
def postprocess(self, preds: torch.Tensor) -> list[dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
preds (torch.Tensor): Raw predictions from the model.
|
||||
|
||||
Returns:
|
||||
(list[dict[str, torch.Tensor]]): Processed predictions with angle information concatenated to bboxes.
|
||||
"""
|
||||
preds = super().postprocess(preds)
|
||||
for pred in preds:
|
||||
pred["bboxes"] = torch.cat([pred["bboxes"], pred.pop("extra")], dim=-1) # concatenate angle
|
||||
return preds
|
||||
|
||||
def _prepare_batch(self, si: int, batch: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare batch data for OBB validation with proper scaling and formatting.
|
||||
|
||||
Args:
|
||||
si (int): Batch index to process.
|
||||
batch (dict[str, Any]): Dictionary containing batch data with keys:
|
||||
- batch_idx: Tensor of batch indices
|
||||
- cls: Tensor of class labels
|
||||
- bboxes: Tensor of bounding boxes
|
||||
- ori_shape: Original image shapes
|
||||
- img: Batch of images
|
||||
- ratio_pad: Ratio and padding information
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Prepared batch data with scaled bounding boxes and metadata.
|
||||
"""
|
||||
idx = batch["batch_idx"] == si
|
||||
cls = batch["cls"][idx].squeeze(-1)
|
||||
bbox = batch["bboxes"][idx]
|
||||
ori_shape = batch["ori_shape"][si]
|
||||
imgsz = batch["img"].shape[2:]
|
||||
ratio_pad = batch["ratio_pad"][si]
|
||||
if cls.shape[0]:
|
||||
bbox[..., :4].mul_(torch.tensor(imgsz, device=self.device)[[1, 0, 1, 0]]) # target boxes
|
||||
return {
|
||||
"cls": cls,
|
||||
"bboxes": bbox,
|
||||
"ori_shape": ori_shape,
|
||||
"imgsz": imgsz,
|
||||
"ratio_pad": ratio_pad,
|
||||
"im_file": batch["im_file"][si],
|
||||
}
|
||||
|
||||
def plot_predictions(self, batch: dict[str, Any], preds: list[torch.Tensor], ni: int) -> None:
|
||||
"""
|
||||
Plot predicted bounding boxes on input images and save the result.
|
||||
|
||||
Args:
|
||||
batch (dict[str, Any]): Batch data containing images, file paths, and other metadata.
|
||||
preds (list[torch.Tensor]): List of prediction tensors for each image in the batch.
|
||||
ni (int): Batch index used for naming the output file.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> batch = {"img": images, "im_file": paths}
|
||||
>>> preds = [torch.rand(10, 7)] # Example predictions for one image
|
||||
>>> validator.plot_predictions(batch, preds, 0)
|
||||
"""
|
||||
for p in preds:
|
||||
# TODO: fix this duplicated `xywh2xyxy`
|
||||
p["bboxes"][:, :4] = ops.xywh2xyxy(p["bboxes"][:, :4]) # convert to xyxy format for plotting
|
||||
super().plot_predictions(batch, preds, ni) # plot bboxes
|
||||
|
||||
def pred_to_json(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert YOLO predictions to COCO JSON format with rotated bounding box information.
|
||||
|
||||
Args:
|
||||
predn (dict[str, torch.Tensor]): Prediction dictionary containing 'bboxes', 'conf', and 'cls' keys
|
||||
with bounding box coordinates, confidence scores, and class predictions.
|
||||
pbatch (dict[str, Any]): Batch dictionary containing 'imgsz', 'ori_shape', 'ratio_pad', and 'im_file'.
|
||||
|
||||
Notes:
|
||||
This method processes rotated bounding box predictions and converts them to both rbox format
|
||||
(x, y, w, h, angle) and polygon format (x1, y1, x2, y2, x3, y3, x4, y4) before adding them
|
||||
to the JSON dictionary.
|
||||
"""
|
||||
path = Path(pbatch["im_file"])
|
||||
stem = path.stem
|
||||
image_id = int(stem) if stem.isnumeric() else stem
|
||||
rbox = predn["bboxes"]
|
||||
poly = ops.xywhr2xyxyxyxy(rbox).view(-1, 8)
|
||||
for r, b, s, c in zip(rbox.tolist(), poly.tolist(), predn["conf"].tolist(), predn["cls"].tolist()):
|
||||
self.jdict.append(
|
||||
{
|
||||
"image_id": image_id,
|
||||
"file_name": path.name,
|
||||
"category_id": self.class_map[int(c)],
|
||||
"score": round(s, 5),
|
||||
"rbox": [round(x, 3) for x in r],
|
||||
"poly": [round(x, 3) for x in b],
|
||||
}
|
||||
)
|
||||
|
||||
def save_one_txt(self, predn: dict[str, torch.Tensor], save_conf: bool, shape: tuple[int, int], file: Path) -> None:
|
||||
"""
|
||||
Save YOLO OBB detections to a text file in normalized coordinates.
|
||||
|
||||
Args:
|
||||
predn (torch.Tensor): Predicted detections with shape (N, 7) containing bounding boxes, confidence scores,
|
||||
class predictions, and angles in format (x, y, w, h, conf, cls, angle).
|
||||
save_conf (bool): Whether to save confidence scores in the text file.
|
||||
shape (tuple[int, int]): Original image shape in format (height, width).
|
||||
file (Path): Output file path to save detections.
|
||||
|
||||
Examples:
|
||||
>>> validator = OBBValidator()
|
||||
>>> predn = torch.tensor([[100, 100, 50, 30, 0.9, 0, 45]]) # One detection: x,y,w,h,conf,cls,angle
|
||||
>>> validator.save_one_txt(predn, True, (640, 480), "detection.txt")
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from ultralytics.engine.results import Results
|
||||
|
||||
Results(
|
||||
np.zeros((shape[0], shape[1]), dtype=np.uint8),
|
||||
path=None,
|
||||
names=self.names,
|
||||
obb=torch.cat([predn["bboxes"], predn["conf"].unsqueeze(-1), predn["cls"].unsqueeze(-1)], dim=1),
|
||||
).save_txt(file, save_conf=save_conf)
|
||||
|
||||
def scale_preds(self, predn: dict[str, torch.Tensor], pbatch: dict[str, Any]) -> dict[str, torch.Tensor]:
|
||||
"""Scales predictions to the original image size."""
|
||||
return {
|
||||
**predn,
|
||||
"bboxes": ops.scale_boxes(
|
||||
pbatch["imgsz"], predn["bboxes"].clone(), pbatch["ori_shape"], ratio_pad=pbatch["ratio_pad"], xywh=True
|
||||
),
|
||||
}
|
||||
|
||||
def eval_json(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Evaluate YOLO output in JSON format and save predictions in DOTA format.
|
||||
|
||||
Args:
|
||||
stats (dict[str, Any]): Performance statistics dictionary.
|
||||
|
||||
Returns:
|
||||
(dict[str, Any]): Updated performance statistics.
|
||||
"""
|
||||
if self.args.save_json and self.is_dota and len(self.jdict):
|
||||
import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
pred_json = self.save_dir / "predictions.json" # predictions
|
||||
pred_txt = self.save_dir / "predictions_txt" # predictions
|
||||
pred_txt.mkdir(parents=True, exist_ok=True)
|
||||
data = json.load(open(pred_json))
|
||||
# Save split results
|
||||
LOGGER.info(f"Saving predictions with DOTA format to {pred_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"]
|
||||
score = d["score"]
|
||||
classname = self.names[d["category_id"] - 1].replace(" ", "-")
|
||||
p = d["poly"]
|
||||
|
||||
with open(f"{pred_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
# Save merged results, this could result slightly lower map than using official merging script,
|
||||
# because of the probiou calculation.
|
||||
pred_merged_txt = self.save_dir / "predictions_merged_txt" # predictions
|
||||
pred_merged_txt.mkdir(parents=True, exist_ok=True)
|
||||
merged_results = defaultdict(list)
|
||||
LOGGER.info(f"Saving merged predictions with DOTA format to {pred_merged_txt}...")
|
||||
for d in data:
|
||||
image_id = d["image_id"].split("__", 1)[0]
|
||||
pattern = re.compile(r"\d+___\d+")
|
||||
x, y = (int(c) for c in re.findall(pattern, d["image_id"])[0].split("___"))
|
||||
bbox, score, cls = d["rbox"], d["score"], d["category_id"] - 1
|
||||
bbox[0] += x
|
||||
bbox[1] += y
|
||||
bbox.extend([score, cls])
|
||||
merged_results[image_id].append(bbox)
|
||||
for image_id, bbox in merged_results.items():
|
||||
bbox = torch.tensor(bbox)
|
||||
max_wh = torch.max(bbox[:, :2]).item() * 2
|
||||
c = bbox[:, 6:7] * max_wh # classes
|
||||
scores = bbox[:, 5] # scores
|
||||
b = bbox[:, :5].clone()
|
||||
b[:, :2] += c
|
||||
# 0.3 could get results close to the ones from official merging script, even slightly better.
|
||||
i = TorchNMS.fast_nms(b, scores, 0.3, iou_func=batch_probiou)
|
||||
bbox = bbox[i]
|
||||
|
||||
b = ops.xywhr2xyxyxyxy(bbox[:, :5]).view(-1, 8)
|
||||
for x in torch.cat([b, bbox[:, 5:7]], dim=-1).tolist():
|
||||
classname = self.names[int(x[-1])].replace(" ", "-")
|
||||
p = [round(i, 3) for i in x[:-2]] # poly
|
||||
score = round(x[-2], 3)
|
||||
|
||||
with open(f"{pred_merged_txt / f'Task1_{classname}'}.txt", "a", encoding="utf-8") as f:
|
||||
f.writelines(f"{image_id} {score} {p[0]} {p[1]} {p[2]} {p[3]} {p[4]} {p[5]} {p[6]} {p[7]}\n")
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user