GeradeHouse commited on
Commit
5516eb1
·
verified ·
1 Parent(s): 4e367ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -39
app.py CHANGED
@@ -1,7 +1,8 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
- Uses Accelerate’s balanced device mapping for optimal CPU/GPU placement.
 
5
  Author: <your-handle>
6
  """
7
 
@@ -16,37 +17,32 @@ import torchvision.transforms.functional as TF
16
 
17
  # ---------------------------------------------------------------------
18
  # CONFIG ----------------------------------------------------------------
19
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
20
- DTYPE = torch.float16 # or torch.bfloat16
21
- MAX_AREA = 1280 * 720 # ≤720p
22
- DEFAULT_FRAMES = 81 # ~5s @16fps
23
  # ----------------------------------------------------------------------
24
 
25
  def load_pipeline():
26
- """Load & auto-map the pipeline across CPU/GPU with low CPU memory usage."""
27
- # 1) load vision encoder (full precision)
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
- # 2) load VAE (half precision)
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
34
  )
35
- # 3) load the video pipeline with Accelerate helpers
36
  pipe = WanImageToVideoPipeline.from_pretrained(
37
  MODEL_ID,
38
  vae=vae,
39
  image_encoder=image_encoder,
40
  torch_dtype=DTYPE,
41
- low_cpu_mem_usage=True, # lazy-load weights into CPU RAM
42
- device_map="balanced", # balanced CPU/GPU sharding
43
  )
44
-
45
- # 4) use the fast Rust-backed processor
46
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
47
  MODEL_ID, subfolder="image_processor", use_fast=True
48
  )
49
-
50
  return pipe
51
 
52
  PIPE = load_pipeline()
@@ -54,7 +50,6 @@ PIPE = load_pipeline()
54
  # ----------------------------------------------------------------------
55
  # UTILS ----------------------------------------------------------------
56
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
57
- """Resize while keeping aspect and patch-size multiples."""
58
  ar = img.height / img.width
59
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
60
  h = round(np.sqrt(max_area * ar)) // mod * mod
@@ -62,29 +57,25 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
62
  return img.resize((w, h), Image.LANCZOS), h, w
63
 
64
  def center_crop_resize(img: Image.Image, h, w):
65
- """Center-crop & resize to target H×W."""
66
  ratio = max(w / img.width, h / img.height)
67
- img = img.resize(
68
- (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
69
- )
70
  return TF.center_crop(img, [h, w])
71
 
72
  # ----------------------------------------------------------------------
73
- # GENERATE --------------------------------------------------------------
74
- def generate(first_frame, last_frame, prompt, negative_prompt, steps,
75
- guidance, num_frames, seed, fps):
76
-
77
- # handle seed
78
- if seed == -1:
79
- seed = torch.seed()
80
- gen = torch.Generator(device=PIPE.device).manual_seed(seed)
81
-
82
- # preprocess frames
83
  first_frame, h, w = aspect_resize(first_frame)
84
  if last_frame.size != first_frame.size:
85
  last_frame = center_crop_resize(last_frame, h, w)
86
 
87
- # inference
 
 
 
 
88
  output = PIPE(
89
  image=first_frame,
90
  last_image=last_frame,
@@ -97,20 +88,42 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
97
  guidance_scale=guidance,
98
  generator=gen,
99
  )
100
- frames = output.frames[0] # list[PIL.Image]
101
 
102
- # export to mp4
 
103
  video_path = export_to_video(frames, fps=fps)
104
- return video_path, seed
 
 
105
 
106
  # ----------------------------------------------------------------------
107
  # UI --------------------------------------------------------------------
108
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("## Wan 2.1 FLF2V �� First & Last Frame → Video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  with gr.Row():
112
  first_img = gr.Image(label="First frame", type="pil")
113
- last_img = gr.Image(label="Last frame", type="pil")
114
 
115
  prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
116
  negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
@@ -123,14 +136,16 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
123
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
124
 
125
  run_btn = gr.Button("Generate")
126
- video = gr.Video(label="Result (.mp4)")
 
127
  used_seed = gr.Number(label="Seed used", interactive=False)
128
 
129
  run_btn.click(
130
  fn=generate,
131
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
132
- outputs=[video, used_seed]
 
133
  )
134
 
135
- if __name__ == "__main__":
136
  demo.launch()
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
+ shows streaming status updates
5
+ – auto-downloads the generated video
6
  Author: <your-handle>
7
  """
