Spaces:
Paused
Paused
File size: 5,586 Bytes
dacd25b 5516eb1 47b7da6 dacd25b 29a7230 dacd25b f40229f dacd25b 5516eb1 dacd25b 5516eb1 dacd25b 5516eb1 dacd25b 5516eb1 f40229f 47b7da6 dacd25b 5516eb1 29a7230 dacd25b 5516eb1 dacd25b 5516eb1 c83344b dacd25b f40229f 29a7230 5516eb1 dacd25b 5516eb1 29a7230 5516eb1 dacd25b 5516eb1 dacd25b 5516eb1 dacd25b 47b7da6 c83344b dacd25b f40229f 5516eb1 dacd25b 5516eb1 dacd25b 5516eb1 dacd25b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
– shows streaming status updates
– auto-downloads the generated video
Author: <your-handle>
"""
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel, CLIPImageProcessor
from PIL import Image
import torchvision.transforms.functional as TF
# ---------------------------------------------------------------------
# CONFIG ----------------------------------------------------------------
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE = torch.float16
MAX_AREA = 1280 * 720
DEFAULT_FRAMES = 81
# ----------------------------------------------------------------------
def load_pipeline():
"""Load & shard the pipeline across CPU/GPU with Accelerate."""
image_encoder = CLIPVisionModel.from_pretrained(
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
vae=vae,
image_encoder=image_encoder,
torch_dtype=DTYPE,
low_cpu_mem_usage=True, # lazy‐load to CPU RAM
device_map="balanced", # shard across CPU/GPU
)
# switch to the fast Rust processor
pipe.image_processor = CLIPImageProcessor.from_pretrained(
MODEL_ID, subfolder="image_processor", use_fast=True
)
return pipe
PIPE = load_pipeline()
# ----------------------------------------------------------------------
# UTILS ----------------------------------------------------------------
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
ar = img.height / img.width
mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
h = round(np.sqrt(max_area * ar)) // mod * mod
w = round(np.sqrt(max_area / ar)) // mod * mod
return img.resize((w, h), Image.LANCZOS), h, w
def center_crop_resize(img: Image.Image, h, w):
ratio = max(w / img.width, h / img.height)
img = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
return TF.center_crop(img, [h, w])
# ----------------------------------------------------------------------
# GENERATE (streaming) --------------------------------------------------
def generate(first_frame, last_frame, prompt, negative_prompt,
steps, guidance, num_frames, seed, fps):
# 1) Preprocess
yield None, None, "Preprocessing images..."
first_frame, h, w = aspect_resize(first_frame)
if last_frame.size != first_frame.size:
last_frame = center_crop_resize(last_frame, h, w)
# 2) Inference
yield None, None, f"Running inference ({steps} steps)..."
if seed == -1:
seed = torch.seed()
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
output = PIPE(
image=first_frame,
last_image=last_frame,
prompt=prompt,
negative_prompt=negative_prompt or None,
height=h,
width=w,
num_frames=num_frames,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
frames = output.frames[0]
# 3) Export
yield None, None, "Exporting video..."
video_path = export_to_video(frames, fps=fps)
# 4) Done
yield video_path, seed, "Done! Your browser will download the video."
# ----------------------------------------------------------------------
# UI --------------------------------------------------------------------
with gr.Blocks() as demo:
# inject JS for auto-download
gr.HTML("""
<script>
function downloadVideo() {
const container = document.getElementById('output_video');
if (!container) return;
const vid = container.querySelector('video');
if (!vid) return;
const src = vid.currentSrc;
const a = document.createElement('a');
a.href = src;
a.download = 'output.mp4';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
}
</script>
""")
gr.Markdown("## Wan 2.1 FLF2V – Streaming progress + auto-download")
with gr.Row():
first_img = gr.Image(label="First frame", type="pil")
last_img = gr.Image(label="Last frame", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
with gr.Accordion("Advanced parameters", open=False):
steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
run_btn = gr.Button("Generate")
status = gr.Textbox(label="Status", interactive=False)
video = gr.Video(label="Result", elem_id="output_video")
used_seed = gr.Number(label="Seed used", interactive=False)
run_btn.click(
fn=generate,
inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
outputs=[video, used_seed, status],
_js="downloadVideo"
)
demo.queue()
demo.launch() |