alex commited on
Commit
166aa76
·
1 Parent(s): c0bc561

split gif preview added

Browse files
Files changed (3) hide show
  1. app.py +61 -99
  2. packages.txt +12 -0
  3. src/utils/render_utils.py +65 -5
app.py CHANGED
@@ -76,7 +76,7 @@ import importlib, site; site.addsitedir(site.getsitepackages()[0]); importlib.in
76
 
77
 
78
  from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
79
- from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings
80
  from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
81
  from src.utils.image_utils import prepare_image
82
  from src.models.briarmbg import BriaRMBG
@@ -101,82 +101,6 @@ def first_file_from_dir(directory, ext):
101
  return sorted(files)[0] if files else None
102
 
103
 
104
- def explode_mesh(mesh, explosion_scale=0.4):
105
-
106
- if isinstance(mesh, trimesh.Scene):
107
- scene = mesh
108
- elif isinstance(mesh, trimesh.Trimesh):
109
- print("Warning: Single mesh provided, can't create exploded view")
110
- scene = trimesh.Scene(mesh)
111
- return scene
112
- else:
113
- print(f"Warning: Unexpected mesh type: {type(mesh)}")
114
- scene = mesh
115
-
116
- if len(scene.geometry) <= 1:
117
- print("Only one geometry found - nothing to explode")
118
- return scene
119
-
120
- print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}")
121
- print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts")
122
-
123
- exploded_scene = trimesh.Scene()
124
-
125
- part_centers = []
126
- geometry_names = []
127
-
128
- for geometry_name, geometry in scene.geometry.items():
129
- if hasattr(geometry, 'vertices'):
130
- transform = scene.graph[geometry_name][0]
131
- vertices_global = trimesh.transformations.transform_points(
132
- geometry.vertices, transform)
133
- center = np.mean(vertices_global, axis=0)
134
- part_centers.append(center)
135
- geometry_names.append(geometry_name)
136
- print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}")
137
-
138
- if not part_centers:
139
- print("No valid geometries with vertices found")
140
- return scene
141
-
142
- part_centers = np.array(part_centers)
143
- global_center = np.mean(part_centers, axis=0)
144
-
145
- print(f"[EXPLODE_MESH] Global center: {global_center}")
146
-
147
- for i, (geometry_name, geometry) in enumerate(scene.geometry.items()):
148
- if hasattr(geometry, 'vertices'):
149
- if i < len(part_centers):
150
- part_center = part_centers[i]
151
- direction = part_center - global_center
152
-
153
- direction_norm = np.linalg.norm(direction)
154
- if direction_norm > 1e-6:
155
- direction = direction / direction_norm
156
- else:
157
- direction = np.random.randn(3)
158
- direction = direction / np.linalg.norm(direction)
159
-
160
- offset = direction * explosion_scale
161
- else:
162
- offset = np.zeros(3)
163
-
164
- original_transform = scene.graph[geometry_name][0].copy()
165
-
166
- new_transform = original_transform.copy()
167
- new_transform[:3, 3] = new_transform[:3, 3] + offset
168
-
169
- exploded_scene.add_geometry(
170
- geometry,
171
- transform=new_transform,
172
- geom_name=geometry_name
173
- )
174
-
175
- print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}")
176
-
177
- print("[EXPLODE_MESH] Mesh explosion complete")
178
- return exploded_scene
179
-
180
 
