Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
f6d3581 verified
raw
history blame
5.44 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
Streams all HF-Hub & Diffusers tqdm bars, caches the model,
and provides a direct download link for the MP4.
"""
import ftfy
import numpy as np
import torch
import gradio as gr
from PIL import Image
from transformers import CLIPVisionModel, CLIPImageProcessor
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
import torchvision.transforms.functional as TF
# -----------------------------------------------------------------------------
# CONFIG
# -----------------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# -----------------------------------------------------------------------------
# GLOBAL CACHED PIPELINE
# -----------------------------------------------------------------------------
PIPE = None
def load_pipeline():
"""Load & cache the pipeline (once)."""
# 1) CLIP vision encoder (fp32)
vision = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) fast processor
processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="image_processor", use_fast=True
)
# 3) VAE (half precision)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 4) pipeline assembly
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=vision,
image_processor=processor,
torch_dtype=DTYPE,
)
# 5) CPU offload for large models
pipe.enable_model_cpu_offload()
return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# IMAGE RESIZE HELPERS
# -----------------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.transformer.config.patch_size[1] * PIPE.vae_scale_factor_spatial
h = (int(np.sqrt(max_area * ar)) // mod) * mod
w = (int(np.sqrt(max_area / ar)) // mod) * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h: int, w: int):
ratio = max(w / img.width, h / img.height)
img2 = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
return TF.center_crop(img2, [h, w])
# -----------------------------------------------------------------------------
# GENERATION (stream all tqdm → Gradio)
# -----------------------------------------------------------------------------
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress=gr.Progress(track_tqdm=True),
):
global PIPE
# lazy load
if PIPE is None:
progress(0, desc="Loading model…")
PIPE = load_pipeline()
# seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# preprocess
progress(0, desc="Preprocessing…")
frame1, h, w = aspect_resize(first_frame)
if last_frame.size != frame1.size:
last_frame = center_crop_resize(last_frame, h, w)
# inference (all tqdm bars appear in progress)
result = PIPE(
image=frame1,
last_image=last_frame,
prompt=ftfy.fix_text(prompt),
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
frames = result.frames[0]
# export
progress(1.0, desc="Exporting video…")
out_path = export_to_video(frames, fps=fps)
return out_path, seed
# -----------------------------------------------------------------------------
# GRADIO UI
# -----------------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt")
negative = gr.Textbox(label="Negative prompt (optional)")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1,label="Guidance")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1,label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed = gr.Number(value=-1, precision=0, label="Seed")
run_btn = gr.Button("Generate")
download = gr.File(label="Download video (.mp4)")
used_seed= gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
outputs=[download, used_seed],
concurrency_limit=1
)
# enable progress streaming
demo.queue().launch()