Zarxrax commited on
Commit
f35b114
·
verified ·
1 Parent(s): 62f6628

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -20,6 +20,8 @@ def convert_and_save_onnx(model, name: str, torch_input, out_dir: str, opset: in
20
  dynamic_axes = None
21
  input_names = None
22
  output_names = None
 
 
23
  else:
24
  dynamic_axes = {
25
  "input": {0: "batch_size", 2: "width", 3: "height"},
@@ -76,14 +78,10 @@ def convert_pipeline(model_path: str, opset: int = 17, verify: bool = True, opti
76
 
77
  # Generate dummy input
78
  if static:
79
- sr = model_desc.size_requirements
80
- min_size = getattr(sr, "size", 32)
81
- multiple = getattr(sr, "multiple", 1)
82
- h = ((min_size + multiple - 1) // multiple) * multiple
83
- w = ((min_size + multiple - 1) // multiple) * multiple
84
- torch_input = torch.randn(1, model_desc.input_channels, h, w, device="cpu")
85
  else:
86
- torch_input = torch.randn(1, model_desc.input_channels or 3, 32, 32, device="cpu")
87
 
88
  out_dir = "./onnx"
89
  os.makedirs(out_dir, exist_ok=True)
@@ -137,14 +135,19 @@ def load_model(model_path: str):
137
  'input_channels': model.input_channels,
138
  'output_channels': model.output_channels,
139
  'scale': model.scale,
140
- #'size_requirements': model.size_requirements,
141
- 'tags': model.tags
 
 
142
  }
143
 
144
- return str(architecture_info)
 
 
 
145
 
146
  except Exception as e:
147
- return f"Error loading model: {e}"
148
 
149
  def process_choices(opset, fp16, static, slim, file):
150
  if not file:
@@ -176,15 +179,18 @@ with gr.Blocks(title="PTH to ONNX Converter") as demo:
176
  install()
177
  spandrel_extra_arches.install()
178
 
179
- gr.Textbox(value="This space is still under construction. Model outputs may not be correct or valid.",
180
- label="Notice", interactive=False)
 
 
 
181
 
182
  file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
183
  metadata = gr.Textbox(value="Ready", label="File Information")
184
 
185
  dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
186
- radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Usually fine to leave false. An FP32 ONNX model can be ran as FP16, but not the other way around.")
187
- radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Usually best left false unless the conversion fails or the resulting model doesn't work correctly.")
188
  radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim - Can potentially optimize the model further, but usually has little effect and can cause the model to not work correctly.")
189
 
190
  process_button = gr.Button("Convert", interactive=True)
@@ -200,7 +206,7 @@ with gr.Blocks(title="PTH to ONNX Converter") as demo:
200
  process_button.click(fn=process_choices,
201
  inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
202
  outputs=[process_button, file_output])
203
- file_upload.upload(fn=load_model, inputs=file_upload, outputs=metadata)
204
 
205
  if __name__ == "__main__":
206
  demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)
 
20
  dynamic_axes = None
21
  input_names = None
22
  output_names = None
23
+ #input_names = ["input"]
24
+ #output_names = ["output"]
25
  else:
26
  dynamic_axes = {
27
  "input": {0: "batch_size", 2: "width", 3: "height"},
 
78
 
79
  # Generate dummy input
80
  if static:
81
+ height, width = 256, 256
82
+ torch_input = torch.randn(1, model_desc.input_channels, height, width, device="cpu")
 
 
 
 
83
  else:
84
+ torch_input = torch.randn(1, model_desc.input_channels, 32, 32, device="cpu")
85
 
86
  out_dir = "./onnx"
87
  os.makedirs(out_dir, exist_ok=True)
 
135
  'input_channels': model.input_channels,
136
  'output_channels': model.output_channels,
137
  'scale': model.scale,
138
+ 'tags': model.tags,
139
+ 'supports_fp16': model.supports_half
140
+ #'supports_bf16': model.supports_bfloat16,
141
+ #'size_requirements': model.size_requirements
142
  }
143
 
144
+ if model.supports_half:
145
+ return [str(architecture_info), gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
146
+ else:
147
+ return [str(architecture_info), gr.Radio(choices=["True", "False"], value="False", interactive=False, label="FP16 - Export at half precision. Not supported by all models.")]
148
 
149
  except Exception as e:
150
+ return [f"Error loading model: {e}", gr.Radio(choices=["True", "False"], interactive=True, label="FP16 - Export at half precision. Not supported by all models.")]
151
 
152
  def process_choices(opset, fp16, static, slim, file):
153
  if not file:
 
179
  install()
180
  spandrel_extra_arches.install()
181
 
182
+ gr.Textbox(value="This tool is still in development. Outputs may be incorrect.", label="Notice")
183
+ # with gr.Accordion("Model Cheat Sheet", open=False):
184
+ # gr.Markdown("""
185
+ # After converting a model, be sure to click the log button above to check the output for any errors or warnings.
186
+ # """)
187
 
188
  file_upload = gr.File(label="Upload a PyTorch model", file_types=['.pth', '.pt', '.safetensors'])
189
  metadata = gr.Textbox(value="Ready", label="File Information")
190
 
191
  dropdown_opset = gr.Dropdown(choices=[17, 18, 19, 20], value=20, label="Opset")
192
+ radio_fp16 = gr.Radio(choices=["True", "False"], value="False", label="FP16 - Export at half precision. Not supported by all models.")
193
+ radio_static = gr.Radio(choices=["True", "False"], value="False", label="Static Shapes - Some models may work better with static shapes. You may need to try both options.")
194
  radio_slim = gr.Radio(choices=["True", "False"], value="False", label="OnnxSlim - Can potentially optimize the model further, but usually has little effect and can cause the model to not work correctly.")
195
 
196
  process_button = gr.Button("Convert", interactive=True)
 
206
  process_button.click(fn=process_choices,
207
  inputs=[dropdown_opset, radio_fp16, radio_static, radio_slim, file_upload],
208
  outputs=[process_button, file_output])
209
+ file_upload.upload(fn=load_model, inputs=file_upload, outputs=[metadata, radio_fp16])
210
 
211
  if __name__ == "__main__":
212
  demo.launch(show_error=True, inbrowser=True, show_api=False, debug=False)