Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ Gradio demo for Wan2.1-FLF2V – First & Last Frame → Video
|
|
4 |
"""
|
5 |
|
6 |
import os
|
7 |
-
# Persist
|
8 |
os.environ["HF_HOME"] = "/mnt/data/huggingface"
|
9 |
|
10 |
import torch
|
@@ -12,7 +12,7 @@ import numpy as np
|
|
12 |
import gradio as gr
|
13 |
from PIL import Image
|
14 |
import torchvision.transforms.functional as TF
|
15 |
-
from transformers import CLIPVisionModel,
|
16 |
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
|
17 |
from diffusers.utils import export_to_video
|
18 |
|
@@ -20,7 +20,7 @@ from diffusers.utils import export_to_video
|
|
20 |
# CONFIGURATION
|
21 |
# -----------------------------------------------------------------------------
|
22 |
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
23 |
-
DTYPE = torch.float16
|
24 |
MAX_AREA = 1280 * 720
|
25 |
DEFAULT_FRAMES = 81
|
26 |
|
@@ -28,36 +28,36 @@ DEFAULT_FRAMES = 81
|
|
28 |
# PIPELINE LOADING (ONCE)
|
29 |
# -----------------------------------------------------------------------------
|
30 |
def load_pipeline():
|
31 |
-
# 1)
|
32 |
clip_encoder = CLIPVisionModel.from_pretrained(
|
33 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
34 |
)
|
35 |
-
# 2) VAE
|
36 |
vae = AutoencoderKLWan.from_pretrained(
|
37 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
38 |
)
|
39 |
-
# 3)
|
40 |
-
|
41 |
"openai/clip-vit-base-patch32", use_fast=True
|
42 |
)
|
43 |
-
# 4)
|
44 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
45 |
MODEL_ID,
|
46 |
image_encoder=clip_encoder,
|
47 |
vae=vae,
|
48 |
-
image_processor=
|
49 |
torch_dtype=DTYPE,
|
50 |
device_map="balanced",
|
51 |
)
|
52 |
-
# 5)
|
53 |
try:
|
54 |
pipe.vae.enable_slicing()
|
55 |
except AttributeError:
|
56 |
pass
|
57 |
-
pipe.enable_model_cpu_offload()
|
58 |
return pipe
|
59 |
|
60 |
-
|
|
|
61 |
|
62 |
# -----------------------------------------------------------------------------
|
63 |
# IMAGE RESIZE HELPERS
|
@@ -78,37 +78,37 @@ def center_crop_resize(img: Image.Image, h: int, w: int):
|
|
78 |
return TF.center_crop(img, [h, w])
|
79 |
|
80 |
# -----------------------------------------------------------------------------
|
81 |
-
# GENERATION (STREAMING
|
82 |
# -----------------------------------------------------------------------------
|
83 |
def generate(
|
84 |
-
first_frame:
|
85 |
-
last_frame:
|
86 |
-
prompt:
|
87 |
-
negative:
|
88 |
-
steps:
|
89 |
-
guidance:
|
90 |
-
num_frames:
|
91 |
-
seed:
|
92 |
-
fps:
|
93 |
-
progress=
|
94 |
):
|
95 |
-
# Seed
|
96 |
if seed == -1:
|
97 |
seed = torch.seed()
|
98 |
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
|
99 |
|
100 |
-
# Preprocessing
|
101 |
progress(0, steps, desc="Preprocessing images")
|
102 |
f0, h, w = aspect_resize(first_frame)
|
103 |
if last_frame.size != f0.size:
|
104 |
last_frame = center_crop_resize(last_frame, h, w)
|
105 |
|
106 |
-
#
|
107 |
def cb(step, timestep, latents):
|
108 |
progress(step, steps, desc=f"Inference step {step}/{steps}")
|
109 |
|
110 |
-
# Run pipeline
|
111 |
-
|
112 |
image=f0,
|
113 |
last_image=last_frame,
|
114 |
prompt=prompt,
|
@@ -122,12 +122,12 @@ def generate(
|
|
122 |
callback=cb
|
123 |
)
|
124 |
|
125 |
-
# Export
|
126 |
-
video_path = export_to_video(
|
127 |
return video_path, seed
|
128 |
|
129 |
# -----------------------------------------------------------------------------
|
130 |
-
# GRADIO
|
131 |
# -----------------------------------------------------------------------------
|
132 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
133 |
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
|
|
|
4 |
"""
|
5 |
|
6 |
import os
|
7 |
+
# Persist HF cache between launches
|
8 |
os.environ["HF_HOME"] = "/mnt/data/huggingface"
|
9 |
|
10 |
import torch
|
|
|
12 |
import gradio as gr
|
13 |
from PIL import Image
|
14 |
import torchvision.transforms.functional as TF
|
15 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor
|
16 |
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
|
17 |
from diffusers.utils import export_to_video
|
18 |
|
|
|
20 |
# CONFIGURATION
|
21 |
# -----------------------------------------------------------------------------
|
22 |
MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
|
23 |
+
DTYPE = torch.float16
|
24 |
MAX_AREA = 1280 * 720
|
25 |
DEFAULT_FRAMES = 81
|
26 |
|
|
|
28 |
# PIPELINE LOADING (ONCE)
|
29 |
# -----------------------------------------------------------------------------
|
30 |
def load_pipeline():
|
31 |
+
# 1) Vision encoder (fp32)
|
32 |
clip_encoder = CLIPVisionModel.from_pretrained(
|
33 |
MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
|
34 |
)
|
35 |
+
# 2) VAE (reduced precision)
|
36 |
vae = AutoencoderKLWan.from_pretrained(
|
37 |
MODEL_ID, subfolder="vae", torch_dtype=DTYPE
|
38 |
)
|
39 |
+
# 3) CLIPImageProcessor (exactly the type Wan expects)
|
40 |
+
img_processor = CLIPImageProcessor.from_pretrained(
|
41 |
"openai/clip-vit-base-patch32", use_fast=True
|
42 |
)
|
43 |
+
# 4) Load the Wan‐to‐Video pipeline, balanced across GPU & CPU
|
44 |
pipe = WanImageToVideoPipeline.from_pretrained(
|
45 |
MODEL_ID,
|
46 |
image_encoder=clip_encoder,
|
47 |
vae=vae,
|
48 |
+
image_processor=img_processor,
|
49 |
torch_dtype=DTYPE,
|
50 |
device_map="balanced",
|
51 |
)
|
52 |
+
# 5) Slice the VAE to cut VRAM spikes
|
53 |
try:
|
54 |
pipe.vae.enable_slicing()
|
55 |
except AttributeError:
|
56 |
pass
|
|
|
57 |
return pipe
|
58 |
|
59 |
+
# instantiate once
|
60 |
+
PIPE = load_pipeline()
|
61 |
|
62 |
# -----------------------------------------------------------------------------
|
63 |
# IMAGE RESIZE HELPERS
|
|
|
78 |
return TF.center_crop(img, [h, w])
|
79 |
|
80 |
# -----------------------------------------------------------------------------
|
81 |
+
# GENERATION (STREAMING)
|
82 |
# -----------------------------------------------------------------------------
|
83 |
def generate(
|
84 |
+
first_frame: Image.Image,
|
85 |
+
last_frame: Image.Image,
|
86 |
+
prompt: str,
|
87 |
+
negative: str,
|
88 |
+
steps: int,
|
89 |
+
guidance: float,
|
90 |
+
num_frames: int,
|
91 |
+
seed: int,
|
92 |
+
fps: int,
|
93 |
+
progress= gr.Progress()
|
94 |
):
|
95 |
+
# Seed management
|
96 |
if seed == -1:
|
97 |
seed = torch.seed()
|
98 |
gen = torch.Generator(device=PIPE.device).manual_seed(seed)
|
99 |
|
100 |
+
# Preprocessing update
|
101 |
progress(0, steps, desc="Preprocessing images")
|
102 |
f0, h, w = aspect_resize(first_frame)
|
103 |
if last_frame.size != f0.size:
|
104 |
last_frame = center_crop_resize(last_frame, h, w)
|
105 |
|
106 |
+
# Step callback
|
107 |
def cb(step, timestep, latents):
|
108 |
progress(step, steps, desc=f"Inference step {step}/{steps}")
|
109 |
|
110 |
+
# Run the pipeline
|
111 |
+
out = PIPE(
|
112 |
image=f0,
|
113 |
last_image=last_frame,
|
114 |
prompt=prompt,
|
|
|
122 |
callback=cb
|
123 |
)
|
124 |
|
125 |
+
# Export video
|
126 |
+
video_path = export_to_video(out.frames[0], fps=fps)
|
127 |
return video_path, seed
|
128 |
|
129 |
# -----------------------------------------------------------------------------
|
130 |
+
# GRADIO APP
|
131 |
# -----------------------------------------------------------------------------
|
132 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
133 |
gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
|