GeradeHouse commited on
Commit
d8d26ca
·
verified ·
1 Parent(s): fa6cab1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -31
app.py CHANGED
@@ -4,7 +4,7 @@ Gradio demo for Wan2.1-FLF2V – First & Last Frame → Video
4
  """
5
 
6
  import os
7
- # Persist the HF cache between launches
8
  os.environ["HF_HOME"] = "/mnt/data/huggingface"
9
 
10
  import torch
@@ -12,7 +12,7 @@ import numpy as np
12
  import gradio as gr
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
- from transformers import CLIPVisionModel, CLIPProcessor
16
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
17
  from diffusers.utils import export_to_video
18
 
@@ -20,7 +20,7 @@ from diffusers.utils import export_to_video
20
  # CONFIGURATION
21
  # -----------------------------------------------------------------------------
22
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
- DTYPE = torch.float16 # or torch.bfloat16
24
  MAX_AREA = 1280 * 720
25
  DEFAULT_FRAMES = 81
26
 
@@ -28,36 +28,36 @@ DEFAULT_FRAMES = 81
28
  # PIPELINE LOADING (ONCE)
29
  # -----------------------------------------------------------------------------
30
  def load_pipeline():
31
- # 1) CLIP vision encoder in fp32
32
  clip_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
- # 2) VAE in reduced precision
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
- # 3) CLIPProcessor (inherits ProcessorMixin) from a standard CLIP repo
40
- clip_processor = CLIPProcessor.from_pretrained(
41
  "openai/clip-vit-base-patch32", use_fast=True
42
  )
43
- # 4) Build Wan2video pipeline with balanced device_map
44
  pipe = WanImageToVideoPipeline.from_pretrained(
45
  MODEL_ID,
46
  image_encoder=clip_encoder,
47
  vae=vae,
48
- image_processor=clip_processor,
49
  torch_dtype=DTYPE,
50
  device_map="balanced",
51
  )
52
- # 5) Reduce VAE peaks and offload other modules
53
  try:
54
  pipe.vae.enable_slicing()
55
  except AttributeError:
56
  pass
57
- pipe.enable_model_cpu_offload()
58
  return pipe
59
 
60
- PIPE = load_pipeline() # single load
 
61
 
62
  # -----------------------------------------------------------------------------
63
  # IMAGE RESIZE HELPERS
@@ -78,37 +78,37 @@ def center_crop_resize(img: Image.Image, h: int, w: int):
78
  return TF.center_crop(img, [h, w])
79
 
80
  # -----------------------------------------------------------------------------
81
- # GENERATION (STREAMING PROGRESS)
82
  # -----------------------------------------------------------------------------
83
  def generate(
84
- first_frame: Image.Image,
85
- last_frame: Image.Image,
86
- prompt: str,
87
- negative: str,
88
- steps: int,
89
- guidance: float,
90
- num_frames: int,
91
- seed: int,
92
- fps: int,
93
- progress= gr.Progress()
94
  ):
95
- # Seed
96
  if seed == -1:
97
  seed = torch.seed()
98
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
99
 
100
- # Preprocessing
101
  progress(0, steps, desc="Preprocessing images")
102
  f0, h, w = aspect_resize(first_frame)
103
  if last_frame.size != f0.size:
104
  last_frame = center_crop_resize(last_frame, h, w)
105
 
106
- # Callback for each denoising step
107
  def cb(step, timestep, latents):
108
  progress(step, steps, desc=f"Inference step {step}/{steps}")
109
 
110
- # Run pipeline
111
- output = PIPE(
112
  image=f0,
113
  last_image=last_frame,
114
  prompt=prompt,
@@ -122,12 +122,12 @@ def generate(
122
  callback=cb
123
  )
124
 
125
- # Export
126
- video_path = export_to_video(output.frames[0], fps=fps)
127
  return video_path, seed
128
 
129
  # -----------------------------------------------------------------------------
130
- # GRADIO UI
131
  # -----------------------------------------------------------------------------
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")
 
4
  """
