Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
"""
|
3 |
Gradio demo for Wan2.1 FLF2V – full streaming progress
|
|
|
4 |
Author: <your-handle>
|
5 |
"""
|
6 |
|
@@ -22,17 +23,20 @@ DEFAULT_FRAMES = 81
|
|
22 |
# ----------------------------------------------------------------------
|
23 |
|
24 |
def load_pipeline(progress):
|
25 |
-
"""Load
|
26 |
-
|
27 |
-
progress(0.0, desc="Initializing model load…")
|
28 |
image_encoder = CLIPVisionModel.from_pretrained(
|
29 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
30 |
)
|
31 |
-
progress(0.
|
|
|
|
|
32 |
vae = AutoencoderKLWan.from_pretrained(
|
33 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
34 |
)
|
35 |
-
progress(0.
|
|
|
|
|
36 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
37 |
MODEL_ID,
|
38 |
vae=vae,
|
@@ -41,81 +45,82 @@ def load_pipeline(progress):
|
|
41 |
low_cpu_mem_usage=True,
|
42 |
device_map="balanced",
|
43 |
)
|
44 |
-
progress(0.
|
|
|
|
|
45 |
pipe.image_processor = CLIPImageProcessor.from_pretrained(
|
46 |
MODEL_ID, subfolder="image_processor", use_fast=True
|
47 |
)
|
48 |
-
progress(0.
|
49 |
-
return pipe
|
50 |
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
|
56 |
ar = img.height / img.width
|
57 |
-
mod =
|
58 |
h = round(np.sqrt(max_area * ar)) // mod * mod
|
59 |
w = round(np.sqrt(max_area / ar)) // mod * mod
|
60 |
return img.resize((w, h), Image.LANCZOS), h, w
|
61 |
|
62 |
-
def center_crop_resize(img: Image.Image, h, w):
|
|
|
63 |
ratio = max(w / img.width, h / img.height)
|
64 |
-
img = img.resize(
|
|
|
|
|
|
|
65 |
return TF.center_crop(img, [h, w])
|
66 |
|
67 |
-
# ----------------------------------------------------------------------
|
68 |
-
# GENERATE --------------------------------------------------------------
|
69 |
def generate(first_frame, last_frame, prompt, negative_prompt,
|
70 |
steps, guidance, num_frames, seed, fps,
|
71 |
-
progress=gr.Progress()): #
|
72 |
|
73 |
-
# 1) Load
|
74 |
pipe = load_pipeline(progress)
|
75 |
|
76 |
-
# 2) Preprocess
|
77 |
-
progress(0.
|
78 |
-
first_frame, h, w = aspect_resize(first_frame)
|
79 |
if last_frame.size != first_frame.size:
|
80 |
-
|
81 |
-
|
|
|
82 |
|
83 |
-
# 3)
|
84 |
if seed == -1:
|
85 |
seed = torch.seed()
|
86 |
gen = torch.Generator(device=pipe.device).manual_seed(seed)
|
87 |
|
88 |
-
def
|
89 |
-
|
90 |
-
frac =
|
91 |
-
progress(frac, desc=f"Inference: step {step+1}/{steps}")
|
92 |
|
93 |
-
progress(0.
|
94 |
output = pipe(
|
95 |
image=first_frame,
|
96 |
last_image=last_frame,
|
97 |
prompt=prompt,
|
98 |
negative_prompt=negative_prompt or None,
|
99 |
-
height=h,
|
|
|
100 |
num_frames=num_frames,
|
101 |
num_inference_steps=steps,
|
102 |
guidance_scale=guidance,
|
103 |
generator=gen,
|
104 |
-
callback_on_step_end=
|
105 |
-
callback_steps=1,
|
106 |
)
|
107 |
frames = output.frames[0]
|
108 |
|
109 |
-
# 4) Export
|
110 |
-
progress(0.92, desc="
|
111 |
video_path = export_to_video(frames, fps=fps)
|
112 |
|
113 |
-
# 5)
|
114 |
-
progress(1.0, desc="
|
115 |
return video_path
|
116 |
|
117 |
-
# ----------------------------------------------------------------------
|
118 |
-
# UI --------------------------------------------------------------------
|
119 |
with gr.Blocks() as demo:
|
120 |
gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
|
121 |
|
@@ -123,8 +128,8 @@ with gr.Blocks() as demo:
|
|
123 |
first_img = gr.Image(label="First frame", type="pil")
|
124 |
last_img = gr.Image(label="Last frame", type="pil")
|
125 |
|
126 |
-
prompt = gr.Textbox(label="Prompt")
|
127 |
-
negative = gr.Textbox(label="Negative prompt (optional)")
|
128 |
|
129 |
with gr.Accordion("Advanced parameters", open=False):
|
130 |
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
|
@@ -135,13 +140,12 @@ with gr.Blocks() as demo:
|
|
135 |
|
136 |
video = gr.Video(label="Result (.mp4)")
|
137 |
|
138 |
-
|
139 |
-
|
140 |
-
run_btn.click(
|
141 |
fn=generate,
|
142 |
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
|
143 |
outputs=[video],
|
144 |
)
|
145 |
|
146 |
-
demo.queue() # enable
|
147 |
demo.launch()
|
|
|
1 |
#!/usr/bin/env python
|
2 |
"""
|
3 |
Gradio demo for Wan2.1 FLF2V – full streaming progress
|
4 |
+
No globals: pipeline, resize utils all use the local `pipe`.
|
5 |
Author: <your-handle>
|
6 |
"""
|
7 |
|
|
|
23 |
# ----------------------------------------------------------------------
|
24 |
|
25 |
def load_pipeline(progress):
|
26 |
+
"""Load & shard the pipeline across CPU/GPU with streaming progress."""
|
27 |
+
progress(0.00, desc="Init: loading image encoder…")
|
|
|
28 |
image_encoder = CLIPVisionModel.from_pretrained(
|
29 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
30 |
)
|
31 |
+
progress(0.10, desc="Loaded image encoder")
|
32 |
+
|
33 |
+
progress(0.10, desc="Loading VAE…")
|
34 |
vae = AutoencoderKLWan.from_pretrained(
|
35 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
36 |
)
|
37 |
+
progress(0.20, desc="Loaded VAE")
|
38 |
+
|
39 |
+
progress(0.20, desc="Assembling pipeline…")
|
40 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
41 |
MODEL_ID,
|
42 |
vae=vae,
|
|
|
45 |
low_cpu_mem_usage=True,
|
46 |
device_map="balanced",
|
47 |
)
|
48 |
+
progress(0.30, desc="Pipeline assembled")
|
49 |
+
|
50 |
+
progress(0.30, desc="Loading fast image processor…")
|
51 |
pipe.image_processor = CLIPImageProcessor.from_pretrained(
|
52 |
MODEL_ID, subfolder="image_processor", use_fast=True
|
53 |
)
|
54 |
+
progress(0.40, desc="Processor ready")
|
|
|
55 |
|
56 |
+
return pipe
|
57 |
|
58 |
+
def aspect_resize(img: Image.Image, pipe, max_area=MAX_AREA):
|
59 |
+
"""Resize while respecting model patch multiples, using `pipe` for scale."""
|
|
|
60 |
ar = img.height / img.width
|
61 |
+
mod = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
62 |
h = round(np.sqrt(max_area * ar)) // mod * mod
|
63 |
w = round(np.sqrt(max_area / ar)) // mod * mod
|
64 |
return img.resize((w, h), Image.LANCZOS), h, w
|
65 |
|
66 |
+
def center_crop_resize(img: Image.Image, pipe, h, w):
|
67 |
+
"""Center-crop & resize to H×W, using same Lanczos filter."""
|
68 |
ratio = max(w / img.width, h / img.height)
|
69 |
+
img = img.resize(
|
70 |
+
(round(img.width * ratio), round(img.height * ratio)),
|
71 |
+
Image.LANCZOS
|
72 |
+
)
|
73 |
return TF.center_crop(img, [h, w])
|
74 |
|
|
|
|
|
75 |
def generate(first_frame, last_frame, prompt, negative_prompt,
|
76 |
steps, guidance, num_frames, seed, fps,
|
77 |
+
progress=gr.Progress()): # Gradio progress hook
|
78 |
|
79 |
+
# 1) Load & shard pipeline
|
80 |
pipe = load_pipeline(progress)
|
81 |
|
82 |
+
# 2) Preprocess
|
83 |
+
progress(0.45, desc="Preprocessing first frame…")
|
84 |
+
first_frame, h, w = aspect_resize(first_frame, pipe)
|
85 |
if last_frame.size != first_frame.size:
|
86 |
+
progress(0.50, desc="Preprocessing last frame…")
|
87 |
+
last_frame = center_crop_resize(last_frame, pipe, h, w)
|
88 |
+
progress(0.55, desc="Frames ready")
|
89 |
|
90 |
+
# 3) Run inference with per-step callbacks
|
91 |
if seed == -1:
|
92 |
seed = torch.seed()
|
93 |
gen = torch.Generator(device=pipe.device).manual_seed(seed)
|
94 |
|
95 |
+
def _cb(step, timestep, latents):
|
96 |
+
frac = 0.55 + 0.35 * ((step + 1) / steps)
|
97 |
+
progress(frac, desc=f"Inference step {step+1}/{steps}")
|
|
|
98 |
|
99 |
+
progress(0.55, desc="Starting inference…")
|
100 |
output = pipe(
|
101 |
image=first_frame,
|
102 |
last_image=last_frame,
|
103 |
prompt=prompt,
|
104 |
negative_prompt=negative_prompt or None,
|
105 |
+
height=h,
|
106 |
+
width=w,
|
107 |
num_frames=num_frames,
|
108 |
num_inference_steps=steps,
|
109 |
guidance_scale=guidance,
|
110 |
generator=gen,
|
111 |
+
callback_on_step_end=_cb,
|
112 |
+
callback_steps=1,
|
113 |
)
|
114 |
frames = output.frames[0]
|
115 |
|
116 |
+
# 4) Export video
|
117 |
+
progress(0.92, desc="Exporting video…")
|
118 |
video_path = export_to_video(frames, fps=fps)
|
119 |
|
120 |
+
# 5) Done
|
121 |
+
progress(1.0, desc="Complete!")
|
122 |
return video_path
|
123 |
|
|
|
|
|
124 |
with gr.Blocks() as demo:
|
125 |
gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
|
126 |
|
|
|
128 |
first_img = gr.Image(label="First frame", type="pil")
|
129 |
last_img = gr.Image(label="Last frame", type="pil")
|
130 |
|
131 |
+
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
|
132 |
+
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
|
133 |
|
134 |
with gr.Accordion("Advanced parameters", open=False):
|
135 |
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
|
|
|
140 |
|
141 |
video = gr.Video(label="Result (.mp4)")
|
142 |
|
143 |
+
btn = gr.Button("Generate")
|
144 |
+
btn.click(
|
|
|
145 |
fn=generate,
|
146 |
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
|
147 |
outputs=[video],
|
148 |
)
|
149 |
|
150 |
+
demo.queue() # enable streaming updates
|
151 |
demo.launch()
|