Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- .gitmodules +3 -0
- app.py +153 -70
- demo.py +88 -21
- models/pipelines.py +327 -122
.gitmodules
CHANGED
@@ -1,3 +1,6 @@
|
|
1 |
[submodule "submodules/MoGe"]
|
2 |
path = submodules/MoGe
|
3 |
url = https://github.com/microsoft/MoGe.git
|
|
|
|
|
|
|
|
1 |
[submodule "submodules/MoGe"]
|
2 |
path = submodules/MoGe
|
3 |
url = https://github.com/microsoft/MoGe.git
|
4 |
+
[submodule "submodules/vggt"]
|
5 |
+
path = submodules/vggt
|
6 |
+
url = https://github.com/facebookresearch/vggt.git
|
app.py
CHANGED
@@ -16,6 +16,7 @@ sys.path.append(project_root)
|
|
16 |
|
17 |
try:
|
18 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
|
|
19 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
20 |
except:
|
21 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
@@ -27,6 +28,8 @@ hf_hub_download(repo_id="EXCAI/Diffusion-As-Shader", filename='spatracker/spaT_f
|
|
27 |
|
28 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
29 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
|
|
30 |
|
31 |
# Parse command line arguments
|
32 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
@@ -47,6 +50,7 @@ os.makedirs("outputs", exist_ok=True)
|
|
47 |
# Create project tmp directory instead of using system temp
|
48 |
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
|
49 |
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
|
|
|
50 |
def load_media(media_path, max_frames=49, transform=None):
|
51 |
"""Load video or image frames and convert to tensor
|
52 |
|
@@ -69,22 +73,52 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
69 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
70 |
|
71 |
if is_video:
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
# Handle image as single frame
|
76 |
image = load_image(media_path)
|
77 |
frames = [image]
|
78 |
fps = 8 # Default fps for images
|
79 |
-
|
80 |
-
|
81 |
-
if len(frames) > max_frames:
|
82 |
-
frames = frames[:max_frames]
|
83 |
-
elif len(frames) < max_frames:
|
84 |
-
last_frame = frames[-1]
|
85 |
while len(frames) < max_frames:
|
86 |
-
frames.append(
|
87 |
-
|
88 |
# Convert frames to tensor
|
89 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
90 |
|
@@ -131,6 +165,7 @@ def save_uploaded_file(file):
|
|
131 |
|
132 |
das_pipeline = None
|
133 |
moge_model = None
|
|
|
134 |
|
135 |
@spaces.GPU
|
136 |
def get_das_pipeline():
|
@@ -147,6 +182,13 @@ def get_moge_model():
|
|
147 |
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
148 |
return moge_model
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
|
151 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
152 |
"""Process video motion transfer task"""
|
@@ -154,19 +196,20 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
154 |
# Save uploaded files
|
155 |
input_video_path = save_uploaded_file(source)
|
156 |
if input_video_path is None:
|
157 |
-
return None
|
158 |
|
159 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
160 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
161 |
|
162 |
-
|
163 |
das = get_das_pipeline()
|
164 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
|
|
165 |
if not is_video:
|
166 |
tracking_method = "moge"
|
167 |
print("Image input detected, using MoGe for tracking video generation.")
|
168 |
else:
|
169 |
-
tracking_method = "
|
170 |
|
171 |
repaint_img_tensor = None
|
172 |
if mt_repaint_image is not None:
|
@@ -180,7 +223,9 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
180 |
prompt=prompt,
|
181 |
depth_path=None
|
182 |
)
|
|
|
183 |
tracking_tensor = None
|
|
|
184 |
if tracking_method == "moge":
|
185 |
moge = get_moge_model()
|
186 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
@@ -195,32 +240,31 @@ def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image)
|
|
195 |
|
196 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
197 |
|
198 |
-
|
199 |
pred_tracks.cpu().numpy(),
|
200 |
infer_result["mask"].cpu().numpy()
|
201 |
)
|
202 |
print('Export tracking video via MoGe')
|
203 |
else:
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
print('Export tracking video via
|
208 |
|
209 |
output_path = das.apply_tracking(
|
210 |
video_tensor=video_tensor,
|
211 |
-
fps=
|
212 |
tracking_tensor=tracking_tensor,
|
213 |
img_cond_tensor=repaint_img_tensor,
|
214 |
prompt=prompt,
|
215 |
checkpoint_path=DEFAULT_MODEL_PATH
|
216 |
)
|
217 |
|
218 |
-
return output_path
|
219 |
except Exception as e:
|
220 |
import traceback
|
221 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
222 |
-
return None
|
223 |
-
|
224 |
|
225 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
226 |
"""Process camera control task"""
|
@@ -228,17 +272,18 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
228 |
# Save uploaded files
|
229 |
input_media_path = save_uploaded_file(source)
|
230 |
if input_media_path is None:
|
231 |
-
return None
|
232 |
|
233 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
234 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
235 |
|
236 |
das = get_das_pipeline()
|
237 |
-
|
238 |
video_tensor, fps, is_video = load_media(input_media_path)
|
239 |
-
|
|
|
|
|
240 |
tracking_method = "moge"
|
241 |
-
print("Image input detected
|
242 |
|
243 |
cam_motion = CameraMotionGenerator(camera_motion)
|
244 |
repaint_img_tensor = None
|
@@ -267,32 +312,54 @@ def process_camera_control(source, prompt, camera_motion, tracking_method):
|
|
267 |
)
|
268 |
print('Export tracking video via MoGe')
|
269 |
else:
|
270 |
-
|
271 |
-
pred_tracks, pred_visibility
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
272 |
if camera_motion:
|
273 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
274 |
-
|
|
|
275 |
print("Camera motion applied")
|
276 |
-
|
277 |
-
|
278 |
-
print('Export tracking video via
|
279 |
-
|
280 |
|
281 |
output_path = das.apply_tracking(
|
282 |
video_tensor=video_tensor,
|
283 |
-
fps=
|
284 |
tracking_tensor=tracking_tensor,
|
285 |
img_cond_tensor=repaint_img_tensor,
|
286 |
prompt=prompt,
|
287 |
checkpoint_path=DEFAULT_MODEL_PATH
|
288 |
)
|
289 |
|
290 |
-
return output_path
|
291 |
except Exception as e:
|
292 |
import traceback
|
293 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
294 |
-
return None
|
295 |
-
|
296 |
|
297 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
298 |
"""Process object manipulation task"""
|
@@ -300,21 +367,21 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
300 |
# Save uploaded files
|
301 |
input_image_path = save_uploaded_file(source)
|
302 |
if input_image_path is None:
|
303 |
-
return None
|
304 |
|
305 |
object_mask_path = save_uploaded_file(object_mask)
|
306 |
if object_mask_path is None:
|
307 |
print("Object mask not provided")
|
308 |
-
return None
|
309 |
-
|
310 |
|
311 |
das = get_das_pipeline()
|
312 |
video_tensor, fps, is_video = load_media(input_image_path)
|
313 |
-
|
|
|
|
|
314 |
tracking_method = "moge"
|
315 |
-
print("Image input detected
|
316 |
|
317 |
-
|
318 |
mask_image = Image.open(object_mask_path).convert('L')
|
319 |
mask_image = transforms.Resize((480, 720))(mask_image)
|
320 |
mask = torch.from_numpy(np.array(mask_image) > 127)
|
@@ -322,10 +389,10 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
322 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
323 |
repaint_img_tensor = None
|
324 |
tracking_tensor = None
|
|
|
325 |
if tracking_method == "moge":
|
326 |
moge = get_moge_model()
|
327 |
|
328 |
-
|
329 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
330 |
H, W = infer_result["points"].shape[0:2]
|
331 |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
|
@@ -342,7 +409,6 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
342 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
343 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
344 |
|
345 |
-
|
346 |
cam_motion = CameraMotionGenerator(None)
|
347 |
cam_motion.set_intr(infer_result["intrinsics"])
|
348 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
@@ -353,9 +419,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
353 |
)
|
354 |
print('Export tracking video via MoGe')
|
355 |
else:
|
|
|
|
|
356 |
|
357 |
-
|
|
|
|
|
|
|
|
|
358 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
|
360 |
pred_tracks = motion_generator.apply_motion(
|
361 |
pred_tracks=pred_tracks.squeeze(),
|
@@ -363,30 +447,27 @@ def process_object_manipulation(source, prompt, object_motion, object_mask, trac
|
|
363 |
motion_type=object_motion,
|
364 |
distance=50,
|
365 |
num_frames=49,
|
366 |
-
tracking_method="
|
367 |
-
)
|
368 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
369 |
|
370 |
-
|
371 |
-
|
372 |
-
print('Export tracking video via SpaTracker')
|
373 |
-
|
374 |
|
375 |
output_path = das.apply_tracking(
|
376 |
video_tensor=video_tensor,
|
377 |
-
fps=
|
378 |
tracking_tensor=tracking_tensor,
|
379 |
img_cond_tensor=repaint_img_tensor,
|
380 |
prompt=prompt,
|
381 |
checkpoint_path=DEFAULT_MODEL_PATH
|
382 |
)
|
383 |
|
384 |
-
return output_path
|
385 |
except Exception as e:
|
386 |
import traceback
|
387 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
388 |
-
return None
|
389 |
-
|
390 |
|
391 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
392 |
"""Process mesh animation task"""
|
@@ -394,15 +475,16 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
394 |
# Save uploaded files
|
395 |
input_video_path = save_uploaded_file(source)
|
396 |
if input_video_path is None:
|
397 |
-
return None
|
398 |
|
399 |
tracking_video_path = save_uploaded_file(tracking_video)
|
400 |
if tracking_video_path is None:
|
401 |
-
return None
|
402 |
-
|
403 |
|
404 |
das = get_das_pipeline()
|
405 |
video_tensor, fps, is_video = load_media(input_video_path)
|
|
|
|
|
406 |
tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
|
407 |
repaint_img_tensor = None
|
408 |
if ma_repaint_image is not None:
|
@@ -420,18 +502,18 @@ def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma
|
|
420 |
|
421 |
output_path = das.apply_tracking(
|
422 |
video_tensor=video_tensor,
|
423 |
-
fps=
|
424 |
tracking_tensor=tracking_tensor,
|
425 |
img_cond_tensor=repaint_img_tensor,
|
426 |
prompt=prompt,
|
427 |
checkpoint_path=DEFAULT_MODEL_PATH
|
428 |
)
|
429 |
|
430 |
-
return output_path
|
431 |
except Exception as e:
|
432 |
import traceback
|
433 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
434 |
-
return None
|
435 |
|
436 |
# Create Gradio interface with updated layout
|
437 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
@@ -444,6 +526,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
444 |
|
445 |
with right_column:
|
446 |
output_video = gr.Video(label="Generated Video")
|
|
|
447 |
|
448 |
with left_column:
|
449 |
source = gr.File(label="Source", file_types=["image", "video"])
|
@@ -479,7 +562,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
479 |
source, common_prompt,
|
480 |
mt_repaint_option, mt_repaint_image
|
481 |
],
|
482 |
-
outputs=[output_video]
|
483 |
)
|
484 |
|
485 |
# Camera Control tab
|
@@ -597,8 +680,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
597 |
|
598 |
cc_tracking_method = gr.Radio(
|
599 |
label="Tracking Method",
|
600 |
-
choices=["
|
601 |
-
value="
|
602 |
)
|
603 |
|
604 |
# Add run button for Camera Control tab
|
@@ -611,7 +694,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
611 |
source, common_prompt,
|
612 |
cc_camera_motion, cc_tracking_method
|
613 |
],
|
614 |
-
outputs=[output_video]
|
615 |
)
|
616 |
|
617 |
# Object Manipulation tab
|
@@ -629,8 +712,8 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
629 |
)
|
630 |
om_tracking_method = gr.Radio(
|
631 |
label="Tracking Method",
|
632 |
-
choices=["
|
633 |
-
value="
|
634 |
)
|
635 |
|
636 |
# Add run button for Object Manipulation tab
|
@@ -643,7 +726,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
643 |
source, common_prompt,
|
644 |
om_object_motion, om_object_mask, om_tracking_method
|
645 |
],
|
646 |
-
outputs=[output_video]
|
647 |
)
|
648 |
|
649 |
# Animating meshes to video tab
|
@@ -683,7 +766,7 @@ with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
683 |
source, common_prompt,
|
684 |
ma_tracking_video, ma_repaint_option, ma_repaint_image
|
685 |
],
|
686 |
-
outputs=[output_video]
|
687 |
)
|
688 |
|
689 |
# Launch interface
|
|
|
16 |
|
17 |
try:
|
18 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
19 |
+
sys.path.append(os.path.join(project_root, "submodules/vggt"))
|
20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
21 |
except:
|
22 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
|
28 |
|
29 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
30 |
from submodules.MoGe.moge.model import MoGeModel
|
31 |
+
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
32 |
+
from submodules.vggt.vggt.models.vggt import VGGT
|
33 |
|
34 |
# Parse command line arguments
|
35 |
parser = argparse.ArgumentParser(description="Diffusion as Shader Web UI")
|
|
|
50 |
# Create project tmp directory instead of using system temp
|
51 |
os.makedirs(os.path.join(project_root, "tmp"), exist_ok=True)
|
52 |
os.makedirs(os.path.join(project_root, "tmp", "gradio"), exist_ok=True)
|
53 |
+
|
54 |
def load_media(media_path, max_frames=49, transform=None):
|
55 |
"""Load video or image frames and convert to tensor
|
56 |
|
|
|
73 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
74 |
|
75 |
if is_video:
|
76 |
+
# Load video file info
|
77 |
+
video_clip = VideoFileClip(media_path)
|
78 |
+
duration = video_clip.duration
|
79 |
+
original_fps = video_clip.fps
|
80 |
+
|
81 |
+
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
82 |
+
if duration > 6.0:
|
83 |
+
sampling_fps = 8 # 8 frames per second
|
84 |
+
frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
|
85 |
+
fps = sampling_fps
|
86 |
+
# Cases 2 and 3: Video shorter than 6 seconds
|
87 |
+
else:
|
88 |
+
# Load all frames
|
89 |
+
frames = load_video(media_path)
|
90 |
+
|
91 |
+
# Case 2: Total frames less than max_frames, need interpolation
|
92 |
+
if len(frames) < max_frames:
|
93 |
+
fps = len(frames) / duration # Keep original fps
|
94 |
+
|
95 |
+
# Evenly interpolate to max_frames
|
96 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
97 |
+
new_frames = []
|
98 |
+
for i in indices:
|
99 |
+
idx = int(i)
|
100 |
+
new_frames.append(frames[idx])
|
101 |
+
frames = new_frames
|
102 |
+
# Case 3: Total frames more than max_frames but video less than 6 seconds
|
103 |
+
else:
|
104 |
+
# Evenly sample to max_frames
|
105 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
106 |
+
new_frames = []
|
107 |
+
for i in indices:
|
108 |
+
idx = int(i)
|
109 |
+
new_frames.append(frames[idx])
|
110 |
+
frames = new_frames
|
111 |
+
fps = max_frames / duration # New fps to maintain duration
|
112 |
else:
|
113 |
# Handle image as single frame
|
114 |
image = load_image(media_path)
|
115 |
frames = [image]
|
116 |
fps = 8 # Default fps for images
|
117 |
+
|
118 |
+
# Duplicate frame to max_frames
|
|
|
|
|
|
|
|
|
119 |
while len(frames) < max_frames:
|
120 |
+
frames.append(frames[0].copy())
|
121 |
+
|
122 |
# Convert frames to tensor
|
123 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
124 |
|
|
|
165 |
|
166 |
das_pipeline = None
|
167 |
moge_model = None
|
168 |
+
vggt_model = None
|
169 |
|
170 |
@spaces.GPU
|
171 |
def get_das_pipeline():
|
|
|
182 |
moge_model = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
183 |
return moge_model
|
184 |
|
185 |
+
@spaces.GPU
|
186 |
+
def get_vggt_model():
|
187 |
+
global vggt_model
|
188 |
+
if vggt_model is None:
|
189 |
+
das = get_das_pipeline()
|
190 |
+
vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
|
191 |
+
return vggt_model
|
192 |
|
193 |
def process_motion_transfer(source, prompt, mt_repaint_option, mt_repaint_image):
|
194 |
"""Process video motion transfer task"""
|
|
|
196 |
# Save uploaded files
|
197 |
input_video_path = save_uploaded_file(source)
|
198 |
if input_video_path is None:
|
199 |
+
return None, None
|
200 |
|
201 |
print(f"DEBUG: Repaint option: {mt_repaint_option}")
|
202 |
print(f"DEBUG: Repaint image: {mt_repaint_image}")
|
203 |
|
|
|
204 |
das = get_das_pipeline()
|
205 |
video_tensor, fps, is_video = load_media(input_video_path)
|
206 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
207 |
+
|
208 |
if not is_video:
|
209 |
tracking_method = "moge"
|
210 |
print("Image input detected, using MoGe for tracking video generation.")
|
211 |
else:
|
212 |
+
tracking_method = "cotracker"
|
213 |
|
214 |
repaint_img_tensor = None
|
215 |
if mt_repaint_image is not None:
|
|
|
223 |
prompt=prompt,
|
224 |
depth_path=None
|
225 |
)
|
226 |
+
|
227 |
tracking_tensor = None
|
228 |
+
tracking_path = None
|
229 |
if tracking_method == "moge":
|
230 |
moge = get_moge_model()
|
231 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
|
|
240 |
|
241 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
242 |
|
243 |
+
tracking_path, tracking_tensor = das.visualize_tracking_moge(
|
244 |
pred_tracks.cpu().numpy(),
|
245 |
infer_result["mask"].cpu().numpy()
|
246 |
)
|
247 |
print('Export tracking video via MoGe')
|
248 |
else:
|
249 |
+
# 使用 cotracker
|
250 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
251 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
252 |
+
print('Export tracking video via cotracker')
|
253 |
|
254 |
output_path = das.apply_tracking(
|
255 |
video_tensor=video_tensor,
|
256 |
+
fps=fps, # 使用 load_media 返回的 fps
|
257 |
tracking_tensor=tracking_tensor,
|
258 |
img_cond_tensor=repaint_img_tensor,
|
259 |
prompt=prompt,
|
260 |
checkpoint_path=DEFAULT_MODEL_PATH
|
261 |
)
|
262 |
|
263 |
+
return tracking_path, output_path
|
264 |
except Exception as e:
|
265 |
import traceback
|
266 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
267 |
+
return None, None
|
|
|
268 |
|
269 |
def process_camera_control(source, prompt, camera_motion, tracking_method):
|
270 |
"""Process camera control task"""
|
|
|
272 |
# Save uploaded files
|
273 |
input_media_path = save_uploaded_file(source)
|
274 |
if input_media_path is None:
|
275 |
+
return None, None
|
276 |
|
277 |
print(f"DEBUG: Camera motion: '{camera_motion}'")
|
278 |
print(f"DEBUG: Tracking method: '{tracking_method}'")
|
279 |
|
280 |
das = get_das_pipeline()
|
|
|
281 |
video_tensor, fps, is_video = load_media(input_media_path)
|
282 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
283 |
+
|
284 |
+
if not is_video:
|
285 |
tracking_method = "moge"
|
286 |
+
print("Image input detected, switching to MoGe")
|
287 |
|
288 |
cam_motion = CameraMotionGenerator(camera_motion)
|
289 |
repaint_img_tensor = None
|
|
|
312 |
)
|
313 |
print('Export tracking video via MoGe')
|
314 |
else:
|
315 |
+
# 使用 cotracker
|
316 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
317 |
+
|
318 |
+
t, c, h, w = video_tensor.shape
|
319 |
+
new_width = 518
|
320 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
321 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
322 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
323 |
+
|
324 |
+
if new_height > 518:
|
325 |
+
start_y = (new_height - 518) // 2
|
326 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
327 |
+
|
328 |
+
vggt_model = get_vggt_model()
|
329 |
+
|
330 |
+
with torch.no_grad():
|
331 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
332 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
333 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
334 |
+
|
335 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
336 |
+
|
337 |
+
cam_motion.set_intr(intr)
|
338 |
+
cam_motion.set_extr(extr)
|
339 |
+
|
340 |
if camera_motion:
|
341 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
342 |
+
pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
|
343 |
+
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
344 |
print("Camera motion applied")
|
345 |
+
|
346 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, None)
|
347 |
+
print('Export tracking video via cotracker')
|
|
|
348 |
|
349 |
output_path = das.apply_tracking(
|
350 |
video_tensor=video_tensor,
|
351 |
+
fps=fps, # 使用 load_media 返回的 fps
|
352 |
tracking_tensor=tracking_tensor,
|
353 |
img_cond_tensor=repaint_img_tensor,
|
354 |
prompt=prompt,
|
355 |
checkpoint_path=DEFAULT_MODEL_PATH
|
356 |
)
|
357 |
|
358 |
+
return tracking_path, output_path
|
359 |
except Exception as e:
|
360 |
import traceback
|
361 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
362 |
+
return None, None
|
|
|
363 |
|
364 |
def process_object_manipulation(source, prompt, object_motion, object_mask, tracking_method):
|
365 |
"""Process object manipulation task"""
|
|
|
367 |
# Save uploaded files
|
368 |
input_image_path = save_uploaded_file(source)
|
369 |
if input_image_path is None:
|
370 |
+
return None, None
|
371 |
|
372 |
object_mask_path = save_uploaded_file(object_mask)
|
373 |
if object_mask_path is None:
|
374 |
print("Object mask not provided")
|
375 |
+
return None, None
|
|
|
376 |
|
377 |
das = get_das_pipeline()
|
378 |
video_tensor, fps, is_video = load_media(input_image_path)
|
379 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
380 |
+
|
381 |
+
if not is_video:
|
382 |
tracking_method = "moge"
|
383 |
+
print("Image input detected, switching to MoGe")
|
384 |
|
|
|
385 |
mask_image = Image.open(object_mask_path).convert('L')
|
386 |
mask_image = transforms.Resize((480, 720))(mask_image)
|
387 |
mask = torch.from_numpy(np.array(mask_image) > 127)
|
|
|
389 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
390 |
repaint_img_tensor = None
|
391 |
tracking_tensor = None
|
392 |
+
|
393 |
if tracking_method == "moge":
|
394 |
moge = get_moge_model()
|
395 |
|
|
|
396 |
infer_result = moge.infer(video_tensor[0].to(das.device)) # [C, H, W] in range [0,1]
|
397 |
H, W = infer_result["points"].shape[0:2]
|
398 |
pred_tracks = infer_result["points"].unsqueeze(0).repeat(49, 1, 1, 1) #[T, H, W, 3]
|
|
|
409 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
410 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
411 |
|
|
|
412 |
cam_motion = CameraMotionGenerator(None)
|
413 |
cam_motion.set_intr(infer_result["intrinsics"])
|
414 |
pred_tracks = cam_motion.w2s(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
|
|
419 |
)
|
420 |
print('Export tracking video via MoGe')
|
421 |
else:
|
422 |
+
# 使用 cotracker
|
423 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor)
|
424 |
|
425 |
+
t, c, h, w = video_tensor.shape
|
426 |
+
new_width = 518
|
427 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
428 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
429 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
430 |
|
431 |
+
if new_height > 518:
|
432 |
+
start_y = (new_height - 518) // 2
|
433 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
434 |
+
|
435 |
+
vggt_model = get_vggt_model()
|
436 |
+
|
437 |
+
with torch.no_grad():
|
438 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
439 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
440 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
441 |
+
|
442 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
443 |
|
444 |
pred_tracks = motion_generator.apply_motion(
|
445 |
pred_tracks=pred_tracks.squeeze(),
|
|
|
447 |
motion_type=object_motion,
|
448 |
distance=50,
|
449 |
num_frames=49,
|
450 |
+
tracking_method="cotracker"
|
451 |
+
)
|
452 |
print(f"Object motion '{object_motion}' applied using provided mask")
|
453 |
|
454 |
+
tracking_path, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks.unsqueeze(0), None)
|
455 |
+
print('Export tracking video via cotracker')
|
|
|
|
|
456 |
|
457 |
output_path = das.apply_tracking(
|
458 |
video_tensor=video_tensor,
|
459 |
+
fps=fps, # 使用 load_media 返回的 fps
|
460 |
tracking_tensor=tracking_tensor,
|
461 |
img_cond_tensor=repaint_img_tensor,
|
462 |
prompt=prompt,
|
463 |
checkpoint_path=DEFAULT_MODEL_PATH
|
464 |
)
|
465 |
|
466 |
+
return tracking_path, output_path
|
467 |
except Exception as e:
|
468 |
import traceback
|
469 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
470 |
+
return None, None
|
|
|
471 |
|
472 |
def process_mesh_animation(source, prompt, tracking_video, ma_repaint_option, ma_repaint_image):
|
473 |
"""Process mesh animation task"""
|
|
|
475 |
# Save uploaded files
|
476 |
input_video_path = save_uploaded_file(source)
|
477 |
if input_video_path is None:
|
478 |
+
return None, None
|
479 |
|
480 |
tracking_video_path = save_uploaded_file(tracking_video)
|
481 |
if tracking_video_path is None:
|
482 |
+
return None, None
|
|
|
483 |
|
484 |
das = get_das_pipeline()
|
485 |
video_tensor, fps, is_video = load_media(input_video_path)
|
486 |
+
das.fps = fps # 设置 das.fps 为 load_media 返回的 fps
|
487 |
+
|
488 |
tracking_tensor, tracking_fps, _ = load_media(tracking_video_path)
|
489 |
repaint_img_tensor = None
|
490 |
if ma_repaint_image is not None:
|
|
|
502 |
|
503 |
output_path = das.apply_tracking(
|
504 |
video_tensor=video_tensor,
|
505 |
+
fps=fps, # 使用 load_media 返回的 fps
|
506 |
tracking_tensor=tracking_tensor,
|
507 |
img_cond_tensor=repaint_img_tensor,
|
508 |
prompt=prompt,
|
509 |
checkpoint_path=DEFAULT_MODEL_PATH
|
510 |
)
|
511 |
|
512 |
+
return tracking_video_path, output_path
|
513 |
except Exception as e:
|
514 |
import traceback
|
515 |
print(f"Processing failed: {str(e)}\n{traceback.format_exc()}")
|
516 |
+
return None, None
|
517 |
|
518 |
# Create Gradio interface with updated layout
|
519 |
with gr.Blocks(title="Diffusion as Shader") as demo:
|
|
|
526 |
|
527 |
with right_column:
|
528 |
output_video = gr.Video(label="Generated Video")
|
529 |
+
tracking_video = gr.Video(label="Tracking Video")
|
530 |
|
531 |
with left_column:
|
532 |
source = gr.File(label="Source", file_types=["image", "video"])
|
|
|
562 |
source, common_prompt,
|
563 |
mt_repaint_option, mt_repaint_image
|
564 |
],
|
565 |
+
outputs=[tracking_video, output_video]
|
566 |
)
|
567 |
|
568 |
# Camera Control tab
|
|
|
680 |
|
681 |
cc_tracking_method = gr.Radio(
|
682 |
label="Tracking Method",
|
683 |
+
choices=["moge", "cotracker"],
|
684 |
+
value="cotracker"
|
685 |
)
|
686 |
|
687 |
# Add run button for Camera Control tab
|
|
|
694 |
source, common_prompt,
|
695 |
cc_camera_motion, cc_tracking_method
|
696 |
],
|
697 |
+
outputs=[tracking_video, output_video]
|
698 |
)
|
699 |
|
700 |
# Object Manipulation tab
|
|
|
712 |
)
|
713 |
om_tracking_method = gr.Radio(
|
714 |
label="Tracking Method",
|
715 |
+
choices=["moge", "cotracker"],
|
716 |
+
value="cotracker"
|
717 |
)
|
718 |
|
719 |
# Add run button for Object Manipulation tab
|
|
|
726 |
source, common_prompt,
|
727 |
om_object_motion, om_object_mask, om_tracking_method
|
728 |
],
|
729 |
+
outputs=[tracking_video, output_video]
|
730 |
)
|
731 |
|
732 |
# Animating meshes to video tab
|
|
|
766 |
source, common_prompt,
|
767 |
ma_tracking_video, ma_repaint_option, ma_repaint_image
|
768 |
],
|
769 |
+
outputs=[tracking_video, output_video]
|
770 |
)
|
771 |
|
772 |
# Launch interface
|
demo.py
CHANGED
@@ -5,6 +5,7 @@ from PIL import Image
|
|
5 |
project_root = os.path.dirname(os.path.abspath(__file__))
|
6 |
try:
|
7 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
|
|
8 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
9 |
except:
|
10 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
@@ -18,6 +19,8 @@ from diffusers.utils import load_image, load_video
|
|
18 |
|
19 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
20 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
|
|
21 |
|
22 |
def load_media(media_path, max_frames=49, transform=None):
|
23 |
"""Load video or image frames and convert to tensor
|
@@ -28,7 +31,7 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
28 |
transform (callable): Transform to apply to frames
|
29 |
|
30 |
Returns:
|
31 |
-
Tuple[torch.Tensor, float]: Video tensor [T,C,H,W] and
|
32 |
"""
|
33 |
if transform is None:
|
34 |
transform = transforms.Compose([
|
@@ -41,22 +44,52 @@ def load_media(media_path, max_frames=49, transform=None):
|
|
41 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
42 |
|
43 |
if is_video:
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
else:
|
47 |
# Handle image as single frame
|
48 |
image = load_image(media_path)
|
49 |
frames = [image]
|
50 |
fps = 8 # Default fps for images
|
51 |
-
|
52 |
-
|
53 |
-
if len(frames) > max_frames:
|
54 |
-
frames = frames[:max_frames]
|
55 |
-
elif len(frames) < max_frames:
|
56 |
-
last_frame = frames[-1]
|
57 |
while len(frames) < max_frames:
|
58 |
-
frames.append(
|
59 |
-
|
60 |
# Convert frames to tensor
|
61 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
62 |
|
@@ -77,8 +110,8 @@ if __name__ == "__main__":
|
|
77 |
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
|
78 |
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
|
79 |
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
|
80 |
-
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge'],
|
81 |
-
help='Tracking method to use (spatracker or moge)')
|
82 |
args = parser.parse_args()
|
83 |
|
84 |
# Load input video/image
|
@@ -89,6 +122,7 @@ if __name__ == "__main__":
|
|
89 |
|
90 |
# Initialize pipeline
|
91 |
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
|
|
|
92 |
if args.tracking_method == "moge" and args.tracking_path is None:
|
93 |
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
94 |
|
@@ -153,7 +187,7 @@ if __name__ == "__main__":
|
|
153 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
154 |
# change pred_tracks into screen coordinate
|
155 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
156 |
-
pred_tracks = cam_motion.
|
157 |
_, tracking_tensor = das.visualize_tracking_moge(
|
158 |
pred_tracks.cpu().numpy(),
|
159 |
infer_result["mask"].cpu().numpy()
|
@@ -161,13 +195,44 @@ if __name__ == "__main__":
|
|
161 |
print('export tracking video via MoGe.')
|
162 |
|
163 |
else:
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
|
167 |
# Apply camera motion if specified
|
168 |
if args.camera_motion:
|
169 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
170 |
-
|
|
|
171 |
print("Camera motion applied")
|
172 |
|
173 |
# Apply object motion if specified
|
@@ -184,7 +249,7 @@ if __name__ == "__main__":
|
|
184 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
185 |
|
186 |
pred_tracks = motion_generator.apply_motion(
|
187 |
-
pred_tracks=pred_tracks
|
188 |
mask=mask,
|
189 |
motion_type=args.object_motion,
|
190 |
distance=50,
|
@@ -193,12 +258,14 @@ if __name__ == "__main__":
|
|
193 |
).unsqueeze(0)
|
194 |
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
|
195 |
|
196 |
-
|
197 |
-
|
|
|
|
|
198 |
|
199 |
das.apply_tracking(
|
200 |
video_tensor=video_tensor,
|
201 |
-
fps=
|
202 |
tracking_tensor=tracking_tensor,
|
203 |
img_cond_tensor=repaint_img_tensor,
|
204 |
prompt=args.prompt,
|
|
|
5 |
project_root = os.path.dirname(os.path.abspath(__file__))
|
6 |
try:
|
7 |
sys.path.append(os.path.join(project_root, "submodules/MoGe"))
|
8 |
+
sys.path.append(os.path.join(project_root, "submodules/vggt"))
|
9 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
10 |
except:
|
11 |
print("Warning: MoGe not found, motion transfer will not be applied")
|
|
|
19 |
|
20 |
from models.pipelines import DiffusionAsShaderPipeline, FirstFrameRepainter, CameraMotionGenerator, ObjectMotionGenerator
|
21 |
from submodules.MoGe.moge.model import MoGeModel
|
22 |
+
from submodules.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
23 |
+
from submodules.vggt.vggt.models.vggt import VGGT
|
24 |
|
25 |
def load_media(media_path, max_frames=49, transform=None):
|
26 |
"""Load video or image frames and convert to tensor
|
|
|
31 |
transform (callable): Transform to apply to frames
|
32 |
|
33 |
Returns:
|
34 |
+
Tuple[torch.Tensor, float, bool]: Video tensor [T,C,H,W], FPS, and is_video flag
|
35 |
"""
|
36 |
if transform is None:
|
37 |
transform = transforms.Compose([
|
|
|
44 |
is_video = ext in ['.mp4', '.avi', '.mov']
|
45 |
|
46 |
if is_video:
|
47 |
+
# Load video file info
|
48 |
+
video_clip = VideoFileClip(media_path)
|
49 |
+
duration = video_clip.duration
|
50 |
+
original_fps = video_clip.fps
|
51 |
+
|
52 |
+
# Case 1: Video longer than 6 seconds, sample first 6 seconds + 1 frame
|
53 |
+
if duration > 6.0:
|
54 |
+
sampling_fps = 8 # 8 frames per second
|
55 |
+
frames = load_video(media_path, sampling_fps=sampling_fps, max_frames=max_frames)
|
56 |
+
fps = sampling_fps
|
57 |
+
# Cases 2 and 3: Video shorter than 6 seconds
|
58 |
+
else:
|
59 |
+
# Load all frames
|
60 |
+
frames = load_video(media_path)
|
61 |
+
|
62 |
+
# Case 2: Total frames less than max_frames, need interpolation
|
63 |
+
if len(frames) < max_frames:
|
64 |
+
fps = len(frames) / duration # Keep original fps
|
65 |
+
|
66 |
+
# Evenly interpolate to max_frames
|
67 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
68 |
+
new_frames = []
|
69 |
+
for i in indices:
|
70 |
+
idx = int(i)
|
71 |
+
new_frames.append(frames[idx])
|
72 |
+
frames = new_frames
|
73 |
+
# Case 3: Total frames more than max_frames but video less than 6 seconds
|
74 |
+
else:
|
75 |
+
# Evenly sample to max_frames
|
76 |
+
indices = np.linspace(0, len(frames) - 1, max_frames)
|
77 |
+
new_frames = []
|
78 |
+
for i in indices:
|
79 |
+
idx = int(i)
|
80 |
+
new_frames.append(frames[idx])
|
81 |
+
frames = new_frames
|
82 |
+
fps = max_frames / duration # New fps to maintain duration
|
83 |
else:
|
84 |
# Handle image as single frame
|
85 |
image = load_image(media_path)
|
86 |
frames = [image]
|
87 |
fps = 8 # Default fps for images
|
88 |
+
|
89 |
+
# Duplicate frame to max_frames
|
|
|
|
|
|
|
|
|
90 |
while len(frames) < max_frames:
|
91 |
+
frames.append(frames[0].copy())
|
92 |
+
|
93 |
# Convert frames to tensor
|
94 |
video_tensor = torch.stack([transform(frame) for frame in frames])
|
95 |
|
|
|
110 |
help='Camera motion mode: "trans <dx> <dy> <dz>" or "rot <axis> <angle>" or "spiral <radius>"')
|
111 |
parser.add_argument('--object_motion', type=str, default=None, help='Object motion mode: up/down/left/right')
|
112 |
parser.add_argument('--object_mask', type=str, default=None, help='Path to object mask image (binary image)')
|
113 |
+
parser.add_argument('--tracking_method', type=str, default='spatracker', choices=['spatracker', 'moge', 'cotracker'],
|
114 |
+
help='Tracking method to use (spatracker, cotracker or moge)')
|
115 |
args = parser.parse_args()
|
116 |
|
117 |
# Load input video/image
|
|
|
122 |
|
123 |
# Initialize pipeline
|
124 |
das = DiffusionAsShaderPipeline(gpu_id=args.gpu, output_dir=args.output_dir)
|
125 |
+
das.fps = fps
|
126 |
if args.tracking_method == "moge" and args.tracking_path is None:
|
127 |
moge = MoGeModel.from_pretrained("Ruicheng/moge-vitl").to(das.device)
|
128 |
|
|
|
187 |
poses = torch.eye(4).unsqueeze(0).repeat(49, 1, 1)
|
188 |
# change pred_tracks into screen coordinate
|
189 |
pred_tracks_flatten = pred_tracks.reshape(video_tensor.shape[0], H*W, 3)
|
190 |
+
pred_tracks = cam_motion.w2s_moge(pred_tracks_flatten, poses).reshape([video_tensor.shape[0], H, W, 3]) # [T, H, W, 3]
|
191 |
_, tracking_tensor = das.visualize_tracking_moge(
|
192 |
pred_tracks.cpu().numpy(),
|
193 |
infer_result["mask"].cpu().numpy()
|
|
|
195 |
print('export tracking video via MoGe.')
|
196 |
|
197 |
else:
|
198 |
+
|
199 |
+
if args.tracking_method == "cotracker":
|
200 |
+
pred_tracks, pred_visibility = das.generate_tracking_cotracker(video_tensor) # T N 3, T N
|
201 |
+
else:
|
202 |
+
pred_tracks, pred_visibility, T_Firsts = das.generate_tracking_spatracker(video_tensor) # T N 3, T N, B N
|
203 |
+
|
204 |
+
# Preprocess video tensor to match VGGT requirements
|
205 |
+
t, c, h, w = video_tensor.shape
|
206 |
+
new_width = 518
|
207 |
+
new_height = round(h * (new_width / w) / 14) * 14
|
208 |
+
resize_transform = transforms.Resize((new_height, new_width), interpolation=Image.BICUBIC)
|
209 |
+
video_vggt = resize_transform(video_tensor) # [T, C, H, W]
|
210 |
+
|
211 |
+
if new_height > 518:
|
212 |
+
start_y = (new_height - 518) // 2
|
213 |
+
video_vggt = video_vggt[:, :, start_y:start_y + 518, :]
|
214 |
+
|
215 |
+
# Get extrinsic and intrinsic matrices
|
216 |
+
vggt_model = VGGT.from_pretrained("facebook/VGGT-1B").to(das.device)
|
217 |
+
|
218 |
+
with torch.no_grad():
|
219 |
+
with torch.cuda.amp.autocast(dtype=das.dtype):
|
220 |
+
|
221 |
+
video_vggt = video_vggt.unsqueeze(0) # [1, T, C, H, W]
|
222 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_vggt.to(das.device))
|
223 |
+
|
224 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
225 |
+
extr, intr = pose_encoding_to_extri_intri(vggt_model.camera_head(aggregated_tokens_list)[-1], video_vggt.shape[-2:])
|
226 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_vggt, ps_idx)
|
227 |
+
|
228 |
+
cam_motion.set_intr(intr)
|
229 |
+
cam_motion.set_extr(extr)
|
230 |
|
231 |
# Apply camera motion if specified
|
232 |
if args.camera_motion:
|
233 |
poses = cam_motion.get_default_motion() # shape: [49, 4, 4]
|
234 |
+
pred_tracks_world = cam_motion.s2w_vggt(pred_tracks, extr, intr)
|
235 |
+
pred_tracks = cam_motion.w2s_vggt(pred_tracks_world, extr, intr, poses) # [T, N, 3]
|
236 |
print("Camera motion applied")
|
237 |
|
238 |
# Apply object motion if specified
|
|
|
249 |
motion_generator = ObjectMotionGenerator(device=das.device)
|
250 |
|
251 |
pred_tracks = motion_generator.apply_motion(
|
252 |
+
pred_tracks=pred_tracks,
|
253 |
mask=mask,
|
254 |
motion_type=args.object_motion,
|
255 |
distance=50,
|
|
|
258 |
).unsqueeze(0)
|
259 |
print(f"Object motion '{args.object_motion}' applied using mask from {args.object_mask}")
|
260 |
|
261 |
+
if args.tracking_method == "cotracker":
|
262 |
+
_, tracking_tensor = das.visualize_tracking_cotracker(pred_tracks, pred_visibility)
|
263 |
+
else:
|
264 |
+
_, tracking_tensor = das.visualize_tracking_spatracker(video_tensor, pred_tracks, pred_visibility, T_Firsts)
|
265 |
|
266 |
das.apply_tracking(
|
267 |
video_tensor=video_tensor,
|
268 |
+
fps=fps,
|
269 |
tracking_tensor=tracking_tensor,
|
270 |
img_cond_tensor=repaint_img_tensor,
|
271 |
prompt=args.prompt,
|
models/pipelines.py
CHANGED
@@ -22,9 +22,9 @@ from models.spatracker.utils.visualizer import Visualizer
|
|
22 |
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
|
23 |
|
24 |
from submodules.MoGe.moge.model import MoGeModel
|
|
|
25 |
from image_gen_aux import DepthPreprocessor
|
26 |
from moviepy.editor import ImageSequenceClip
|
27 |
-
import spaces
|
28 |
|
29 |
class DiffusionAsShaderPipeline:
|
30 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
@@ -45,6 +45,7 @@ class DiffusionAsShaderPipeline:
|
|
45 |
# device
|
46 |
self.device = f"cuda:{gpu_id}"
|
47 |
torch.cuda.set_device(gpu_id)
|
|
|
48 |
|
49 |
# files
|
50 |
self.output_dir = output_dir
|
@@ -56,7 +57,6 @@ class DiffusionAsShaderPipeline:
|
|
56 |
transforms.ToTensor()
|
57 |
])
|
58 |
|
59 |
-
@spaces.GPU(duration=240)
|
60 |
@torch.no_grad()
|
61 |
def _infer(
|
62 |
self,
|
@@ -65,7 +65,7 @@ class DiffusionAsShaderPipeline:
|
|
65 |
tracking_tensor: torch.Tensor = None,
|
66 |
image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
|
67 |
output_path: str = "./output.mp4",
|
68 |
-
num_inference_steps: int =
|
69 |
guidance_scale: float = 6.0,
|
70 |
num_videos_per_prompt: int = 1,
|
71 |
dtype: torch.dtype = torch.bfloat16,
|
@@ -114,6 +114,8 @@ class DiffusionAsShaderPipeline:
|
|
114 |
pipe.text_encoder.eval()
|
115 |
pipe.vae.eval()
|
116 |
|
|
|
|
|
117 |
# Process tracking tensor
|
118 |
tracking_maps = tracking_tensor.float() # [T, C, H, W]
|
119 |
tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
|
@@ -167,60 +169,9 @@ class DiffusionAsShaderPipeline:
|
|
167 |
|
168 |
def _set_camera_motion(self, camera_motion):
|
169 |
self.camera_motion = camera_motion
|
170 |
-
|
171 |
-
def _get_intr(self, fov, H=480, W=720):
|
172 |
-
fov_rad = math.radians(fov)
|
173 |
-
focal_length = (W / 2) / math.tan(fov_rad / 2)
|
174 |
-
|
175 |
-
cx = W / 2
|
176 |
-
cy = H / 2
|
177 |
-
|
178 |
-
intr = torch.tensor([
|
179 |
-
[focal_length, 0, cx],
|
180 |
-
[0, focal_length, cy],
|
181 |
-
[0, 0, 1]
|
182 |
-
], dtype=torch.float32)
|
183 |
-
|
184 |
-
return intr
|
185 |
-
|
186 |
-
@spaces.GPU
|
187 |
-
def _apply_poses(self, pts, intr, poses):
|
188 |
-
"""
|
189 |
-
Args:
|
190 |
-
pts (torch.Tensor): pointclouds coordinates [T, N, 3]
|
191 |
-
intr (torch.Tensor): camera intrinsics [T, 3, 3]
|
192 |
-
poses (numpy.ndarray): camera poses [T, 4, 4]
|
193 |
-
"""
|
194 |
-
poses = torch.from_numpy(poses).float().to(self.device)
|
195 |
-
|
196 |
-
T, N, _ = pts.shape
|
197 |
-
ones = torch.ones(T, N, 1, device=self.device, dtype=torch.float)
|
198 |
-
pts_hom = torch.cat([pts[:, :, :2], ones], dim=-1) # (T, N, 3)
|
199 |
-
pts_cam = torch.bmm(pts_hom, torch.linalg.inv(intr).transpose(1, 2)) # (T, N, 3)
|
200 |
-
pts_cam[:,:, :3] /= pts[:, :, 2:3]
|
201 |
-
|
202 |
-
# to homogeneous
|
203 |
-
pts_cam = torch.cat([pts_cam, ones], dim=-1) # (T, N, 4)
|
204 |
-
|
205 |
-
if poses.shape[0] == 1:
|
206 |
-
poses = poses.repeat(T, 1, 1)
|
207 |
-
elif poses.shape[0] != T:
|
208 |
-
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
209 |
-
|
210 |
-
pts_world = torch.bmm(pts_cam, poses.transpose(1, 2))[:, :, :3] # (T, N, 3)
|
211 |
-
|
212 |
-
pts_proj = torch.bmm(pts_world, intr.transpose(1, 2)) # (T, N, 3)
|
213 |
-
pts_proj[:, :, :2] /= pts_proj[:, :, 2:3]
|
214 |
-
|
215 |
-
return pts_proj
|
216 |
-
|
217 |
-
def apply_traj_on_tracking(self, pred_tracks, camera_motion=None, fov=55, frame_num=49):
|
218 |
-
intr = self._get_intr(fov).unsqueeze(0).repeat(frame_num, 1, 1).to(self.device)
|
219 |
-
tracking_pts = self._apply_poses(pred_tracks.squeeze(), intr, camera_motion).unsqueeze(0)
|
220 |
-
return tracking_pts
|
221 |
|
222 |
##============= SpatialTracker =============##
|
223 |
-
|
224 |
def generate_tracking_spatracker(self, video_tensor, density=70):
|
225 |
"""Generate tracking video
|
226 |
|
@@ -233,7 +184,7 @@ class DiffusionAsShaderPipeline:
|
|
233 |
print("Loading tracking models...")
|
234 |
# Load tracking model
|
235 |
tracker = SpaTrackerPredictor(
|
236 |
-
checkpoint=os.path.join(project_root, 'checkpoints/
|
237 |
interp_shape=(384, 576),
|
238 |
seq_length=12
|
239 |
).to(self.device)
|
@@ -268,14 +219,13 @@ class DiffusionAsShaderPipeline:
|
|
268 |
progressive_tracking=False
|
269 |
)
|
270 |
|
271 |
-
return pred_tracks, pred_visibility, T_Firsts
|
272 |
|
273 |
finally:
|
274 |
# Clean up GPU memory
|
275 |
del tracker, self.depth_preprocessor
|
276 |
torch.cuda.empty_cache()
|
277 |
|
278 |
-
@spaces.GPU
|
279 |
def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
|
280 |
video = video.unsqueeze(0).to(self.device)
|
281 |
vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
|
@@ -365,7 +315,6 @@ class DiffusionAsShaderPipeline:
|
|
365 |
outline=tuple(color),
|
366 |
)
|
367 |
|
368 |
-
@spaces.GPU
|
369 |
def visualize_tracking_moge(self, points, mask, save_tracking=True):
|
370 |
"""Visualize tracking results from MoGe model
|
371 |
|
@@ -399,8 +348,6 @@ class DiffusionAsShaderPipeline:
|
|
399 |
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
400 |
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
401 |
colors = colors.astype(np.uint8)
|
402 |
-
# colors = colors * mask[..., None]
|
403 |
-
# points = points * mask[None, :, :, None]
|
404 |
|
405 |
points = points.reshape(T, -1, 3)
|
406 |
colors = colors.reshape(-1, 3)
|
@@ -408,7 +355,7 @@ class DiffusionAsShaderPipeline:
|
|
408 |
# Initialize list to store frames
|
409 |
frames = []
|
410 |
|
411 |
-
for i, pts_i in enumerate(tqdm(points)):
|
412 |
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
413 |
pixels[..., 0] = pixels[..., 0] * W
|
414 |
pixels[..., 1] = pixels[..., 1] * H
|
@@ -451,8 +398,178 @@ class DiffusionAsShaderPipeline:
|
|
451 |
tracking_path = None
|
452 |
|
453 |
return tracking_path, tracking_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
|
455 |
-
@spaces.GPU(duration=240)
|
456 |
def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
|
457 |
"""Generate final video with motion transfer
|
458 |
|
@@ -478,7 +595,7 @@ class DiffusionAsShaderPipeline:
|
|
478 |
tracking_tensor=tracking_tensor,
|
479 |
image_tensor=img_cond_tensor,
|
480 |
output_path=final_output,
|
481 |
-
num_inference_steps=
|
482 |
guidance_scale=6.0,
|
483 |
dtype=torch.bfloat16,
|
484 |
fps=self.fps
|
@@ -493,7 +610,6 @@ class DiffusionAsShaderPipeline:
|
|
493 |
"""
|
494 |
self.object_motion = motion_type
|
495 |
|
496 |
-
@spaces.GPU(duration=120)
|
497 |
class FirstFrameRepainter:
|
498 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
499 |
"""Initialize FirstFrameRepainter
|
@@ -506,8 +622,7 @@ class FirstFrameRepainter:
|
|
506 |
self.output_dir = output_dir
|
507 |
self.max_depth = 65.0
|
508 |
os.makedirs(output_dir, exist_ok=True)
|
509 |
-
|
510 |
-
@spaces.GPU(duration=120)
|
511 |
def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
|
512 |
"""Repaint first frame using Flux
|
513 |
|
@@ -599,48 +714,158 @@ class CameraMotionGenerator:
|
|
599 |
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
600 |
|
601 |
self.intr[0, 0] = fx
|
602 |
-
self.intr[1, 1] = fy
|
|
|
|
|
603 |
|
604 |
-
def
|
605 |
"""
|
|
|
|
|
606 |
Args:
|
607 |
-
|
608 |
-
|
609 |
-
|
|
|
|
|
|
|
610 |
"""
|
611 |
-
if isinstance(
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
|
|
|
|
|
|
|
|
|
|
623 |
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
raise ValueError(f"Poses length ({poses.shape[0]}) must match sequence length ({T})")
|
628 |
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
|
634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
635 |
|
636 |
-
def
|
637 |
if isinstance(poses, np.ndarray):
|
638 |
poses = torch.from_numpy(poses)
|
639 |
assert poses.shape[0] == self.frame_num
|
640 |
poses = poses.to(torch.float32).to(self.device)
|
641 |
T, N, _ = pts.shape # (T, N, 3)
|
642 |
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
643 |
-
# Step 1: 扩展点的维度,使其变成 (T, N, 4),最后一维填充1 (齐次坐标)
|
644 |
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
645 |
points_world_h = torch.cat([pts, ones], dim=-1)
|
646 |
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
@@ -649,22 +874,21 @@ class CameraMotionGenerator:
|
|
649 |
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
650 |
|
651 |
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
652 |
-
|
653 |
-
# Step 5: 提取深度 (Z) 并拼接
|
654 |
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
655 |
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
656 |
|
657 |
-
return uvd
|
658 |
-
|
659 |
-
def apply_motion_on_pts(self, pts, camera_motion):
|
660 |
-
tracking_pts = self._apply_poses(pts.squeeze(), camera_motion).unsqueeze(0)
|
661 |
-
return tracking_pts
|
662 |
|
663 |
def set_intr(self, K):
|
664 |
if isinstance(K, np.ndarray):
|
665 |
K = torch.from_numpy(K)
|
666 |
self.intr = K.to(self.device)
|
667 |
|
|
|
|
|
|
|
|
|
|
|
668 |
def rot_poses(self, angle, axis='y'):
|
669 |
"""Generate a single rotation matrix
|
670 |
|
@@ -783,26 +1007,6 @@ class CameraMotionGenerator:
|
|
783 |
camera_poses = np.concatenate(cam_poses, axis=0)
|
784 |
return torch.from_numpy(camera_poses).to(self.device)
|
785 |
|
786 |
-
def rot(self, pts, angle, axis):
|
787 |
-
"""
|
788 |
-
pts: torch.Tensor, (T, N, 2)
|
789 |
-
"""
|
790 |
-
rot_mats = self.rot_poses(angle, axis)
|
791 |
-
pts = self.apply_motion_on_pts(pts, rot_mats)
|
792 |
-
return pts
|
793 |
-
|
794 |
-
def trans(self, pts, dx, dy, dz):
|
795 |
-
if pts.shape[-1] != 3:
|
796 |
-
raise ValueError("points should be in the 3d coordinate.")
|
797 |
-
trans_mats = self.trans_poses(dx, dy, dz)
|
798 |
-
pts = self.apply_motion_on_pts(pts, trans_mats)
|
799 |
-
return pts
|
800 |
-
|
801 |
-
def spiral(self, pts, radius):
|
802 |
-
spiral_poses = self.spiral_poses(radius)
|
803 |
-
pts = self.apply_motion_on_pts(pts, spiral_poses)
|
804 |
-
return pts
|
805 |
-
|
806 |
def get_default_motion(self):
|
807 |
"""Parse motion parameters and generate corresponding motion matrices
|
808 |
|
@@ -820,6 +1024,7 @@ class CameraMotionGenerator:
|
|
820 |
- if not specified, defaults to 0-49
|
821 |
- frames after end_frame will maintain the final transformation
|
822 |
- for combined transformations, they are applied in sequence
|
|
|
823 |
|
824 |
Returns:
|
825 |
torch.Tensor: Motion matrices [num_frames, 4, 4]
|
|
|
22 |
from models.cogvideox_tracking import CogVideoXImageToVideoPipelineTracking
|
23 |
|
24 |
from submodules.MoGe.moge.model import MoGeModel
|
25 |
+
|
26 |
from image_gen_aux import DepthPreprocessor
|
27 |
from moviepy.editor import ImageSequenceClip
|
|
|
28 |
|
29 |
class DiffusionAsShaderPipeline:
|
30 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
|
|
45 |
# device
|
46 |
self.device = f"cuda:{gpu_id}"
|
47 |
torch.cuda.set_device(gpu_id)
|
48 |
+
self.dtype = torch.bfloat16
|
49 |
|
50 |
# files
|
51 |
self.output_dir = output_dir
|
|
|
57 |
transforms.ToTensor()
|
58 |
])
|
59 |
|
|
|
60 |
@torch.no_grad()
|
61 |
def _infer(
|
62 |
self,
|
|
|
65 |
tracking_tensor: torch.Tensor = None,
|
66 |
image_tensor: torch.Tensor = None, # [C,H,W] in range [0,1]
|
67 |
output_path: str = "./output.mp4",
|
68 |
+
num_inference_steps: int = 25,
|
69 |
guidance_scale: float = 6.0,
|
70 |
num_videos_per_prompt: int = 1,
|
71 |
dtype: torch.dtype = torch.bfloat16,
|
|
|
114 |
pipe.text_encoder.eval()
|
115 |
pipe.vae.eval()
|
116 |
|
117 |
+
self.dtype = dtype
|
118 |
+
|
119 |
# Process tracking tensor
|
120 |
tracking_maps = tracking_tensor.float() # [T, C, H, W]
|
121 |
tracking_maps = tracking_maps.to(device=self.device, dtype=dtype)
|
|
|
169 |
|
170 |
def _set_camera_motion(self, camera_motion):
|
171 |
self.camera_motion = camera_motion
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
##============= SpatialTracker =============##
|
174 |
+
|
175 |
def generate_tracking_spatracker(self, video_tensor, density=70):
|
176 |
"""Generate tracking video
|
177 |
|
|
|
184 |
print("Loading tracking models...")
|
185 |
# Load tracking model
|
186 |
tracker = SpaTrackerPredictor(
|
187 |
+
checkpoint=os.path.join(project_root, 'checkpoints/spaT_final.pth'),
|
188 |
interp_shape=(384, 576),
|
189 |
seq_length=12
|
190 |
).to(self.device)
|
|
|
219 |
progressive_tracking=False
|
220 |
)
|
221 |
|
222 |
+
return pred_tracks.squeeze(0), pred_visibility.squeeze(0), T_Firsts
|
223 |
|
224 |
finally:
|
225 |
# Clean up GPU memory
|
226 |
del tracker, self.depth_preprocessor
|
227 |
torch.cuda.empty_cache()
|
228 |
|
|
|
229 |
def visualize_tracking_spatracker(self, video, pred_tracks, pred_visibility, T_Firsts, save_tracking=True):
|
230 |
video = video.unsqueeze(0).to(self.device)
|
231 |
vis = Visualizer(save_dir=self.output_dir, grayscale=False, fps=24, pad_value=0)
|
|
|
315 |
outline=tuple(color),
|
316 |
)
|
317 |
|
|
|
318 |
def visualize_tracking_moge(self, points, mask, save_tracking=True):
|
319 |
"""Visualize tracking results from MoGe model
|
320 |
|
|
|
348 |
normalized_z = np.clip((inv_z - p2) / (p98 - p2), 0, 1)
|
349 |
colors[:, :, 2] = (normalized_z * 255).astype(np.uint8)
|
350 |
colors = colors.astype(np.uint8)
|
|
|
|
|
351 |
|
352 |
points = points.reshape(T, -1, 3)
|
353 |
colors = colors.reshape(-1, 3)
|
|
|
355 |
# Initialize list to store frames
|
356 |
frames = []
|
357 |
|
358 |
+
for i, pts_i in enumerate(tqdm(points, desc="rendering frames")):
|
359 |
pixels, depths = pts_i[..., :2], pts_i[..., 2]
|
360 |
pixels[..., 0] = pixels[..., 0] * W
|
361 |
pixels[..., 1] = pixels[..., 1] * H
|
|
|
398 |
tracking_path = None
|
399 |
|
400 |
return tracking_path, tracking_video
|
401 |
+
|
402 |
+
|
403 |
+
##============= CoTracker =============##
|
404 |
+
|
405 |
+
def generate_tracking_cotracker(self, video_tensor, density=70):
|
406 |
+
"""Generate tracking video
|
407 |
+
|
408 |
+
Args:
|
409 |
+
video_tensor (torch.Tensor): Input video tensor
|
410 |
+
|
411 |
+
Returns:
|
412 |
+
tuple: (pred_tracks, pred_visibility)
|
413 |
+
- pred_tracks (torch.Tensor): Tracking points with depth [T, N, 3]
|
414 |
+
- pred_visibility (torch.Tensor): Visibility mask [T, N, 1]
|
415 |
+
"""
|
416 |
+
# Generate tracking points
|
417 |
+
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker3_offline").to(self.device)
|
418 |
+
|
419 |
+
# Load depth model
|
420 |
+
if not hasattr(self, 'depth_preprocessor') or self.depth_preprocessor is None:
|
421 |
+
self.depth_preprocessor = DepthPreprocessor.from_pretrained("Intel/zoedepth-nyu-kitti")
|
422 |
+
self.depth_preprocessor.to(self.device)
|
423 |
+
|
424 |
+
try:
|
425 |
+
video = video_tensor.unsqueeze(0).to(self.device)
|
426 |
+
|
427 |
+
# Process all frames to get depth maps
|
428 |
+
video_depths = []
|
429 |
+
for i in tqdm(range(video_tensor.shape[0]), desc="estimating depth"):
|
430 |
+
frame = (video_tensor[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
431 |
+
depth = self.depth_preprocessor(Image.fromarray(frame))[0]
|
432 |
+
depth_tensor = transforms.ToTensor()(depth) # [1, H, W]
|
433 |
+
video_depths.append(depth_tensor)
|
434 |
+
|
435 |
+
video_depth = torch.stack(video_depths, dim=0).to(self.device) # [T, 1, H, W]
|
436 |
+
|
437 |
+
# Get tracking points and visibility
|
438 |
+
print("tracking...")
|
439 |
+
pred_tracks, pred_visibility = cotracker(video, grid_size=density) # B T N 2, B T N 1
|
440 |
+
|
441 |
+
# Extract dimensions
|
442 |
+
B, T, N, _ = pred_tracks.shape
|
443 |
+
H, W = video_depth.shape[2], video_depth.shape[3]
|
444 |
+
|
445 |
+
# Create output tensor with depth
|
446 |
+
pred_tracks_with_depth = torch.zeros((B, T, N, 3), device=self.device)
|
447 |
+
pred_tracks_with_depth[:, :, :, :2] = pred_tracks # Copy x,y coordinates
|
448 |
+
|
449 |
+
# Vectorized approach to get depths for all points
|
450 |
+
# Reshape pred_tracks to process all batches and frames at once
|
451 |
+
flat_tracks = pred_tracks.reshape(B*T, N, 2)
|
452 |
+
|
453 |
+
# Clamp coordinates to valid image bounds
|
454 |
+
x_coords = flat_tracks[:, :, 0].clamp(0, W-1).long() # [B*T, N]
|
455 |
+
y_coords = flat_tracks[:, :, 1].clamp(0, H-1).long() # [B*T, N]
|
456 |
+
|
457 |
+
# Get depths for all points at once
|
458 |
+
# For each point in the flattened batch, get its depth from the corresponding frame
|
459 |
+
depths = torch.zeros((B*T, N), device=self.device)
|
460 |
+
for bt in range(B*T):
|
461 |
+
t = bt % T # Time index
|
462 |
+
depths[bt] = video_depth[t, 0, y_coords[bt], x_coords[bt]]
|
463 |
+
|
464 |
+
# Reshape depths back to [B, T, N] and assign to output tensor
|
465 |
+
pred_tracks_with_depth[:, :, :, 2] = depths.reshape(B, T, N)
|
466 |
+
|
467 |
+
return pred_tracks_with_depth.squeeze(0), pred_visibility.squeeze(0)
|
468 |
+
|
469 |
+
finally:
|
470 |
+
del cotracker
|
471 |
+
torch.cuda.empty_cache()
|
472 |
+
|
473 |
+
def visualize_tracking_cotracker(self, points, vis_mask=None, save_tracking=True, point_wise=4, video_size=(480, 720)):
|
474 |
+
"""Visualize tracking results from CoTracker
|
475 |
+
|
476 |
+
Args:
|
477 |
+
points (torch.Tensor): Points array of shape [T, N, 3]
|
478 |
+
vis_mask (torch.Tensor): Visibility mask of shape [T, N, 1]
|
479 |
+
save_tracking (bool): Whether to save tracking video
|
480 |
+
point_wise (int): Size of points in visualization
|
481 |
+
video_size (tuple): Render size (height, width)
|
482 |
+
|
483 |
+
Returns:
|
484 |
+
tuple: (tracking_path, tracking_video)
|
485 |
+
"""
|
486 |
+
# Move tensors to CPU and convert to numpy
|
487 |
+
if isinstance(points, torch.Tensor):
|
488 |
+
points = points.detach().cpu().numpy()
|
489 |
+
|
490 |
+
if vis_mask is not None and isinstance(vis_mask, torch.Tensor):
|
491 |
+
vis_mask = vis_mask.detach().cpu().numpy()
|
492 |
+
# Reshape if needed
|
493 |
+
if vis_mask.ndim == 3 and vis_mask.shape[2] == 1:
|
494 |
+
vis_mask = vis_mask.squeeze(-1)
|
495 |
+
|
496 |
+
T, N, _ = points.shape
|
497 |
+
H, W = video_size
|
498 |
+
|
499 |
+
if vis_mask is None:
|
500 |
+
vis_mask = np.ones((T, N), dtype=bool)
|
501 |
+
|
502 |
+
colors = np.zeros((N, 3), dtype=np.uint8)
|
503 |
+
|
504 |
+
first_frame_pts = points[0]
|
505 |
+
|
506 |
+
u_min, u_max = 0, W
|
507 |
+
u_normalized = np.clip((first_frame_pts[:, 0] - u_min) / (u_max - u_min), 0, 1)
|
508 |
+
colors[:, 0] = (u_normalized * 255).astype(np.uint8)
|
509 |
+
|
510 |
+
v_min, v_max = 0, H
|
511 |
+
v_normalized = np.clip((first_frame_pts[:, 1] - v_min) / (v_max - v_min), 0, 1)
|
512 |
+
colors[:, 1] = (v_normalized * 255).astype(np.uint8)
|
513 |
+
|
514 |
+
z_values = first_frame_pts[:, 2]
|
515 |
+
if np.all(z_values == 0):
|
516 |
+
colors[:, 2] = np.random.randint(0, 256, N, dtype=np.uint8)
|
517 |
+
else:
|
518 |
+
inv_z = 1 / (z_values + 1e-10)
|
519 |
+
p2 = np.percentile(inv_z, 2)
|
520 |
+
p98 = np.percentile(inv_z, 98)
|
521 |
+
normalized_z = np.clip((inv_z - p2) / (p98 - p2 + 1e-10), 0, 1)
|
522 |
+
colors[:, 2] = (normalized_z * 255).astype(np.uint8)
|
523 |
+
|
524 |
+
frames = []
|
525 |
+
|
526 |
+
for i in tqdm(range(T), desc="rendering frames"):
|
527 |
+
pts_i = points[i]
|
528 |
+
|
529 |
+
visibility = vis_mask[i]
|
530 |
+
|
531 |
+
pixels, depths = pts_i[visibility, :2], pts_i[visibility, 2]
|
532 |
+
pixels = pixels.astype(int)
|
533 |
+
|
534 |
+
in_frame = self.valid_mask(pixels, W, H)
|
535 |
+
pixels = pixels[in_frame]
|
536 |
+
depths = depths[in_frame]
|
537 |
+
frame_rgb = colors[visibility][in_frame]
|
538 |
+
|
539 |
+
img = Image.fromarray(np.zeros((H, W, 3), dtype=np.uint8), mode="RGB")
|
540 |
+
|
541 |
+
sorted_pixels, _, sort_index = self.sort_points_by_depth(pixels, depths)
|
542 |
+
sorted_rgb = frame_rgb[sort_index]
|
543 |
+
|
544 |
+
for j in range(sorted_pixels.shape[0]):
|
545 |
+
self.draw_rectangle(
|
546 |
+
img,
|
547 |
+
coord=(sorted_pixels[j, 0], sorted_pixels[j, 1]),
|
548 |
+
side_length=point_wise,
|
549 |
+
color=sorted_rgb[j],
|
550 |
+
)
|
551 |
+
|
552 |
+
frames.append(np.array(img))
|
553 |
+
|
554 |
+
# Convert frames to video tensor in range [0,1]
|
555 |
+
tracking_video = torch.from_numpy(np.stack(frames)).permute(0, 3, 1, 2).float() / 255.0
|
556 |
+
|
557 |
+
tracking_path = None
|
558 |
+
if save_tracking:
|
559 |
+
try:
|
560 |
+
tracking_path = os.path.join(self.output_dir, "tracking_video_cotracker.mp4")
|
561 |
+
# Convert back to uint8 for saving
|
562 |
+
uint8_frames = [frame.astype(np.uint8) for frame in frames]
|
563 |
+
clip = ImageSequenceClip(uint8_frames, fps=self.fps)
|
564 |
+
clip.write_videofile(tracking_path, codec="libx264", fps=self.fps, logger=None)
|
565 |
+
print(f"Video saved to {tracking_path}")
|
566 |
+
except Exception as e:
|
567 |
+
print(f"Warning: Failed to save tracking video: {e}")
|
568 |
+
tracking_path = None
|
569 |
+
|
570 |
+
return tracking_path, tracking_video
|
571 |
+
|
572 |
|
|
|
573 |
def apply_tracking(self, video_tensor, fps=8, tracking_tensor=None, img_cond_tensor=None, prompt=None, checkpoint_path=None):
|
574 |
"""Generate final video with motion transfer
|
575 |
|
|
|
595 |
tracking_tensor=tracking_tensor,
|
596 |
image_tensor=img_cond_tensor,
|
597 |
output_path=final_output,
|
598 |
+
num_inference_steps=25,
|
599 |
guidance_scale=6.0,
|
600 |
dtype=torch.bfloat16,
|
601 |
fps=self.fps
|
|
|
610 |
"""
|
611 |
self.object_motion = motion_type
|
612 |
|
|
|
613 |
class FirstFrameRepainter:
|
614 |
def __init__(self, gpu_id=0, output_dir='outputs'):
|
615 |
"""Initialize FirstFrameRepainter
|
|
|
622 |
self.output_dir = output_dir
|
623 |
self.max_depth = 65.0
|
624 |
os.makedirs(output_dir, exist_ok=True)
|
625 |
+
|
|
|
626 |
def repaint(self, image_tensor, prompt, depth_path=None, method="dav"):
|
627 |
"""Repaint first frame using Flux
|
628 |
|
|
|
714 |
fx = fy = (W / 2) / math.tan(fov_rad / 2)
|
715 |
|
716 |
self.intr[0, 0] = fx
|
717 |
+
self.intr[1, 1] = fy
|
718 |
+
|
719 |
+
self.extr = torch.eye(4, device=device)
|
720 |
|
721 |
+
def s2w_vggt(self, points, extrinsics, intrinsics):
|
722 |
"""
|
723 |
+
Transform points from pixel coordinates to world coordinates
|
724 |
+
|
725 |
Args:
|
726 |
+
points: Point cloud data of shape [T, N, 3] in uvz format
|
727 |
+
extrinsics: Camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
|
728 |
+
intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
|
729 |
+
|
730 |
+
Returns:
|
731 |
+
world_points: Point cloud in world coordinates [T, N, 3]
|
732 |
"""
|
733 |
+
if isinstance(points, torch.Tensor):
|
734 |
+
points = points.detach().cpu().numpy()
|
735 |
+
|
736 |
+
if isinstance(extrinsics, torch.Tensor):
|
737 |
+
extrinsics = extrinsics.detach().cpu().numpy()
|
738 |
+
# Handle batch dimension
|
739 |
+
if extrinsics.ndim == 4: # [B, T, 3, 4]
|
740 |
+
extrinsics = extrinsics[0] # Take first batch
|
741 |
+
|
742 |
+
if isinstance(intrinsics, torch.Tensor):
|
743 |
+
intrinsics = intrinsics.detach().cpu().numpy()
|
744 |
+
# Handle batch dimension
|
745 |
+
if intrinsics.ndim == 4: # [B, T, 3, 3]
|
746 |
+
intrinsics = intrinsics[0] # Take first batch
|
747 |
+
|
748 |
+
T, N, _ = points.shape
|
749 |
+
world_points = np.zeros_like(points)
|
750 |
|
751 |
+
# Extract uvz coordinates
|
752 |
+
uvz = points
|
753 |
+
valid_mask = uvz[..., 2] > 0
|
|
|
754 |
|
755 |
+
# Create homogeneous coordinates [u, v, 1]
|
756 |
+
uv_homogeneous = np.concatenate([uvz[..., :2], np.ones((T, N, 1))], axis=-1)
|
757 |
+
|
758 |
+
# Transform from pixel to camera coordinates
|
759 |
+
for i in range(T):
|
760 |
+
K = intrinsics[i]
|
761 |
+
K_inv = np.linalg.inv(K)
|
762 |
+
|
763 |
+
R = extrinsics[i, :, :3]
|
764 |
+
t = extrinsics[i, :, 3]
|
765 |
+
|
766 |
+
R_inv = np.linalg.inv(R)
|
767 |
+
|
768 |
+
valid_indices = np.where(valid_mask[i])[0]
|
769 |
+
|
770 |
+
if len(valid_indices) > 0:
|
771 |
+
valid_uv = uv_homogeneous[i, valid_indices]
|
772 |
+
valid_z = uvz[i, valid_indices, 2]
|
773 |
+
|
774 |
+
valid_xyz_camera = valid_uv @ K_inv.T
|
775 |
+
valid_xyz_camera = valid_xyz_camera * valid_z[:, np.newaxis]
|
776 |
+
|
777 |
+
# Transform from camera to world coordinates: X_world = R^-1 * (X_camera - t)
|
778 |
+
valid_world_points = (valid_xyz_camera - t) @ R_inv.T
|
779 |
+
|
780 |
+
world_points[i, valid_indices] = valid_world_points
|
781 |
+
|
782 |
+
return world_points
|
783 |
|
784 |
+
def w2s_vggt(self, world_points, extrinsics, intrinsics, poses=None):
|
785 |
+
"""
|
786 |
+
Project points from world coordinates to camera view
|
787 |
+
|
788 |
+
Args:
|
789 |
+
world_points: Point cloud in world coordinates [T, N, 3]
|
790 |
+
extrinsics: Original camera extrinsic matrices [B, T, 3, 4] or [T, 3, 4]
|
791 |
+
intrinsics: Camera intrinsic matrices [B, T, 3, 3] or [T, 3, 3]
|
792 |
+
poses: Camera pose matrices [T, 4, 4], if None use first frame extrinsics
|
793 |
+
|
794 |
+
Returns:
|
795 |
+
camera_points: Point cloud in camera coordinates [T, N, 3] in uvz format
|
796 |
+
"""
|
797 |
+
if isinstance(world_points, torch.Tensor):
|
798 |
+
world_points = world_points.detach().cpu().numpy()
|
799 |
+
|
800 |
+
if isinstance(extrinsics, torch.Tensor):
|
801 |
+
extrinsics = extrinsics.detach().cpu().numpy()
|
802 |
+
if extrinsics.ndim == 4:
|
803 |
+
extrinsics = extrinsics[0]
|
804 |
+
|
805 |
+
if isinstance(intrinsics, torch.Tensor):
|
806 |
+
intrinsics = intrinsics.detach().cpu().numpy()
|
807 |
+
if intrinsics.ndim == 4:
|
808 |
+
intrinsics = intrinsics[0]
|
809 |
+
|
810 |
+
T, N, _ = world_points.shape
|
811 |
+
|
812 |
+
# If no poses provided, use first frame extrinsics
|
813 |
+
if poses is None:
|
814 |
+
pose1 = np.eye(4)
|
815 |
+
pose1[:3, :3] = extrinsics[0, :, :3]
|
816 |
+
pose1[:3, 3] = extrinsics[0, :, 3]
|
817 |
+
|
818 |
+
camera_poses = np.tile(pose1[np.newaxis, :, :], (T, 1, 1))
|
819 |
+
else:
|
820 |
+
if isinstance(poses, torch.Tensor):
|
821 |
+
camera_poses = poses.cpu().numpy()
|
822 |
+
else:
|
823 |
+
camera_poses = poses
|
824 |
+
|
825 |
+
# Scale translation by 1/5
|
826 |
+
scaled_poses = camera_poses.copy()
|
827 |
+
scaled_poses[:, :3, 3] = camera_poses[:, :3, 3] / 5.0
|
828 |
+
camera_poses = scaled_poses
|
829 |
+
|
830 |
+
# Add homogeneous coordinates
|
831 |
+
ones = np.ones([T, N, 1])
|
832 |
+
world_points_hom = np.concatenate([world_points, ones], axis=-1)
|
833 |
+
|
834 |
+
# Transform points using batch matrix multiplication
|
835 |
+
pts_cam_hom = np.matmul(world_points_hom, np.transpose(camera_poses, (0, 2, 1)))
|
836 |
+
pts_cam = pts_cam_hom[..., :3]
|
837 |
+
|
838 |
+
# Extract depth information
|
839 |
+
depths = pts_cam[..., 2:3]
|
840 |
+
valid_mask = depths[..., 0] > 0
|
841 |
+
|
842 |
+
# Normalize coordinates
|
843 |
+
normalized_pts = pts_cam / (depths + 1e-10)
|
844 |
+
|
845 |
+
# Apply intrinsic matrix for projection
|
846 |
+
pts_pixel = np.matmul(normalized_pts, np.transpose(intrinsics, (0, 2, 1)))
|
847 |
+
|
848 |
+
# Extract pixel coordinates
|
849 |
+
u = pts_pixel[..., 0:1]
|
850 |
+
v = pts_pixel[..., 1:2]
|
851 |
+
|
852 |
+
# Set invalid points to zero
|
853 |
+
u[~valid_mask] = 0
|
854 |
+
v[~valid_mask] = 0
|
855 |
+
depths[~valid_mask] = 0
|
856 |
+
|
857 |
+
# Return points in uvz format
|
858 |
+
result = np.concatenate([u, v, depths], axis=-1)
|
859 |
+
|
860 |
+
return torch.from_numpy(result)
|
861 |
|
862 |
+
def w2s_moge(self, pts, poses):
|
863 |
if isinstance(poses, np.ndarray):
|
864 |
poses = torch.from_numpy(poses)
|
865 |
assert poses.shape[0] == self.frame_num
|
866 |
poses = poses.to(torch.float32).to(self.device)
|
867 |
T, N, _ = pts.shape # (T, N, 3)
|
868 |
intr = self.intr.unsqueeze(0).repeat(self.frame_num, 1, 1)
|
|
|
869 |
ones = torch.ones((T, N, 1), device=self.device, dtype=pts.dtype)
|
870 |
points_world_h = torch.cat([pts, ones], dim=-1)
|
871 |
points_camera_h = torch.bmm(poses, points_world_h.permute(0, 2, 1))
|
|
|
874 |
points_image_h = torch.bmm(points_camera, intr.permute(0, 2, 1))
|
875 |
|
876 |
uv = points_image_h[:, :, :2] / points_image_h[:, :, 2:3]
|
|
|
|
|
877 |
depth = points_camera[:, :, 2:3] # (T, N, 1)
|
878 |
uvd = torch.cat([uv, depth], dim=-1) # (T, N, 3)
|
879 |
|
880 |
+
return uvd
|
|
|
|
|
|
|
|
|
881 |
|
882 |
def set_intr(self, K):
|
883 |
if isinstance(K, np.ndarray):
|
884 |
K = torch.from_numpy(K)
|
885 |
self.intr = K.to(self.device)
|
886 |
|
887 |
+
def set_extr(self, extr):
|
888 |
+
if isinstance(extr, np.ndarray):
|
889 |
+
extr = torch.from_numpy(extr)
|
890 |
+
self.extr = extr.to(self.device)
|
891 |
+
|
892 |
def rot_poses(self, angle, axis='y'):
|
893 |
"""Generate a single rotation matrix
|
894 |
|
|
|
1007 |
camera_poses = np.concatenate(cam_poses, axis=0)
|
1008 |
return torch.from_numpy(camera_poses).to(self.device)
|
1009 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1010 |
def get_default_motion(self):
|
1011 |
"""Parse motion parameters and generate corresponding motion matrices
|
1012 |
|
|
|
1024 |
- if not specified, defaults to 0-49
|
1025 |
- frames after end_frame will maintain the final transformation
|
1026 |
- for combined transformations, they are applied in sequence
|
1027 |
+
- moving left, up and zoom out is positive in video
|
1028 |
|
1029 |
Returns:
|
1030 |
torch.Tensor: Motion matrices [num_frames, 4, 4]
|