import os import gradio as gr import time import numpy as np import onnx import onnxruntime as ort import torch from onnxconverter_common.float16 import convert_float_to_float16 from onnxslim import slim from spandrel import ModelLoader, ImageModelDescriptor import spandrel_extra_arches from rich.traceback import install def get_out_path(out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False, static: bool = False) -> str: filename = f"{name}_fp{'16' if fp16 else '32'}{'_static' if static else ''}_op{opset}{'_onnxslim' if optimized else ''}.onnx" return os.path.normpath(os.path.join(out_dir, filename)) def convert_and_save_onnx(model, name: str, torch_input, out_dir: str, opset: int, use_static_shapes: bool) -> tuple[onnx.ModelProto, str]: if use_static_shapes: dynamic_axes = None input_names = None output_names = None #input_names = ["input"] #output_names = ["output"] else: dynamic_axes = { "input": {0: "batch_size", 2: "width", 3: "height"}, "output": {0: "batch_size", 2: "width", 3: "height"}, } input_names = ["input"] output_names = ["output"] out_path = get_out_path(out_dir, name, opset, False, False, use_static_shapes) #this class was taken from chainner. Running the model through this seems to fix some issues with various arches. class FakeModel(torch.nn.Module): def __init__(self, model: ImageModelDescriptor): super().__init__() self.model = model def forward(self, x: torch.Tensor): return self.model(x) model = FakeModel(model) torch.onnx.export( model, (torch_input,), out_path, dynamo=False, verbose=False, opset_version=opset, dynamic_axes=dynamic_axes, input_names=input_names, output_names=output_names, ) model_proto = onnx.load(out_path) return model_proto, out_path def verify_onnx(model, torch_input, onnx_path: str) -> None: with torch.inference_mode(): torch_output_np = model(torch_input).cpu().numpy() onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) try: ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()} onnx_output = ort_session.run(None, ort_inputs) np.testing.assert_allclose( torch_output_np, onnx_output[0], rtol=1e-02, atol=1e-03, ) print("ONNX output verified against PyTorch output successfully.") except AssertionError as e: print(f"ONNX verification completed with warnings: {e}") gr.Warning("ONNX verification completed with warnings") except Exception as e: print(f"ONNX verification failed: {e}") gr.Warning("ONNX verification failed") def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False, static: bool = False) -> str: loader = ModelLoader() model_desc = loader.load_from_file(model_path) assert isinstance(model_desc, ImageModelDescriptor) model = model_desc.model.to("cpu").eval() model_name = os.path.splitext(os.path.basename(model_path))[0] # Generate dummy input if static: height, width = 256, 256 torch_input = torch.randn(1, model_desc.input_channels, height, width, device="cpu") else: torch_input = torch.randn(1, model_desc.input_channels, 32, 32, device="cpu") out_dir = "./onnx" os.makedirs(out_dir, exist_ok=True) # Convert to ONNX start_time = time.time() model_proto, out_path_fp32 = convert_and_save_onnx( model, model_name, torch_input, out_dir, opset, static ) out_path = out_path_fp32 print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.") # Verify if verify: verify_onnx(model, torch_input, out_path_fp32) # Optimize if optimize: #model_proto = slim(model_proto) session_opt = ort.SessionOptions() session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED session_opt.optimized_model_filepath = get_out_path(out_dir, model_name, opset, False, True, static) ort.InferenceSession(out_path_fp32, session_opt) if verify: verify_onnx(model, torch_input, session_opt.optimized_model_filepath) model_proto = onnx.load(session_opt.optimized_model_filepath) out_path = session_opt.optimized_model_filepath # Convert to FP16 if fp16: start_time = time.time() out_path = get_out_path(out_dir, model_name, opset, True, optimize, static) model_proto_fp16 = convert_float_to_float16(model_proto) onnx.save(model_proto_fp16, out_path) print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.") return out_path def load_model(model_path: str): if not model_path: return "Ready" loader = ModelLoader() try: model = loader.load_from_file(model_path) assert isinstance(model, ImageModelDescriptor) architecture_info = { 'architecture_name': getattr(model.architecture, 'name', str(model.architecture)), 'input_channels': model.input_channels, 'output_channels': model.output_channels, 'scale': model.scale, 'tags': model.tags, 'supports_fp16': model.supports_half #'supports_bf16': model.supports_bfloat16, #'size_requirements': model.size_requirements } if model.supports_half: return [str(architecture_info), gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")] else: return [str(architecture_info), gr.Radio(choices=["True", "False"], value="False", interactive=False, label="FP16 - Export at half precision. Not supported by all models.")] except Exception as e: return [f"Error loading model: {e}", gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")] def process_choices(opset, fp16, static, slim, file): if not file: print("No file loaded.") gr.Warning("No file loaded.") yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)] return # Convert string choices to boolean fp16 = fp16 == "True" static = static == "True" slim = slim == "True" yield [gr.Button("Processing", interactive=False), gr.DownloadButton(label="💾 Download Converted Model", visible=False)] try: result = convert_pipeline(file, opset, True, slim, fp16, static) short_name = os.path.basename(result) yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True)] return except Exception as e: print(f"{e}") gr.Warning("Conversion error.") yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)] return # Create Gradio interface with gr.Blocks(title="PTH to ONNX Converter") as demo: install() spandrel_extra_arches.install() file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors']) metadata = gr.Textbox(value="Ready", label="File Information") dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset") radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Not supported by all models. Not very useful because FP16 TRT engines can still be built from FP32 ONNX models.") radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Might be required by some models, but can cause slower performance.") radio_slim = gr.Radio(choices=["True", "False"], value="False", visible=False, label="OnnxSlim - Can cause issues in some models. I have not yet found any cases where it helps. May remove in the future.") #turn off for now gr.Markdown("After converting, click the logs button at the top to check for any errors or warnings.") process_button = gr.Button("Convert", interactive=True) file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False) gr.Markdown(""" # Resources - [OpenModelDB](https://openmodeldb.info): Find upscaling models here - [VideoJaNai](https://github.com/the-database/VideoJaNai): For upscaling videos using ONNX models - [REAL Video Enhancer](https://github.com/TNTwise/REAL-Video-Enhancer): For upscaling videos using ONNX models """) process_button.click(fn=process_choices, inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload], outputs=[process_button, file_output]) file_upload.upload(fn=load_model, inputs=file_upload, outputs=[metadata, radio_fp16]) if __name__ == "__main__": demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)