alexnasa commited on
Commit
8dd9712
·
verified ·
1 Parent(s): 6a82d6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -80,7 +80,7 @@ from src.models.briarmbg import BriaRMBG
80
 
81
  # Constants
82
  MAX_NUM_PARTS = 16
83
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
84
  DTYPE = torch.float16
85
 
86
  # Download and initialize models
@@ -95,22 +95,22 @@ pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weig
95
 
96
  @spaces.GPU()
97
  @torch.no_grad()
98
- def run_triposg(image: Image.Image,
99
- num_parts: int,
100
- seed: int,
101
- num_tokens: int,
102
- num_inference_steps: int,
103
- guidance_scale: float,
104
- max_num_expanded_coords: float,
105
- use_flash_decoder: bool,
106
- rmbg: bool):
107
  """
108
  Generate 3D part meshes from an input image.
109
  """
110
  if rmbg:
111
- img_pil = prepare_image(image, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
112
  else:
113
- img_pil = image
114
 
115
  set_seed(seed)
116
  start_time = time.time()
@@ -159,7 +159,7 @@ def build_demo():
159
  )
160
  with gr.Row():
161
  with gr.Column(scale=1):
162
- input_image = gr.Image(type="pil", label="Input Image")
163
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
164
  seed = gr.Number(value=0, label="Random Seed", precision=0)
165
  num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
 
80
 
81
  # Constants
82
  MAX_NUM_PARTS = 16
83
+ DEVICE = "cuda"
84
  DTYPE = torch.float16
85
 
86
  # Download and initialize models
 
95
 
96
  @spaces.GPU()
97
  @torch.no_grad()
98
+ def run_triposg(image_path: str,
99
+ num_parts: int = 10,
100
+ seed: int = 123,
101
+ num_tokens: int = 1024,
102
+ num_inference_steps: int = 50,
103
+ guidance_scale: float = 7.0,
104
+ max_num_expanded_coords: float = 1e9,
105
+ use_flash_decoder: bool = False,
106
+ rmbg: bool = True):
107
  """
108
  Generate 3D part meshes from an input image.
109
  """
110
  if rmbg:
111
+ img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
112
  else:
113
+ img_pil = Image.open(image_path_or_pil)
114
 
115
  set_seed(seed)
116
  start_time = time.time()
 
159
  )
160
  with gr.Row():
161
  with gr.Column(scale=1):
162
+ input_image = gr.Image(type="filepath", label="Input Image")
163
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
164
  seed = gr.Number(value=0, label="Random Seed", precision=0)
165
  num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")