GeradeHouse commited on
Commit
699b386
·
verified ·
1 Parent(s): 64a6a24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -54
app.py CHANGED
@@ -1,54 +1,45 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
- Loads the huge model lazily (only once), streams **all** tqdm bars
5
- (from HF downloads, shard loading, to denoising) into Gradio's UI,
6
- and outputs a direct File download for the generated video.
7
  """
8
 
9
  import os
10
- import tempfile
11
-
12
  import ftfy
13
  import numpy as np
14
  import torch
15
  import gradio as gr
 
 
16
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
17
  from diffusers.utils import export_to_video
18
- from transformers import CLIPVisionModel, CLIPImageProcessor
19
- from PIL import Image
20
 
21
  # -----------------------------------------------------------------------------
22
  # CONFIG
23
  # -----------------------------------------------------------------------------
24
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
25
- DTYPE = torch.float16 # or torch.bfloat16 on AMP-friendly cards
26
- MAX_AREA = 1280 * 720 # ≤720p
27
- DEFAULT_FRAMES = 81 # ~5s @16fps
28
 
29
  # -----------------------------------------------------------------------------
30
- # GLOBAL PIPELINE (lazy)
31
  # -----------------------------------------------------------------------------
32
  PIPE = None
33
 
34
  def load_pipeline():
35
- """
36
- Load the Wan2.1-FLF2V pipeline once, with fast processor,
37
- CPU-offload for large models, and in half-precision.
38
- """
39
- # 1) full-precision CLIP encoder
40
  vision = CLIPVisionModel.from_pretrained(
41
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
42
  )
43
- # 2) fast CLIP image processor
44
  processor = CLIPImageProcessor.from_pretrained(
45
  MODEL_ID, subfolder="preprocessor", use_fast=True
46
  )
47
- # 3) reduced-precision VAE
48
  vae = AutoencoderKLWan.from_pretrained(
49
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
50
  )
51
- # 4) assemble pipeline
52
  pipe = WanImageToVideoPipeline.from_pretrained(
53
  MODEL_ID,
54
  vae=vae,
@@ -56,18 +47,13 @@ def load_pipeline():
56
  image_processor=processor,
57
  torch_dtype=DTYPE,
58
  )
59
- # 5) offload to CPU/AutoDevice
60
  pipe.enable_model_cpu_offload()
61
- # (we drop .enable_slicing() because it's unsupported here)
62
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
63
 
64
  # -----------------------------------------------------------------------------
65
- # UTILS
66
  # -----------------------------------------------------------------------------
67
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
68
- """
69
- Resize while respecting the model's patch size (multiple of 8 * transformer patch).
70
- """
71
  ar = img.height / img.width
72
  mod = PIPE.transformer.config.patch_size[1] * PIPE.vae_scale_factor_spatial
73
  h = (int(np.sqrt(max_area * ar)) // mod) * mod
@@ -75,15 +61,12 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
75
  return img.resize((w, h), Image.LANCZOS), h, w
76
 
77
  def center_crop_resize(img: Image.Image, h: int, w: int):
78
- """
79
- Center-crop + resize to exactly h×w.
80
- """
81
  ratio = max(w / img.width, h / img.height)
82
  img2 = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
83
  return TF.center_crop(img2, [h, w])
84
 
85
  # -----------------------------------------------------------------------------
86
- # GENERATION (with full tqdm → Gradio progress streaming)
87
  # -----------------------------------------------------------------------------
88
  def generate(
89
  first_frame: Image.Image,
@@ -98,27 +81,27 @@ def generate(
98
  progress=gr.Progress(track_tqdm=True),
99
  ):
100
  global PIPE
101
- # lazy instantiate
102
  if PIPE is None:
103
  progress(0, desc="Loading pipeline…")
104
  PIPE = load_pipeline()
105
 
106
- # seeding
107
  if seed == -1:
108
  seed = torch.seed()
109
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
110
 
111
- # preprocess
112
- progress(0, desc="Preprocessing…")
113
  frame1, h, w = aspect_resize(first_frame)
114
  if last_frame.size != frame1.size:
115
  last_frame = center_crop_resize(last_frame, h, w)
116
 
117
- # inference (all tqdm inside will stream to UI)
118
  result = PIPE(
119
  image=frame1,
120
  last_image=last_frame,
121
- prompt=whitespace_clean(basic_clean(prompt)),
122
  negative_prompt=negative_prompt or None,
123
  height=h,
124
  width=w,
@@ -126,42 +109,40 @@ def generate(
126
  num_inference_steps=steps,
127
  guidance_scale=guidance,
128
  generator=gen,
129
- # no callback_steps here!
130
  )
131
- frames = result.frames[0] # list of PIL images
132
 
133
- # export to MP4
134
- progress(1.0, desc="Assembling video…")
135
  out_path = export_to_video(frames, fps=fps)
136
  return out_path, seed
137
 
138
  # -----------------------------------------------------------------------------
139
- # BUILD UI
140
  # -----------------------------------------------------------------------------
141
  with gr.Blocks() as demo:
142
- gr.Markdown("## Wan 2.1 FLF2V – First & Last Frame → Video (Diffusers)")
143
  with gr.Row():
144
  first_img = gr.Image(label="First frame", type="pil")
145
  last_img = gr.Image(label="Last frame", type="pil")
146
- prompt = gr.Textbox(label="Prompt", placeholder="A small blue bird takes off…")
147
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
148
  with gr.Accordion("Advanced parameters", open=False):
149
- steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
150
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
151
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
152
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
153
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
154
- run_btn = gr.Button("Generate")
155
- # **File** component for direct download link:
156
  download = gr.File(label="Download video (.mp4)")
157
- used_seed = gr.Number(label="Seed used", interactive=False)
158
 
159
- # queue() for async + progress
160
  run_btn.click(
161
  fn=generate,
162
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
163
  outputs=[download, used_seed],
 
164
  )
165
 
166
- # MUST call .queue() to enable gr.Progress()
167
- demo.queue(concurrency_count=1).launch()
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
+ Streams all HF-Hub & Diffusers tqdm bars into Gradio, caches the pipeline,
5
+ and outputs a direct download link.
 
6
  """
