GeradeHouse commited on
Commit
2c7ebd6
·
verified ·
1 Parent(s): f956532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -94
app.py CHANGED
@@ -1,112 +1,113 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
- Streams all HF-Hub & Diffusers tqdm bars, caches the model,
5
- and provides a direct download link for the MP4.
6
  """
7
 
8
- import ftfy
9
- import numpy as np
10
  import torch
 
11
  import gradio as gr
12
- from PIL import Image
13
- from transformers import CLIPVisionModel, CLIPProcessor
14
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
 
15
  from diffusers.utils import export_to_video
 
16
  import torchvision.transforms.functional as TF
17
 
18
- # -----------------------------------------------------------------------------
19
  # CONFIG
20
- # -----------------------------------------------------------------------------
21
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
22
- DTYPE = torch.float16
23
- MAX_AREA = 1280 * 720
24
- DEFAULT_FRAMES = 81
25
-
26
- # -----------------------------------------------------------------------------
27
- # GLOBAL CACHED PIPELINE
28
- # -----------------------------------------------------------------------------
29
- PIPE = None
30
-
31
  def load_pipeline():
32
- """Load & cache the pipeline (once)."""
33
- # 1) CLIP vision encoder (fp32)
34
- vision = CLIPVisionModel.from_pretrained(
35
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
36
  )
37
- # 2) unified CLIP processor (fast Rust-backed+tokenizer stub)
38
- processor = CLIPProcessor.from_pretrained(
39
- MODEL_ID, subfolder="image_processor", use_fast=True
40
- )
41
- # 3) VAE (half precision)
42
  vae = AutoencoderKLWan.from_pretrained(
43
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
44
  )
45
- # 4) assemble pipeline
 
 
 
46
  pipe = WanImageToVideoPipeline.from_pretrained(
47
  MODEL_ID,
48
  vae=vae,
49
- image_encoder=vision,
50
- image_processor=processor,
51
  torch_dtype=DTYPE,
52
  )
53
- # 5) CPU offload for large models
 
54
  pipe.enable_model_cpu_offload()
55
- # return on correct device
56
- return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
- # -----------------------------------------------------------------------------
59
- # IMAGE RESIZE HELPERS
60
- # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
61
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
62
- ar = img.height / img.width
63
- mod = PIPE.transformer.config.patch_size[1] * PIPE.vae_scale_factor_spatial
64
- h = (int(np.sqrt(max_area * ar)) // mod) * mod
65
- w = (int(np.sqrt(max_area / ar)) // mod) * mod
 
66
  return img.resize((w, h), Image.LANCZOS), h, w
67
 
68
- def center_crop_resize(img: Image.Image, h: int, w: int):
69
  ratio = max(w / img.width, h / img.height)
70
- img2 = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
71
- return TF.center_crop(img2, [h, w])
 
 
72
 
73
- # -----------------------------------------------------------------------------
74
- # GENERATION (streams all tqdm → Gradio)
75
- # -----------------------------------------------------------------------------
76
  def generate(
77
  first_frame: Image.Image,
78
- last_frame: Image.Image,
79
- prompt: str,
80
  negative_prompt: str,
81
- steps: int,
82
- guidance: float,
83
- num_frames: int,
84
- seed: int,
85
- fps: int,
86
- progress=gr.Progress(track_tqdm=True),
87
  ):
88
- global PIPE
89
- # lazy load once
90
- if PIPE is None:
91
- progress(0, desc="Loading model…")
92
- PIPE = load_pipeline()
93
-
94
- # ensure reproducibility
95
  if seed == -1:
96
  seed = torch.seed()
97
- gen = torch.Generator(device=PIPE.device).manual_seed(seed)
98
 
99
- # preprocess
100
- progress(0, desc="Preprocessing frames…")
101
- frame1, h, w = aspect_resize(first_frame)
102
- if last_frame.size != frame1.size:
103
  last_frame = center_crop_resize(last_frame, h, w)
104
 
105
- # inference (all internal tqdm bars streamed)
106
- result = PIPE(
107
- image=frame1,
 
 
 
108
  last_image=last_frame,
109
- prompt=ftfy.fix_text(prompt),
110
  negative_prompt=negative_prompt or None,
111
  height=h,
112
  width=w,
@@ -114,44 +115,45 @@ def generate(
114
  num_inference_steps=steps,
115
  guidance_scale=guidance,
116
  generator=gen,
 
 
117
  )
118
- frames = result.frames[0]
119
 
120
- # export to MP4
121
- progress(1.0, desc="Exporting video…")
122
- out_path = export_to_video(frames, fps=fps)
123
- return out_path, seed
124
 
125
- # -----------------------------------------------------------------------------
126
- # GRADIO UI
127
- # -----------------------------------------------------------------------------
128
- with gr.Blocks() as demo:
129
- gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
130
 
131
  with gr.Row():
132
  first_img = gr.Image(label="First frame", type="pil")
133
  last_img = gr.Image(label="Last frame", type="pil")
134
 
135
- prompt = gr.Textbox(label="Prompt")
136
- negative = gr.Textbox(label="Negative prompt (optional)")
137
 
138
  with gr.Accordion("Advanced parameters", open=False):
139
- steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
140
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1,label="Guidance")
141
- num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1,label="Frames")
142
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
143
- seed = gr.Number(value=-1, precision=0, label="Seed")
144
 
145
- run_btn = gr.Button("Generate")
146
- download = gr.File(label="Download video (.mp4)")
147
- used_seed= gr.Number(label="Seed used", interactive=False)
148
 
149
  run_btn.click(
150
  fn=generate,
151
- inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
152
- outputs=[download, used_seed],
153
- concurrency_limit=1
154
  )
155
 
156
- # enable queue + tqdm streaming
157
- demo.queue().launch()
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
 
 
4
  """
