Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
f40229f verified
raw
history blame
5.31 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
Author: GeradeHouse
"""
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
import torchvision.transforms.functional as TF
# ---------------------------------------------------------------------
# CONFIG ----------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
DTYPE = torch.float16 # or bfloat16
MAX_AREA = 1280 * 720 # ≤720p
DEFAULT_FRAMES = 81 # ~5s @16 fps
# ----------------------------------------------------------------------
def load_pipeline():
"""Lazy‐load & configure the pipeline once per process."""
# 1) load the CLIP image encoder (full-precision)
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) load the VAE (half-precision)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) load the video pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=DTYPE,
)
# 4) override the processor with the fast Rust implementation
pipe.image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="image_processor", use_fast=True
)
# 5) memory helpers (offload UNet to CPU as needed)
# pipe.enable_model_cpu_offload()
# (Removed pipe.vae.enable_slicing() — not supported on AutoencoderKLWan)
return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
PIPE = load_pipeline()
# ----------------------------------------------------------------------
# UTILS ----------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
"""Resize while keeping aspect & respecting patch multiples."""
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
"""Center‐crop & resize to H×W."""
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# ----------------------------------------------------------------------
# GENERATE --------------------------------------------------------------
def generate(first_frame, last_frame, prompt, negative_prompt, steps,
guidance, num_frames, seed, fps):
# seed handling
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# preprocess frames
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
last_frame = center_crop_resize(last_frame, h, w)
# run the pipeline
output = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
frames = output.frames[0] # list of PIL Image
# export to MP4
video_path = export_to_video(frames, fps=fps)
return video_path, seed
# ----------------------------------------------------------------------
# UI --------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
video = gr.Video(label="Result (.mp4)")
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
outputs=[video, used_seed]
)
if __name__ == "__main__":
demo.launch()