#!/usr/bin/env python """ Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V) Loads the huge model once, uses balanced device placement, streams high-level progress, and auto-offers the .mp4 for download. """ import os import numpy as np import torch import gradio as gr from diffusers import WanImageToVideoPipeline, AutoencoderKLWan from diffusers.utils import export_to_video from transformers import CLIPImageProcessor, CLIPVisionModel from PIL import Image import torchvision.transforms.functional as TF # -------------------------------------------------------------------- # CONFIG MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" DTYPE = torch.float16 # half-precision MAX_AREA = 1280 * 720 # ≤720p DEFAULT_FRAMES = 81 # ≈5s @16fps # -------------------------------------------------------------------- def load_pipeline(): # 1) image encoder in full precision image_encoder = CLIPVisionModel.from_pretrained( MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32 ) # 2) VAE in reduced precision vae = AutoencoderKLWan.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=DTYPE ) # 3) CLIPImageProcessor so we get the right class image_processor = CLIPImageProcessor.from_pretrained( MODEL_ID, subfolder="", torch_dtype=DTYPE ) # 4) load everything with a balanced device map pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, vae=vae, image_encoder=image_encoder, image_processor=image_processor, torch_dtype=DTYPE, device_map="balanced", # splits weights CPU/GPU ) return pipe # load once at import PIPE = load_pipeline() # -------------------------------------------------------------------- # UTILS def aspect_resize(img: Image.Image, max_area=MAX_AREA): """Resize while respecting multiples of the model’s patch size.""" ar = img.height / img.width mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1] h = round(np.sqrt(max_area * ar)) // mod * mod w = round(np.sqrt(max_area / ar)) // mod * mod return img.resize((w, h), Image.LANCZOS), h, w def center_crop_resize(img: Image.Image, h, w): """Crop-and-resize to exactly (h, w).""" ratio = max(w / img.width, h / img.height) img = img.resize( (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS ) return TF.center_crop(img, [h, w]) # -------------------------------------------------------------------- # GENERATE (with simple progress streaming) def generate( first_frame: Image.Image, last_frame: Image.Image, prompt: str, negative_prompt: str, steps: int, guidance: float, num_frames: int, seed: int, fps: int, progress=gr.Progress(), # gradio’s built-in progress callback ): # pick or set seed if seed == -1: seed = torch.seed() gen = torch.Generator(device=PIPE.device).manual_seed(seed) # 0→10%: resize progress(0.0, desc="Resizing first frame…") first_frame, h, w = aspect_resize(first_frame) if last_frame.size != first_frame.size: progress(0.1, desc="Resizing last frame…") last_frame = center_crop_resize(last_frame, h, w) # 10→20%: ready to run progress(0.2, desc="Starting video inference…") result = PIPE( image=first_frame, last_image=last_frame, prompt=prompt, negative_prompt=negative_prompt or None, height=h, width=w, num_frames=num_frames, num_inference_steps=steps, guidance_scale=guidance, generator=gen, ) # 80→100%: export progress(0.8, desc="Assembling video file…") video_path = export_to_video(result.frames[0], fps=fps) progress(1.0, desc="Done!") # return path so gr.File offers immediate download, plus seed used return video_path, seed # -------------------------------------------------------------------- # UI with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video") with gr.Row(): first_img = gr.Image(label="First frame", type="pil") last_img = gr.Image(label="Last frame", type="pil") prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…") negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry") with gr.Accordion("Advanced parameters", open=False): steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps") guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale") num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames") fps = gr.Slider(4, 30, value=16, step=1, label="FPS") seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)") run_btn = gr.Button("Generate") download = gr.File(label="Download video", interactive=False) used_seed = gr.Number(label="Seed used", interactive=False) run_btn.click( fn=generate, inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps], outputs=[download, used_seed], ) # queue tasks so users see the little task-queue progress bar demo.queue().launch(server_name="0.0.0.0", server_port=7860)