GeradeHouse commited on
Commit
f956532
·
verified ·
1 Parent(s): f6d3581

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -1,7 +1,7 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
- Streams all HF-Hub & Diffusers tqdm bars, caches the model,
5
  and provides a direct download link for the MP4.
6
  """
7
 
@@ -10,7 +10,7 @@ import numpy as np
10
  import torch
11
  import gradio as gr
12
  from PIL import Image
13
- from transformers import CLIPVisionModel, CLIPImageProcessor
14
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
15
  from diffusers.utils import export_to_video
16
  import torchvision.transforms.functional as TF
@@ -34,15 +34,15 @@ def load_pipeline():
34
  vision = CLIPVisionModel.from_pretrained(
35
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
36
  )
37
- # 2) fast processor
38
- processor = CLIPImageProcessor.from_pretrained(
39
  MODEL_ID, subfolder="image_processor", use_fast=True
40
  )
41
  # 3) VAE (half precision)
42
  vae = AutoencoderKLWan.from_pretrained(
43
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
44
  )
45
- # 4) pipeline assembly
46
  pipe = WanImageToVideoPipeline.from_pretrained(
47
  MODEL_ID,
48
  vae=vae,
@@ -52,6 +52,7 @@ def load_pipeline():
52
  )
53
  # 5) CPU offload for large models
54
  pipe.enable_model_cpu_offload()
 
55
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
56
 
57
  # -----------------------------------------------------------------------------
@@ -70,7 +71,7 @@ def center_crop_resize(img: Image.Image, h: int, w: int):
70
  return TF.center_crop(img2, [h, w])
71
 
72
  # -----------------------------------------------------------------------------
73
- # GENERATION (stream all tqdm → Gradio)
74
  # -----------------------------------------------------------------------------
75
  def generate(
76
  first_frame: Image.Image,
@@ -85,23 +86,23 @@ def generate(
85
  progress=gr.Progress(track_tqdm=True),
86
  ):
87
  global PIPE
88
- # lazy load
89
  if PIPE is None:
90
  progress(0, desc="Loading model…")
91
  PIPE = load_pipeline()
92
 
93
- # seed
94
  if seed == -1:
95
  seed = torch.seed()
96
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
97
 
98
  # preprocess
99
- progress(0, desc="Preprocessing…")
100
  frame1, h, w = aspect_resize(first_frame)
101
  if last_frame.size != frame1.size:
102
  last_frame = center_crop_resize(last_frame, h, w)
103
 
104
- # inference (all tqdm bars appear in progress)
105
  result = PIPE(
106
  image=frame1,
107
  last_image=last_frame,
@@ -116,7 +117,7 @@ def generate(
116
  )
117
  frames = result.frames[0]
118
 
119
- # export
120
  progress(1.0, desc="Exporting video…")
121
  out_path = export_to_video(frames, fps=fps)
122
  return out_path, seed
@@ -135,11 +136,11 @@ with gr.Blocks() as demo:
135
  negative = gr.Textbox(label="Negative prompt (optional)")
136
 
137
  with gr.Accordion("Advanced parameters", open=False):
138
- steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
139
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1,label="Guidance")
140
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1,label="Frames")
141
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
142
- seed = gr.Number(value=-1, precision=0, label="Seed")
143
 
144
  run_btn = gr.Button("Generate")
145
  download = gr.File(label="Download video (.mp4)")
@@ -152,5 +153,5 @@ with gr.Blocks() as demo:
152
  concurrency_limit=1
153
  )
154
 
155
- # enable progress streaming
156
  demo.queue().launch()
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – First & Last Frame → Video
4
+ Streams all HF-Hub & Diffusers tqdm bars, caches the model,
5
  and provides a direct download link for the MP4.
6
  """
7
 
 
10
  import torch
11
  import gradio as gr
12
  from PIL import Image
13
+ from transformers import CLIPVisionModel, CLIPProcessor
14
  from diffusers import WanImageToVideoPipeline, AutoencoderKLWan
15
  from diffusers.utils import export_to_video
16
  import torchvision.transforms.functional as TF
 
34
  vision = CLIPVisionModel.from_pretrained(
35
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
36
  )
37
+ # 2) unified CLIP processor (fast Rust-backed+tokenizer stub)
38
+ processor = CLIPProcessor.from_pretrained(
39
  MODEL_ID, subfolder="image_processor", use_fast=True
40
  )
41
  # 3) VAE (half precision)
42
  vae = AutoencoderKLWan.from_pretrained(
43
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
44
  )
45
+ # 4) assemble pipeline
46
  pipe = WanImageToVideoPipeline.from_pretrained(
47
  MODEL_ID,
48
  vae=vae,
 
52
  )
53
  # 5) CPU offload for large models
54
  pipe.enable_model_cpu_offload()
55
+ # return on correct device
56
  return pipe.to("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
  # -----------------------------------------------------------------------------
 
71
  return TF.center_crop(img2, [h, w])
72
 
73
  # -----------------------------------------------------------------------------
74
+ # GENERATION (streams all tqdm → Gradio)
75
  # -----------------------------------------------------------------------------
76
  def generate(
77
  first_frame: Image.Image,
 
86
  progress=gr.Progress(track_tqdm=True),
87
  ):
88
  global PIPE
89
+ # lazy load once
90
  if PIPE is None:
91
  progress(0, desc="Loading model…")
92
  PIPE = load_pipeline()
93
 
94
+ # ensure reproducibility
95
  if seed == -1:
96
  seed = torch.seed()
97
  gen = torch.Generator(device=PIPE.device).manual_seed(seed)
98
 
99
  # preprocess
100
+ progress(0, desc="Preprocessing frames…")
101
  frame1, h, w = aspect_resize(first_frame)
102
  if last_frame.size != frame1.size:
103
  last_frame = center_crop_resize(last_frame, h, w)
104
 
105
+ # inference (all internal tqdm bars streamed)
106
  result = PIPE(
107
  image=frame1,
108
  last_image=last_frame,
 
117
  )
118
  frames = result.frames[0]
119
 
120
+ # export to MP4
121
  progress(1.0, desc="Exporting video…")
122
  out_path = export_to_video(frames, fps=fps)
123
  return out_path, seed
 
136
  negative = gr.Textbox(label="Negative prompt (optional)")
137
 
138
  with gr.Accordion("Advanced parameters", open=False):
139
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
140
  guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1,label="Guidance")
141
  num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1,label="Frames")
142
+ fps = gr.Slider(4, 30, value=16, step=1, label="FPS")
143
+ seed = gr.Number(value=-1, precision=0, label="Seed")
144
 
145
  run_btn = gr.Button("Generate")
146
  download = gr.File(label="Download video (.mp4)")
 
153
  concurrency_limit=1
154
  )
155
 
156
+ # enable queue + tqdm streaming
157
  demo.queue().launch()