Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
2b5109d verified
raw
history blame
5.5 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
Loads the huge model once, uses balanced device placement,
streams high-level progress, and auto-offers the .mp4 for download.
"""
import os
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF
# --------------------------------------------------------------------
# CONFIG
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16 # half-precision
MAX_AREA = 1280 * 720 # ≤720p
DEFAULT_FRAMES = 81 # ≈5s @16fps
# --------------------------------------------------------------------
def load_pipeline():
# 1) image encoder in full precision
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) VAE in reduced precision
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) CLIPImageProcessor so we get the right class
image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="", torch_dtype=DTYPE
)
# 4) load everything with a balanced device map
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
image_processor=image_processor,
torch_dtype=DTYPE,
device_map="balanced", # splits weights CPU/GPU
)
return pipe
# load once at import
PIPE = load_pipeline()
# --------------------------------------------------------------------
# UTILS
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
"""Resize while respecting multiples of the model’s patch size."""
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
"""Crop-and-resize to exactly (h, w)."""
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# --------------------------------------------------------------------
# GENERATE (with simple progress streaming)
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress=gr.Progress(), # gradio’s built-in progress callback
):
# pick or set seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# 0→10%: resize
progress(0.0, desc="Resizing first frame…")
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
progress(0.1, desc="Resizing last frame…")
last_frame = center_crop_resize(last_frame, h, w)
# 10→20%: ready to run
progress(0.2, desc="Starting video inference…")
result = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
# 80→100%: export
progress(0.8, desc="Assembling video file…")
video_path = export_to_video(result.frames[0], fps=fps)
progress(1.0, desc="Done!")
# return path so gr.File offers immediate download, plus seed used
return video_path, seed
# --------------------------------------------------------------------
# UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
run_btn = gr.Button("Generate")
download = gr.File(label="Download video", interactive=False)
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative,
steps, guidance, num_frames, seed, fps],
outputs=[download, used_seed],
)
# queue tasks so users see the little task-queue progress bar
demo.queue().launch(server_name="0.0.0.0", server_port=7860)