GeradeHouse commited on
Commit
18358fb
·
verified ·
1 Parent(s): 2b5109d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -72
app.py CHANGED
@@ -1,8 +1,7 @@
1
  #!/usr/bin/env python
2
  """
3
- Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
- Loads the huge model once, uses balanced device placement,
5
- streams high-level progress, and auto-offers the .mp4 for download.
6
  """
7
  import os
8
  import numpy as np
@@ -10,99 +9,104 @@ import torch
10
  import gradio as gr
11
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
12
  from diffusers.utils import export_to_video
13
- from transformers import CLIPImageProcessor, CLIPVisionModel
14
  from PIL import Image
15
  import torchvision.transforms.functional as TF
16
 
17
- # --------------------------------------------------------------------
18
  # CONFIG
19
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
- DTYPE = torch.float16 # half-precision
21
- MAX_AREA = 1280 * 720 # ≤720p
22
- DEFAULT_FRAMES = 81 # ≈5s @16fps
23
- # --------------------------------------------------------------------
24
-
 
 
 
 
 
 
25
  def load_pipeline():
26
- # 1) image encoder in full precision
27
  image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
29
  )
30
- # 2) VAE in reduced precision
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
34
- # 3) CLIPImageProcessor so we get the right class
35
- image_processor = CLIPImageProcessor.from_pretrained(
36
- MODEL_ID, subfolder="", torch_dtype=DTYPE
37
- )
38
- # 4) load everything with a balanced device map
39
  pipe = WanImageToVideoPipeline.from_pretrained(
40
  MODEL_ID,
41
- vae=vae,
42
  image_encoder=image_encoder,
43
- image_processor=image_processor,
44
  torch_dtype=DTYPE,
45
- device_map="balanced", # splits weights CPU/GPU
 
46
  )
47
  return pipe
48
 
49
- # load once at import
50
  PIPE = load_pipeline()
51
 
52
 
53
- # --------------------------------------------------------------------
54
  # UTILS
 
55
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
56
- """Resize while respecting multiples of the model’s patch size."""
57
  ar = img.height / img.width
58
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
59
- h = round(np.sqrt(max_area * ar)) // mod * mod
60
- w = round(np.sqrt(max_area / ar)) // mod * mod
61
  return img.resize((w, h), Image.LANCZOS), h, w
62
 
63
  def center_crop_resize(img: Image.Image, h, w):
64
- """Crop-and-resize to exactly (h, w)."""
65
  ratio = max(w / img.width, h / img.height)
66
- img = img.resize(
67
  (round(img.width * ratio), round(img.height * ratio)),
68
  Image.LANCZOS
69
  )
70
  return TF.center_crop(img, [h, w])
71
 
72
 
73
- # --------------------------------------------------------------------
74
- # GENERATE (with simple progress streaming)
 
75
  def generate(
76
  first_frame: Image.Image,
77
- last_frame: Image.Image,
78
- prompt: str,
79
- negative_prompt: str,
80
- steps: int,
81
- guidance: float,
82
- num_frames: int,
83
- seed: int,
84
- fps: int,
85
- progress=gr.Progress(), # gradio’s built-in progress callback
86
  ):
87
- # pick or set seed
88
  if seed == -1:
89
  seed = torch.seed()
90
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
91
 
92
- # 0→10%: resize
93
  progress(0.0, desc="Resizing first frame���")