8
 
 
17
 
18
  # ---------------------------------------------------------------------
19
  # CONFIG ----------------------------------------------------------------
20
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
21
+ DTYPE = torch.float16
22
+ MAX_AREA = 1280 * 720
23
+ DEFAULT_FRAMES = 81
24
  # ----------------------------------------------------------------------
25
 
26
  def load_pipeline():
27
+ """Load & shard the pipeline across CPU/GPU with Accelerate."""
 
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
 
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
 
34
  pipe = WanImageToVideoPipeline.from_pretrained(
35
  MODEL_ID,
36
  vae=vae,
37
  image_encoder=image_encoder,
38
  torch_dtype=DTYPE,
39
+ low_cpu_mem_usage=True, # lazyload to CPU RAM
40
+ device_map="balanced", # shard across CPU/GPU
41
  )
42
+ # switch to the fast Rust processor
 
43
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
44
  MODEL_ID, subfolder="image_processor", use_fast=True
45
  )
 
46
  return pipe
47
 
48
  PIPE = load_pipeline()
 
50
  # ----------------------------------------------------------------------
51
  # UTILS ----------------------------------------------------------------
52
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
 
53
  ar = img.height / img.width
54
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
55
  h = round(np.sqrt(max_area * ar)) // mod * mod
 
57
  return img.resize((w, h), Image.LANCZOS), h, w
58
 
59
  def center_crop_resize(img: Image.Image, h, w):
 
60
  ratio = max(w / img.width, h / img.height)
61
+ img = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
 
 
62
  return TF.center_crop(img, [h, w])
63
 
64
  # ----------------------------------------------------------------------
65
+ # GENERATE (streaming) --------------------------------------------------
66
+ def generate(first_frame, last_frame, prompt, negative_prompt,
67
+ steps, guidance, num_frames, seed, fps):
68
+ # 1) Preprocess
69
+ yield None, None, "Preprocessing images..."
 
 
 
 
 
70
  first_frame, h, w = aspect_resize(first_frame)
71
  if last_frame.size != first_frame.size:
72
  last_frame = center_crop_resize(last_frame, h, w)
73
 
74
+ # 2) Inference
75
+ yield None, None, f"Running inference ({steps} steps)..."
76
+ if seed == -1:
77
+ seed = torch.seed()
78
+ gen = torch.Generator(device=PIPE.device).manual_seed(seed)
79
  output = PIPE(
80
  image=first_frame,
81
  last_image=last_frame,
 
88
  guidance_scale=guidance,
89
  generator=gen,
90
  )
91
+ frames = output.frames[0]
92
 
93
+ # 3) Export
94
+ yield None, None, "Exporting video..."
95
  video_path = export_to_video(frames, fps=fps)
96
+
97
+ # 4) Done
98
+ yield video_path, seed, "Done! Your browser will download the video."
99
 
100
  # ----------------------------------------------------------------------
101
  # UI --------------------------------------------------------------------
102
+ with gr.Blocks() as demo:
103
+ # inject JS for auto-download
104
+ gr.HTML("""
105
+ <script>
106
+ function downloadVideo() {
107
+ const container = document.getElementById('output_video');
108
+ if (!container) return;
109
+ const vid = container.querySelector('video');
110
+ if (!vid) return;
111
+ const src = vid.currentSrc;
112
+ const a = document.createElement('a');
113
+ a.href = src;
114
+ a.download = 'output.mp4';
115
+ document.body.appendChild(a);
116
+ a.click();
117
+ document.body.removeChild(a);
118
+ }
119
+ </script>
120
+ """)
121
+
122
+ gr.Markdown("## Wan 2.1 FLF2V – Streaming progress + auto-download")
123
 
124
  with gr.Row():
125
  first_img = gr.Image(label="First frame", type="pil")
126
+ last_img = gr.Image(label="Last frame", type="pil")
127
 
128
  prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
129
  negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
 
136
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
137
 
138
  run_btn = gr.Button("Generate")
139
+ status = gr.Textbox(label="Status", interactive=False)
140
+ video = gr.Video(label="Result", elem_id="output_video")
141
  used_seed = gr.Number(label="Seed used", interactive=False)
142
 
143
  run_btn.click(
144
  fn=generate,
145
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
146
+ outputs=[video, used_seed, status],
147
+ _js="downloadVideo"
148
  )
149
 
150
+ demo.queue()
151
  demo.launch()