5
 
6
+ import os
 
7
  import torch
8
+ import numpy as np
9
  import gradio as gr
 
 
10
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
11
+ from transformers import CLIPProcessor, CLIPVisionModel
12
  from diffusers.utils import export_to_video
13
+ from PIL import Image
14
  import torchvision.transforms.functional as TF
15
 
16
+ # ----------------------------------------------------------------------
17
  # CONFIG
18
+ # ----------------------------------------------------------------------
19
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
+ DTYPE = torch.float16 # switch to torch.bfloat16 if you have AMP-friendly GPU
21
+ MAX_AREA = 1280 * 720 # ≤ 720p
22
+ DEFAULT_FRAMES = 81 # ~5s @ 16fps
23
+
24
+ # ----------------------------------------------------------------------
25
+ # PIPELINE LOADING (once)
26
+ # ----------------------------------------------------------------------
 
 
27
  def load_pipeline():
28
+ # 1) image encoder in fp32 for stability
29
+ image_encoder = CLIPVisionModel.from_pretrained(
 
30
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
31
  )
32
+ # 2) VAE in reduced precision
 
 
 
 
33
  vae = AutoencoderKLWan.from_pretrained(
34
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
35
  )
36
+ # 3) use the unified CLIPProcessor (inherits ProcessorMixin) in fast mode
37
+ processor = CLIPProcessor.from_pretrained(MODEL_ID, use_fast=True)
38
+
39
+ # 4) assemble pipeline, overriding the default processor
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
42
  vae=vae,
43
+ image_encoder=image_encoder,
44
+ processor=processor,
45
  torch_dtype=DTYPE,
46
  )
47
+
48
+ # 5) offload to CPU / reduce footprint
49
  pipe.enable_model_cpu_offload()
 
 
50
 
51
+ # 6) safe VAE slicing if available
52
+ try:
53
+ pipe.vae.enable_slicing()
54
+ except (AttributeError, TypeError):
55
+ pass
56
+
57
+ return pipe
58
+
59
+ pipe = load_pipeline()
60
+
61
+ # ----------------------------------------------------------------------
62
+ # IMAGE RESIZING HELPERS
63
+ # ----------------------------------------------------------------------
64
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
65
+ ar = img.height / img.width
66
+ # align to VAE & transformer patch grid
67
+ mod = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
68
+ h = round(np.sqrt(max_area * ar)) // mod * mod
69
+ w = round(np.sqrt(max_area / ar)) // mod * mod
70
  return img.resize((w, h), Image.LANCZOS), h, w
71
 
72
+ def center_crop_resize(img: Image.Image, h, w):
73
  ratio = max(w / img.width, h / img.height)
74
+ img = img.resize(
75
+ (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
76
+ )
77
+ return TF.center_crop(img, [h, w])
78
 
79
+ # ----------------------------------------------------------------------
80
+ # GENERATION FUNCTION
81
+ # ----------------------------------------------------------------------
82
  def generate(
83
  first_frame: Image.Image,
84
+ last_frame: Image.Image,
85
+ prompt: str,
86
  negative_prompt: str,
87
+ steps: int,
88
+ guidance: float,
89
+ num_frames: int,
90
+ seed: int,
91
+ fps: int,
 
92
  ):
93
+ # randomize seed if requested
 
 
 
 
 
 
94
  if seed == -1:
95
  seed = torch.seed()
96
+ gen = torch.Generator(device=pipe.device).manual_seed(seed)
97
 
98
+ # preprocess inputs
99
+ first_frame, h, w = aspect_resize(first_frame)
100
+ if last_frame.size != first_frame.size:
 
101
  last_frame = center_crop_resize(last_frame, h, w)
102
 
103
+ # set up streaming progress
104
+ progress = gr.Progress(track_tqdm=True)
105
+
106
+ # run the pipeline, streaming progress every step
107
+ result = pipe(
108
+ image=first_frame,
109
  last_image=last_frame,
110
+ prompt=prompt,
111
  negative_prompt=negative_prompt or None,
112
  height=h,
113
  width=w,
 
115
  num_inference_steps=steps,
116
  guidance_scale=guidance,
117
  generator=gen,
118
+ callback=progress,
119
+ callback_steps=1,
120
  )
 
121
 
122
+ # export to video and return path + seed used
123
+ frames = result.frames[0]
124
+ video_path = export_to_video(frames, fps=fps)
125
+ return video_path, seed
126
 
127
+ # ----------------------------------------------------------------------
128
+ # GRADIO APP
129
+ # ----------------------------------------------------------------------
130
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
+ gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video")
132
 
133
  with gr.Row():
134
  first_img = gr.Image(label="First frame", type="pil")
135
  last_img = gr.Image(label="Last frame", type="pil")
136
 
137
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
138
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
139
 
140
  with gr.Accordion("Advanced parameters", open=False):
141
+ steps = gr.Slider(10, 50, value=30, label="Sampling steps")
142
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
143
+ num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, label="Frames")
144
+ fps = gr.Slider(4, 30, value=16, label="FPS (export)")
145
+ seed_input = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
146
 
147
+ run_btn = gr.Button("Generate")
148
+ video_out = gr.Video(label="Result (.mp4)")
149
+ used_seed = gr.Number(label="Seed used", interactive=False)
150
 
151
  run_btn.click(
152
  fn=generate,
153
+ inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed_input, fps],
154
+ outputs=[video_out, used_seed],
155
+ show_progress=True, # hook into Gradio’s built-in progress UI
156
  )
157
 
158
+ demo.queue() # serialize GPU calls
159
+ demo.launch()