181
  def get_duration(
182
  image_path,
@@ -200,10 +124,50 @@ def get_duration(
200
 
201
  return int(duration_seconds)
202
 
203
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  @spaces.GPU(duration=get_duration)
205
  @torch.no_grad()
206
- def run_triposg(image_path: str,
207
  num_parts: int = 1,
208
  seed: int = 0,
209
  num_tokens: int = 1024,
@@ -220,7 +184,7 @@ def run_triposg(image_path: str,
220
  This function takes a single 2D image as input and produces a set of part-based 3D meshes,
221
  using compositional latent diffusion with attention to structure and part separation.
222
  Optionally removes the background using a pretrained background removal model (RMBG),
223
- and outputs a merged object mesh, a split preview (exploded view).
224
 
225
  Args:
226
  image_path (str): Path to the input image file on disk.
@@ -237,13 +201,10 @@ def run_triposg(image_path: str,
237
  Returns:
238
  Tuple[str, str, str, str]:
239
  - `merged_path` (str): File path to the merged full object mesh (`object.glb`).
240
- - `split_preview_path` (str): File path to the exploded-view mesh (`split.glb`).
241
- - `export_dir` (str): Directory where all generated meshes were saved.
242
 
243
  Notes:
244
  - This function utilizes HuggingFace pretrained weights for both part generation and background removal.
245
- - The final output includes exploded and merged views to visualize object structure.
246
- - Parts are exported in `.glb` format, and zipped for bulk download.
247
  - Generation time depends on the number of parts and inference parameters.
248
  """
249
 
@@ -299,11 +260,8 @@ def run_triposg(image_path: str,
299
 
300
  merged_path = os.path.join(export_dir, "object.glb")
301
  merged.export(merged_path)
302
-
303
- split_preview_path = os.path.join(export_dir, "split.glb")
304
- split_mesh.export(split_preview_path)
305
-
306
- return merged_path, split_preview_path, export_dir
307
 
308
  def cleanup(request: gr.Request):
309
 
@@ -350,7 +308,8 @@ def build_demo():
350
  )
351
  input_image = gr.Image(type="filepath", label="Input Image", height=256)
352
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
353
- run_button = gr.Button("🧩 Generate 3D Parts", variant="primary")
 
354
 
355
  with gr.Accordion("Advanced Settings", open=False):
356
  seed = gr.Number(value=0, label="Random Seed", precision=0)
@@ -369,9 +328,8 @@ def build_demo():
369
  """
370
  )
371
  with gr.Row():
372
- output_model = gr.Model3D(label="Merged 3D Object", height=512)
373
- split_model = gr.Model3D(label="Split Preview", height=512)
374
- output_dir = gr.Textbox(label="Export Directory", visible=False)
375
  with gr.Row():
376
  with gr.Column():
377
  examples = gr.Examples(
@@ -396,19 +354,23 @@ def build_demo():
396
 
397
  ],
398
  inputs=[input_image, num_parts],
399
- outputs=[output_model, split_model, output_dir],
400
- fn=run_triposg,
401
- cache_examples=True,
402
  )
403
 
404
- run_button.click(fn=run_triposg,
405
  inputs=[input_image, num_parts, seed, num_tokens, num_steps,
406
  guidance, flash_decoder, remove_bg, session_state],
407
- outputs=[output_model, split_model, output_dir])
 
 
 
 
408
  return demo
409
 
410
  if __name__ == "__main__":
411
  demo = build_demo()
412
  demo.unload(cleanup)
413
  demo.queue()
414
- demo.launch()
 
76
 
77
 
78
  from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
79
+ from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings, explode_mesh
80
  from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
81
  from src.utils.image_utils import prepare_image
82
  from src.models.briarmbg import BriaRMBG
 
101
  return sorted(files)[0] if files else None
102
 
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  def get_duration(
106
  image_path,
 
124
 
125
  return int(duration_seconds)
126
 
127
+
128
+ @spaces.GPU(duration=135)
129
+ def gen_model_n_video(image_path: str,
130
+ num_parts: int,
131
+ progress=gr.Progress(track_tqdm=True),):
132
+
133
+ model_path = run_partcrafter(image_path, num_parts=num_parts, progress=progress)
134
+ video_path = gen_video(model_path)
135
+
136
+ return model_path, video_path
137
+
138
+ @spaces.GPU()
139
+ def gen_video(model_path):
140
+
141
+ if model_path is None:
142
+ gr.Info("You must craft the 3d parts first")
143
+
144
+ return None
145
+
146
+ export_dir = os.path.dirname(model_path)
147
+
148
+ merged = trimesh.load(model_path)
149
+
150
+ preview_path = os.path.join(export_dir, "rendering.gif")
151
+
152
+ num_views = 36
153
+ radius = 4
154
+ fps = 7
155
+ rendered_images = render_views_around_mesh(
156
+ merged,
157
+ num_views=num_views,
158
+ radius=radius,
159
+ )
160
+
161
+ export_renderings(
162
+ rendered_images,
163
+ preview_path,
164
+ fps=fps,
165
+ )
166
+ return preview_path
167
+
168
  @spaces.GPU(duration=get_duration)
169
  @torch.no_grad()
170
+ def run_partcrafter(image_path: str,
171
  num_parts: int = 1,
172
  seed: int = 0,
173
  num_tokens: int = 1024,
 
184
  This function takes a single 2D image as input and produces a set of part-based 3D meshes,
185
  using compositional latent diffusion with attention to structure and part separation.
186
  Optionally removes the background using a pretrained background removal model (RMBG),
187
+ and outputs a merged object mesh.
188
 
189
  Args:
190
  image_path (str): Path to the input image file on disk.
 
201
  Returns:
202
  Tuple[str, str, str, str]:
203
  - `merged_path` (str): File path to the merged full object mesh (`object.glb`).
 
 
204
 
205
  Notes:
206
  - This function utilizes HuggingFace pretrained weights for both part generation and background removal.
207
+ - The final output includes merged model parts to visualize object structure.
 
208
  - Generation time depends on the number of parts and inference parameters.
209
  """
210
 
 
260
 
261
  merged_path = os.path.join(export_dir, "object.glb")
262
  merged.export(merged_path)
263
+
264
+ return merged_path
 
 
 
265
 
266
  def cleanup(request: gr.Request):
267
 
 
308
  )
309
  input_image = gr.Image(type="filepath", label="Input Image", height=256)
310
  num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
311
+ run_button = gr.Button("🧩 Craft 3D Parts", variant="primary")
312
+ video_button = gr.Button("🎥 Generate Gif")
313
 
314
  with gr.Accordion("Advanced Settings", open=False):
315
  seed = gr.Number(value=0, label="Random Seed", precision=0)
 
328
  """
329
  )
330
  with gr.Row():
331
+ output_model = gr.Model3D(label="Merged 3D Object", height=512, interactive=False)
332
+ video_output = gr.Image(label="Split Preview", height=512)
 
333
  with gr.Row():
334
  with gr.Column():
335
  examples = gr.Examples(
 
354
 
355
  ],
356
  inputs=[input_image, num_parts],
357
+ outputs=[output_model, video_output],
358
+ fn=gen_model_n_video,
359
+ cache_examples=True
360
  )
361
 
362
+ run_button.click(fn=run_partcrafter,
363
  inputs=[input_image, num_parts, seed, num_tokens, num_steps,
364
  guidance, flash_decoder, remove_bg, session_state],
365
+ outputs=[output_model])
366
+ video_button.click(fn=gen_video,
367
+ inputs=[output_model],
368
+ outputs=[video_output])
369
+
370
  return demo
371
 
372
  if __name__ == "__main__":
373
  demo = build_demo()
374
  demo.unload(cleanup)
375
  demo.queue()
376
+ demo.launch(mcp_server=True)
packages.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ libglfw3-dev
2
+ libgles2-mesa-dev
3
+ libgl1
4
+ freeglut3-dev
5
+ unzip
6
+ ffmpeg
7
+ libsm6
8
+ libxext6
9
+ libgl1-mesa-dri
10
+ libegl1-mesa
11
+ libgbm1
12
+ build-essential
src/utils/render_utils.py CHANGED
@@ -10,9 +10,60 @@ from diffusers.utils import export_to_video
10
  from diffusers.utils.loading_utils import load_video
11
  import torch
12
  from torchvision.utils import make_grid
 
13
 
14
  os.environ['PYOPENGL_PLATFORM'] = 'egl'
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def render(
17
  scene: pyrender.Scene,
18
  renderer: pyrender.Renderer,
@@ -123,13 +174,22 @@ def render_views_around_mesh(
123
  Tuple[List[Image.Image], List[Image.Image]],
124
  Tuple[List[np.ndarray], List[np.ndarray]]
125
  ]:
 
 
 
126
 
127
  if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
128
  raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
129
  if isinstance(mesh, trimesh.Trimesh):
130
- mesh = trimesh.Scene(mesh)
131
-
132
- scene = pyrender.Scene.from_trimesh_scene(mesh)
 
 
 
 
 
 
133
  light = pyrender.DirectionalLight(
134
  color=np.ones(3),
135
  intensity=light_intensity
@@ -149,9 +209,9 @@ def render_views_around_mesh(
149
  )
150
 
151
  images, depths = [], []
152
- for pose in camera_poses:
153
  image, depth = render(
154
- scene, renderer, camera, pose, light,
155
  normalize_depth=normalize_depth,
156
  flags=flags,
157
  return_type=return_type
 
10
  from diffusers.utils.loading_utils import load_video
11
  import torch
12
  from torchvision.utils import make_grid
13
+ import math
14
 
15
  os.environ['PYOPENGL_PLATFORM'] = 'egl'
16
 
17
+ def explode_mesh(mesh, explosion_scale=0.4):
18
+ # ensure we have a Scene
19
+ if isinstance(mesh, trimesh.Trimesh):
20
+ scene = trimesh.Scene(mesh)
21
+ elif isinstance(mesh, trimesh.Scene):
22
+ scene = mesh
23
+ else:
24
+ raise ValueError(f"Expected Trimesh or Scene, got {type(mesh)}")
25
+
26
+ if len(scene.geometry) <= 1:
27
+ print("Nothing to explode")
28
+ return scene
29
+
30
+ # 1) collect (name, geom, world_center)
31
+ parts = []
32
+ for name, geom in scene.geometry.items():
33
+ # ← get(name) returns (4×4 world‐space matrix, parent_frame)
34
+ world_tf, _ = scene.graph.get(name)
35
+ pts = trimesh.transformations.transform_points(geom.vertices, world_tf)
36
+ center = pts.mean(axis=0)
37
+ parts.append((name, geom, center))
38
+
39
+ # compute global center
40
+ all_centers = np.stack([c for _,_,c in parts], axis=0)
41
+ global_center = all_centers.mean(axis=0)
42
+
43
+ exploded = trimesh.Scene()
44
+ for name, geom, center in parts:
45
+ dir_vec = center - global_center
46
+ norm = np.linalg.norm(dir_vec)
47
+ if norm < 1e-6:
48
+ dir_vec = np.random.randn(3)
49
+ dir_vec /= np.linalg.norm(dir_vec)
50
+ else:
51
+ dir_vec /= norm
52
+
53
+ offset = dir_vec * explosion_scale
54
+
55
+ # fetch the same 4×4, then bump just the translation
56
+ world_tf, _ = scene.graph.get(name)
57
+ world_tf = world_tf.copy()
58
+ world_tf[:3, 3] += offset
59
+
60
+ exploded.add_geometry(geom, transform=world_tf, geom_name=name)
61
+ print(f"[explode] {name} moved by {np.linalg.norm(offset):.4f}")
62
+
63
+ return exploded
64
+
65
+
66
+
67
  def render(
68
  scene: pyrender.Scene,
69
  renderer: pyrender.Renderer,
 
174
  Tuple[List[Image.Image], List[Image.Image]],
175
  Tuple[List[np.ndarray], List[np.ndarray]]
176
  ]:
177
+
178
+ meshes = []
179
+ scenes = []
180
 
181
  if not isinstance(mesh, (trimesh.Trimesh, trimesh.Scene)):
182
  raise ValueError("mesh must be a trimesh.Trimesh or trimesh.Scene object")
183
  if isinstance(mesh, trimesh.Trimesh):
184
+ for i in range(num_views):
185
+ scenes.append(pyrender.Scene.from_trimesh_scene(trimesh.Scene(mesh)))
186
+ else:
187
+ for i in range(num_views):
188
+ value = math.sin(math.pi * (i - 1) / num_views)
189
+ scenes.append(pyrender.Scene.from_trimesh_scene(explode_mesh(mesh, 0.2 * value),
190
+ ambient_light=[0.02, 0.02, 0.02],
191
+ bg_color=[0.0, 0.0, 0.0, 1.0]))
192
+
193
  light = pyrender.DirectionalLight(
194
  color=np.ones(3),
195
  intensity=light_intensity
 
209
  )
210
 
211
  images, depths = [], []
212
+ for i, pose in enumerate(camera_poses):
213
  image, depth = render(
214
+ scenes[i], renderer, camera, pose, light,
215
  normalize_depth=normalize_depth,
216
  flags=flags,
217
  return_type=return_type