GeradeHouse commited on
Commit
c83344b
·
verified ·
1 Parent(s): 29a7230

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -15
app.py CHANGED
@@ -4,8 +4,6 @@ Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
  Author: <your-handle>
5
  """
6
 
7
- import os
8
- import tempfile
9
  import numpy as np
10
  import torch
11
  import gradio as gr
@@ -19,12 +17,12 @@ import torchvision.transforms.functional as TF
19
  # CONFIG ----------------------------------------------------------------
20
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # switch to 1.3B if needed
21
  DTYPE = torch.float16 # or torch.bfloat16 on AMP-friendly GPUs
22
- MAX_AREA = 1280 * 720 # keep ≤ 720 p
23
- DEFAULT_FRAMES = 81 # ≈ 5 s at 16 fps
24
  # ----------------------------------------------------------------------
25
 
26
  def load_pipeline():
27
- """Lazy-load the huge model once per process."""
28
  # image encoder in full precision
29
  image_encoder = CLIPVisionModel.from_pretrained(
30
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
@@ -41,8 +39,8 @@ def load_pipeline():
41
  )
42
 
43
  # memory helpers for ≤ 24 GB cards / HF T4-medium
44
- pipe.enable_model_cpu_offload() # paged UNet blocks
45
- pipe.vae.enable_slicing() # reduce VAE peak RAM
46
  # Optional: if you have xformers installed
47
  # pipe.enable_xformers_memory_efficient_attention()
48
 
@@ -53,7 +51,7 @@ PIPE = load_pipeline()
53
  # ----------------------------------------------------------------------
54
  # UTILS ----------------------------------------------------------------
55
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
56
- """Resize while respecting model patch size (multiple of 8*transformer patch)."""
57
  ar = img.height / img.width
58
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
59
  h = round(np.sqrt(max_area * ar)) // mod * mod
@@ -61,6 +59,7 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
61
  return img.resize((w, h), Image.LANCZOS), h, w
62
 
63
  def center_crop_resize(img: Image.Image, h, w):
 
64
  ratio = max(w / img.width, h / img.height)
65
  img = img.resize(
66
  (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
@@ -76,13 +75,13 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
76
  seed = torch.seed()
77
  generator = torch.Generator(device=PIPE.device).manual_seed(seed)
78
 
79
- # preprocess
80
  first_frame, h, w = aspect_resize(first_frame)
81
  if last_frame.size != first_frame.size:
82
  last_frame = center_crop_resize(last_frame, h, w)
83
 
84
- # run pipeline
85
- result = PIPE(
86
  image=first_frame,
87
  last_image=last_frame,
88
  prompt=prompt,
@@ -94,9 +93,9 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
94
  guidance_scale=guidance,
95
  generator=generator,
96
  )
97
- frames = result.frames[0] # list of PIL images
98
 
99
- # export
100
  video_path = export_to_video(frames, fps=fps)
101
  return video_path, seed
102
 
@@ -109,8 +108,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
  first_img = gr.Image(label="First frame", type="pil")
110
  last_img = gr.Image(label="Last frame", type="pil")
111
 
112
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
113
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
 
114
  with gr.Accordion("Advanced parameters", open=False):
115
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
116
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
 
4
  Author: <your-handle>
5
  """
6
 
 
 
7
  import numpy as np
8
  import torch
9
  import gradio as gr
 
17
  # CONFIG ----------------------------------------------------------------
18
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers" # switch to 1.3B if needed
19
  DTYPE = torch.float16 # or torch.bfloat16 on AMP-friendly GPUs
20
+ MAX_AREA = 1280 * 720 # keep ≤ 720p
21
+ DEFAULT_FRAMES = 81 # ≈ 5s at 16 fps
22
  # ----------------------------------------------------------------------
23
 
24
  def load_pipeline():
25
+ """Lazyload the huge model once per process."""
26
  # image encoder in full precision
27
  image_encoder = CLIPVisionModel.from_pretrained(
28
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
 
39
  )
40
 
41
  # memory helpers for ≤ 24 GB cards / HF T4-medium
42
+ pipe.enable_model_cpu_offload() # page UNet blocks off GPU
43
+ pipe.vae.enable_slicing() # reduce VAE peak RAM
44
  # Optional: if you have xformers installed
45
  # pipe.enable_xformers_memory_efficient_attention()
46
 
 
51
  # ----------------------------------------------------------------------
52
  # UTILS ----------------------------------------------------------------
53
  def aspect_resize(img: Image.Image, max_area=MAX_AREA):
54
+ """Resize while respecting model patch size (multiple of transformer patch)."""
55
  ar = img.height / img.width
56
  mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
57
  h = round(np.sqrt(max_area * ar)) // mod * mod
 
59
  return img.resize((w, h), Image.LANCZOS), h, w
60
 
61
  def center_crop_resize(img: Image.Image, h, w):
62
+ """Center‐crop & resize to target H×W."""
63
  ratio = max(w / img.width, h / img.height)
64
  img = img.resize(
65
  (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
 
75
  seed = torch.seed()
76
  generator = torch.Generator(device=PIPE.device).manual_seed(seed)
77
 
78
+ # preprocess inputs
79
  first_frame, h, w = aspect_resize(first_frame)
80
  if last_frame.size != first_frame.size:
81
  last_frame = center_crop_resize(last_frame, h, w)
82
 
83
+ # run the pipeline
84
+ output = PIPE(
85
  image=first_frame,
86
  last_image=last_frame,
87
  prompt=prompt,
 
93
  guidance_scale=guidance,
94
  generator=generator,
95
  )
96
+ frames = output.frames[0] # list[PIL.Image]
97
 
98
+ # export to .mp4
99
  video_path = export_to_video(frames, fps=fps)
100
  return video_path, seed
101
 
 
108
  first_img = gr.Image(label="First frame", type="pil")
109
  last_img = gr.Image(label="Last frame", type="pil")
110
 
111
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
112
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
113
+
114
  with gr.Accordion("Advanced parameters", open=False):
115
  steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
116
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")