File size: 5,499 Bytes
dacd25b
 
2b5109d
 
 
dacd25b
2c7ebd6
c078b58
2b5109d
29a7230
c078b58
 
2b5109d
 
 
 
 
 
 
 
 
 
 
dacd25b
64a6a24
2b5109d
 
dacd25b
 
2b5109d
dacd25b
 
 
2b5109d
 
 
c078b58
2b5109d
dacd25b
 
 
2b5109d
 
dacd25b
2b5109d
dacd25b
2c7ebd6
 
2b5109d
d8d26ca
2c7ebd6
2b5109d
 
 
64a6a24
2b5109d
 
 
 
 
dacd25b
 
2b5109d
 
dacd25b
2b5109d
b75a45c
 
2c7ebd6
 
64a6a24
2b5109d
 
 
64a6a24
d8d26ca
2b5109d
 
 
 
 
 
 
 
 
64a6a24
2b5109d
5516eb1
 
b75a45c
9c8f4c5
2b5109d
 
 
 
 
64a6a24
9c8f4c5
2b5109d
 
 
 
dacd25b
2c7ebd6
2b5109d
1c8aab2
 
dacd25b
 
 
f40229f
29a7230
5516eb1
2b5109d
 
 
 
 
 
2c7ebd6
dacd25b
2b5109d
 
 
2c7ebd6
c078b58
f6d3581
dacd25b
 
5516eb1
f6d3581
2b5109d
 
f6d3581
dacd25b
2b5109d
 
c078b58
2b5109d
 
f6d3581
2b5109d
 
 
64a6a24
 
dacd25b
2b5109d
 
 
dacd25b
 
2b5109d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#!/usr/bin/env python
"""
Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
Loads the huge model once, uses balanced device placement,
streams high-level progress, and auto-offers the .mp4 for download.
"""
import os
import numpy as np
import torch
import gradio as gr
from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
from diffusers.utils import export_to_video
from transformers import CLIPImageProcessor, CLIPVisionModel
from PIL import Image
import torchvision.transforms.functional as TF

# --------------------------------------------------------------------
# CONFIG
MODEL_ID     = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
DTYPE        = torch.float16              # half-precision
MAX_AREA     = 1280 * 720                 # ≤720p
DEFAULT_FRAMES = 81                       # ≈5s @16fps
# --------------------------------------------------------------------

def load_pipeline():
    # 1) image encoder in full precision
    image_encoder = CLIPVisionModel.from_pretrained(
        MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
    )
    # 2) VAE in reduced precision
    vae = AutoencoderKLWan.from_pretrained(
        MODEL_ID, subfolder="vae", torch_dtype=DTYPE
    )
    # 3) CLIPImageProcessor so we get the right class
    image_processor = CLIPImageProcessor.from_pretrained(
        MODEL_ID, subfolder="", torch_dtype=DTYPE
    )
    # 4) load everything with a balanced device map
    pipe = WanImageToVideoPipeline.from_pretrained(
        MODEL_ID,
        vae=vae,
        image_encoder=image_encoder,
        image_processor=image_processor,
        torch_dtype=DTYPE,
        device_map="balanced",           # splits weights CPU/GPU
    )
    return pipe

# load once at import
PIPE = load_pipeline()


# --------------------------------------------------------------------
# UTILS
def aspect_resize(img: Image.Image, max_area=MAX_AREA):
    """Resize while respecting multiples of the model’s patch size."""
    ar   = img.height / img.width
    mod  = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
    h    = round(np.sqrt(max_area * ar)) // mod * mod
    w    = round(np.sqrt(max_area / ar)) // mod * mod
    return img.resize((w, h), Image.LANCZOS), h, w

def center_crop_resize(img: Image.Image, h, w):
    """Crop-and-resize to exactly (h, w)."""
    ratio = max(w / img.width, h / img.height)
    img   = img.resize(
        (round(img.width * ratio), round(img.height * ratio)),
        Image.LANCZOS
    )
    return TF.center_crop(img, [h, w])


# --------------------------------------------------------------------
# GENERATE (with simple progress streaming)
def generate(
    first_frame: Image.Image,
    last_frame: Image.Image,
    prompt: str,
    negative_prompt: str,
    steps: int,
    guidance: float,
    num_frames: int,
    seed: int,
    fps: int,
    progress=gr.Progress(),   # gradio’s built-in progress callback
):
    # pick or set seed
    if seed == -1:
        seed = torch.seed()
    gen = torch.Generator(device=PIPE.device).manual_seed(seed)

    # 0→10%: resize
    progress(0.0, desc="Resizing first frame…")
    first_frame, h, w = aspect_resize(first_frame)
    if last_frame.size != first_frame.size:
        progress(0.1, desc="Resizing last frame…")
        last_frame = center_crop_resize(last_frame, h, w)

    # 10→20%: ready to run
    progress(0.2, desc="Starting video inference…")
    result = PIPE(
        image=first_frame,
        last_image=last_frame,
        prompt=prompt,
        negative_prompt=negative_prompt or None,
        height=h,
        width=w,
        num_frames=num_frames,
        num_inference_steps=steps,
        guidance_scale=guidance,
        generator=gen,
    )

    # 80→100%: export
    progress(0.8, desc="Assembling video file…")
    video_path = export_to_video(result.frames[0], fps=fps)
    progress(1.0, desc="Done!")

    # return path so gr.File offers immediate download, plus seed used
    return video_path, seed


# --------------------------------------------------------------------
# UI
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")

    with gr.Row():
        first_img = gr.Image(label="First frame", type="pil")
        last_img  = gr.Image(label="Last frame",  type="pil")

    prompt          = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
    negative        = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")

    with gr.Accordion("Advanced parameters", open=False):
        steps      = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
        guidance   = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
        num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
        fps        = gr.Slider(4, 30, value=16, step=1, label="FPS")
        seed       = gr.Number(value=-1, precision=0, label="Seed (-1=random)")

    run_btn     = gr.Button("Generate")
    download    = gr.File(label="Download video", interactive=False)
    used_seed   = gr.Number(label="Seed used", interactive=False)

    run_btn.click(
        fn=generate,
        inputs=[first_img, last_img, prompt, negative,
                steps, guidance, num_frames, seed, fps],
        outputs=[download, used_seed],
    )

# queue tasks so users see the little task-queue progress bar
demo.queue().launch(server_name="0.0.0.0", server_port=7860)