Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
7725ce2 verified
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
• Single global load (no repeated downloads)
• Balanced device_map to avoid OOM on 24 GB A10
• Fast CLIP processor via use_fast=True
• High-level streaming progress
• Auto-download via gr.File
"""
import os
# persist Hugging Face cache so safetensors only download once
os.environ["HF_HOME"] = "/mnt/data/huggingface"
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF
# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# -----------------------------------------------------------------------------
# LOAD PIPELINE ONCE
# -----------------------------------------------------------------------------
def load_pipeline():
# 1) CLIP image encoder (fp32)
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) VAE (fp16)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) Balanced device placement + fast processor
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
image_encoder=image_encoder,
vae=vae,
torch_dtype=DTYPE,
device_map="balanced", # spread weights CPU↔GPU
use_fast=True, # internal fast CLIPImageProcessor
)
return pipe
PIPE = load_pipeline()
# -----------------------------------------------------------------------------
# HELPERS
# -----------------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = int(np.sqrt(max_area * ar)) // mod * mod
w = int(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
ratio = max(w / img.width, h / img.height)
img2 = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img2, [h, w])
# -----------------------------------------------------------------------------
# GENERATION + STREAMING
# -----------------------------------------------------------------------------
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress= gr.Progress(),
):
# choose seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# 0–15%: resize
progress(0.0, desc="Resizing first frame…")
f_resized, h, w = aspect_resize(first_frame)
if last_frame.size != f_resized.size:
progress(0.15, desc="Resizing last frame…")
l_resized = center_crop_resize(last_frame, h, w)
else:
l_resized = f_resized
# 15–25%: spin up pipeline
progress(0.25, desc="Launching inference…")
out = PIPE(
image=f_resized,
last_image=l_resized,
prompt=prompt,
negative_prompt=negative or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
# 90–100%: export
progress(0.90, desc="Building video file…")
video_path = export_to_video(out.frames[0], fps=fps)
progress(1.0, desc="Done!")
return video_path, seed
# -----------------------------------------------------------------------------
# GRADIO UI
# -----------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (opt)", placeholder="blurry, lowres")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=rand)")
run_btn = gr.Button("Generate")
download = gr.File(label="Download .mp4", interactive=False)
seed_used = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[ first_img, last_img, prompt, negative,
steps, guidance, num_frames, seed_input, fps ],
outputs=[ download, seed_used ],
)
demo.queue().launch(server_name="0.0.0.0", server_port=7860)