Spaces:
Paused
Paused
File size: 5,499 Bytes
dacd25b 2b5109d dacd25b 2c7ebd6 c078b58 2b5109d 29a7230 c078b58 2b5109d dacd25b 64a6a24 2b5109d dacd25b 2b5109d dacd25b 2b5109d c078b58 2b5109d dacd25b 2b5109d dacd25b 2b5109d dacd25b 2c7ebd6 2b5109d d8d26ca 2c7ebd6 2b5109d 64a6a24 2b5109d dacd25b 2b5109d dacd25b 2b5109d b75a45c 2c7ebd6 64a6a24 2b5109d 64a6a24 d8d26ca 2b5109d 64a6a24 2b5109d 5516eb1 b75a45c 9c8f4c5 2b5109d 64a6a24 9c8f4c5 2b5109d dacd25b 2c7ebd6 2b5109d 1c8aab2 dacd25b f40229f 29a7230 5516eb1 2b5109d 2c7ebd6 dacd25b 2b5109d 2c7ebd6 c078b58 f6d3581 dacd25b 5516eb1 f6d3581 2b5109d f6d3581 dacd25b 2b5109d c078b58 2b5109d f6d3581 2b5109d 64a6a24 dacd25b 2b5109d dacd25b 2b5109d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
Loads the huge model once, uses balanced device placement,
streams high-level progress, and auto-offers the .mp4 for download.
"""
import os
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF
# --------------------------------------------------------------------
# CONFIG
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16 # half-precision
MAX_AREA = 1280 * 720 # ≤720p
DEFAULT_FRAMES = 81 # ≈5s @16fps
# --------------------------------------------------------------------
def load_pipeline():
# 1) image encoder in full precision
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
# 2) VAE in reduced precision
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
# 3) CLIPImageProcessor so we get the right class
image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="", torch_dtype=DTYPE
)
# 4) load everything with a balanced device map
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
image_processor=image_processor,
torch_dtype=DTYPE,
device_map="balanced", # splits weights CPU/GPU
)
return pipe
# load once at import
PIPE = load_pipeline()
# --------------------------------------------------------------------
# UTILS
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
"""Resize while respecting multiples of the model’s patch size."""
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
"""Crop-and-resize to exactly (h, w)."""
ratio = max(w / img.width, h / img.height)
img = img.resize(
(round(img.width * ratio), round(img.height * ratio)),
Image.LANCZOS
)
return TF.center_crop(img, [h, w])
# --------------------------------------------------------------------
# GENERATE (with simple progress streaming)
def generate(
first_frame: Image.Image,
last_frame: Image.Image,
prompt: str,
negative_prompt: str,
steps: int,
guidance: float,
num_frames: int,
seed: int,
fps: int,
progress=gr.Progress(), # gradio’s built-in progress callback
):
# pick or set seed
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
# 0→10%: resize
progress(0.0, desc="Resizing first frame…")
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
progress(0.1, desc="Resizing last frame…")
last_frame = center_crop_resize(last_frame, h, w)
# 10→20%: ready to run
progress(0.2, desc="Starting video inference…")
result = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
# 80→100%: export
progress(0.8, desc="Assembling video file…")
video_path = export_to_video(result.frames[0], fps=fps)
progress(1.0, desc="Done!")
# return path so gr.File offers immediate download, plus seed used
return video_path, seed
# --------------------------------------------------------------------
# UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
run_btn = gr.Button("Generate")
download = gr.File(label="Download video", interactive=False)
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative,
steps, guidance, num_frames, seed, fps],
outputs=[download, used_seed],
)
# queue tasks so users see the little task-queue progress bar
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |