Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
5516eb1 verified
raw
history blame
5.59 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
– shows streaming status updates
– auto-downloads the generated video
Author: <your-handle>
"""
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
import torchvision.transforms.functional as TF
# ---------------------------------------------------------------------
# CONFIG ----------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# ----------------------------------------------------------------------
def load_pipeline():
"""Load & shard the pipeline across CPU/GPU with Accelerate."""
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=DTYPE,
low_cpu_mem_usage=True, # lazy‐load to CPU RAM
device_map="balanced", # shard across CPU/GPU
)
# switch to the fast Rust processor
pipe.image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="image_processor", use_fast=True
)
return pipe
PIPE = load_pipeline()
# ----------------------------------------------------------------------
# UTILS ----------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
ratio = max(w / img.width, h / img.height)
img = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
return TF.center_crop(img, [h, w])
# ----------------------------------------------------------------------
# GENERATE (streaming) --------------------------------------------------
def generate(first_frame, last_frame, prompt, negative_prompt,
steps, guidance, num_frames, seed, fps):
# 1) Preprocess
yield None, None, "Preprocessing images..."
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
last_frame = center_crop_resize(last_frame, h, w)
# 2) Inference
yield None, None, f"Running inference ({steps} steps)..."
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
output = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
frames = output.frames[0]
# 3) Export
yield None, None, "Exporting video..."
video_path = export_to_video(frames, fps=fps)
# 4) Done
yield video_path, seed, "Done! Your browser will download the video."
# ----------------------------------------------------------------------
# UI --------------------------------------------------------------------
with gr.Blocks() as demo:
# inject JS for auto-download
gr.HTML("""
<script>
function downloadVideo() {
const container = document.getElementById('output_video');
if (!container) return;
const vid = container.querySelector('video');
if (!vid) return;
const src = vid.currentSrc;
const a = document.createElement('a');
a.href = src;
a.download = 'output.mp4';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
}
</script>
""")
gr.Markdown("## Wan 2.1 FLF2V – Streaming progress + auto-download")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
status = gr.Textbox(label="Status", interactive=False)
video = gr.Video(label="Result", elem_id="output_video")
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
outputs=[video, used_seed, status],
_js="downloadVideo"
)
demo.queue()
demo.launch()