GeradeHouse commited on
Commit
29a7230
·
verified ·
1 Parent(s): 5a1047f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -10
app.py CHANGED
@@ -4,7 +4,11 @@ Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
  Author: <your-handle>
5
  """
6
 
7
- import os, tempfile, numpy as np, torch, gradio as gr
 
 
 
 
8
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
9
  from diffusers.utils import export_to_video
10
  from transformers import CLIPVisionModel
@@ -21,9 +25,11 @@ DEFAULT_FRAMES = 81 # ≈ 5 s at 16 fps
21
 
22
  def load_pipeline():
23
  """Lazy-load the huge model once per process."""
 
24
  image_encoder = CLIPVisionModel.from_pretrained(
25
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
26
  )
 
27
  vae = AutoencoderKLWan.from_pretrained(
28
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
29
  )
@@ -35,9 +41,11 @@ def load_pipeline():
35
  )
36
 
37
  # memory helpers for ≤ 24 GB cards / HF T4-medium
38
- pipe.enable_model_cpu_offload() # paged UNet blocks
39
- pipe.enable_vae_slicing() # reduces VAE RAM spikes
40
- # Optional (needs xformers): pipe.enable_xformers_memory_efficient_attention()
 
 
41
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  PIPE = load_pipeline()
@@ -54,9 +62,10 @@ def aspect_resize(img: Image.Image, max_area=MAX_AREA):
54
 
55
  def center_crop_resize(img: Image.Image, h, w):
56
  ratio = max(w / img.width, h / img.height)
57
- img = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
58
- img = TF.center_crop(img, [h, w])
59
- return img
 
60
 
61
  # ----------------------------------------------------------------------
62
  # GENERATE --------------------------------------------------------------
@@ -67,11 +76,13 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
67
  seed = torch.seed()
68
  generator = torch.Generator(device=PIPE.device).manual_seed(seed)
69
 
 
70
  first_frame, h, w = aspect_resize(first_frame)
71
  if last_frame.size != first_frame.size:
72
  last_frame = center_crop_resize(last_frame, h, w)
73
 
74
- out = PIPE(
 
75
  image=first_frame,
76
  last_image=last_frame,
77
  prompt=prompt,
@@ -82,9 +93,11 @@ def generate(first_frame, last_frame, prompt, negative_prompt, steps,
82
  num_inference_steps=steps,
83
  guidance_scale=guidance,
84
  generator=generator,
85
- ).frames[0] # list[pillow]
 
86
 
87
- video_path = export_to_video(out, fps=fps)
 
88
  return video_path, seed
89
 
90
  # ----------------------------------------------------------------------
 
4
  Author: <your-handle>
5
  """
6
 
7
+ import os
8
+ import tempfile
9
+ import numpy as np
10
+ import torch
11
+ import gradio as gr
12
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
13
  from diffusers.utils import export_to_video
14
  from transformers import CLIPVisionModel
 
25
 
26
  def load_pipeline():
27
  """Lazy-load the huge model once per process."""
28
+ # image encoder in full precision
29
  image_encoder = CLIPVisionModel.from_pretrained(
30
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
31
  )
32
+ # VAE in reduced precision
33
  vae = AutoencoderKLWan.from_pretrained(
34
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
35
  )
 
41
  )
42
 
43
  # memory helpers for ≤ 24 GB cards / HF T4-medium
44
+ pipe.enable_model_cpu_offload() # paged UNet blocks
45
+ pipe.vae.enable_slicing() # reduce VAE peak RAM
46
+ # Optional: if you have xformers installed
47
+ # pipe.enable_xformers_memory_efficient_attention()
48
+
49
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
50
 
51
  PIPE = load_pipeline()
 
62
 
63
  def center_crop_resize(img: Image.Image, h, w):
64
  ratio = max(w / img.width, h / img.height)
65
+ img = img.resize(
66
+ (round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS
67
+ )
68
+ return TF.center_crop(img, [h, w])
69
 
70
  # ----------------------------------------------------------------------
71
  # GENERATE --------------------------------------------------------------
 
76
  seed = torch.seed()
77
  generator = torch.Generator(device=PIPE.device).manual_seed(seed)
78
 
79
+ # preprocess
80
  first_frame, h, w = aspect_resize(first_frame)
81
  if last_frame.size != first_frame.size:
82
  last_frame = center_crop_resize(last_frame, h, w)
83
 
84
+ # run pipeline
85
+ result = PIPE(
86
  image=first_frame,
87
  last_image=last_frame,
88
  prompt=prompt,
 
93
  num_inference_steps=steps,
94
  guidance_scale=guidance,
95
  generator=generator,
96
+ )
97
+ frames = result.frames[0] # list of PIL images
98
 
99
+ # export
100
+ video_path = export_to_video(frames, fps=fps)
101
  return video_path, seed
102
 
103
  # ----------------------------------------------------------------------