Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -80,7 +80,7 @@ from src.models.briarmbg import BriaRMBG
|
|
80 |
|
81 |
# Constants
|
82 |
MAX_NUM_PARTS = 16
|
83 |
-
DEVICE = "cuda"
|
84 |
DTYPE = torch.float16
|
85 |
|
86 |
# Download and initialize models
|
@@ -95,22 +95,22 @@ pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weig
|
|
95 |
|
96 |
@spaces.GPU()
|
97 |
@torch.no_grad()
|
98 |
-
def run_triposg(
|
99 |
-
num_parts: int,
|
100 |
-
seed: int,
|
101 |
-
num_tokens: int,
|
102 |
-
num_inference_steps: int,
|
103 |
-
guidance_scale: float,
|
104 |
-
max_num_expanded_coords: float,
|
105 |
-
use_flash_decoder: bool,
|
106 |
-
rmbg: bool):
|
107 |
"""
|
108 |
Generate 3D part meshes from an input image.
|
109 |
"""
|
110 |
if rmbg:
|
111 |
-
img_pil = prepare_image(
|
112 |
else:
|
113 |
-
img_pil =
|
114 |
|
115 |
set_seed(seed)
|
116 |
start_time = time.time()
|
@@ -159,7 +159,7 @@ def build_demo():
|
|
159 |
)
|
160 |
with gr.Row():
|
161 |
with gr.Column(scale=1):
|
162 |
-
input_image = gr.Image(type="
|
163 |
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
|
164 |
seed = gr.Number(value=0, label="Random Seed", precision=0)
|
165 |
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
|
|
|
80 |
|
81 |
# Constants
|
82 |
MAX_NUM_PARTS = 16
|
83 |
+
DEVICE = "cuda"
|
84 |
DTYPE = torch.float16
|
85 |
|
86 |
# Download and initialize models
|
|
|
95 |
|
96 |
@spaces.GPU()
|
97 |
@torch.no_grad()
|
98 |
+
def run_triposg(image_path: str,
|
99 |
+
num_parts: int = 10,
|
100 |
+
seed: int = 123,
|
101 |
+
num_tokens: int = 1024,
|
102 |
+
num_inference_steps: int = 50,
|
103 |
+
guidance_scale: float = 7.0,
|
104 |
+
max_num_expanded_coords: float = 1e9,
|
105 |
+
use_flash_decoder: bool = False,
|
106 |
+
rmbg: bool = True):
|
107 |
"""
|
108 |
Generate 3D part meshes from an input image.
|
109 |
"""
|
110 |
if rmbg:
|
111 |
+
img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
|
112 |
else:
|
113 |
+
img_pil = Image.open(image_path_or_pil)
|
114 |
|
115 |
set_seed(seed)
|
116 |
start_time = time.time()
|
|
|
159 |
)
|
160 |
with gr.Row():
|
161 |
with gr.Column(scale=1):
|
162 |
+
input_image = gr.Image(type="filepath", label="Input Image")
|
163 |
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
|
164 |
seed = gr.Number(value=0, label="Random Seed", precision=0)
|
165 |
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
|