GeradeHouse commited on
Commit
c078b58
·
verified ·
1 Parent(s): b75a45c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -68
app.py CHANGED
@@ -1,69 +1,77 @@
1
  #!/usr/bin/env python
2
  """
3
- Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
  """
5
 
6
  import os
7
- # Persist HF cache on /mnt/data so it survives across launches
 
8
  os.environ["HF_HOME"] = "/mnt/data/huggingface"
9
 
10
- import numpy as np
11
  import torch
 
12
  import gradio as gr
13
- from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
14
- from diffusers.utils import export_to_video
15
- from transformers import CLIPVisionModel
16
  from PIL import Image
17
  import torchvision.transforms.functional as TF
 
 
 
18
 
19
- # ----------------------------------------------------------------------
20
  # CONFIG
21
- # ----------------------------------------------------------------------
22
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
- DTYPE = torch.float16
24
- MAX_AREA = 1280 * 720
25
- DEFAULT_FRAMES = 81
26
 
27
- # ----------------------------------------------------------------------
28
- # PIPELINE LOADING
29
- # ----------------------------------------------------------------------
30
  def load_pipeline():
31
- # 1) load CLIP image encoder in full precision
32
- image_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
- # 2) load VAE in reduced precision
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
- # 3) load the WanImageToVideo pipeline, balanced across GPU/CPU
 
 
 
 
 
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
 
42
  vae=vae,
43
- image_encoder=image_encoder,
44
  torch_dtype=DTYPE,
45
- device_map="balanced", # auto-offload large modules to CPU
46
  )
47
- # 4) reduce VAE peaks & enable CPU offload for everything else
48
- pipe.enable_vae_slicing()
 
 
 
49
  pipe.enable_model_cpu_offload()
50
  return pipe
51
 
52
- # create once, at import time
53
  PIPE = load_pipeline()
54
 
55
- # ----------------------------------------------------------------------
56
- # IMAGE PREPROCESSING UTILS
57
- # ----------------------------------------------------------------------
58
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
59
  ar = img.height / img.width
60
- # ensure multiple of patch size
61
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
62
- h = round(np.sqrt(max_area * ar)) // mod * mod
63
- w = round(np.sqrt(max_area / ar)) // mod * mod
64
  return img.resize((w, h), Image.LANCZOS), h, w
65
 
66
- def center_crop_resize(img: Image.Image, h, w):
67
  ratio = max(w / img.width, h / img.height)
68
  img = img.resize(
69
  (round(img.width * ratio), round(img.height * ratio)),
@@ -71,9 +79,9 @@ def center_crop_resize(img: Image.Image, h, w):
71
  )
72
  return TF.center_crop(img, [h, w])
73
 
