GeradeHouse commited on
Commit
b75a45c
·
verified ·
1 Parent(s): 2c7ebd6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -57
app.py CHANGED
@@ -4,75 +4,70 @@ Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
  """
5
 
6
  import os
7
- import torch
 
 
8
  import numpy as np
 
9
  import gradio as gr
10
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
11
- from transformers import CLIPProcessor, CLIPVisionModel
12
  from diffusers.utils import export_to_video
 
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
 
16
  # ----------------------------------------------------------------------
17
  # CONFIG
18
  # ----------------------------------------------------------------------
19
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
- DTYPE = torch.float16 # switch to torch.bfloat16 if you have AMP-friendly GPU
21
- MAX_AREA = 1280 * 720 # ≤ 720p
22
- DEFAULT_FRAMES = 81 # ~5s @ 16fps
23
 
24
  # ----------------------------------------------------------------------
25
- # PIPELINE LOADING (once)
26
  # ----------------------------------------------------------------------
27
  def load_pipeline():
28
- # 1) image encoder in fp32 for stability
29
  image_encoder = CLIPVisionModel.from_pretrained(
30
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
31
  )
32
- # 2) VAE in reduced precision
33
  vae = AutoencoderKLWan.from_pretrained(
34
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
35
  )
36
- # 3) use the unified CLIPProcessor (inherits ProcessorMixin) in fast mode
37
- processor = CLIPProcessor.from_pretrained(MODEL_ID, use_fast=True)
38
-
39
- # 4) assemble pipeline, overriding the default processor
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
42
  vae=vae,
43
  image_encoder=image_encoder,
44
- processor=processor,
45
  torch_dtype=DTYPE,
 
46
  )
47
-
48
- # 5) offload to CPU / reduce footprint
49
  pipe.enable_model_cpu_offload()
50
-
51
- # 6) safe VAE slicing if available
52
- try:
53
- pipe.vae.enable_slicing()
54
- except (AttributeError, TypeError):
55
- pass
56
-
57
  return pipe
58
 
59
- pipe = load_pipeline()
 
60
 
61
  # ----------------------------------------------------------------------
62
- # IMAGE RESIZING HELPERS
63
  # ----------------------------------------------------------------------
64
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
65
  ar = img.height / img.width
66
- # align to VAE & transformer patch grid
67
- mod = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
68
  h = round(np.sqrt(max_area * ar)) // mod * mod
69
  w = round(np.sqrt(max_area / ar)) // mod * mod
70
  return img.resize((w, h), Image.LANCZOS), h, w
71
 
72
  def center_crop_resize(img: Image.Image, h, w):
73
  ratio = max(w / img.width, h / img.height)
74
- img = img.resize(
75
- (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
 
76
  )
77
  return TF.center_crop(img, [h, w])
78
 
@@ -83,49 +78,55 @@ def generate(
83
  first_frame: Image.Image,
84
  last_frame: Image.Image,
85
  prompt: str,
86
- negative_prompt: str,
87
  steps: int,
88
  guidance: float,
89
  num_frames: int,
90
  seed: int,
91
  fps: int,
 
92
  ):
93
- # randomize seed if requested
94
  if seed == -1:
95
  seed = torch.seed()
96
- gen = torch.Generator(device=pipe.device).manual_seed(seed)
97
 
98
- # preprocess inputs
 
 
 
99
  first_frame, h, w = aspect_resize(first_frame)
100
  if last_frame.size != first_frame.size:
101
  last_frame = center_crop_resize(last_frame, h, w)
102
 
103
- # set up streaming progress
104
- progress = gr.Progress(track_tqdm=True)
 
105
 
106
- # run the pipeline, streaming progress every step
107
- result = pipe(
108
  image=first_frame,
109
  last_image=last_frame,
110
  prompt=prompt,
111
- negative_prompt=negative_prompt or None,
112
  height=h,
113
  width=w,
114
  num_frames=num_frames,
115
  num_inference_steps=steps,
116
  guidance_scale=guidance,
117
  generator=gen,
118
- callback=progress,
119
- callback_steps=1,
120
  )
121
 
122
- # export to video and return path + seed used
123
- frames = result.frames[0]
124
  video_path = export_to_video(frames, fps=fps)
 
 
125
  return video_path, seed
126
 
127
  # ----------------------------------------------------------------------
128
- # GRADIO APP
129
  # ----------------------------------------------------------------------
130
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
  gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
@@ -134,26 +135,26 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
134
  first_img = gr.Image(label="First frame", type="pil")
135
  last_img = gr.Image(label="Last frame", type="pil")
136
 
137
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
138
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
139
 
140
  with gr.Accordion("Advanced parameters", open=False):
141
- steps = gr.Slider(10, 50, value=30, label="Sampling steps")
142
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
143
- num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, label="Frames")
144
- fps = gr.Slider(4, 30, value=16, label="FPS (export)")
145
- seed_input = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
146
 
147
- run_btn = gr.Button("Generate")
148
- video_out = gr.Video(label="Result (.mp4)")
149
- used_seed = gr.Number(label="Seed used", interactive=False)
150
 
151
  run_btn.click(
152
  fn=generate,
153
- inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed_input, fps],
154
- outputs=[video_out, used_seed],
155
- show_progress=True, # hook into Gradio’s built-in progress UI
156
  )
157
 
158
- demo.queue() # serialize GPU calls
159
- demo.launch()
 
4
  """