94
- first_frame, h, w = aspect_resize(first_frame)
95
- if last_frame.size != first_frame.size:
96
- progress(0.1, desc="Resizing last frame…")
97
- last_frame = center_crop_resize(last_frame, h, w)
98
-
99
- # 10→20%: ready to run
100
- progress(0.2, desc="Starting video inference…")
101
- result = PIPE(
102
- image=first_frame,
103
- last_image=last_frame,
 
 
104
  prompt=prompt,
105
- negative_prompt=negative_prompt or None,
106
  height=h,
107
  width=w,
108
  num_frames=num_frames,
@@ -111,17 +115,18 @@ def generate(
111
  generator=gen,
112
  )
113
 
114
- # 80→100%: export
115
- progress(0.8, desc="Assembling video file…")
116
- video_path = export_to_video(result.frames[0], fps=fps)
117
- progress(1.0, desc="Done!")
118
 
119
- # return path so gr.File offers immediate download, plus seed used
 
120
  return video_path, seed
121
 
122
 
123
- # --------------------------------------------------------------------
124
- # UI
 
125
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
126
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
127
 
@@ -129,26 +134,26 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
129
  first_img = gr.Image(label="First frame", type="pil")
130
  last_img = gr.Image(label="Last frame", type="pil")
131
 
132
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
133
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
134
 
135
  with gr.Accordion("Advanced parameters", open=False):
136
- steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
137
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
138
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
139
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
140
- seed = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
141
 
142
- run_btn = gr.Button("Generate")
143
- download = gr.File(label="Download video", interactive=False)
144
- used_seed = gr.Number(label="Seed used", interactive=False)
145
 
146
  run_btn.click(
147
  fn=generate,
148
- inputs=[first_img, last_img, prompt, negative,
149
- steps, guidance, num_frames, seed, fps],
150
- outputs=[download, used_seed],
151
  )
152
 
153
- # queue tasks so users see the little task-queue progress bar
154
- demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
  #!/usr/bin/env python
2
  """
3
+ Gradio demo for Wan2.1 FLF2V – First & Last FrameVideo
4
+ Auto-loads the fast processor and avoids missing preprocessor_config.json.
 
5
  """
6
  import os
7
  import numpy as np
 
9
  import gradio as gr
10
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
11
  from diffusers.utils import export_to_video
12
+ from transformers import CLIPVisionModel
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
 
16
+ # -----------------------------------------------------------------------------
17
  # CONFIG
18
+ # -----------------------------------------------------------------------------
19
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
20
+ DTYPE = torch.float16
21
+ MAX_AREA = 1280 * 720
22
+ DEFAULT_FRAMES = 81
23
+
24
+ # Persist cache so safetensors only download once
25
+ os.environ["HF_HOME"] = "/mnt/data/huggingface"
26
+
27
+ # -----------------------------------------------------------------------------
28
+ # LOAD PIPELINE ONCE
29
+ # -----------------------------------------------------------------------------
30
  def load_pipeline():
31
+ # 1) Image encoder (fp32)
32
  image_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
+ # 2) VAE (half-precision) + slicing
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
+ vae.enable_slicing()
40
+
41
+ # 3) Pipeline, balanced across GPU & CPU, fast processor by default
 
 
42
  pipe = WanImageToVideoPipeline.from_pretrained(
43
  MODEL_ID,
 
44
  image_encoder=image_encoder,
45
+ vae=vae,
46
  torch_dtype=DTYPE,
47
+ device_map="balanced",
48
+ use_fast=True, # get the fast CLIPImageProcessor internally
49
  )
50
  return pipe
51
 
 
52
  PIPE = load_pipeline()
53
 
54
 
55
+ # -----------------------------------------------------------------------------
56
  # UTILS
57
+ # -----------------------------------------------------------------------------
58
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
 
59
  ar = img.height / img.width
60
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
61
+ h = int(np.sqrt(max_area * ar)) // mod * mod
62
+ w = int(np.sqrt(max_area / ar)) // mod * mod
63
  return img.resize((w, h), Image.LANCZOS), h, w
64
 
65
  def center_crop_resize(img: Image.Image, h, w):
 
66
  ratio = max(w / img.width, h / img.height)
67
+ img = img.resize(
68
  (round(img.width * ratio), round(img.height * ratio)),
69
  Image.LANCZOS
70
  )
71
  return TF.center_crop(img, [h, w])
72
 
73
 
74
+ # -----------------------------------------------------------------------------
75
+ # GENERATION WITH PROGRESS STREAMING
76
+ # -----------------------------------------------------------------------------
77
  def generate(
78
  first_frame: Image.Image,
79
+ last_frame: Image.Image,
80
+ prompt: str,
81
+ negative: str,
82
+ steps: int,
83
+ guidance: float,
84
+ num_frames: int,
85
+ seed: int,
86
+ fps: int,
87
+ progress= gr.Progress(), # built-in streamer
88
  ):
89
+ # seed
90
  if seed == -1:
91
  seed = torch.seed()
92
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
93
 
94
+ # 0–15%: resize
95
  progress(0.0, desc="Resizing first frame���")
96
+ first_resized, h, w = aspect_resize(first_frame)
97
+ if last_frame.size != first_resized.size:
98
+ progress(0.15, desc="Resizing last frame…")
99
+ last_resized = center_crop_resize(last_frame, h, w)
100
+ else:
101
+ last_resized = first_resized # same size
102
+
103
+ # 15–25%: setup
104
+ progress(0.25, desc="Launching pipeline…")
105
+ out = PIPE(
106
+ image=first_resized,
107
+ last_image=last_resized,
108
  prompt=prompt,
109
+ negative_prompt=negative or None,
110
  height=h,
111
  width=w,
112
  num_frames=num_frames,
 
115
  generator=gen,
116
  )
117
 
118
+ # 25–90%: we assume the pipeline prints its own bars in console
119
+ progress(0.90, desc="Building video…")
120
+ video_path = export_to_video(out.frames[0], fps=fps)
 
121
 
122
+ # done
123
+ progress(1.0, desc="Done!")
124
  return video_path, seed
125
 
126
 
127
+ # -----------------------------------------------------------------------------
128
+ # GRADIO UI
129
+ # -----------------------------------------------------------------------------
130
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
131
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
132
 
 
134
  first_img = gr.Image(label="First frame", type="pil")
135
  last_img = gr.Image(label="Last frame", type="pil")
136
 
137
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
138
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="blurry, lowres")
139
 
140
  with gr.Accordion("Advanced parameters", open=False):
141
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
142
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
143
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
144
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
145
+ seed_input = gr.Number(value=-1, precision=0, label="Seed (-1=random)")
146
 
147
+ run_btn = gr.Button("Generate")
148
+ download = gr.File(label="Download .mp4", interactive=False)
149
+ seed_used = gr.Number(label="Seed used", interactive=False)
150
 
151
  run_btn.click(
152
  fn=generate,
153
+ inputs=[ first_img, last_img, prompt, negative,
154
+ steps, guidance, num_frames, seed_input, fps ],
155
+ outputs=[ download, seed_used ],
156
  )
157
 
158
+ # queue() so tasks are serialized with a top-right mini-progress indicator
159
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)