74
- # ----------------------------------------------------------------------
75
- # GENERATION FUNCTION
76
- # ----------------------------------------------------------------------
77
  def generate(
78
  first_frame: Image.Image,
79
  last_frame: Image.Image,
@@ -86,26 +94,24 @@ def generate(
86
  fps: int,
87
  progress= gr.Progress()
88
  ):
89
- # seed
90
  if seed == -1:
91
  seed = torch.seed()
92
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
93
 
94
- # initial progress
95
  progress(0, steps, desc="Preprocessing images")
96
-
97
- # resize / crop
98
- first_frame, h, w = aspect_resize(first_frame)
99
- if last_frame.size != first_frame.size:
100
  last_frame = center_crop_resize(last_frame, h, w)
101
 
102
- # callback to update progress bar on each denoising step
103
- def progress_callback(step, timestep, latents):
104
  progress(step, steps, desc=f"Inference step {step}/{steps}")
105
 
106
- # run the pipeline (streams progress via callback)
107
- result = PIPE(
108
- image=first_frame,
109
  last_image=last_frame,
110
  prompt=prompt,
111
  negative_prompt=negative or None,
@@ -115,46 +121,43 @@ def generate(
115
  num_inference_steps=steps,
116
  guidance_scale=guidance,
117
  generator=gen,
118
- callback=progress_callback,
119
  )
120
 
121
- # assemble and export to video
122
- frames = result.frames[0] # list of PIL images
123
- video_path = export_to_video(frames, fps=fps)
124
-
125
- # return video and seed used (Gradio will auto-download the .mp4)
126
  return video_path, seed
127
 
128
- # ----------------------------------------------------------------------
129
  # GRADIO UI
130
- # ----------------------------------------------------------------------
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
- gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
133
 
134
  with gr.Row():
135
  first_img = gr.Image(label="First frame", type="pil")
136
  last_img = gr.Image(label="Last frame", type="pil")
137
 
138
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
139
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
140
 
141
  with gr.Accordion("Advanced parameters", open=False):
142
- steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
143
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
144
- num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Number of frames")
145
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
146
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
147
 
148
- run_btn = gr.Button("Generate")
149
- video_out = gr.Video(label="Result (.mp4)")
150
- used_seed = gr.Number(label="Seed used", interactive=False)
151
 
152
  run_btn.click(
153
  fn=generate,
154
- inputs=[ first_img, last_img, prompt, negative,
155
- steps, guidance, num_frames, seed, fps ],
156
- outputs=[ video_out, used_seed ]
157
  )
158
 
159
- # no special queue args needed
160
- demo.launch()
 
1
  #!/usr/bin/env python
2
  """
3
+ Gradio demo for Wan2.1-FLF2V – First & Last Frame → Video
4
  """
5
 
6
  import os
7
+
8
+ # Persist HF cache across runs
9
  os.environ["HF_HOME"] = "/mnt/data/huggingface"
10
 
 
11
  import torch
12
+ import numpy as np
13
  import gradio as gr
 
 
 
14
  from PIL import Image
15
  import torchvision.transforms.functional as TF
16
+ from transformers import CLIPVisionModel, CLIPImageProcessor
17
+ from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
18
+ from diffusers.utils import export_to_video
19
 
20
+ # -----------------------------------------------------------------------------
21
  # CONFIG
22
+ # -----------------------------------------------------------------------------
23
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
24
+ DTYPE = torch.float16 # use bfloat16 if your GPU supports AMP
25
+ MAX_AREA = 1280 * 720 # cap at 720p
26
+ DEFAULT_FRAMES = 81 # ~5s at 16fps
27
 
28
+ # -----------------------------------------------------------------------------
29
+ # LOAD & CACHE PIPELINE (once)
30
+ # -----------------------------------------------------------------------------
31
  def load_pipeline():
32
+ # 1) CLIP vision encoder (fp32)
33
+ clip_encoder = CLIPVisionModel.from_pretrained(
34
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
35
  )
36
+ # 2) VAE in reduced precision
37
  vae = AutoencoderKLWan.from_pretrained(
38
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
39
  )
40
+ # 3) Standard CLIPImageProcessor (needs its own config files)
41
+ clip_processor = CLIPImageProcessor.from_pretrained(
42
+ "openai/clip-vit-base-patch32", # uses a known CLIP repo
43
+ use_fast=True
44
+ )
45
+ # 4) Build the Wan‐to‐video pipeline, balanced across GPU/CPU
46
  pipe = WanImageToVideoPipeline.from_pretrained(
47
  MODEL_ID,
48
+ image_encoder=clip_encoder,
49
  vae=vae,
50
+ image_processor=clip_processor,
51
  torch_dtype=DTYPE,
52
+ device_map="balanced", # autooffload large submodules to CPU
53
  )
54
+ # 5) Slice VAE & offload rest
55
+ try:
56
+ pipe.vae.enable_slicing()
57
+ except AttributeError:
58
+ pass
59
  pipe.enable_model_cpu_offload()
60
  return pipe
61
 
 
62
  PIPE = load_pipeline()
63
 
64
+ # -----------------------------------------------------------------------------
65
+ # IMAGE RESIZE HELPERS
66
+ # -----------------------------------------------------------------------------
67
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
68
  ar = img.height / img.width
 
69
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
70
+ h = int(np.sqrt(max_area * ar)) // mod * mod
71
+ w = int(np.sqrt(max_area / ar)) // mod * mod
72
  return img.resize((w, h), Image.LANCZOS), h, w
73
 
74
+ def center_crop_resize(img: Image.Image, h: int, w: int):
75
  ratio = max(w / img.width, h / img.height)
76
  img = img.resize(
77
  (round(img.width * ratio), round(img.height * ratio)),
 
79
  )
80
  return TF.center_crop(img, [h, w])
81
 
82
+ # -----------------------------------------------------------------------------
83
+ # GENERATION FUNCTION (streams every step)
84
+ # -----------------------------------------------------------------------------
85
  def generate(
86
  first_frame: Image.Image,
87
  last_frame: Image.Image,
 
94
  fps: int,
95
  progress= gr.Progress()
96
  ):
97
+ # 1) Seed
98
  if seed == -1:
99
  seed = torch.seed()
100
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
101
 
102
+ # 2) Preprocess
103
  progress(0, steps, desc="Preprocessing images")
104
+ f0, h, w = aspect_resize(first_frame)
105
+ if last_frame.size != f0.size:
 
 
106
  last_frame = center_crop_resize(last_frame, h, w)
107
 
108
+ # 3) Streaming callback
109
+ def cb(step, timestep, latents):
110
  progress(step, steps, desc=f"Inference step {step}/{steps}")
111
 
112
+ # 4) Inference
113
+ output = PIPE(
114
+ image=f0,
115
  last_image=last_frame,
116
  prompt=prompt,
117
  negative_prompt=negative or None,
 
121
  num_inference_steps=steps,
122
  guidance_scale=guidance,
123
  generator=gen,
124
+ callback=cb
125
  )
126
 
127
+ # 5) Export to MP4
128
+ video_path = export_to_video(output.frames[0], fps=fps)
 
 
 
129
  return video_path, seed
130
 
131
+ # -----------------------------------------------------------------------------
132
  # GRADIO UI
133
+ # -----------------------------------------------------------------------------
134
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
135
+ gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
136
 
137
  with gr.Row():
138
  first_img = gr.Image(label="First frame", type="pil")
139
  last_img = gr.Image(label="Last frame", type="pil")
140
 
141
+ prompt_box = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
142
+ negative_box = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
143
 
144
  with gr.Accordion("Advanced parameters", open=False):
145
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
146
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
147
+ num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
148
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
149
+ seed_input = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
150
 
151
+ run_btn = gr.Button("Generate")
152
+ video_out= gr.Video(label="Result (.mp4)")
153
+ seed_out = gr.Number(label="Seed used", interactive=False)
154
 
155
  run_btn.click(
156
  fn=generate,
157
+ inputs=[ first_img, last_img, prompt_box, negative_box,
158
+ steps, guidance, num_frames, seed_input, fps ],
159
+ outputs=[ video_out, seed_out ]
160
  )
161
 
162
+ # no extra queue args needed; Gradio will serialize calls automatically
163
+ demo.launch()