7
 
8
  import os
 
 
9
  import ftfy
10
  import numpy as np
11
  import torch
12
  import gradio as gr
13
+ from PIL import Image
14
+ from transformers import CLIPVisionModel, CLIPImageProcessor
15
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
16
  from diffusers.utils import export_to_video
17
+ import torchvision.transforms.functional as TF
 
18
 
19
  # -----------------------------------------------------------------------------
20
  # CONFIG
21
  # -----------------------------------------------------------------------------
22
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
+ DTYPE = torch.float16
24
+ MAX_AREA = 1280 * 720
25
+ DEFAULT_FRAMES = 81
26
 
27
  # -----------------------------------------------------------------------------
28
+ # GLOBAL CACHED PIPELINE
29
  # -----------------------------------------------------------------------------
30
  PIPE = None
31
 
32
  def load_pipeline():
33
+ """Load & shard the pipeline once (CPU offload + fast processor)."""
 
 
 
 
34
  vision = CLIPVisionModel.from_pretrained(
35
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
36
  )
 
37
  processor = CLIPImageProcessor.from_pretrained(
38
  MODEL_ID, subfolder="preprocessor", use_fast=True
39
  )
 
40
  vae = AutoencoderKLWan.from_pretrained(
41
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
42
  )
 
43
  pipe = WanImageToVideoPipeline.from_pretrained(
44
  MODEL_ID,
45
  vae=vae,
 
47
  image_processor=processor,
48
  torch_dtype=DTYPE,
49
  )
 
50
  pipe.enable_model_cpu_offload()
 
51
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
52
 
53
  # -----------------------------------------------------------------------------
54
+ # IMAGE RESIZE HELPERS
55
  # -----------------------------------------------------------------------------
56
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
 
 
 
57
  ar = img.height / img.width
58
  mod = PIPE.transformer.config.patch_size[1] * PIPE.vae_scale_factor_spatial
59
  h = (int(np.sqrt(max_area * ar)) // mod) * mod
 
61
  return img.resize((w, h), Image.LANCZOS), h, w
62
 
63
  def center_crop_resize(img: Image.Image, h: int, w: int):
 
 
 
64
  ratio = max(w / img.width, h / img.height)
65
  img2 = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
66
  return TF.center_crop(img2, [h, w])
67
 
68
  # -----------------------------------------------------------------------------
69
+ # GENERATION FUNCTION (with tqdm streaming)
70
  # -----------------------------------------------------------------------------
71
  def generate(
72
  first_frame: Image.Image,
 
81
  progress=gr.Progress(track_tqdm=True),
82
  ):
83
  global PIPE
84
+ # Lazy load pipeline
85
  if PIPE is None:
86
  progress(0, desc="Loading pipeline…")
87
  PIPE = load_pipeline()
88
 
89
+ # Seed
90
  if seed == -1:
91
  seed = torch.seed()
92
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
93
 
94
+ # Preprocess
95
+ progress(0, desc="Preprocessing frames…")
96
  frame1, h, w = aspect_resize(first_frame)
97
  if last_frame.size != frame1.size:
98
  last_frame = center_crop_resize(last_frame, h, w)
99
 
100
+ # Inference (tqdm bars streamed)
101
  result = PIPE(
102
  image=frame1,
103
  last_image=last_frame,
104
+ prompt=ftfy.fix_text(prompt),
105
  negative_prompt=negative_prompt or None,
106
  height=h,
107
  width=w,
 
109
  num_inference_steps=steps,
110
  guidance_scale=guidance,
111
  generator=gen,
 
112
  )
113
+ frames = result.frames[0]
114
 
115
+ # Export
116
+ progress(1.0, desc="Exporting video…")
117
  out_path = export_to_video(frames, fps=fps)
118
  return out_path, seed
119
 
120
  # -----------------------------------------------------------------------------
121
+ # GRADIO UI
122
  # -----------------------------------------------------------------------------
123
  with gr.Blocks() as demo:
124
+ gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
125
  with gr.Row():
126
  first_img = gr.Image(label="First frame", type="pil")
127
  last_img = gr.Image(label="Last frame", type="pil")
128
+ prompt = gr.Textbox(label="Prompt")
129
+ negative = gr.Textbox(label="Negative prompt (optional)")
130
  with gr.Accordion("Advanced parameters", open=False):
131
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
132
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
133
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
134
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
135
+ seed = gr.Number(value=-1, precision=0, label="Seed")
136
+ run_btn = gr.Button("Generate")
 
137
  download = gr.File(label="Download video (.mp4)")
138
+ used_seed= gr.Number(label="Seed used", interactive=False)
139
 
 
140
  run_btn.click(
141
  fn=generate,
142
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
143
  outputs=[download, used_seed],
144
+ concurrency_limit=1
145
  )
146
 
147
+ # **Enable queuing** (uses default_concurrency_limit=1 under the hood)
148
+ demo.queue().launch()