Upload app.py
Browse files
app.py
CHANGED
@@ -20,6 +20,8 @@ def convert_and_save_onnx(model, name: str, torch_input, out_dir: str, opset: in
|
|
20 |
dynamic_axes = None
|
21 |
input_names = None
|
22 |
output_names = None
|
|
|
|
|
23 |
else:
|
24 |
dynamic_axes = {
|
25 |
"input": {0: "batch_size", 2: "width", 3: "height"},
|
@@ -76,14 +78,10 @@ def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, opti
|
|
76 |
|
77 |
# Generate dummy input
|
78 |
if static:
|
79 |
-
|
80 |
-
|
81 |
-
multiple = getattr(sr, "multiple", 1)
|
82 |
-
h = ((min_size + multiple - 1) // multiple) * multiple
|
83 |
-
w = ((min_size + multiple - 1) // multiple) * multiple
|
84 |
-
torch_input = torch.randn(1, model_desc.input_channels, h, w, device="cpu")
|
85 |
else:
|
86 |
-
torch_input = torch.randn(1, model_desc.input_channels
|
87 |
|
88 |
out_dir = "./onnx"
|
89 |
os.makedirs(out_dir, exist_ok=True)
|
@@ -137,14 +135,19 @@ def load_model(model_path: str):
|
|
137 |
'input_channels': model.input_channels,
|
138 |
'output_channels': model.output_channels,
|
139 |
'scale': model.scale,
|
140 |
-
|
141 |
-
'
|
|
|
|
|
142 |
}
|
143 |
|
144 |
-
|
|
|
|
|
|
|
145 |
|
146 |
except Exception as e:
|
147 |
-
return f"Error loading model: {e}"
|
148 |
|
149 |
def process_choices(opset, fp16, static, slim, file):
|
150 |
if not file:
|
@@ -176,15 +179,18 @@ with gr.Blocks(title="PTH to ONNX Converter") as demo:
|
|
176 |
install()
|
177 |
spandrel_extra_arches.install()
|
178 |
|
179 |
-
gr.Textbox(value="This
|
180 |
-
|
|
|
|
|
|
|
181 |
|
182 |
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
|
183 |
metadata = gr.Textbox(value="Ready", label="File Information")
|
184 |
|
185 |
dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
|
186 |
-
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 -
|
187 |
-
radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes -
|
188 |
radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim - Can potentially optimize the model further, but usually has little effect and can cause the model to not work correctly.")
|
189 |
|
190 |
process_button = gr.Button("Convert", interactive=True)
|
@@ -200,7 +206,7 @@ with gr.Blocks(title="PTH to ONNX Converter") as demo:
|
|
200 |
process_button.click(fn=process_choices,
|
201 |
inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
|
202 |
outputs=[process_button, file_output])
|
203 |
-
file_upload.upload(fn=load_model, inputs=file_upload, outputs=metadata)
|
204 |
|
205 |
if __name__ == "__main__":
|
206 |
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)
|
|
|
20 |
dynamic_axes = None
|
21 |
input_names = None
|
22 |
output_names = None
|
23 |
+
#input_names = ["input"]
|
24 |
+
#output_names = ["output"]
|
25 |
else:
|
26 |
dynamic_axes = {
|
27 |
"input": {0: "batch_size", 2: "width", 3: "height"},
|
|
|
78 |
|
79 |
# Generate dummy input
|
80 |
if static:
|
81 |
+
height, width = 256, 256
|
82 |
+
torch_input = torch.randn(1, model_desc.input_channels, height, width, device="cpu")
|
|
|
|
|
|
|
|
|
83 |
else:
|
84 |
+
torch_input = torch.randn(1, model_desc.input_channels, 32, 32, device="cpu")
|
85 |
|
86 |
out_dir = "./onnx"
|
87 |
os.makedirs(out_dir, exist_ok=True)
|
|
|
135 |
'input_channels': model.input_channels,
|
136 |
'output_channels': model.output_channels,
|
137 |
'scale': model.scale,
|
138 |
+
'tags': model.tags,
|
139 |
+
'supports_fp16': model.supports_half
|
140 |
+
#'supports_bf16': model.supports_bfloat16,
|
141 |
+
#'size_requirements': model.size_requirements
|
142 |
}
|
143 |
|
144 |
+
if model.supports_half:
|
145 |
+
return [str(architecture_info), gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
|
146 |
+
else:
|
147 |
+
return [str(architecture_info), gr.Radio(choices=["True", "False"], value="False", interactive=False, label="FP16 - Export at half precision. Not supported by all models.")]
|
148 |
|
149 |
except Exception as e:
|
150 |
+
return [f"Error loading model: {e}", gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
|
151 |
|
152 |
def process_choices(opset, fp16, static, slim, file):
|
153 |
if not file:
|
|
|
179 |
install()
|
180 |
spandrel_extra_arches.install()
|
181 |
|
182 |
+
gr.Textbox(value="This tool is still in development. Outputs may be incorrect.", label="Notice")
|
183 |
+
# with gr.Accordion("Model Cheat Sheet", open=False):
|
184 |
+
# gr.Markdown("""
|
185 |
+
# After converting a model, be sure to click the log button above to check the output for any errors or warnings.
|
186 |
+
# """)
|
187 |
|
188 |
file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
|
189 |
metadata = gr.Textbox(value="Ready", label="File Information")
|
190 |
|
191 |
dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
|
192 |
+
radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Export at half precision. Not supported by all models.")
|
193 |
+
radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Some models may work better with static shapes. You may need to try both options.")
|
194 |
radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim - Can potentially optimize the model further, but usually has little effect and can cause the model to not work correctly.")
|
195 |
|
196 |
process_button = gr.Button("Convert", interactive=True)
|
|
|
206 |
process_button.click(fn=process_choices,
|
207 |
inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
|
208 |
outputs=[process_button, file_output])
|
209 |
+
file_upload.upload(fn=load_model, inputs=file_upload, outputs=[metadata, radio_fp16])
|
210 |
|
211 |
if __name__ == "__main__":
|
212 |
demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)
|