GeradeHouse commited on
Commit
1c8aab2
·
verified ·
1 Parent(s): 9c8f4c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -45
app.py CHANGED
@@ -1,6 +1,7 @@
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – full streaming progress
 
4
  Author: <your-handle>
5
  """
6
 
@@ -22,17 +23,20 @@ DEFAULT_FRAMES = 81
22
  # ----------------------------------------------------------------------
23
 
24
  def load_pipeline(progress):
25
- """Load model components with progress updates."""
26
- # 0% 5%: start loading
27
- progress(0.0, desc="Initializing model load…")
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
- progress(0.02, desc="Image encoder loaded")
 
 
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
34
  )
35
- progress(0.04, desc="VAE loaded")
 
 
36
  pipe = WanImageToVideoPipeline.from_pretrained(
37
  MODEL_ID,
38
  vae=vae,
@@ -41,81 +45,82 @@ def load_pipeline(progress):
41
  low_cpu_mem_usage=True,
42
  device_map="balanced",
43
  )
44
- progress(0.06, desc="Pipeline assembled")
 
 
45
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
46
  MODEL_ID, subfolder="image_processor", use_fast=True
47
  )
48
- progress(0.08, desc="Processor ready")
49
- return pipe
50
 
51
- # Preload nothing here—model loads in-function to stream progress.
52
 
53
- # ----------------------------------------------------------------------
54
- # UTILS ----------------------------------------------------------------
55
- def aspect_resize(img: Image.Image, max_area=MAX_AREA):
56
  ar = img.height / img.width
57
- mod = PIPE.vae_scale_factor_spatial * PIPE.transformer.config.patch_size[1]
58
  h = round(np.sqrt(max_area * ar)) // mod * mod
59
  w = round(np.sqrt(max_area / ar)) // mod * mod
60
  return img.resize((w, h), Image.LANCZOS), h, w
61
 
62
- def center_crop_resize(img: Image.Image, h, w):
 
63
  ratio = max(w / img.width, h / img.height)
64
- img = img.resize((round(img.width * ratio), round(img.height * ratio)), Image.LANCZOS)
 
 
 
65
  return TF.center_crop(img, [h, w])
66
 
67
- # ----------------------------------------------------------------------
68
- # GENERATE --------------------------------------------------------------
69
  def generate(first_frame, last_frame, prompt, negative_prompt,
70
  steps, guidance, num_frames, seed, fps,
71
- progress=gr.Progress()): # ← inject Gradio progress tracker 3
72
 
73
- # 1) Load the pipeline with streaming
74
  pipe = load_pipeline(progress)
75
 
76
- # 2) Preprocess images
77
- progress(0.10, desc="Preprocessing frames…")
78
- first_frame, h, w = aspect_resize(first_frame)
79
  if last_frame.size != first_frame.size:
80
- last_frame = center_crop_resize(last_frame, h, w)
81
- progress(0.12, desc="Frames ready")
 
82
 
83
- # 3) Inference with per-step updates
84
  if seed == -1:
85
  seed = torch.seed()
86
  gen = torch.Generator(device=pipe.device).manual_seed(seed)
87
 
88
- def _callback(step, timestep, latents):
89
- # Map step to [0.12…0.90] fraction of bar
90
- frac = 0.12 + 0.78 * (step + 1) / steps
91
- progress(frac, desc=f"Inference: step {step+1}/{steps}")
92
 
93
- progress(0.12, desc="Starting inference…")
94
  output = pipe(
95
  image=first_frame,
96
  last_image=last_frame,
97
  prompt=prompt,
98
  negative_prompt=negative_prompt or None,
99
- height=h, width=w,
 
100
  num_frames=num_frames,
101
  num_inference_steps=steps,
102
  guidance_scale=guidance,
103
  generator=gen,
104
- callback_on_step_end=_callback,
105
- callback_steps=1, # call our callback every step 4
106
  )
107
  frames = output.frames[0]
108
 
109
- # 4) Export
110
- progress(0.92, desc="Building video…")
111
  video_path = export_to_video(frames, fps=fps)
112
 
113
- # 5) Complete!
114
- progress(1.0, desc="Done!")
115
  return video_path
116
 
117
- # ----------------------------------------------------------------------
118
- # UI --------------------------------------------------------------------
119
  with gr.Blocks() as demo:
120
  gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
121
 
@@ -123,8 +128,8 @@ with gr.Blocks() as demo:
123
  first_img = gr.Image(label="First frame", type="pil")
124
  last_img = gr.Image(label="Last frame", type="pil")
125
 
126
- prompt = gr.Textbox(label="Prompt")
127
- negative = gr.Textbox(label="Negative prompt (optional)")
128
 
129
  with gr.Accordion("Advanced parameters", open=False):
130
  steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
@@ -135,13 +140,12 @@ with gr.Blocks() as demo:
135
 
136
  video = gr.Video(label="Result (.mp4)")
137
 
138
- # bind generator to button; progress bar overlays on the video output
139
- run_btn = gr.Button("Generate")
140
- run_btn.click(
141
  fn=generate,
142
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
143
  outputs=[video],
144
  )
145
 
146
- demo.queue() # enable progress tracking
147
  demo.launch()
 
1
  #!/usr/bin/env python
2
  """
