GeradeHouse commited on
Commit
9c8f4c5
·
verified ·
1 Parent(s): 5516eb1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -57
app.py CHANGED
@@ -1,8 +1,6 @@
1
  #!/usr/bin/env python
2
  """
3
- Gradio demo for Wan2.1 First-Last-Frame-to-Video (FLF2V)
4
- – shows streaming status updates
5
- – auto-downloads the generated video
6
  Author: <your-handle>
7
  """
8
 
@@ -23,29 +21,34 @@ MAX_AREA = 1280 * 720
23
  DEFAULT_FRAMES = 81
24
  # ----------------------------------------------------------------------
25
 
26
- def load_pipeline():
27
- """Load & shard the pipeline across CPU/GPU with Accelerate."""
 
 
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
 
31
  vae = AutoencoderKLWan.from_pretrained(
32
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
33
  )
 
34
  pipe = WanImageToVideoPipeline.from_pretrained(
35
  MODEL_ID,
36
  vae=vae,
37
  image_encoder=image_encoder,
38
  torch_dtype=DTYPE,
39
- low_cpu_mem_usage=True, # lazy‐load to CPU RAM
40
- device_map="balanced", # shard across CPU/GPU
41
  )
42
- # switch to the fast Rust processor
43
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
44
  MODEL_ID, subfolder="image_processor", use_fast=True
45
  )
 
46
  return pipe
47
 
48
- PIPE = load_pipeline()
49
 
50
  # ----------------------------------------------------------------------
51
  # UTILS ----------------------------------------------------------------
@@ -62,90 +65,83 @@ def center_crop_resize(img: Image.Image, h, w):
62
  return TF.center_crop(img, [h, w])
63
 
64
  # ----------------------------------------------------------------------
65
- # GENERATE (streaming) --------------------------------------------------
66
  def generate(first_frame, last_frame, prompt, negative_prompt,
67
- steps, guidance, num_frames, seed, fps):
68
- # 1) Preprocess
69
- yield None, None, "Preprocessing images..."
 
 
 
 
 
70
  first_frame, h, w = aspect_resize(first_frame)
71
  if last_frame.size != first_frame.size:
72
  last_frame = center_crop_resize(last_frame, h, w)
 
73
 
74
- # 2) Inference
75
- yield None, None, f"Running inference ({steps} steps)..."
76
  if seed == -1:
77
  seed = torch.seed()
78
- gen = torch.Generator(device=PIPE.device).manual_seed(seed)
79
- output = PIPE(
 
 
 
 
 
 
 
80
  image=first_frame,
81
  last_image=last_frame,
82
  prompt=prompt,
83
  negative_prompt=negative_prompt or None,
84
- height=h,
85
- width=w,
86
  num_frames=num_frames,
87
  num_inference_steps=steps,
88
  guidance_scale=guidance,
89
  generator=gen,
 
 
90
  )
91
  frames = output.frames[0]
92
 
93
- # 3) Export
94
- yield None, None, "Exporting video..."
95
  video_path = export_to_video(frames, fps=fps)
96
 
97
- # 4) Done
98
- yield video_path, seed, "Done! Your browser will download the video."
 
99
 
100
  # ----------------------------------------------------------------------
101
  # UI --------------------------------------------------------------------
102
  with gr.Blocks() as demo:
103
- # inject JS for auto-download
104
- gr.HTML("""
105
- <script>
106
- function downloadVideo() {
107
- const container = document.getElementById('output_video');
108
- if (!container) return;
109
- const vid = container.querySelector('video');
110
- if (!vid) return;
111
- const src = vid.currentSrc;
112
- const a = document.createElement('a');
113
- a.href = src;
114
- a.download = 'output.mp4';
115
- document.body.appendChild(a);
116
- a.click();
117
- document.body.removeChild(a);
118
- }
119
- </script>
120
- """)
121
-
122
- gr.Markdown("## Wan 2.1 FLF2V – Streaming progress + auto-download")
123
 
124
  with gr.Row():
125
  first_img = gr.Image(label="First frame", type="pil")
126
  last_img = gr.Image(label="Last frame", type="pil")
127
 
128
- prompt = gr.Textbox(label="Prompt", placeholder="A blue bird takes off…")
129
- negative = gr.Textbox(label="Negative prompt (optional)", placeholder="ugly, blurry")
130
 
131
  with gr.Accordion("Advanced parameters", open=False):
132
- steps = gr.Slider(10, 50, value=30, step=1, label="Sampling steps")
133
- guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance scale")
134
- num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, step=1, label="Frames")
135
- fps = gr.Slider(4, 30, value=16, step=1, label="FPS (export)")
136
- seed = gr.Number(value=-1, precision=0, label="Seed (-1 = random)")
137
 
138
- run_btn = gr.Button("Generate")
139
- status = gr.Textbox(label="Status", interactive=False)
140
- video = gr.Video(label="Result", elem_id="output_video")
141
- used_seed = gr.Number(label="Seed used", interactive=False)
142
 
 
 
143
  run_btn.click(
144
  fn=generate,
145
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
146
- outputs=[video, used_seed, status],
147
- _js="downloadVideo"
148
  )
149
 
150
- demo.queue()
151
  demo.launch()
 
1
  #!/usr/bin/env python