5
 
6
  import os
7
+ # Persist HF cache between launches
8
  os.environ["HF_HOME"] = "/mnt/data/huggingface"
9
 
10
  import torch
 
12
  import gradio as gr
13
  from PIL import Image
14
  import torchvision.transforms.functional as TF
15
+ from transformers import CLIPVisionModel, CLIPImageProcessor
16
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
17
  from diffusers.utils import export_to_video
18
 
 
20
  # CONFIGURATION
21
  # -----------------------------------------------------------------------------
22
  MODEL_ID = "Wan-AI/Wan2.1-FLF2V-14B-720P-diffusers"
23
+ DTYPE = torch.float16
24
  MAX_AREA = 1280 * 720
25
  DEFAULT_FRAMES = 81
26
 
 
28
  # PIPELINE LOADING (ONCE)
29
  # -----------------------------------------------------------------------------
30
  def load_pipeline():
31
+ # 1) Vision encoder (fp32)
32
  clip_encoder = CLIPVisionModel.from_pretrained(
33
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
34
  )
35
+ # 2) VAE (reduced precision)
36
  vae = AutoencoderKLWan.from_pretrained(
37
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
38
  )
39
+ # 3) CLIPImageProcessor (exactly the type Wan expects)
40
+ img_processor = CLIPImageProcessor.from_pretrained(
41
  "openai/clip-vit-base-patch32", use_fast=True
42
  )
43
+ # 4) Load the Wan‐to‐Video pipeline, balanced across GPU & CPU
44
  pipe = WanImageToVideoPipeline.from_pretrained(
45
  MODEL_ID,
46
  image_encoder=clip_encoder,
47
  vae=vae,
48
+ image_processor=img_processor,
49
  torch_dtype=DTYPE,
50
  device_map="balanced",
51
  )
52
+ # 5) Slice the VAE to cut VRAM spikes
53
  try:
54
  pipe.vae.enable_slicing()
55
  except AttributeError:
56
  pass
 
57
  return pipe
58
 
59
+ # instantiate once
60
+ PIPE = load_pipeline()
61
 
62
  # -----------------------------------------------------------------------------
63
  # IMAGE RESIZE HELPERS
 
78
  return TF.center_crop(img, [h, w])
79
 
80
  # -----------------------------------------------------------------------------
81
+ # GENERATION (STREAMING)
82
  # -----------------------------------------------------------------------------
83
  def generate(
84
+ first_frame: Image.Image,
85
+ last_frame: Image.Image,
86
+ prompt: str,
87
+ negative: str,
88
+ steps: int,
89
+ guidance: float,
90
+ num_frames: int,
91
+ seed: int,
92
+ fps: int,
93
+ progress= gr.Progress()
94
  ):
95
+ # Seed management
96
  if seed == -1:
97
  seed = torch.seed()
98
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
99
 
100
+ # Preprocessing update
101
  progress(0, steps, desc="Preprocessing images")
102
  f0, h, w = aspect_resize(first_frame)
103
  if last_frame.size != f0.size:
104
  last_frame = center_crop_resize(last_frame, h, w)
105
 
106
+ # Step callback
107
  def cb(step, timestep, latents):
108
  progress(step, steps, desc=f"Inference step {step}/{steps}")
109
 
110
+ # Run the pipeline
111
+ out = PIPE(
112
  image=f0,
113
  last_image=last_frame,
114
  prompt=prompt,
 
122
  callback=cb
123
  )
124
 
125
+ # Export video
126
+ video_path = export_to_video(out.frames[0], fps=fps)
127
  return video_path, seed
128
 
129
  # -----------------------------------------------------------------------------
130
+ # GRADIO APP
131
  # -----------------------------------------------------------------------------
132
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
133
  gr.Markdown("## Wan2.1 FLF2V – First & Last Frame → Video")