3
  Gradio demo for Wan2.1 FLF2V – full streaming progress
4
+ No globals: pipeline, resize utils all use the local `pipe`.
5
  Author: <your-handle>
6
  """
7
 
 
23
  # ----------------------------------------------------------------------
24
 
25
  def load_pipeline(progress):
26
+ """Load & shard the pipeline across CPU/GPU with streaming progress."""
27
+ progress(0.00, desc="Init: loading image encoder…")
 
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
+ progress(0.10, desc="Loaded image encoder")
32
+
33
+ progress(0.10, desc="Loading VAE…")
34
  vae = AutoencoderKLWan.from_pretrained(
35
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
36
  )
37
+ progress(0.20, desc="Loaded VAE")
38
+
39
+ progress(0.20, desc="Assembling pipeline…")
40
  pipe = WanImageToVideoPipeline.from_pretrained(
41
  MODEL_ID,
42
  vae=vae,
 
45
  low_cpu_mem_usage=True,
46
  device_map="balanced",
47
  )
48
+ progress(0.30, desc="Pipeline assembled")
49
+
50
+ progress(0.30, desc="Loading fast image processor…")
51
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
52
  MODEL_ID, subfolder="image_processor", use_fast=True
53
  )
54
+ progress(0.40, desc="Processor ready")
 
55
 
56
+ return pipe
57
 
58
+ def aspect_resize(img: Image.Image, pipe, max_area=MAX_AREA):
59
+ """Resize while respecting model patch multiples, using `pipe` for scale."""
 
60
  ar = img.height / img.width
61
+ mod = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
62
  h = round(np.sqrt(max_area * ar)) // mod * mod
63
  w = round(np.sqrt(max_area / ar)) // mod * mod
64
  return img.resize((w, h), Image.LANCZOS), h, w
65
 
66
+ def center_crop_resize(img: Image.Image, pipe, h, w):
67
+ """Center-crop & resize to H×W, using same Lanczos filter."""
68
  ratio = max(w / img.width, h / img.height)
69
+ img = img.resize(
70
+ (round(img.width * ratio), round(img.height * ratio)),
71
+ Image.LANCZOS
72
+ )
73
  return TF.center_crop(img, [h, w])
74
 
 
 
75
  def generate(first_frame, last_frame, prompt, negative_prompt,
76
  steps, guidance, num_frames, seed, fps,
77
+ progress=gr.Progress()): # Gradio progress hook
78
 
79
+ # 1) Load & shard pipeline
80
  pipe = load_pipeline(progress)
81
 
82
+ # 2) Preprocess
83
+ progress(0.45, desc="Preprocessing first frame…")
84
+ first_frame, h, w = aspect_resize(first_frame, pipe)
85
  if last_frame.size != first_frame.size:
86
+ progress(0.50, desc="Preprocessing last frame…")
87
+ last_frame = center_crop_resize(last_frame, pipe, h, w)
88
+ progress(0.55, desc="Frames ready")
89
 
90
+ # 3) Run inference with per-step callbacks
91
  if seed == -1:
92
  seed = torch.seed()
93
  gen = torch.Generator(device=pipe.device).manual_seed(seed)
94
 
95
+ def _cb(step, timestep, latents):
96
+ frac = 0.55 + 0.35 * ((step + 1) / steps)
97
+ progress(frac, desc=f"Inference step {step+1}/{steps}")
 
98
 
99
+ progress(0.55, desc="Starting inference…")
100
  output = pipe(
101
  image=first_frame,
102
  last_image=last_frame,
103
  prompt=prompt,
104
  negative_prompt=negative_prompt or None,
105
+ height=h,
106
+ width=w,
107
  num_frames=num_frames,
108
  num_inference_steps=steps,
109
  guidance_scale=guidance,
110
  generator=gen,
111
+ callback_on_step_end=_cb,
112
+ callback_steps=1,
113
  )
114
  frames = output.frames[0]
115
 
116
+ # 4) Export video
117
+ progress(0.92, desc="Exporting video…")
118
  video_path = export_to_video(frames, fps=fps)
119
 
120
+ # 5) Done
121
+ progress(1.0, desc="Complete!")
122
  return video_path
123
 
 
 
124
  with gr.Blocks() as demo:
125
  gr.Markdown("## Wan2.1 FLF2V – Full Streaming Progress")
126
 
 
128
  first_img = gr.Image(label="First frame", type="pil")
129
  last_img = gr.Image(label="Last frame", type="pil")
130
 
131
+ prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
132
+ negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
133
 
134
  with gr.Accordion("Advanced parameters", open=False):
135
  steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
 
140
 
141
  video = gr.Video(label="Result (.mp4)")
142
 
143
+ btn = gr.Button("Generate")
144
+ btn.click(
 
145
  fn=generate,
146
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
147
  outputs=[video],
148
  )
149
 
150
+ demo.queue() # enable streaming updates
151
  demo.launch()