Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
"""
|
3 |
-
Gradio demo for Wan2.1
|
4 |
-
– shows streaming status updates
|
5 |
-
– auto-downloads the generated video
|
6 |
Author: <your-handle>
|
7 |
"""
|
8 |
|
@@ -23,29 +21,34 @@ MAX_AREA = 1280 * 720
|
|
23 |
DEFAULT_FRAMES = 81
|
24 |
# ----------------------------------------------------------------------
|
25 |
|
26 |
-
def load_pipeline():
|
27 |
-
"""Load
|
|
|
|
|
28 |
image_encoder = CLIPVisionModel.from_pretrained(
|
29 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
30 |
)
|
|
|
31 |
vae = AutoencoderKLWan.from_pretrained(
|
32 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
33 |
)
|
|
|
34 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
35 |
MODEL_ID,
|
36 |
vae=vae,
|
37 |
image_encoder=image_encoder,
|
38 |
torch_dtype=DTYPE,
|
39 |
-
low_cpu_mem_usage=True,
|
40 |
-
device_map="balanced",
|
41 |
)
|
42 |
-
|
43 |
pipe.image_processor = CLIPImageProcessor.from_pretrained(
|
44 |
MODEL_ID, subfolder="image_processor", use_fast=True
|
45 |
)
|
|
|
46 |
return pipe
|
47 |
|
48 |
-
|
49 |
|
50 |
# ----------------------------------------------------------------------
|
51 |
# UTILS ----------------------------------------------------------------
|
@@ -62,90 +65,83 @@ def center_crop_resize(img: Image.Image, h, w):
|
|
62 |
return TF.center_crop(img, [h, w])
|
63 |
|
64 |
# ----------------------------------------------------------------------
|
65 |
-
# GENERATE
|
66 |
def generate(first_frame, last_frame, prompt, negative_prompt,
|
67 |
-
steps, guidance, num_frames, seed, fps
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
70 |
first_frame, h, w = aspect_resize(first_frame)
|
71 |
if last_frame.size != first_frame.size:
|
72 |
last_frame = center_crop_resize(last_frame, h, w)
|
|
|
73 |
|
74 |
-
#
|
75 |
-
yield None, None, f"Running inference ({steps} steps)..."
|
76 |
if seed == -1:
|
77 |
seed = torch.seed()
|
78 |
-
gen = torch.Generator(device=
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
image=first_frame,
|
81 |
last_image=last_frame,
|
82 |
prompt=prompt,
|
83 |
negative_prompt=negative_prompt or None,
|
84 |
-
height=h,
|
85 |
-
width=w,
|
86 |
num_frames=num_frames,
|
87 |
num_inference_steps=steps,
|
88 |
guidance_scale=guidance,
|
89 |
generator=gen,
|
|
|
|
|
90 |
)
|
91 |
frames = output.frames[0]
|
92 |
|
93 |
-
#
|
94 |
-
|
95 |
video_path = export_to_video(frames, fps=fps)
|
96 |
|
97 |
-
#
|
98 |
-
|
|
|
99 |
|
100 |
# ----------------------------------------------------------------------
|
101 |
# UI --------------------------------------------------------------------
|
102 |
with gr.Blocks() as demo:
|
103 |
-
|
104 |
-
gr.HTML("""
|
105 |
-
<script>
|
106 |
-
function downloadVideo() {
|
107 |
-
const container = document.getElementById('output_video');
|
108 |
-
if (!container) return;
|
109 |
-
const vid = container.querySelector('video');
|
110 |
-
if (!vid) return;
|
111 |
-
const src = vid.currentSrc;
|
112 |
-
const a = document.createElement('a');
|
113 |
-
a.href = src;
|
114 |
-
a.download = 'output.mp4';
|
115 |
-
document.body.appendChild(a);
|
116 |
-
a.click();
|
117 |
-
document.body.removeChild(a);
|
118 |
-
}
|
119 |
-
</script>
|
120 |
-
""")
|
121 |
-
|
122 |
-
gr.Markdown("## Wan 2.1 FLF2V – Streaming progress + auto-download")
|
123 |
|
124 |
with gr.Row():
|
125 |
first_img = gr.Image(label="First frame", type="pil")
|
126 |
last_img = gr.Image(label="Last frame", type="pil")
|
127 |
|
128 |
-
prompt = gr.Textbox(label="Prompt"
|
129 |
-
negative = gr.Textbox(label="Negative prompt (optional)"
|
130 |
|
131 |
with gr.Accordion("Advanced parameters", open=False):
|
132 |
-
steps = gr.Slider(10, 50, value=30, step=1, label="
|
133 |
-
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance
|
134 |
-
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES,
|
135 |
-
fps = gr.Slider(4, 30, value=16,
|
136 |
-
seed = gr.Number(value=-1, precision=0, label="Seed
|
137 |
|
138 |
-
|
139 |
-
status = gr.Textbox(label="Status", interactive=False)
|
140 |
-
video = gr.Video(label="Result", elem_id="output_video")
|
141 |
-
used_seed = gr.Number(label="Seed used", interactive=False)
|
142 |
|
|
|
|
|
143 |
run_btn.click(
|
144 |
fn=generate,
|
145 |
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
|
146 |
-
outputs=[video
|
147 |
-
_js="downloadVideo"
|
148 |
)
|
149 |
|
150 |
-
demo.queue()
|
151 |
demo.launch()
|
|
|
1 |
#!/usr/bin/env python
|
2 |
"""
|
3 |
+
Gradio demo for Wan2.1 FLF2V – full streaming progress
|
|
|
|
|
4 |
Author: <your-handle>
|
5 |
"""
|
6 |
|
|
|
21 |
DEFAULT_FRAMES = 81
|
22 |
# ----------------------------------------------------------------------
|
23 |
|
24 |
+
def load_pipeline(progress):
|
25 |
+
"""Load model components with progress updates."""
|
26 |
+
# 0% → 5%: start loading
|
27 |
+
progress(0.0, desc="Initializing model load…")
|
28 |
image_encoder = CLIPVisionModel.from_pretrained(
|
29 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
30 |
)
|
31 |
+
progress(0.02, desc="Image encoder loaded")
|
32 |
vae = AutoencoderKLWan.from_pretrained(
|
33 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
34 |
)
|
35 |
+
progress(0.04, desc="VAE loaded")
|
36 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
37 |
MODEL_ID,
|
38 |
vae=vae,
|
39 |
image_encoder=image_encoder,
|
40 |
torch_dtype=DTYPE,
|
41 |
+
low_cpu_mem_usage=True,
|
42 |
+
device_map="balanced",
|
43 |
)
|
44 |
+
progress(0.06, desc="Pipeline assembled")
|
45 |
pipe.image_processor = CLIPImageProcessor.from_pretrained(
|
46 |
MODEL_ID, subfolder="image_processor", use_fast=True
|
47 |
)
|
48 |
+
progress(0.08, desc="Processor ready")
|
49 |
return pipe
|
50 |
|
51 |
+
# Preload nothing here—model loads in-function to stream progress.
|
52 |
|
53 |
# ----------------------------------------------------------------------
|
54 |
# UTILS ----------------------------------------------------------------
|
|
|
65 |
return TF.center_crop(img, [h, w])
|
66 |
|
67 |
# ----------------------------------------------------------------------
|
68 |
+
# GENERATE --------------------------------------------------------------
|
69 |
def generate(first_frame, last_frame, prompt, negative_prompt,
|
70 |
+
steps, guidance, num_frames, seed, fps,
|
71 |
+
progress=gr.Progress()): # ← inject Gradio progress tracker 3
|
72 |
+
|
73 |
+
# 1) Load the pipeline with streaming
|
74 |
+
pipe = load_pipeline(progress)
|
75 |
+
|
76 |
+
# 2) Preprocess images
|
77 |
+
progress(0.10, desc="Preprocessing frames…")
|
78 |
first_frame, h, w = aspect_resize(first_frame)
|
79 |
if last_frame.size != first_frame.size:
|
80 |
last_frame = center_crop_resize(last_frame, h, w)
|
81 |
+
progress(0.12, desc="Frames ready")
|
82 |
|
83 |
+
# 3) Inference with per-step updates
|
|
|
84 |
if seed == -1:
|
85 |
seed = torch.seed()
|
86 |
+
gen = torch.Generator(device=pipe.device).manual_seed(seed)
|
87 |
+
|
88 |
+
def _callback(step, timestep, latents):
|
89 |
+
# Map step to [0.12…0.90] fraction of bar
|
90 |
+
frac = 0.12 + 0.78 * (step + 1) / steps
|
91 |
+
progress(frac, desc=f"Inference: step {step+1}/{steps}")
|
92 |
+
|
93 |
+
progress(0.12, desc="Starting inference…")
|
94 |
+
output = pipe(
|
95 |
image=first_frame,
|
96 |
last_image=last_frame,
|
97 |
prompt=prompt,
|
98 |
negative_prompt=negative_prompt or None,
|
99 |
+
height=h, width=w,
|
|
|
100 |
num_frames=num_frames,
|
101 |
num_inference_steps=steps,
|
102 |
guidance_scale=guidance,
|
103 |
generator=gen,
|
104 |
+
callback_on_step_end=_callback,
|
105 |
+
callback_steps=1, # call our callback every step 4
|
106 |
)
|
107 |
frames = output.frames[0]
|
108 |
|
109 |
+
# 4) Export
|
110 |
+
progress(0.92, desc="Building video…")
|
111 |
video_path = export_to_video(frames, fps=fps)
|
112 |
|
113 |
+
# 5) Complete!
|
114 |
+
progress(1.0, desc="Done!")
|
115 |
+
return video_path
|
116 |
|
117 |
# ----------------------------------------------------------------------
|
118 |
# UI --------------------------------------------------------------------
|
119 |
with gr.Blocks() as demo:
|
120 |
+
gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
with gr.Row():
|
123 |
first_img = gr.Image(label="First frame", type="pil")
|
124 |
last_img = gr.Image(label="Last frame", type="pil")
|
125 |
|
126 |
+
prompt = gr.Textbox(label="Prompt")
|
127 |
+
negative = gr.Textbox(label="Negative prompt (optional)")
|
128 |
|
129 |
with gr.Accordion("Advanced parameters", open=False):
|
130 |
+
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
|
131 |
+
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
|
132 |
+
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, label="Frames")
|
133 |
+
fps = gr.Slider(4, 30, value=16, label="FPS")
|
134 |
+
seed = gr.Number(value=-1, precision=0, label="Seed")
|
135 |
|
136 |
+
video = gr.Video(label="Result (.mp4)")
|
|
|
|
|
|
|
137 |
|
138 |
+
# bind generator to button; progress bar overlays on the video output
|
139 |
+
run_btn = gr.Button("Generate")
|
140 |
run_btn.click(
|
141 |
fn=generate,
|
142 |
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
|
143 |
+
outputs=[video],
|
|
|
144 |
)
|
145 |
|
146 |
+
demo.queue() # enable progress tracking
|
147 |
demo.launch()
|