init commit
This commit is contained in:
239
ultralytics/utils/export/__init__.py
Normal file
239
ultralytics/utils/export/__init__.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.utils import IS_JETSON, LOGGER
|
||||
|
||||
from .imx import torch2imx # noqa
|
||||
|
||||
|
||||
def torch2onnx(
|
||||
torch_model: torch.nn.Module,
|
||||
im: torch.Tensor,
|
||||
onnx_file: str,
|
||||
opset: int = 14,
|
||||
input_names: list[str] = ["images"],
|
||||
output_names: list[str] = ["output0"],
|
||||
dynamic: bool | dict = False,
|
||||
) -> None:
|
||||
"""
|
||||
Export a PyTorch model to ONNX format.
|
||||
|
||||
Args:
|
||||
torch_model (torch.nn.Module): The PyTorch model to export.
|
||||
im (torch.Tensor): Example input tensor for the model.
|
||||
onnx_file (str): Path to save the exported ONNX file.
|
||||
opset (int): ONNX opset version to use for export.
|
||||
input_names (list[str]): List of input tensor names.
|
||||
output_names (list[str]): List of output tensor names.
|
||||
dynamic (bool | dict, optional): Whether to enable dynamic axes.
|
||||
|
||||
Notes:
|
||||
Setting `do_constant_folding=True` may cause issues with DNN inference for torch>=1.12.
|
||||
"""
|
||||
torch.onnx.export(
|
||||
torch_model,
|
||||
im,
|
||||
onnx_file,
|
||||
verbose=False,
|
||||
opset_version=opset,
|
||||
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic or None,
|
||||
)
|
||||
|
||||
|
||||
def onnx2engine(
|
||||
onnx_file: str,
|
||||
engine_file: str | None = None,
|
||||
workspace: int | None = None,
|
||||
half: bool = False,
|
||||
int8: bool = False,
|
||||
dynamic: bool = False,
|
||||
shape: tuple[int, int, int, int] = (1, 3, 640, 640),
|
||||
dla: int | None = None,
|
||||
dataset=None,
|
||||
metadata: dict | None = None,
|
||||
verbose: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Export a YOLO model to TensorRT engine format.
|
||||
|
||||
Args:
|
||||
onnx_file (str): Path to the ONNX file to be converted.
|
||||
engine_file (str, optional): Path to save the generated TensorRT engine file.
|
||||
workspace (int, optional): Workspace size in GB for TensorRT.
|
||||
half (bool, optional): Enable FP16 precision.
|
||||
int8 (bool, optional): Enable INT8 precision.
|
||||
dynamic (bool, optional): Enable dynamic input shapes.
|
||||
shape (tuple[int, int, int, int], optional): Input shape (batch, channels, height, width).
|
||||
dla (int, optional): DLA core to use (Jetson devices only).
|
||||
dataset (ultralytics.data.build.InfiniteDataLoader, optional): Dataset for INT8 calibration.
|
||||
metadata (dict, optional): Metadata to include in the engine file.
|
||||
verbose (bool, optional): Enable verbose logging.
|
||||
prefix (str, optional): Prefix for log messages.
|
||||
|
||||
Raises:
|
||||
ValueError: If DLA is enabled on non-Jetson devices or required precision is not set.
|
||||
RuntimeError: If the ONNX file cannot be parsed.
|
||||
|
||||
Notes:
|
||||
TensorRT version compatibility is handled for workspace size and engine building.
|
||||
INT8 calibration requires a dataset and generates a calibration cache.
|
||||
Metadata is serialized and written to the engine file if provided.
|
||||
"""
|
||||
import tensorrt as trt # noqa
|
||||
|
||||
engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
|
||||
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
if verbose:
|
||||
logger.min_severity = trt.Logger.Severity.VERBOSE
|
||||
|
||||
# Engine builder
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
workspace_bytes = int((workspace or 0) * (1 << 30))
|
||||
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
|
||||
if is_trt10 and workspace_bytes > 0:
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
elif workspace_bytes > 0: # TensorRT versions 7, 8
|
||||
config.max_workspace_size = workspace_bytes
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
half = builder.platform_has_fast_fp16 and half
|
||||
int8 = builder.platform_has_fast_int8 and int8
|
||||
|
||||
# Optionally switch to DLA if enabled
|
||||
if dla is not None:
|
||||
if not IS_JETSON:
|
||||
raise ValueError("DLA is only available on NVIDIA Jetson devices")
|
||||
LOGGER.info(f"{prefix} enabling DLA on core {dla}...")
|
||||
if not half and not int8:
|
||||
raise ValueError(
|
||||
"DLA requires either 'half=True' (FP16) or 'int8=True' (INT8) to be enabled. Please enable one of them and try again."
|
||||
)
|
||||
config.default_device_type = trt.DeviceType.DLA
|
||||
config.DLA_core = int(dla)
|
||||
config.set_flag(trt.BuilderFlag.GPU_FALLBACK)
|
||||
|
||||
# Read ONNX file
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
if not parser.parse_from_file(onnx_file):
|
||||
raise RuntimeError(f"failed to load ONNX file: {onnx_file}")
|
||||
|
||||
# Network inputs
|
||||
inputs = [network.get_input(i) for i in range(network.num_inputs)]
|
||||
outputs = [network.get_output(i) for i in range(network.num_outputs)]
|
||||
for inp in inputs:
|
||||
LOGGER.info(f'{prefix} input "{inp.name}" with shape{inp.shape} {inp.dtype}')
|
||||
for out in outputs:
|
||||
LOGGER.info(f'{prefix} output "{out.name}" with shape{out.shape} {out.dtype}')
|
||||
|
||||
if dynamic:
|
||||
profile = builder.create_optimization_profile()
|
||||
min_shape = (1, shape[1], 32, 32) # minimum input shape
|
||||
max_shape = (*shape[:2], *(int(max(2, workspace or 2) * d) for d in shape[2:])) # max input shape
|
||||
for inp in inputs:
|
||||
profile.set_shape(inp.name, min=min_shape, opt=shape, max=max_shape)
|
||||
config.add_optimization_profile(profile)
|
||||
if int8:
|
||||
config.set_calibration_profile(profile)
|
||||
|
||||
LOGGER.info(f"{prefix} building {'INT8' if int8 else 'FP' + ('16' if half else '32')} engine as {engine_file}")
|
||||
if int8:
|
||||
config.set_flag(trt.BuilderFlag.INT8)
|
||||
config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
|
||||
|
||||
class EngineCalibrator(trt.IInt8Calibrator):
|
||||
"""
|
||||
Custom INT8 calibrator for TensorRT engine optimization.
|
||||
|
||||
This calibrator provides the necessary interface for TensorRT to perform INT8 quantization calibration
|
||||
using a dataset. It handles batch generation, caching, and calibration algorithm selection.
|
||||
|
||||
Attributes:
|
||||
dataset: Dataset for calibration.
|
||||
data_iter: Iterator over the calibration dataset.
|
||||
algo (trt.CalibrationAlgoType): Calibration algorithm type.
|
||||
batch (int): Batch size for calibration.
|
||||
cache (Path): Path to save the calibration cache.
|
||||
|
||||
Methods:
|
||||
get_algorithm: Get the calibration algorithm to use.
|
||||
get_batch_size: Get the batch size to use for calibration.
|
||||
get_batch: Get the next batch to use for calibration.
|
||||
read_calibration_cache: Use existing cache instead of calibrating again.
|
||||
write_calibration_cache: Write calibration cache to disk.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset, # ultralytics.data.build.InfiniteDataLoader
|
||||
cache: str = "",
|
||||
) -> None:
|
||||
"""Initialize the INT8 calibrator with dataset and cache path."""
|
||||
trt.IInt8Calibrator.__init__(self)
|
||||
self.dataset = dataset
|
||||
self.data_iter = iter(dataset)
|
||||
self.algo = (
|
||||
trt.CalibrationAlgoType.ENTROPY_CALIBRATION_2 # DLA quantization needs ENTROPY_CALIBRATION_2
|
||||
if dla is not None
|
||||
else trt.CalibrationAlgoType.MINMAX_CALIBRATION
|
||||
)
|
||||
self.batch = dataset.batch_size
|
||||
self.cache = Path(cache)
|
||||
|
||||
def get_algorithm(self) -> trt.CalibrationAlgoType:
|
||||
"""Get the calibration algorithm to use."""
|
||||
return self.algo
|
||||
|
||||
def get_batch_size(self) -> int:
|
||||
"""Get the batch size to use for calibration."""
|
||||
return self.batch or 1
|
||||
|
||||
def get_batch(self, names) -> list[int] | None:
|
||||
"""Get the next batch to use for calibration, as a list of device memory pointers."""
|
||||
try:
|
||||
im0s = next(self.data_iter)["img"] / 255.0
|
||||
im0s = im0s.to("cuda") if im0s.device.type == "cpu" else im0s
|
||||
return [int(im0s.data_ptr())]
|
||||
except StopIteration:
|
||||
# Return None to signal to TensorRT there is no calibration data remaining
|
||||
return None
|
||||
|
||||
def read_calibration_cache(self) -> bytes | None:
|
||||
"""Use existing cache instead of calibrating again, otherwise, implicitly return None."""
|
||||
if self.cache.exists() and self.cache.suffix == ".cache":
|
||||
return self.cache.read_bytes()
|
||||
|
||||
def write_calibration_cache(self, cache: bytes) -> None:
|
||||
"""Write calibration cache to disk."""
|
||||
_ = self.cache.write_bytes(cache)
|
||||
|
||||
# Load dataset w/ builder (for batching) and calibrate
|
||||
config.int8_calibrator = EngineCalibrator(
|
||||
dataset=dataset,
|
||||
cache=str(Path(onnx_file).with_suffix(".cache")),
|
||||
)
|
||||
|
||||
elif half:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
|
||||
# Write file
|
||||
build = builder.build_serialized_network if is_trt10 else builder.build_engine
|
||||
with build(network, config) as engine, open(engine_file, "wb") as t:
|
||||
# Metadata
|
||||
if metadata is not None:
|
||||
meta = json.dumps(metadata)
|
||||
t.write(len(meta).to_bytes(4, byteorder="little", signed=True))
|
||||
t.write(meta.encode())
|
||||
# Model
|
||||
t.write(engine if is_trt10 else engine.serialize())
|
||||
BIN
ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
ultralytics/utils/export/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
ultralytics/utils/export/__pycache__/imx.cpython-310.pyc
Normal file
BIN
ultralytics/utils/export/__pycache__/imx.cpython-310.pyc
Normal file
Binary file not shown.
289
ultralytics/utils/export/imx.py
Normal file
289
ultralytics/utils/export/imx.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from ultralytics.nn.modules import Detect, Pose
|
||||
from ultralytics.utils import LOGGER
|
||||
from ultralytics.utils.tal import make_anchors
|
||||
from ultralytics.utils.torch_utils import copy_attr
|
||||
|
||||
|
||||
class FXModel(torch.nn.Module):
|
||||
"""
|
||||
A custom model class for torch.fx compatibility.
|
||||
|
||||
This class extends `torch.nn.Module` and is designed to ensure compatibility with torch.fx for tracing and graph
|
||||
manipulation. It copies attributes from an existing model and explicitly sets the model attribute to ensure proper
|
||||
copying.
|
||||
|
||||
Attributes:
|
||||
model (nn.Module): The original model's layers.
|
||||
"""
|
||||
|
||||
def __init__(self, model, imgsz=(640, 640)):
|
||||
"""
|
||||
Initialize the FXModel.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The original model to wrap for torch.fx compatibility.
|
||||
imgsz (tuple[int, int]): The input image size (height, width). Default is (640, 640).
|
||||
"""
|
||||
super().__init__()
|
||||
copy_attr(self, model)
|
||||
# Explicitly set `model` since `copy_attr` somehow does not copy it.
|
||||
self.model = model.model
|
||||
self.imgsz = imgsz
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the model.
|
||||
|
||||
This method performs the forward pass through the model, handling the dependencies between layers and saving
|
||||
intermediate outputs.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor to the model.
|
||||
|
||||
Returns:
|
||||
(torch.Tensor): The output tensor from the model.
|
||||
"""
|
||||
y = [] # outputs
|
||||
for m in self.model:
|
||||
if m.f != -1: # if not from previous layer
|
||||
# from earlier layers
|
||||
x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f]
|
||||
if isinstance(m, Detect):
|
||||
m._inference = types.MethodType(_inference, m) # bind method to Detect
|
||||
m.anchors, m.strides = (
|
||||
x.transpose(0, 1)
|
||||
for x in make_anchors(
|
||||
torch.cat([s / m.stride.unsqueeze(-1) for s in self.imgsz], dim=1), m.stride, 0.5
|
||||
)
|
||||
)
|
||||
if type(m) is Pose:
|
||||
m.forward = types.MethodType(pose_forward, m) # bind method to Detect
|
||||
x = m(x) # run
|
||||
y.append(x) # save output
|
||||
return x
|
||||
|
||||
|
||||
def _inference(self, x: list[torch.Tensor]) -> tuple[torch.Tensor]:
|
||||
"""Decode boxes and cls scores for imx object detection."""
|
||||
x_cat = torch.cat([xi.view(x[0].shape[0], self.no, -1) for xi in x], 2)
|
||||
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
|
||||
dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides
|
||||
return dbox.transpose(1, 2), cls.sigmoid().permute(0, 2, 1)
|
||||
|
||||
|
||||
def pose_forward(self, x: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward pass for imx pose estimation, including keypoint decoding."""
|
||||
bs = x[0].shape[0] # batch size
|
||||
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
|
||||
x = Detect.forward(self, x)
|
||||
pred_kpt = self.kpts_decode(bs, kpt)
|
||||
return (*x, pred_kpt.permute(0, 2, 1))
|
||||
|
||||
|
||||
class NMSWrapper(torch.nn.Module):
|
||||
"""Wrap PyTorch Module with multiclass_nms layer from sony_custom_layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
score_threshold: float = 0.001,
|
||||
iou_threshold: float = 0.7,
|
||||
max_detections: int = 300,
|
||||
task: str = "detect",
|
||||
):
|
||||
"""
|
||||
Initialize NMSWrapper with PyTorch Module and NMS parameters.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model instance.
|
||||
score_threshold (float): Score threshold for non-maximum suppression.
|
||||
iou_threshold (float): Intersection over union threshold for non-maximum suppression.
|
||||
max_detections (int): The number of detections to return.
|
||||
task (str): Task type, either 'detect' or 'pose'.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.score_threshold = score_threshold
|
||||
self.iou_threshold = iou_threshold
|
||||
self.max_detections = max_detections
|
||||
self.task = task
|
||||
|
||||
def forward(self, images):
|
||||
"""Forward pass with model inference and NMS post-processing."""
|
||||
from sony_custom_layers.pytorch import multiclass_nms_with_indices
|
||||
|
||||
# model inference
|
||||
outputs = self.model(images)
|
||||
boxes, scores = outputs[0], outputs[1]
|
||||
nms_outputs = multiclass_nms_with_indices(
|
||||
boxes=boxes,
|
||||
scores=scores,
|
||||
score_threshold=self.score_threshold,
|
||||
iou_threshold=self.iou_threshold,
|
||||
max_detections=self.max_detections,
|
||||
)
|
||||
if self.task == "pose":
|
||||
kpts = outputs[2] # (bs, max_detections, kpts 17*3)
|
||||
out_kpts = torch.gather(kpts, 1, nms_outputs.indices.unsqueeze(-1).expand(-1, -1, kpts.size(-1)))
|
||||
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, out_kpts
|
||||
return nms_outputs.boxes, nms_outputs.scores, nms_outputs.labels, nms_outputs.n_valid
|
||||
|
||||
|
||||
def torch2imx(
|
||||
model: torch.nn.Module,
|
||||
file: Path | str,
|
||||
conf: float,
|
||||
iou: float,
|
||||
max_det: int,
|
||||
metadata: dict | None = None,
|
||||
gptq: bool = False,
|
||||
dataset=None,
|
||||
prefix: str = "",
|
||||
):
|
||||
"""
|
||||
Export YOLO model to IMX format for deployment on Sony IMX500 devices.
|
||||
|
||||
This function quantizes a YOLO model using Model Compression Toolkit (MCT) and exports it
|
||||
to IMX format compatible with Sony IMX500 edge devices. It supports both YOLOv8n and YOLO11n
|
||||
models for detection and pose estimation tasks.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The YOLO model to export. Must be YOLOv8n or YOLO11n.
|
||||
file (Path | str): Output file path for the exported model.
|
||||
conf (float): Confidence threshold for NMS post-processing.
|
||||
iou (float): IoU threshold for NMS post-processing.
|
||||
max_det (int): Maximum number of detections to return.
|
||||
metadata (dict | None, optional): Metadata to embed in the ONNX model. Defaults to None.
|
||||
gptq (bool, optional): Whether to use Gradient-Based Post Training Quantization.
|
||||
If False, uses standard Post Training Quantization. Defaults to False.
|
||||
dataset (optional): Representative dataset for quantization calibration. Defaults to None.
|
||||
prefix (str, optional): Logging prefix string. Defaults to "".
|
||||
|
||||
Returns:
|
||||
f (Path): Path to the exported IMX model directory
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not a supported YOLOv8n or YOLO11n variant.
|
||||
|
||||
Example:
|
||||
>>> from ultralytics import YOLO
|
||||
>>> model = YOLO("yolo11n.pt")
|
||||
>>> path, _ = export_imx(model, "model.imx", conf=0.25, iou=0.45, max_det=300)
|
||||
|
||||
Note:
|
||||
- Requires model_compression_toolkit, onnx, edgemdt_tpc, and sony_custom_layers packages
|
||||
- Only supports YOLOv8n and YOLO11n models (detection and pose tasks)
|
||||
- Output includes quantized ONNX model, IMX binary, and labels.txt file
|
||||
"""
|
||||
import model_compression_toolkit as mct
|
||||
import onnx
|
||||
from edgemdt_tpc import get_target_platform_capabilities
|
||||
|
||||
LOGGER.info(f"\n{prefix} starting export with model_compression_toolkit {mct.__version__}...")
|
||||
|
||||
def representative_dataset_gen(dataloader=dataset):
|
||||
for batch in dataloader:
|
||||
img = batch["img"]
|
||||
img = img / 255.0
|
||||
yield [img]
|
||||
|
||||
tpc = get_target_platform_capabilities(tpc_version="4.0", device_type="imx500")
|
||||
|
||||
bit_cfg = mct.core.BitWidthConfig()
|
||||
if "C2PSA" in model.__str__(): # YOLO11
|
||||
if model.task == "detect":
|
||||
layer_names = ["sub", "mul_2", "add_14", "cat_21"]
|
||||
weights_memory = 2585350.2439
|
||||
n_layers = 238 # 238 layers for fused YOLO11n
|
||||
elif model.task == "pose":
|
||||
layer_names = ["sub", "mul_2", "add_14", "cat_22", "cat_23", "mul_4", "add_15"]
|
||||
weights_memory = 2437771.67
|
||||
n_layers = 257 # 257 layers for fused YOLO11n-pose
|
||||
else: # YOLOv8
|
||||
if model.task == "detect":
|
||||
layer_names = ["sub", "mul", "add_6", "cat_17"]
|
||||
weights_memory = 2550540.8
|
||||
n_layers = 168 # 168 layers for fused YOLOv8n
|
||||
elif model.task == "pose":
|
||||
layer_names = ["add_7", "mul_2", "cat_19", "mul", "sub", "add_6", "cat_18"]
|
||||
weights_memory = 2482451.85
|
||||
n_layers = 187 # 187 layers for fused YOLO11n-pose
|
||||
|
||||
# Check if the model has the expected number of layers
|
||||
if len(list(model.modules())) != n_layers:
|
||||
raise ValueError("IMX export only supported for YOLOv8n and YOLO11n models.")
|
||||
|
||||
for layer_name in layer_names:
|
||||
bit_cfg.set_manual_activation_bit_width([mct.core.common.network_editors.NodeNameFilter(layer_name)], 16)
|
||||
|
||||
config = mct.core.CoreConfig(
|
||||
mixed_precision_config=mct.core.MixedPrecisionQuantizationConfig(num_of_images=10),
|
||||
quantization_config=mct.core.QuantizationConfig(concat_threshold_update=True),
|
||||
bit_width_config=bit_cfg,
|
||||
)
|
||||
|
||||
resource_utilization = mct.core.ResourceUtilization(weights_memory=weights_memory)
|
||||
|
||||
quant_model = (
|
||||
mct.gptq.pytorch_gradient_post_training_quantization( # Perform Gradient-Based Post Training Quantization
|
||||
model=model,
|
||||
representative_data_gen=representative_dataset_gen,
|
||||
target_resource_utilization=resource_utilization,
|
||||
gptq_config=mct.gptq.get_pytorch_gptq_config(
|
||||
n_epochs=1000, use_hessian_based_weights=False, use_hessian_sample_attention=False
|
||||
),
|
||||
core_config=config,
|
||||
target_platform_capabilities=tpc,
|
||||
)[0]
|
||||
if gptq
|
||||
else mct.ptq.pytorch_post_training_quantization( # Perform post training quantization
|
||||
in_module=model,
|
||||
representative_data_gen=representative_dataset_gen,
|
||||
target_resource_utilization=resource_utilization,
|
||||
core_config=config,
|
||||
target_platform_capabilities=tpc,
|
||||
)[0]
|
||||
)
|
||||
|
||||
quant_model = NMSWrapper(
|
||||
model=quant_model,
|
||||
score_threshold=conf or 0.001,
|
||||
iou_threshold=iou,
|
||||
max_detections=max_det,
|
||||
task=model.task,
|
||||
)
|
||||
|
||||
f = Path(str(file).replace(file.suffix, "_imx_model"))
|
||||
f.mkdir(exist_ok=True)
|
||||
onnx_model = f / Path(str(file.name).replace(file.suffix, "_imx.onnx")) # js dir
|
||||
mct.exporter.pytorch_export_model(
|
||||
model=quant_model, save_model_path=onnx_model, repr_dataset=representative_dataset_gen
|
||||
)
|
||||
|
||||
model_onnx = onnx.load(onnx_model) # load onnx model
|
||||
for k, v in metadata.items():
|
||||
meta = model_onnx.metadata_props.add()
|
||||
meta.key, meta.value = k, str(v)
|
||||
|
||||
onnx.save(model_onnx, onnx_model)
|
||||
|
||||
subprocess.run(
|
||||
["imxconv-pt", "-i", str(onnx_model), "-o", str(f), "--no-input-persistency", "--overwrite-output"],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Needed for imx models.
|
||||
with open(f / "labels.txt", "w", encoding="utf-8") as file:
|
||||
file.writelines([f"{name}\n" for _, name in model.names.items()])
|
||||
|
||||
return f
|
||||
Reference in New Issue
Block a user