JichenHu commited on
Commit
27113d5
·
verified ·
1 Parent(s): ea095c9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -4
app.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
 
 
7
 
8
  # 延迟 CUDA 初始化
9
  weight_dtype = torch.float32
@@ -57,7 +59,7 @@ def process_image(input_image):
57
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
58
  processed_frame = Image.fromarray(processed_frame)
59
 
60
- return processed_frame
61
 
62
  # 创建 Gradio 界面
63
  def create_gradio_interface():
@@ -73,13 +75,20 @@ def create_gradio_interface():
73
  input_image = gr.Image(label="Input Image", type="numpy")
74
  submit_btn = gr.Button("Remove Reflection", variant="primary")
75
  with gr.Column():
76
- output_image = gr.Image(label="Processed Image")
 
 
 
 
 
 
 
77
 
78
  # 添加示例
79
  gr.Examples(
80
  examples=example_images,
81
  inputs=input_image,
82
- outputs=output_image,
83
  fn=process_image,
84
  cache_examples=False, # 缓存结果以加快加载速度
85
  label="Example Images",
@@ -89,7 +98,7 @@ def create_gradio_interface():
89
  submit_btn.click(
90
  fn=process_image,
91
  inputs=input_image,
92
- outputs=output_image,
93
  )
94
 
95
  return demo
 
4
  import torch
5
  from PIL import Image
6
  import gradio as gr
7
+ from gradio_imageslider import ImageSlider
8
+
9
 
10
  # 延迟 CUDA 初始化
11
  weight_dtype = torch.float32
 
59
  processed_frame = (processed_frame[0] * 255).astype(np.uint8)
60
  processed_frame = Image.fromarray(processed_frame)
61
 
62
+ return input_image, processed_frame
63
 
64
  # 创建 Gradio 界面
65
  def create_gradio_interface():
 
75
  input_image = gr.Image(label="Input Image", type="numpy")
76
  submit_btn = gr.Button("Remove Reflection", variant="primary")
77
  with gr.Column():
78
+ image_output_slider = ImageSlider(
79
+ label="outputs",
80
+ type="filepath",
81
+ show_download_button=True,
82
+ show_share_button=True,
83
+ interactive=False,
84
+ elem_classes="slider",
85
+ )
86
 
87
  # 添加示例
88
  gr.Examples(
89
  examples=example_images,
90
  inputs=input_image,
91
+ outputs=image_output_slider,
92
  fn=process_image,
93
  cache_examples=False, # 缓存结果以加快加载速度
94
  label="Example Images",
 
98
  submit_btn.click(
99
  fn=process_image,
100
  inputs=input_image,
101
+ outputs=image_output_slider,
102
  )
103
 
104
  return demo