sisr2onnx / app.py
Zarxrax's picture
Upload app.py
3a3e965 verified
import os
import gradio as gr
import time
import numpy as np
import onnx
import onnxruntime as ort
import torch
from onnxconverter_common.float16 import convert_float_to_float16
from onnxslim import slim
from spandrel import ModelLoader, ImageModelDescriptor
import spandrel_extra_arches
from rich.traceback import install
def get_out_path(out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False, static: bool = False) -> str:
filename = f"{name}_fp{'16' if fp16 else '32'}{'_static' if static else ''}_op{opset}{'_onnxslim' if optimized else ''}.onnx"
return os.path.normpath(os.path.join(out_dir, filename))
def convert_and_save_onnx(model, name: str, torch_input, out_dir: str, opset: int, use_static_shapes: bool) -> tuple[onnx.ModelProto, str]:
if use_static_shapes:
dynamic_axes = None
input_names = None
output_names = None
#input_names = ["input"]
#output_names = ["output"]
else:
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
input_names = ["input"]
output_names = ["output"]
out_path = get_out_path(out_dir, name, opset, False, False, use_static_shapes)
#this class was taken from chainner. Running the model through this seems to fix some issues with various arches.
class FakeModel(torch.nn.Module):
def __init__(self, model: ImageModelDescriptor):
super().__init__()
self.model = model
def forward(self, x: torch.Tensor):
return self.model(x)
model = FakeModel(model)
torch.onnx.export(
model,
(torch_input,),
out_path,
dynamo=False,
verbose=False,
opset_version=opset,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
)
model_proto = onnx.load(out_path)
return model_proto, out_path
def verify_onnx(model, torch_input, onnx_path: str) -> None:
with torch.inference_mode():
torch_output_np = model(torch_input).cpu().numpy()
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
try:
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()}
onnx_output = ort_session.run(None, ort_inputs)
np.testing.assert_allclose(
torch_output_np,
onnx_output[0],
rtol=1e-02,
atol=1e-03,
)
print("ONNX output verified against PyTorch output successfully.")
except AssertionError as e:
print(f"ONNX verification completed with warnings: {e}")
gr.Warning("ONNX verification completed with warnings")
except Exception as e:
print(f"ONNX verification failed: {e}")
gr.Warning("ONNX verification failed")
def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False, static: bool = False) -> str:
loader = ModelLoader()
model_desc = loader.load_from_file(model_path)
assert isinstance(model_desc, ImageModelDescriptor)
model = model_desc.model.to("cpu").eval()
model_name = os.path.splitext(os.path.basename(model_path))[0]
# Generate dummy input
if static:
height, width = 256, 256
torch_input = torch.randn(1, model_desc.input_channels, height, width, device="cpu")
else:
torch_input = torch.randn(1, model_desc.input_channels, 32, 32, device="cpu")
out_dir = "./onnx"
os.makedirs(out_dir, exist_ok=True)
# Convert to ONNX
start_time = time.time()
model_proto, out_path_fp32 = convert_and_save_onnx(
model, model_name, torch_input, out_dir, opset, static
)
out_path = out_path_fp32
print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.")
# Verify
if verify:
verify_onnx(model, torch_input, out_path_fp32)
# Optimize
if optimize:
#model_proto = slim(model_proto)
session_opt = ort.SessionOptions()
session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session_opt.optimized_model_filepath = get_out_path(out_dir, model_name, opset, False, True, static)
ort.InferenceSession(out_path_fp32, session_opt)
if verify:
verify_onnx(model, torch_input, session_opt.optimized_model_filepath)
model_proto = onnx.load(session_opt.optimized_model_filepath)
out_path = session_opt.optimized_model_filepath
# Convert to FP16
if fp16:
start_time = time.time()
out_path = get_out_path(out_dir, model_name, opset, True, optimize, static)
model_proto_fp16 = convert_float_to_float16(model_proto)
onnx.save(model_proto_fp16, out_path)
print(f"Saved to {out_path_fp32} in {time.time() - start_time:.2f} seconds.")
return out_path
def load_model(model_path: str):
if not model_path:
return "Ready"
loader = ModelLoader()
try:
model = loader.load_from_file(model_path)
assert isinstance(model, ImageModelDescriptor)
architecture_info = {
'architecture_name': getattr(model.architecture, 'name', str(model.architecture)),
'input_channels': model.input_channels,
'output_channels': model.output_channels,
'scale': model.scale,
'tags': model.tags,
'supports_fp16': model.supports_half
#'supports_bf16': model.supports_bfloat16,
#'size_requirements': model.size_requirements
}
if model.supports_half:
return [str(architecture_info), gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
else:
return [str(architecture_info), gr.Radio(choices=["True", "False"], value="False", interactive=False, label="FP16 - Export at half precision. Not supported by all models.")]
except Exception as e:
return [f"Error loading model: {e}", gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
def process_choices(opset, fp16, static, slim, file):
if not file:
print("No file loaded.")
gr.Warning("No file loaded.")
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
return
# Convert string choices to boolean
fp16 = fp16 == "True"
static = static == "True"
slim = slim == "True"
yield [gr.Button("Processing", interactive=False), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
try:
result = convert_pipeline(file, opset, True, slim, fp16, static)
short_name = os.path.basename(result)
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True)]
return
except Exception as e:
print(f"{e}")
gr.Warning("Conversion error.")
yield [gr.Button("Convert", interactive=True), gr.DownloadButton(label="💾 Download Converted Model", visible=False)]
return
# Create Gradio interface
with gr.Blocks(title="PTH to ONNX Converter") as demo:
install()
spandrel_extra_arches.install()
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
metadata = gr.Textbox(value="Ready", label="File Information")
dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Not supported by all models. Not very useful because FP16 TRT engines can still be built from FP32 ONNX models.")
radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Might be required by some models, but can cause slower performance.")
radio_slim = gr.Radio(choices=["True", "False"], value="False", visible=False, label="OnnxSlim - Can cause issues in some models. I have not yet found any cases where it helps. May remove in the future.") #turn off for now
gr.Markdown("After converting, click the logs button at the top to check for any errors or warnings.")
process_button = gr.Button("Convert", interactive=True)
file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False)
gr.Markdown("""
# Resources
- [OpenModelDB](https://openmodeldb.info): Find upscaling models here
- [VideoJaNai](https://github.com/the-database/VideoJaNai): For upscaling videos using ONNX models
- [REAL Video Enhancer](https://github.com/TNTwise/REAL-Video-Enhancer): For upscaling videos using ONNX models
""")
process_button.click(fn=process_choices,
inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
outputs=[process_button, file_output])
file_upload.upload(fn=load_model, inputs=file_upload, outputs=[metadata, radio_fp16])
if __name__ == "__main__":
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)