Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
d8d26ca verified
raw
history blame
5.55 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1-FLF2V – First & Last Frame → Video
"""
import os
# Persist HF cache between launches
os.environ["HF_HOME"] = "/mnt/data/huggingface"
import torch
import numpy as np
import gradio as gr
from PIL import Image
import torchvision.transforms.functional as TF
from transformers import CLIPVisionModel, CLIPImageProcessor
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
# -----------------------------------------------------------------------------
# CONFIGURATION
# -----------------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# -----------------------------------------------------------------------------
# PIPELINE LOADING (ONCE)
# -----------------------------------------------------------------------------
def load_pipeline():
# 1) Vision encoder (fp32)
clip_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) VAE (reduced precision)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) CLIPImageProcessor (exactly the type Wan expects)
img_processor = CLIPImageProcessor.from_pretrained(
"openai/clip-vit-base-patch32", use_fast=True
)
# 4) Load the Wan‐to‐Video pipeline, balanced across GPU & CPU
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
image_encoder=clip_encoder,
vae=vae,
image_processor=img_processor,
torch_dtype=DTYPE,
device_map="balanced",
)
# 5) Slice the VAE to cut VRAM spikes
try:
pipe.vae.enable_slicing()
except AttributeError:
pass
return pipe
# instantiate once
PIPE = load_pipeline()
# -----------------------------------------------------------------------------
# IMAGE RESIZE HELPERS
# -----------------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = int(np.sqrt(max_area * ar)) // mod * mod
w = int(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h: int, w: int):
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# -----------------------------------------------------------------------------
# GENERATION (STREAMING)
# -----------------------------------------------------------------------------
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress= gr.Progress()
):
# Seed management
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# Preprocessing update
progress(0, steps, desc="Preprocessing images")
f0, h, w = aspect_resize(first_frame)
if last_frame.size != f0.size:
last_frame = center_crop_resize(last_frame, h, w)
# Step callback
def cb(step, timestep, latents):
progress(step, steps, desc=f"Inference step {step}/{steps}")
# Run the pipeline
out = PIPE(
image=f0,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
callback=cb
)
# Export video
video_path = export_to_video(out.frames[0], fps=fps)
return video_path, seed
# -----------------------------------------------------------------------------
# GRADIO APP
# -----------------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt_box = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative_box = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed_input = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
video_out = gr.Video(label="Result (.mp4)")
seed_out = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[ first_img, last_img, prompt_box, negative_box,
steps, guidance, num_frames, seed_input, fps ],
outputs=[ video_out, seed_out ]
)
demo.launch()