|
import os
|
|
import gradio as gr
|
|
import time
|
|
import numpy as np
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import torch
|
|
from onnxconverter_common.float16 import convert_float_to_float16
|
|
from onnxslim import slim
|
|
from spandrel import ModelLoader, ImageModelDescriptor
|
|
import spandrel_extra_arches
|
|
from rich.traceback import install
|
|
|
|
def get_out_path(out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False, static: bool = False) -> str:
|
|
filename = f"{name}_fp{'16' if fp16 else '32'}{'_static' if static else ''}_op{opset}{'_onnxslim' if optimized else ''}.onnx"
|
|
return os.path.normpath(os.path.join(out_dir, filename))
|
|
|
|
def convert_and_save_onnx(model, name: str, torch_input, out_dir: str, opset: int, use_static_shapes: bool) -> tuple[onnx.ModelProto, str]:
|
|
if use_static_shapes:
|
|
dynamic_axes = None
|
|
input_names = None
|
|
output_names = None
|
|
|
|
|
|
else:
|
|
dynamic_axes = {
|
|
"input": {0: "batch_size", 2: "width", 3: "height"},
|
|
"output": {0: "batch_size", 2: "width", 3: "height"},
|
|
}
|
|
input_names = ["input"]
|
|
output_names = ["output"]
|
|
|
|
out_path = get_out_path(out_dir, name, opset, False, False, use_static_shapes)
|
|
|
|
|
|
class FakeModel(torch.nn.Module):
|
|
def __init__(self, model: ImageModelDescriptor):
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.model(x)
|
|
|
|
model = FakeModel(model)
|
|
|
|
torch.onnx.export(
|
|
model,
|
|
(torch_input,),
|
|
out_path,
|
|
dynamo=False,
|
|
verbose=False,
|
|
opset_version=opset,
|
|
dynamic_axes=dynamic_axes,
|
|
input_names=input_names,
|
|
output_names=output_names,
|
|
)
|
|
model_proto = onnx.load(out_path)
|
|
return model_proto, out_path
|
|
|
|
def verify_onnx(model, torch_input, onnx_path: str) -> None:
|
|
with torch.inference_mode():
|
|
torch_output_np = model(torch_input).cpu().numpy()
|
|
|
|
onnx_model = onnx.load(onnx_path)
|
|
onnx.checker.check_model(onnx_model)
|
|
|
|
try:
|
|
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
|
ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()}
|
|
onnx_output = ort_session.run(None, ort_inputs)
|
|
np.testing.assert_allclose(
|
|
torch_output_np,
|
|
onnx_output[0],
|
|
rtol=1e-02,
|
|
atol=1e-03,
|
|
)
|
|
print("ONNX output verified against PyTorch output successfully.")
|
|
except AssertionError as e:
|
|
print(f"ONNX verification completed with warnings: {e}")
|
|
gr.Warning("ONNX verification completed with warnings")
|
|
except Exception as e:
|
|
print(f"ONNX verification failed: {e}")
|
|
gr.Warning("ONNX verification failed")
|
|
|
|
|
|
def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False, static: bool = False) -> str:
|
|
loader = ModelLoader()
|
|
model_desc = loader.load_from_file(model_path)
|
|
assert isinstance(model_desc, ImageModelDescriptor)
|
|
|
|
model = model_desc.model.to("cpu").eval()
|
|
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
|
|
|
|
|
if static:
|
|
height, width = 256, 256
|
|
torch_input = torch.randn(1, model_desc.input_channels, height, width, device="cpu")
|
|
else:
|
|
torch_input = torch.randn(1, model_desc.input_channels, 32, 32, device="cpu")
|
|
|
|
out_dir = "./onnx"
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
|
|
|
|
start_time = time.time()
|
|
model_proto, out_path_fp32 = convert_and_save_onnx(
|
|
model, model_name, torch_input, out_dir, opset, static
|
|
)
|
|
out_path = out_path_fp32
|
|
|
|
print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.")
|
|
|
|
|
|
if verify:
|
|
verify_onnx(model, torch_input, out_path_fp32)
|
|
|
|
|
|
if optimize:
|
|
|
|
session_opt = ort.SessionOptions()
|
|
session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
|
session_opt.optimized_model_filepath = get_out_path(out_dir, model_name, opset, False, True, static)
|
|
ort.InferenceSession(out_path_fp32, session_opt)
|
|
if verify:
|
|
verify_onnx(model, torch_input, session_opt.optimized_model_filepath)
|
|
model_proto = onnx.load(session_opt.optimized_model_filepath)
|
|
out_path = session_opt.optimized_model_filepath
|
|
|
|
|
|
if fp16:
|
|
start_time = time.time()
|
|
out_path = get_out_path(out_dir, model_name, opset, True, optimize, static)
|
|
model_proto_fp16 = convert_float_to_float16(model_proto)
|
|
onnx.save(model_proto_fp16, out_path)
|
|
print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.")
|
|
|
|
return out_path
|
|
|
|
def load_model(model_path: str):
|
|
if not model_path:
|
|
return "Ready"
|
|
|
|
loader = ModelLoader()
|
|
try:
|
|
model = loader.load_from_file(model_path)
|
|
assert isinstance(model, ImageModelDescriptor)
|
|
|
|
architecture_info = {
|
|
'architecture_name': getattr(model.architecture, 'name', str(model.architecture)),
|
|
'input_channels': model.input_channels,
|
|
'output_channels': model.output_channels,
|
|
'scale': model.scale,
|
|
'tags': model.tags,
|
|
'supports_fp16': model.supports_half
|
|
|
|
|
|
}
|
|
|
|
if model.supports_half:
|
|
return [str(architecture_info), gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
|
|
else:
|
|
return [str(architecture_info), gr.Radio(choices=["True", "False"], value="False", interactive=False, label="FP16 - Export at half precision. Not supported by all models.")]
|
|
|
|
except Exception as e:
|
|
return [f"Error loading model: {e}", gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
|
|
|
|
def process_choices(opset, fp16, static, slim, file):
|
|
if not file:
|
|
print("No file loaded.")
|
|
gr.Warning("No file loaded.")
|
|
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
|
|
return
|
|
|
|
|
|
fp16 = fp16 == "True"
|
|
static = static == "True"
|
|
slim = slim == "True"
|
|
|
|
yield [gr.Button("Processing", interactive=False), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
|
|
|
|
try:
|
|
result = convert_pipeline(file, opset, True, slim, fp16, static)
|
|
short_name = os.path.basename(result)
|
|
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True)]
|
|
return
|
|
except Exception as e:
|
|
print(f"{e}")
|
|
gr.Warning("Conversion error.")
|
|
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
|
|
return
|
|
|
|
|
|
with gr.Blocks(title="PTH to ONNX Converter") as demo:
|
|
install()
|
|
spandrel_extra_arches.install()
|
|
|
|
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
|
|
metadata = gr.Textbox(value="Ready", label="File Information")
|
|
|
|
dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
|
|
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Not supported by all models. Not very useful because FP16 TRT engines can still be built from FP32 ONNX models.")
|
|
radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Might be required by some models, but can cause slower performance.")
|
|
radio_slim = gr.Radio(choices=["True", "False"], value="False", visible=False, label="OnnxSlim - Can cause issues in some models. I have not yet found any cases where it helps. May remove in the future.")
|
|
|
|
gr.Markdown("After converting, click the logs button at the top to check for any errors or warnings.")
|
|
process_button = gr.Button("Convert", interactive=True)
|
|
file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False)
|
|
|
|
gr.Markdown("""
|
|
# Resources
|
|
- [OpenModelDB](https://openmodeldb.info): Find upscaling models here
|
|
- [VideoJaNai](https://github.com/the-database/VideoJaNai): For upscaling videos using ONNX models
|
|
- [REAL Video Enhancer](https://github.com/TNTwise/REAL-Video-Enhancer): For upscaling videos using ONNX models
|
|
""")
|
|
|
|
process_button.click(fn=process_choices,
|
|
inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
|
|
outputs=[process_button, file_output])
|
|
file_upload.upload(fn=load_model, inputs=file_upload, outputs=[metadata, radio_fp16])
|
|
|
|
if __name__ == "__main__":
|
|
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False) |