GeradeHouse commited on
Commit
f40229f
·
verified ·
1 Parent(s): c83344b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -26
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
- Author: <your-handle>
5
  """
6
 
7
  import numpy as np
@@ -9,28 +9,29 @@ import torch
9
  import gradio as gr
10
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
11
  from diffusers.utils import export_to_video
12
- from transformers import CLIPVisionModel
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
 
16
  # ---------------------------------------------------------------------
17
  # CONFIG ----------------------------------------------------------------
18
- MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # switch to 1.3B if needed
19
- DTYPE = torch.float16 # or torch.bfloat16 on AMP-friendly GPUs
20
- MAX_AREA = 1280 * 720 # keep 720p
21
- DEFAULT_FRAMES = 81 # 5s at 16 fps
22
  # ----------------------------------------------------------------------
23
 
24
  def load_pipeline():
25
- """Lazy‐load the huge model once per process."""
26
- # image encoder in full precision
27
  image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
29
  )
30
- # VAE in reduced precision
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
 
34
  pipe = WanImageToVideoPipeline.from_pretrained(
35
  MODEL_ID,
36
  vae=vae,
@@ -38,11 +39,14 @@ def load_pipeline():
38
  torch_dtype=DTYPE,
39
  )
40
 
41
- # memory helpers for 24 GB cards / HF T4-medium
42
- pipe.enable_model_cpu_offload() # page UNet blocks off GPU
43
- pipe.vae.enable_slicing() # reduce VAE peak RAM
44
- # Optional: if you have xformers installed
45
- # pipe.enable_xformers_memory_efficient_attention()
 
 
 
46
 
47
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
48
 
@@ -51,7 +55,7 @@ PIPE = load_pipeline()
51
  # ----------------------------------------------------------------------
52
  # UTILS ----------------------------------------------------------------
53
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
54
- """Resize while respecting model patch size (multiple of transformer patch)."""
55
  ar = img.height / img.width
56
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
57
  h = round(np.sqrt(max_area * ar)) // mod * mod
@@ -59,10 +63,11 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
59
  return img.resize((w, h), Image.LANCZOS), h, w
60
 
61
  def center_crop_resize(img: Image.Image, h, w):
62
- """Center‐crop & resize to target H×W."""
63
  ratio = max(w / img.width, h / img.height)
64
  img = img.resize(
65
- (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
 
66
  )
67
  return TF.center_crop(img, [h, w])
68
 
@@ -71,11 +76,12 @@ def center_crop_resize(img: Image.Image, h, w):
71
  def generate(first_frame, last_frame, prompt, negative_prompt, steps,
72
  guidance, num_frames, seed, fps):
73
 
 
74
  if seed == -1:
75
  seed = torch.seed()
76
- generator = torch.Generator(device=PIPE.device).manual_seed(seed)
77
 
78
- # preprocess inputs
79
  first_frame, h, w = aspect_resize(first_frame)
80
  if last_frame.size != first_frame.size:
81
  last_frame = center_crop_resize(last_frame, h, w)
@@ -91,11 +97,11 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
91
  num_frames=num_frames,
92
  num_inference_steps=steps,
93
  guidance_scale=guidance,
94
- generator=generator,
95
  )
96
- frames = output.frames[0] # list[PIL.Image]
97
 
98
- # export to .mp4
99
  video_path = export_to_video(frames, fps=fps)
100
  return video_path, seed
101
 
@@ -108,8 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
108
  first_img = gr.Image(label="First frame", type="pil")
109
  last_img = gr.Image(label="Last frame", type="pil")
110
 
111
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
112
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
113
 
114
  with gr.Accordion("Advanced parameters", open=False):
115
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
@@ -118,8 +124,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
118
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
119
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
120
 
121
- run_btn = gr.Button("Generate")
122
- video = gr.Video(label="Result (.mp4)")
123
  used_seed = gr.Number(label="Seed used", interactive=False)
124
 
125
  run_btn.click(
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
+ Author: GeradeHouse
5
  """
6
 
7
  import numpy as np
 
9
  import gradio as gr
10
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
11
  from diffusers.utils import export_to_video
12
+ from transformers import CLIPVisionModel, CLIPImageProcessor
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
 
16
  # ---------------------------------------------------------------------
17
  # CONFIG ----------------------------------------------------------------
18
+ MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # or switch to 1.3B
19
+ DTYPE = torch.float16 # or bfloat16
20
+ MAX_AREA = 1280 * 720 # ≤720p
21
+ DEFAULT_FRAMES = 81 # ~5s @16 fps
22
  # ----------------------------------------------------------------------
23
 
24
  def load_pipeline():
25
+ """Lazy‐load & configure the pipeline once per process."""
26
+ # 1) load the CLIP image encoder (full-precision)
27
  image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
29
  )
30
+ # 2) load the VAE (half-precision)
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
34
+ # 3) load the video pipeline
35
  pipe = WanImageToVideoPipeline.from_pretrained(
36
  MODEL_ID,
37
  vae=vae,
 
39
  torch_dtype=DTYPE,
40
  )
41
 
42
+ # 4) override the processor with the fast Rust implementation
43
+ pipe.image_processor = CLIPImageProcessor.from_pretrained(
44
+ MODEL_ID, subfolder="image_processor", use_fast=True
45
+ )
46
+
47
+ # 5) memory helpers (offload UNet to CPU as needed)
48
+ # pipe.enable_model_cpu_offload()
49
+ # (Removed pipe.vae.enable_slicing() — not supported on AutoencoderKLWan)
50
 
51
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
52
 
 
55
  # ----------------------------------------------------------------------
56
  # UTILS ----------------------------------------------------------------
57
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
58
+ """Resize while keeping aspect & respecting patch multiples."""
59
  ar = img.height / img.width
60
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
61
  h = round(np.sqrt(max_area * ar)) // mod * mod
 
63
  return img.resize((w, h), Image.LANCZOS), h, w
64
 
65
  def center_crop_resize(img: Image.Image, h, w):
66
+ """Center‐crop & resize to H×W."""
67
  ratio = max(w / img.width, h / img.height)
68
  img = img.resize(
69
+ (round(img.width * ratio), round(img.height * ratio)),
70
+ Image.LANCZOS
71
  )
72
  return TF.center_crop(img, [h, w])
73
 
 
76
  def generate(first_frame, last_frame, prompt, negative_prompt, steps,
77
  guidance, num_frames, seed, fps):
78
 
79
+ # seed handling
80
  if seed == -1:
81
  seed = torch.seed()
82
+ gen = torch.Generator(device=PIPE.device).manual_seed(seed)
83
 
84
+ # preprocess frames
85
  first_frame, h, w = aspect_resize(first_frame)
86
  if last_frame.size != first_frame.size:
87
  last_frame = center_crop_resize(last_frame, h, w)
 
97
  num_frames=num_frames,
98
  num_inference_steps=steps,
99
  guidance_scale=guidance,
100
+ generator=gen,
101
  )
102
+ frames = output.frames[0] # list of PIL Image
103
 
104
+ # export to MP4
105
  video_path = export_to_video(frames, fps=fps)
106
  return video_path, seed
107
 
 
114
  first_img = gr.Image(label="First frame", type="pil")
115
  last_img = gr.Image(label="Last frame", type="pil")
116
 
117
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
118
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
119
 
120
  with gr.Accordion("Advanced parameters", open=False):
121
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
 
124
  fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
125
  seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
126
 
127
+ run_btn = gr.Button("Generate")
128
+ video = gr.Video(label="Result (.mp4)")
129
  used_seed = gr.Number(label="Seed used", interactive=False)
130
 
131
  run_btn.click(