#!/usr/bin/env python """ Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video """ import os # Persist HF cache on /mnt/data so it survives across launches os.environ["HF_HOME"] = "/mnt/data/huggingface" import numpy as np import torch import gradio as gr from diffusers import WanImageToVideoPipeline, AutoencoderKLWan from diffusers.utils import export_to_video from transformers import CLIPVisionModel from PIL import Image import torchvision.transforms.functional as TF # ---------------------------------------------------------------------- # CONFIG # ---------------------------------------------------------------------- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" DTYPE = torch.float16 MAX_AREA = 1280 * 720 DEFAULT_FRAMES = 81 # ---------------------------------------------------------------------- # PIPELINE LOADING # ---------------------------------------------------------------------- def load_pipeline(): # 1) load CLIP image encoder in full precision image_encoder = CLIPVisionModel.from_pretrained( MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32 ) # 2) load VAE in reduced precision vae = AutoencoderKLWan.from_pretrained( MODEL_ID, subfolder="vae", torch_dtype=DTYPE ) # 3) load the WanImageToVideo pipeline, balanced across GPU/CPU pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=DTYPE, device_map="balanced", # auto-offload large modules to CPU ) # 4) reduce VAE peaks & enable CPU offload for everything else pipe.enable_vae_slicing() pipe.enable_model_cpu_offload() return pipe # create once, at import time PIPE = load_pipeline() # ---------------------------------------------------------------------- # IMAGE PREPROCESSING UTILS # ---------------------------------------------------------------------- def aspect_resize(img: Image.Image, max_area=MAX_AREA): ar = img.height / img.width # ensure multiple of patch size mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1] h = round(np.sqrt(max_area * ar)) // mod * mod w = round(np.sqrt(max_area / ar)) // mod * mod return img.resize((w, h), Image.LANCZOS), h, w def center_crop_resize(img: Image.Image, h, w): ratio = max(w / img.width, h / img.height) img = img.resize( (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS ) return TF.center_crop(img, [h, w]) # ---------------------------------------------------------------------- # GENERATION FUNCTION # ---------------------------------------------------------------------- def generate( first_frame: Image.Image, last_frame: Image.Image, prompt: str, negative: str, steps: int, guidance: float, num_frames: int, seed: int, fps: int, progress= gr.Progress() ): # seed if seed == -1: seed = torch.seed() gen = torch.Generator(device=PIPE.device).manual_seed(seed) # initial progress progress(0, steps, desc="Preprocessing images") # resize / crop first_frame, h, w = aspect_resize(first_frame) if last_frame.size != first_frame.size: last_frame = center_crop_resize(last_frame, h, w) # callback to update progress bar on each denoising step def progress_callback(step, timestep, latents): progress(step, steps, desc=f"Inference step {step}/{steps}") # run the pipeline (streams progress via callback) result = PIPE( image=first_frame, last_image=last_frame, prompt=prompt, negative_prompt=negative or None, height=h, width=w, num_frames=num_frames, num_inference_steps=steps, guidance_scale=guidance, generator=gen, callback=progress_callback, ) # assemble and export to video frames = result.frames[0] # list of PIL images video_path = export_to_video(frames, fps=fps) # return video and seed used (Gradio will auto-download the .mp4) return video_path, seed # ---------------------------------------------------------------------- # GRADIO UI # ---------------------------------------------------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video") with gr.Row(): first_img = gr.Image(label="First frame", type="pil") last_img = gr.Image(label="Last frame", type="pil") prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…") negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry") with gr.Accordion("Advanced parameters", open=False): steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps") guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale") num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Number of frames") fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)") seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)") run_btn = gr.Button("Generate") video_out = gr.Video(label="Result (.mp4)") used_seed = gr.Number(label="Seed used", interactive=False) run_btn.click( fn=generate, inputs=[ first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps ], outputs=[ video_out, used_seed ] ) # no special queue args needed demo.launch()