Wan2.1-FLF2V / app.py
GeradeHouse's picture
Update app.py
b75a45c verified
raw
history blame
5.69 kB
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
"""
import os
# Persist HF cache on /mnt/data so it survives across launches
os.environ["HF_HOME"] = "/mnt/data/huggingface"
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF
# ----------------------------------------------------------------------
# CONFIG
# ----------------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# ----------------------------------------------------------------------
# PIPELINE LOADING
# ----------------------------------------------------------------------
def load_pipeline():
# 1) load CLIP image encoder in full precision
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) load VAE in reduced precision
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) load the WanImageToVideo pipeline, balanced across GPU/CPU
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=DTYPE,
device_map="balanced", # auto-offload large modules to CPU
)
# 4) reduce VAE peaks & enable CPU offload for everything else
pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload()
return pipe
# create once, at import time
PIPE = load_pipeline()
# ----------------------------------------------------------------------
# IMAGE PREPROCESSING UTILS
# ----------------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
# ensure multiple of patch size
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# ----------------------------------------------------------------------
# GENERATION FUNCTION
# ----------------------------------------------------------------------
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress= gr.Progress()
):
# seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# initial progress
progress(0, steps, desc="Preprocessing images")
# resize / crop
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
last_frame = center_crop_resize(last_frame, h, w)
# callback to update progress bar on each denoising step
def progress_callback(step, timestep, latents):
progress(step, steps, desc=f"Inference step {step}/{steps}")
# run the pipeline (streams progress via callback)
result = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
callback=progress_callback,
)
# assemble and export to video
frames = result.frames[0] # list of PIL images
video_path = export_to_video(frames, fps=fps)
# return video and seed used (Gradio will auto-download the .mp4)
return video_path, seed
# ----------------------------------------------------------------------
# GRADIO UI
# ----------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Number of frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
video_out = gr.Video(label="Result (.mp4)")
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[ first_img, last_img, prompt, negative,
steps, guidance, num_frames, seed, fps ],
outputs=[ video_out, used_seed ]
)
# no special queue args needed
demo.launch()