2
  """
3
+ Gradio demo for Wan2.1 FLF2V – full streaming progress
 
 
4
  Author: <your-handle>
5
  """
6
 
 
21
  DEFAULT_FRAMES = 81
22
  # ----------------------------------------------------------------------
23
 
24
+ def load_pipeline(progress):
25
+ """Load model components with progress updates."""
26
+ # 0% → 5%: start loading
27
+ progress(0.0, desc="Initializing model load…")
28
  image_encoder = CLIPVisionModel.from_pretrained(
29
  MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float32
30
  )
31
+ progress(0.02, desc="Image encoder loaded")
32
  vae = AutoencoderKLWan.from_pretrained(
33
  MODEL_ID, subfolder="vae", torch_dtype=DTYPE
34
  )
35
+ progress(0.04, desc="VAE loaded")
36
  pipe = WanImageToVideoPipeline.from_pretrained(
37
  MODEL_ID,
38
  vae=vae,
39
  image_encoder=image_encoder,
40
  torch_dtype=DTYPE,
41
+ low_cpu_mem_usage=True,
42
+ device_map="balanced",
43
  )
44
+ progress(0.06, desc="Pipeline assembled")
45
  pipe.image_processor = CLIPImageProcessor.from_pretrained(
46
  MODEL_ID, subfolder="image_processor", use_fast=True
47
  )
48
+ progress(0.08, desc="Processor ready")
49
  return pipe
50
 
51
+ # Preload nothing here—model loads in-function to stream progress.
52
 
53
  # ----------------------------------------------------------------------
54
  # UTILS ----------------------------------------------------------------
 
65
  return TF.center_crop(img, [h, w])
66
 
67
  # ----------------------------------------------------------------------
68
+ # GENERATE --------------------------------------------------------------
69
  def generate(first_frame, last_frame, prompt, negative_prompt,
70
+ steps, guidance, num_frames, seed, fps,
71
+ progress=gr.Progress()): # inject Gradio progress tracker 3
72
+
73
+ # 1) Load the pipeline with streaming
74
+ pipe = load_pipeline(progress)
75
+
76
+ # 2) Preprocess images
77
+ progress(0.10, desc="Preprocessing frames…")
78
  first_frame, h, w = aspect_resize(first_frame)
79
  if last_frame.size != first_frame.size:
80
  last_frame = center_crop_resize(last_frame, h, w)
81
+ progress(0.12, desc="Frames ready")
82
 
83
+ # 3) Inference with per-step updates
 
84
  if seed == -1:
85
  seed = torch.seed()
86
+ gen = torch.Generator(device=pipe.device).manual_seed(seed)
87
+
88
+ def _callback(step, timestep, latents):
89
+ # Map step to [0.12…0.90] fraction of bar
90
+ frac = 0.12 + 0.78 * (step + 1) / steps
91
+ progress(frac, desc=f"Inference: step {step+1}/{steps}")
92
+
93
+ progress(0.12, desc="Starting inference…")
94
+ output = pipe(
95
  image=first_frame,
96
  last_image=last_frame,
97
  prompt=prompt,
98
  negative_prompt=negative_prompt or None,
99
+ height=h, width=w,
 
100
  num_frames=num_frames,
101
  num_inference_steps=steps,
102
  guidance_scale=guidance,
103
  generator=gen,
104
+ callback_on_step_end=_callback,
105
+ callback_steps=1, # call our callback every step 4
106
  )
107
  frames = output.frames[0]
108
 
109
+ # 4) Export
110
+ progress(0.92, desc="Building video")
111
  video_path = export_to_video(frames, fps=fps)
112
 
113
+ # 5) Complete!
114
+ progress(1.0, desc="Done!")
115
+ return video_path
116
 
117
  # ----------------------------------------------------------------------
118
  # UI --------------------------------------------------------------------
119
  with gr.Blocks() as demo:
120
+ gr.Markdown("## Wan2.1 FLF2V Full Streaming Progress")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  with gr.Row():
123
  first_img = gr.Image(label="First frame", type="pil")
124
  last_img = gr.Image(label="Last frame", type="pil")
125
 
126
+ prompt = gr.Textbox(label="Prompt")
127
+ negative = gr.Textbox(label="Negative prompt (optional)")
128
 
129
  with gr.Accordion("Advanced parameters", open=False):
130
+ steps = gr.Slider(10, 50, value=30, step=1, label="Steps")
131
+ guidance = gr.Slider(0.0, 10.0, value=5.5, step=0.1, label="Guidance")
132
+ num_frames = gr.Slider(16, 129, value=DEFAULT_FRAMES, label="Frames")
133
+ fps = gr.Slider(4, 30, value=16, label="FPS")
134
+ seed = gr.Number(value=-1, precision=0, label="Seed")
135
 
136
+ video = gr.Video(label="Result (.mp4)")
 
 
 
137
 
138
+ # bind generator to button; progress bar overlays on the video output
139
+ run_btn = gr.Button("Generate")
140
  run_btn.click(
141
  fn=generate,
142
  inputs=[first_img, last_img, prompt, negative, steps, guidance, num_frames, seed, fps],
143
+ outputs=[video],
 
144
  )
145
 
146
+ demo.queue() # enable progress tracking
147
  demo.launch()