Spaces:
Paused
Paused
#!/usr/bin/env python | |
""" | |
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V) | |
Loads the huge model once, uses balanced device placement, | |
streams high-level progress, and auto-offers the .mp4 for download. | |
""" | |
import os | |
import numpy as np | |
import torch | |
import gradio as gr | |
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan | |
from diffusers.utils import export_to_video | |
from transformers import CLIPImageProcessor, CLIPVisionModel | |
from PIL import Image | |
import torchvision.transforms.functional as TF | |
# -------------------------------------------------------------------- | |
# CONFIG | |
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" | |
DTYPE = torch.float16 # half-precision | |
MAX_AREA = 1280 * 720 # ≤720p | |
DEFAULT_FRAMES = 81 # ≈5s @16fps | |
# -------------------------------------------------------------------- | |
def load_pipeline(): | |
# 1) image encoder in full precision | |
image_encoder = CLIPVisionModel.from_pretrained( | |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32 | |
) | |
# 2) VAE in reduced precision | |
vae = AutoencoderKLWan.from_pretrained( | |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE | |
) | |
# 3) CLIPImageProcessor so we get the right class | |
image_processor = CLIPImageProcessor.from_pretrained( | |
MODEL_ID, subfolder="", torch_dtype=DTYPE | |
) | |
# 4) load everything with a balanced device map | |
pipe = WanImageToVideoPipeline.from_pretrained( | |
MODEL_ID, | |
vae=vae, | |
image_encoder=image_encoder, | |
image_processor=image_processor, | |
torch_dtype=DTYPE, | |
device_map="balanced", # splits weights CPU/GPU | |
) | |
return pipe | |
# load once at import | |
PIPE = load_pipeline() | |
# -------------------------------------------------------------------- | |
# UTILS | |
def aspect_resize(img: Image.Image, max_area=MAX_AREA): | |
"""Resize while respecting multiples of the model’s patch size.""" | |
ar = img.height / img.width | |
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1] | |
h = round(np.sqrt(max_area * ar)) // mod * mod | |
w = round(np.sqrt(max_area / ar)) // mod * mod | |
return img.resize((w, h), Image.LANCZOS), h, w | |
def center_crop_resize(img: Image.Image, h, w): | |
"""Crop-and-resize to exactly (h, w).""" | |
ratio = max(w / img.width, h / img.height) | |
img = img.resize( | |
(round(img.width * ratio), round(img.height * ratio)), | |
Image.LANCZOS | |
) | |
return TF.center_crop(img, [h, w]) | |
# -------------------------------------------------------------------- | |
# GENERATE (with simple progress streaming) | |
def generate( | |
first_frame: Image.Image, | |
last_frame: Image.Image, | |
prompt: str, | |
negative_prompt: str, | |
steps: int, | |
guidance: float, | |
num_frames: int, | |
seed: int, | |
fps: int, | |
progress=gr.Progress(), # gradio’s built-in progress callback | |
): | |
# pick or set seed | |
if seed == -1: | |
seed = torch.seed() | |
gen = torch.Generator(device=PIPE.device).manual_seed(seed) | |
# 0→10%: resize | |
progress(0.0, desc="Resizing first frame…") | |
first_frame, h, w = aspect_resize(first_frame) | |
if last_frame.size != first_frame.size: | |
progress(0.1, desc="Resizing last frame…") | |
last_frame = center_crop_resize(last_frame, h, w) | |
# 10→20%: ready to run | |
progress(0.2, desc="Starting video inference…") | |
result = PIPE( | |
image=first_frame, | |
last_image=last_frame, | |
prompt=prompt, | |
negative_prompt=negative_prompt or None, | |
height=h, | |
width=w, | |
num_frames=num_frames, | |
num_inference_steps=steps, | |
guidance_scale=guidance, | |
generator=gen, | |
) | |
# 80→100%: export | |
progress(0.8, desc="Assembling video file…") | |
video_path = export_to_video(result.frames[0], fps=fps) | |
progress(1.0, desc="Done!") | |
# return path so gr.File offers immediate download, plus seed used | |
return video_path, seed | |
# -------------------------------------------------------------------- | |
# UI | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video") | |
with gr.Row(): | |
first_img = gr.Image(label="First frame", type="pil") | |
last_img = gr.Image(label="Last frame", type="pil") | |
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…") | |
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry") | |
with gr.Accordion("Advanced parameters", open=False): | |
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps") | |
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale") | |
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames") | |
fps = gr.Slider(4, 30, value=16, step=1, label="FPS") | |
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)") | |
run_btn = gr.Button("Generate") | |
download = gr.File(label="Download video", interactive=False) | |
used_seed = gr.Number(label="Seed used", interactive=False) | |
run_btn.click( | |
fn=generate, | |
inputs=[first_img, last_img, prompt, negative, | |
steps, guidance, num_frames, seed, fps], | |
outputs=[download, used_seed], | |
) | |
# queue tasks so users see the little task-queue progress bar | |
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |