sisr2onnx / app.py
Zarxrax's picture
Upload 823 files
62dbcfb verified
raw
history blame
7.9 kB
import os
import gradio as gr
import time
from logging import Logger
from os import path as osp
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
import torch
from onnx import ModelProto
from onnxconverter_common.float16 import convert_float_to_float16
from onnxslim import slim
from spandrel import ModelLoader, ImageModelDescriptor
import spandrel_extra_arches
from rich.traceback import install
from torch import Tensor
from trainner-redux.traiNNer.models.base_model import BaseModel
from trainner-redux.traiNNer.utils.logger import clickable_file_path, get_root_logger
from trainner-redux.traiNNer.utils.redux_options import ReduxOptions
def get_out_path(
out_dir: str, name: str, opset: int, fp16: bool = False, optimized: bool = False
) -> str:
filename = f"{name}_fp{'16' if fp16 else '32'}_op{opset}{'_onnxslim' if optimized else ''}.onnx"
return osp.normpath(osp.join(out_dir, filename))
def convert_and_save_onnx(
model: BaseModel,
logger: Logger,
opt: ReduxOptions,
torch_input: Tensor,
out_dir: str,
) -> tuple[ModelProto, int, str]:
assert model.net_g is not None
assert opt.onnx is not None
if opt.onnx.use_static_shapes:
dynamic_axes = None
input_names = None
output_names = None
else:
dynamic_axes = {
"input": {0: "batch_size", 2: "width", 3: "height"},
"output": {0: "batch_size", 2: "width", 3: "height"},
}
input_names = ["input"]
output_names = ["output"]
out_path = get_out_path(out_dir, opt.name, opt.onnx.opset, False)
torch.onnx.export(
model.net_g,
(torch_input,),
out_path,
dynamo=opt.onnx.dynamo,
verbose=False,
opset_version=opt.onnx.opset,
dynamic_axes=dynamic_axes,
input_names=input_names,
output_names=output_names,
)
model_proto = onnx.load(out_path)
assert model_proto is not None
return model_proto, opt.onnx.opset, out_path
def verify_onnx(
model: BaseModel, logger: Logger, torch_input: Tensor, onnx_path: str
) -> None:
assert model.net_g is not None
with torch.inference_mode():
torch_output_np = model.net_g(torch_input).cpu().numpy()
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
ort_inputs = {ort_session.get_inputs()[0].name: torch_input.cpu().numpy()}
onnx_output = ort_session.run(None, ort_inputs)
try:
np.testing.assert_allclose( # pyright: ignore # TODO onnx 1.18
torch_output_np,
onnx_output[0], # pyright: ignore # TODO onnx 1.18
rtol=1e-02,
atol=1e-03, # pyright: ignore # TODO onnx 1.18
)
logger.info("ONNX output verified against PyTorch output successfully.")
except AssertionError as e:
logger.warning("ONNX verification completed with warnings: %s", e)
def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, optimize: bool = True, fp16: bool = False) -> None:
loader = ModelLoader()
model_desc = loader.load_from_file(model_path)
assert isinstance(model_desc, ImageModelDescriptor)
model = model_desc.model.to("cpu").eval()
# Simulate the `opt` object from traiNNer config
class OptONNX:
def __init__(self):
self.opset = opset
self.dynamo = False
self.use_static_shapes = True
self.verify = verify
self.optimize = optimize
self.fp16 = fp16
sr = model_desc.size_requirements
min_size = getattr(sr, "size", 32)
multiple = getattr(sr, "multiple", 1)
h = ((min_size + multiple - 1) // multiple) * multiple
w = ((min_size + multiple - 1) // multiple) * multiple
self.shape = f"1x{model_desc.input_channels}x{h}x{w}"
class Opt:
def __init__(self):
self.name = Path(model_path).stem
self.onnx = OptONNX()
opt = Opt()
# Generate dummy input
if opt.onnx.use_static_shapes:
dims = tuple(map(int, opt.onnx.shape.split("x")))
torch_input = torch.randn(*dims, device="cpu")
else:
torch_input = torch.randn(1, model_desc.input_channels or 3, 32, 32, device="cpu")
assert model is not None
logger = get_root_logger()
out_dir = "./onnx"
os.makedirs(out_dir, exist_ok=True)
class ModelWrapper:
def __init__(self, net_g):
self.net_g = net_g
wrapped_model = ModelWrapper(model)
start_time = time.time()
model_proto, opset, out_path_fp32 = convert_and_save_onnx(
wrapped_model, logger, opt, torch_input, out_dir
)
out_path = out_path_fp32
logger.info(
"Saved to %s in %.2f seconds.",
clickable_file_path(Path(out_path_fp32).absolute().parent, out_path_fp32),
time.time() - start_time,
)
if opt.onnx.verify:
verify_onnx(wrapped_model, logger, torch_input, out_path_fp32)
if opt.onnx.optimize:
model_proto = slim(model_proto)
session_opt = ort.SessionOptions()
session_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
session_opt.optimized_model_filepath = get_out_path(out_dir, opt.name, opset, False, True)
ort.InferenceSession(out_path_fp32, session_opt)
verify_onnx(wrapped_model, logger, torch_input, session_opt.optimized_model_filepath)
model_proto = onnx.load(session_opt.optimized_model_filepath)
out_path = session_opt.optimized_model_filepath
if opt.onnx.fp16:
start_time = time.time()
out_path = get_out_path(out_dir, opt.name, opset, True, opt.onnx.optimize)
model_proto_fp16 = convert_float_to_float16(model_proto)
onnx.save(model_proto_fp16, out_path)
logger.info(
"Saved to %s in %.2f seconds.",
clickable_file_path(Path(out_path).absolute().parent, out_path),
time.time() - start_time,
)
return out_path
# Define a function that processes the selected radio button choices and the uploaded file.
def process_choices(fp16, slim, file):
if fp16 == "False":
fp16 = False
if slim == "False":
slim = False
result = convert_pipeline(file, 17, True, slim, fp16)
short_name = os.path.basename(result)
return gr.DownloadButton(label=f"💾 {short_name}", value=result, visible=True)
# Create a Gradio Blocks interface.
with gr.Blocks(title="PTH to ONNX Converter") as demo:
install()
spandrel_extra_arches.install()
gr.Textbox(value="This space is still under construction. Model outputs may not be correct or valid.", label="Notice", interactive=False)
# Create a file upload component at the top.
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
# Create three groups of yes/no radio buttons.
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16")
radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim")
# Create a button to trigger the processing function.
process_button = gr.Button("Convert")
#file_result = gr.Textbox(value="Ready", label="Result")
# Create a file output component to allow the user to download the processed file.
file_output = gr.DownloadButton(label="💾 Download Converted Model", visible=False)
# Define the event listener for the button click.
process_button.click(fn=process_choices, inputs=[radio_fp16, radio_slim, file_upload], outputs=file_output, show_progress="full")
# Launch the interface.
if __name__ == "__main__":
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)