5
 
6
  import os
7
+ # Persist HF cache on /mnt/data so it survives across launches
8
+ os.environ["HF_HOME"] = "/mnt/data/huggingface"
9
+
10
  import numpy as np
11
+ import torch
12
  import gradio as gr
13
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
 
14
  from diffusers.utils import export_to_video
15
+ from transformers import CLIPVisionModel
16
  from PIL import Image
17
  import torchvision.transforms.functional as TF
18
 
19
  # ----------------------------------------------------------------------
20
  # CONFIG
21
  # ----------------------------------------------------------------------
22
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
+ DTYPE = torch.float16
24
+ MAX_AREA = 1280 * 720
25
+ DEFAULT_FRAMES = 81
26
 
27
  # ----------------------------------------------------------------------
28
+ # PIPELINE LOADING
29
  # ----------------------------------------------------------------------
30
  def load_pipeline():
31
+ # 1) load CLIP image encoder in full precision
32
  image_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
+ # 2) load VAE in reduced precision
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
+ # 3) load the WanImageToVideo pipeline, balanced across GPU/CPU
 
 
 
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
42
  vae=vae,
43
  image_encoder=image_encoder,
 
44
  torch_dtype=DTYPE,
45
+ device_map="balanced", # auto-offload large modules to CPU
46
  )
47
+ # 4) reduce VAE peaks & enable CPU offload for everything else
48
+ pipe.enable_vae_slicing()
49
  pipe.enable_model_cpu_offload()
 
 
 
 
 
 
 
50
  return pipe
51
 
52
+ # create once, at import time
53
+ PIPE = load_pipeline()
54
 
55
  # ----------------------------------------------------------------------
56
+ # IMAGE PREPROCESSING UTILS
57
  # ----------------------------------------------------------------------
58
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
59
  ar = img.height / img.width
60
+ # ensure multiple of patch size
61
+ mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
62
  h = round(np.sqrt(max_area * ar)) // mod * mod
63
  w = round(np.sqrt(max_area / ar)) // mod * mod
64
  return img.resize((w, h), Image.LANCZOS), h, w
65
 
66
  def center_crop_resize(img: Image.Image, h, w):
67
  ratio = max(w / img.width, h / img.height)
68
+ img = img.resize(
69
+ (round(img.width * ratio), round(img.height * ratio)),
70
+ Image.LANCZOS
71
  )
72
  return TF.center_crop(img, [h, w])
73
 
 
78
  first_frame: Image.Image,
79
  last_frame: Image.Image,
80
  prompt: str,
81
+ negative: str,
82
  steps: int,
83
  guidance: float,
84
  num_frames: int,
85
  seed: int,
86
  fps: int,
87
+ progress= gr.Progress()
88
  ):
89
+ # seed
90
  if seed == -1:
91
  seed = torch.seed()
92
+ gen = torch.Generator(device=PIPE.device).manual_seed(seed)
93
 
94
+ # initial progress
95
+ progress(0, steps, desc="Preprocessing images")
96
+
97
+ # resize / crop
98
  first_frame, h, w = aspect_resize(first_frame)
99
  if last_frame.size != first_frame.size:
100
  last_frame = center_crop_resize(last_frame, h, w)
101
 
102
+ # callback to update progress bar on each denoising step
103
+ def progress_callback(step, timestep, latents):
104
+ progress(step, steps, desc=f"Inference step {step}/{steps}")
105
 
106
+ # run the pipeline (streams progress via callback)
107
+ result = PIPE(
108
  image=first_frame,
109
  last_image=last_frame,
110
  prompt=prompt,
111
+ negative_prompt=negative or None,
112
  height=h,
113
  width=w,
114
  num_frames=num_frames,
115
  num_inference_steps=steps,
116
  guidance_scale=guidance,
117
  generator=gen,
118
+ callback=progress_callback,
 
119
  )
120
 
121
+ # assemble and export to video
122
+ frames = result.frames[0] # list of PIL images
123
  video_path = export_to_video(frames, fps=fps)
124
+
125
+ # return video and seed used (Gradio will auto-download the .mp4)
126
  return video_path, seed
127
 
128
  # ----------------------------------------------------------------------
129
+ # GRADIO UI
130
  # ----------------------------------------------------------------------
131
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
132
  gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
 
135
  first_img = gr.Image(label="First frame", type="pil")
136
  last_img = gr.Image(label="Last frame", type="pil")
137
 
138
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
139
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
140
 
141
  with gr.Accordion("Advanced parameters", open=False):
142
+ steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
143
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
144
+ num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Number of frames")
145
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
146
+ seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
147
 
148
+ run_btn = gr.Button("Generate")
149
+ video_out = gr.Video(label="Result (.mp4)")
150
+ used_seed = gr.Number(label="Seed used", interactive=False)
151
 
152
  run_btn.click(
153
  fn=generate,
154
+ inputs=[ first_img, last_img, prompt, negative,
155
+ steps, guidance, num_frames, seed, fps ],
156
+ outputs=[ video_out, used_seed ]
157
  )
158
 
159
+ # no special queue args needed
160
+ demo.launch()