import os import gradio as gr import time from logging import Logger from os import path as osp from pathlib import Path import numpy as np import onnx import onnxruntime as ort import torch from onnx import ModelProto from onnxconverter_common.float16 import convert_float_to_float16 from onnxslim import slim from spandrel import ModelLoader, ImageModelDescriptor import spandrel_extra_arches from rich.traceback import install from torch import Tensor from trainner-redux.traiNNer.models.base_model import BaseModel from trainner-redux.traiNNer.utils.logger import clickable_file_path, get_root_logger from trainner-redux.traiNNer.utils.redux_options import ReduxOptions def get_out_path( out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False ) -> str: filename = f"{name}_fp{'16' if fp16 else '32'}_op{opset}{'_onnxslim' if optimized else ''}.onnx" return osp.normpath(osp.join(out_dir, filename)) def convert_and_save_onnx( model: BaseModel, logger: Logger, opt: ReduxOptions, torch_input: Tensor, out_dir: str, ) -> tuple[ModelProto, int, str]: assert model.net_g is not None assert opt.onnx is not None if opt.onnx.use_static_shapes: dynamic_axes = None input_names = None output_names = None else: dynamic_axes = { "input": {0: "batch_size", 2: "width", 3: "height"}, "output": {0: "batch_size", 2: "width", 3: "height"}, } input_names = ["input"] output_names = ["output"] out_path = get_out_path(out_dir, opt.name, opt.onnx.opset, False) torch.onnx.export( model.net_g, (torch_input,), out_path, dynamo=opt.onnx.dynamo, verbose=False, opset_version=opt.onnx.opset, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, ) model_proto = onnx.load(out_path) assert model_proto is not None return model_proto, opt.onnx.opset, out_path def verify_onnx( model: BaseModel, logger: Logger, torch_input: Tensor, onnx_path: str ) -> None: assert model.net_g is not None with torch.inference_mode(): torch_output_np = model.net_g(torch_input).cpu().numpy() onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()} onnx_output = ort_session.run(None, ort_inputs) try: np.testing.assert_allclose( # pyright: ignore # TODO onnx 1.18 torch_output_np, onnx_output[0], # pyright: ignore # TODO onnx 1.18 rtol=1e-02, atol=1e-03, # pyright: ignore # TODO onnx 1.18 ) logger.info("ONNX output verified against PyTorch output successfully.") except AssertionError as e: logger.warning("ONNX verification completed with warnings: %s", e) def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False) -> None: loader = ModelLoader() model_desc = loader.load_from_file(model_path) assert isinstance(model_desc, ImageModelDescriptor) model = model_desc.model.to("cpu").eval() # Simulate the `opt` object from traiNNer config class OptONNX: def __init__(self): self.opset = opset self.dynamo = False self.use_static_shapes = True self.verify = verify self.optimize = optimize self.fp16 = fp16 sr = model_desc.size_requirements min_size = getattr(sr, "size", 32) multiple = getattr(sr, "multiple", 1) h = ((min_size + multiple - 1) // multiple) * multiple w = ((min_size + multiple - 1) // multiple) * multiple self.shape = f"1x{model_desc.input_channels}x{h}x{w}" class Opt: def __init__(self): self.name = Path(model_path).stem self.onnx = OptONNX() opt = Opt() # Generate dummy input if opt.onnx.use_static_shapes: dims = tuple(map(int, opt.onnx.shape.split("x"))) torch_input = torch.randn(*dims, device="cpu") else: torch_input = torch.randn(1, model_desc.input_channels or 3, 32, 32, device="cpu") assert model is not None logger = get_root_logger() out_dir = "./onnx" os.makedirs(out_dir, exist_ok=True) class ModelWrapper: def __init__(self, net_g): self.net_g = net_g wrapped_model = ModelWrapper(model) start_time = time.time() model_proto, opset, out_path_fp32 = convert_and_save_onnx( wrapped_model, logger, opt, torch_input, out_dir ) out_path = out_path_fp32 logger.info( "Saved to %s in %.2f seconds.", clickable_file_path(Path(out_path_fp32).absolute().parent, out_path_fp32), time.time() - start_time, ) if opt.onnx.verify: verify_onnx(wrapped_model, logger, torch_input, out_path_fp32) if opt.onnx.optimize: model_proto = slim(model_proto) session_opt = ort.SessionOptions() session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED session_opt.optimized_model_filepath = get_out_path(out_dir, opt.name, opset, False, True) ort.InferenceSession(out_path_fp32, session_opt) verify_onnx(wrapped_model, logger, torch_input, session_opt.optimized_model_filepath) model_proto = onnx.load(session_opt.optimized_model_filepath) out_path = session_opt.optimized_model_filepath if opt.onnx.fp16: start_time = time.time() out_path = get_out_path(out_dir, opt.name, opset, True, opt.onnx.optimize) model_proto_fp16 = convert_float_to_float16(model_proto) onnx.save(model_proto_fp16, out_path) logger.info( "Saved to %s in %.2f seconds.", clickable_file_path(Path(out_path).absolute().parent, out_path), time.time() - start_time, ) return out_path # Define a function that processes the selected radio button choices and the uploaded file. def process_choices(fp16, slim, file): if fp16 == "False": fp16 = False if slim == "False": slim = False result = convert_pipeline(file, 17, True, slim, fp16) short_name = os.path.basename(result) return gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True) # Create a Gradio Blocks interface. with gr.Blocks(title="PTH to ONNX Converter") as demo: install() spandrel_extra_arches.install() gr.Textbox(value="This space is still under construction. Model outputs may not be correct or valid.", label="Notice", interactive=False) # Create a file upload component at the top. file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors']) # Create three groups of yes/no radio buttons. radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16") radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim") # Create a button to trigger the processing function. process_button = gr.Button("Convert") #file_result = gr.Textbox(value="Ready", label="Result") # Create a file output component to allow the user to download the processed file. file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False) # Define the event listener for the button click. process_button.click(fn=process_choices, inputs=[radio_fp16, radio_slim, file_upload], outputs=file_output, show_progress="full") # Launch the interface. if __name__ == "__main__": demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)