|
import os |
|
import gradio as gr |
|
import time |
|
from logging import Logger |
|
from os import path as osp |
|
from pathlib import Path |
|
import numpy as np |
|
import onnx |
|
import onnxruntime as ort |
|
import torch |
|
from onnx import ModelProto |
|
from onnxconverter_common.float16 import convert_float_to_float16 |
|
from onnxslim import slim |
|
from spandrel import ModelLoader, ImageModelDescriptor |
|
import spandrel_extra_arches |
|
from rich.traceback import install |
|
from torch import Tensor |
|
from trainner-redux.traiNNer.models.base_model import BaseModel |
|
from trainner-redux.traiNNer.utils.logger import clickable_file_path, get_root_logger |
|
from trainner-redux.traiNNer.utils.redux_options import ReduxOptions |
|
|
|
def get_out_path( |
|
out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False |
|
) -> str: |
|
filename = f"{name}_fp{'16' if fp16 else '32'}_op{opset}{'_onnxslim' if optimized else ''}.onnx" |
|
return osp.normpath(osp.join(out_dir, filename)) |
|
|
|
|
|
def convert_and_save_onnx( |
|
model: BaseModel, |
|
logger: Logger, |
|
opt: ReduxOptions, |
|
torch_input: Tensor, |
|
out_dir: str, |
|
) -> tuple[ModelProto, int, str]: |
|
assert model.net_g is not None |
|
assert opt.onnx is not None |
|
|
|
if opt.onnx.use_static_shapes: |
|
dynamic_axes = None |
|
input_names = None |
|
output_names = None |
|
else: |
|
dynamic_axes = { |
|
"input": {0: "batch_size", 2: "width", 3: "height"}, |
|
"output": {0: "batch_size", 2: "width", 3: "height"}, |
|
} |
|
input_names = ["input"] |
|
output_names = ["output"] |
|
|
|
out_path = get_out_path(out_dir, opt.name, opt.onnx.opset, False) |
|
|
|
torch.onnx.export( |
|
model.net_g, |
|
(torch_input,), |
|
out_path, |
|
dynamo=opt.onnx.dynamo, |
|
verbose=False, |
|
opset_version=opt.onnx.opset, |
|
dynamic_axes=dynamic_axes, |
|
input_names=input_names, |
|
output_names=output_names, |
|
) |
|
model_proto = onnx.load(out_path) |
|
|
|
assert model_proto is not None |
|
|
|
return model_proto, opt.onnx.opset, out_path |
|
|
|
|
|
def verify_onnx( |
|
model: BaseModel, logger: Logger, torch_input: Tensor, onnx_path: str |
|
) -> None: |
|
assert model.net_g is not None |
|
|
|
with torch.inference_mode(): |
|
torch_output_np = model.net_g(torch_input).cpu().numpy() |
|
|
|
onnx_model = onnx.load(onnx_path) |
|
onnx.checker.check_model(onnx_model) |
|
|
|
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) |
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()} |
|
onnx_output = ort_session.run(None, ort_inputs) |
|
|
|
try: |
|
np.testing.assert_allclose( |
|
torch_output_np, |
|
onnx_output[0], |
|
rtol=1e-02, |
|
atol=1e-03, |
|
) |
|
logger.info("ONNX output verified against PyTorch output successfully.") |
|
except AssertionError as e: |
|
logger.warning("ONNX verification completed with warnings: %s", e) |
|
|
|
|
|
def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False) -> None: |
|
|
|
loader = ModelLoader() |
|
model_desc = loader.load_from_file(model_path) |
|
assert isinstance(model_desc, ImageModelDescriptor) |
|
|
|
model = model_desc.model.to("cpu").eval() |
|
|
|
|
|
class OptONNX: |
|
def __init__(self): |
|
self.opset = opset |
|
self.dynamo = False |
|
self.use_static_shapes = True |
|
self.verify = verify |
|
self.optimize = optimize |
|
self.fp16 = fp16 |
|
|
|
sr = model_desc.size_requirements |
|
min_size = getattr(sr, "size", 32) |
|
multiple = getattr(sr, "multiple", 1) |
|
|
|
h = ((min_size + multiple - 1) // multiple) * multiple |
|
w = ((min_size + multiple - 1) // multiple) * multiple |
|
self.shape = f"1x{model_desc.input_channels}x{h}x{w}" |
|
|
|
class Opt: |
|
def __init__(self): |
|
self.name = Path(model_path).stem |
|
self.onnx = OptONNX() |
|
|
|
opt = Opt() |
|
|
|
|
|
if opt.onnx.use_static_shapes: |
|
dims = tuple(map(int, opt.onnx.shape.split("x"))) |
|
torch_input = torch.randn(*dims, device="cpu") |
|
else: |
|
torch_input = torch.randn(1, model_desc.input_channels or 3, 32, 32, device="cpu") |
|
|
|
assert model is not None |
|
logger = get_root_logger() |
|
|
|
out_dir = "./onnx" |
|
os.makedirs(out_dir, exist_ok=True) |
|
|
|
class ModelWrapper: |
|
def __init__(self, net_g): |
|
self.net_g = net_g |
|
|
|
wrapped_model = ModelWrapper(model) |
|
|
|
start_time = time.time() |
|
model_proto, opset, out_path_fp32 = convert_and_save_onnx( |
|
wrapped_model, logger, opt, torch_input, out_dir |
|
) |
|
out_path = out_path_fp32 |
|
|
|
logger.info( |
|
"Saved to %s in %.2f seconds.", |
|
clickable_file_path(Path(out_path_fp32).absolute().parent, out_path_fp32), |
|
time.time() - start_time, |
|
) |
|
|
|
if opt.onnx.verify: |
|
verify_onnx(wrapped_model, logger, torch_input, out_path_fp32) |
|
|
|
if opt.onnx.optimize: |
|
model_proto = slim(model_proto) |
|
session_opt = ort.SessionOptions() |
|
session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED |
|
session_opt.optimized_model_filepath = get_out_path(out_dir, opt.name, opset, False, True) |
|
ort.InferenceSession(out_path_fp32, session_opt) |
|
verify_onnx(wrapped_model, logger, torch_input, session_opt.optimized_model_filepath) |
|
model_proto = onnx.load(session_opt.optimized_model_filepath) |
|
out_path = session_opt.optimized_model_filepath |
|
|
|
if opt.onnx.fp16: |
|
start_time = time.time() |
|
out_path = get_out_path(out_dir, opt.name, opset, True, opt.onnx.optimize) |
|
model_proto_fp16 = convert_float_to_float16(model_proto) |
|
onnx.save(model_proto_fp16, out_path) |
|
logger.info( |
|
"Saved to %s in %.2f seconds.", |
|
clickable_file_path(Path(out_path).absolute().parent, out_path), |
|
time.time() - start_time, |
|
) |
|
return out_path |
|
|
|
|
|
|
|
def process_choices(fp16, slim, file): |
|
if fp16 == "False": |
|
fp16 = False |
|
if slim == "False": |
|
slim = False |
|
result = convert_pipeline(file, 17, True, slim, fp16) |
|
short_name = os.path.basename(result) |
|
return gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True) |
|
|
|
|
|
with gr.Blocks(title="PTH to ONNX Converter") as demo: |
|
|
|
install() |
|
spandrel_extra_arches.install() |
|
|
|
gr.Textbox(value="This space is still under construction. Model outputs may not be correct or valid.", label="Notice", interactive=False) |
|
|
|
|
|
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors']) |
|
|
|
|
|
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16") |
|
radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim") |
|
|
|
|
|
process_button = gr.Button("Convert") |
|
|
|
|
|
|
|
file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False) |
|
|
|
|
|
process_button.click(fn=process_choices, inputs=[radio_fp16, radio_slim, file_upload], outputs=file_output, show_progress="full") |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False) |