Spaces:
Paused
Paused
File size: 5,306 Bytes
dacd25b f40229f dacd25b 29a7230 dacd25b f40229f dacd25b f40229f dacd25b f40229f dacd25b f40229f dacd25b f40229f dacd25b f40229f 29a7230 dacd25b f40229f dacd25b f40229f dacd25b 29a7230 f40229f 29a7230 dacd25b f40229f dacd25b f40229f dacd25b f40229f dacd25b c83344b dacd25b f40229f 29a7230 f40229f dacd25b f40229f 29a7230 dacd25b f40229f c83344b dacd25b f40229f dacd25b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
Author: GeradeHouse
"""
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
import torchvision.transforms.functional as TF
# ---------------------------------------------------------------------
# CONFIG ----------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
DTYPE = torch.float16 # or bfloat16
MAX_AREA = 1280 * 720 # ≤720p
DEFAULT_FRAMES = 81 # ~5s @16 fps
# ----------------------------------------------------------------------
def load_pipeline():
"""Lazy‐load & configure the pipeline once per process."""
# 1) load the CLIP image encoder (full-precision)
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) load the VAE (half-precision)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) load the video pipeline
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=DTYPE,
)
# 4) override the processor with the fast Rust implementation
pipe.image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="image_processor", use_fast=True
)
# 5) memory helpers (offload UNet to CPU as needed)
# pipe.enable_model_cpu_offload()
# (Removed pipe.vae.enable_slicing() — not supported on AutoencoderKLWan)
return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
PIPE = load_pipeline()
# ----------------------------------------------------------------------
# UTILS ----------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
"""Resize while keeping aspect & respecting patch multiples."""
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
"""Center‐crop & resize to H×W."""
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# ----------------------------------------------------------------------
# GENERATE --------------------------------------------------------------
def generate(first_frame, last_frame, prompt, negative_prompt, steps,
guidance, num_frames, seed, fps):
# seed handling
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# preprocess frames
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
last_frame = center_crop_resize(last_frame, h, w)
# run the pipeline
output = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
frames = output.frames[0] # list of PIL Image
# export to MP4
video_path = export_to_video(frames, fps=fps)
return video_path, seed
# ----------------------------------------------------------------------
# UI --------------------------------------------------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
video = gr.Video(label="Result (.mp4)")
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
outputs=[video, used_seed]
)
if __name__ == "__main__":
demo.launch() |