cbensimon HF Staff commited on
Commit
e9125ed
·
1 Parent(s): 7e1b70d

Fix default duration + remove timings

Browse files
Files changed (2) hide show
  1. app.py +6 -3
  2. optimization.py +1 -23
app.py CHANGED
@@ -40,7 +40,10 @@ MAX_SEED = np.iinfo(np.int32).max
40
 
41
  FIXED_FPS = 24
42
  MIN_FRAMES_MODEL = 8
43
- MAX_FRAMES_MODEL = 81
 
 
 
44
 
45
  optimize_pipeline_(pipe,
46
  image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
@@ -99,7 +102,7 @@ def generate_video(
99
  input_image,
100
  prompt,
101
  negative_prompt=default_negative_prompt,
102
- duration_seconds = 2,
103
  guidance_scale = 1,
104
  steps = 4,
105
  seed = 42,
@@ -178,7 +181,7 @@ with gr.Blocks() as demo:
178
  with gr.Column():
179
  input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
180
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
181
- duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=round(MAX_FRAMES_MODEL/FIXED_FPS,1), step=0.1, value=round(MAX_FRAMES_MODEL/FIXED_FPS,1), label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
182
 
183
  with gr.Accordion("Advanced Settings", open=False):
184
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
 
40
 
41
  FIXED_FPS = 24
42
  MIN_FRAMES_MODEL = 8
43
+ MAX_FRAMES_MODEL = 81
44
+
45
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
46
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
47
 
48
  optimize_pipeline_(pipe,
49
  image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
 
102
  input_image,
103
  prompt,
104
  negative_prompt=default_negative_prompt,
105
+ duration_seconds = MAX_DURATION,
106
  guidance_scale = 1,
107
  steps = 4,
108
  seed = 42,
 
181
  with gr.Column():
182
  input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
183
  prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
184
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=MAX_DURATION, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
185
 
186
  with gr.Accordion("Advanced Settings", open=False):
187
  negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
optimization.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  """
3
 
4
- from datetime import datetime
5
  from typing import Any
6
  from typing import Callable
7
  from typing import ParamSpec
@@ -39,25 +38,16 @@ INDUCTOR_CONFIGS = {
39
 
40
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
41
 
42
- t0 = datetime.now()
43
-
44
  @spaces.GPU(duration=1500)
45
  def compile_transformer():
46
- nonlocal t0
47
-
48
- print('compile_transformer', -(t0 - (t0 := datetime.now())))
49
 
50
  with capture_component_call(pipeline, 'transformer') as call:
51
  pipeline(*args, **kwargs)
52
-
53
- print('capture_component_call', -(t0 - (t0 := datetime.now())))
54
 
55
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
56
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
57
 
58
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
59
-
60
- print('quantize_', -(t0 - (t0 := datetime.now())))
61
 
62
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
63
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
@@ -74,8 +64,6 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
74
  kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
75
  dynamic_shapes=dynamic_shapes,
76
  )
77
-
78
- print('exported_landscape', -(t0 - (t0 := datetime.now())))
79
 
80
  exported_portrait = torch.export.export(
81
  mod=pipeline.transformer,
@@ -84,23 +72,13 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
84
  dynamic_shapes=dynamic_shapes,
85
  )
86
 
87
- print('exported_portrait', -(t0 - (t0 := datetime.now())))
88
-
89
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
90
- print('compiled_landscape', -(t0 - (t0 := datetime.now())))
91
-
92
  compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
93
- print('compiled_portrait', -(t0 - (t0 := datetime.now())))
94
-
95
- # Avoid weights duplication when serializing back to main process
96
- compiled_portrait.weights = compiled_landscape.weights
97
 
98
  return compiled_landscape, compiled_portrait
99
 
100
  compiled_landscape, compiled_portrait = compile_transformer()
101
- print('compiled', -(t0 - (t0 := datetime.now())))
102
-
103
- compiled_portrait.weights = compiled_landscape.weights
104
 
105
  def combined_transformer(*args, **kwargs):
106
  hidden_states: torch.Tensor = kwargs['hidden_states']
 
1
  """
2
  """
3
 
 
4
  from typing import Any
5
  from typing import Callable
6
  from typing import ParamSpec
 
38
 
39
  def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
40
 
 
 
41
  @spaces.GPU(duration=1500)
42
  def compile_transformer():
 
 
 
43
 
44
  with capture_component_call(pipeline, 'transformer') as call:
45
  pipeline(*args, **kwargs)
 
 
46
 
47
  dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
48
  dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
49
 
50
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
 
 
51
 
52
  hidden_states: torch.Tensor = call.kwargs['hidden_states']
53
  hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
 
64
  kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
65
  dynamic_shapes=dynamic_shapes,
66
  )
 
 
67
 
68
  exported_portrait = torch.export.export(
69
  mod=pipeline.transformer,
 
72
  dynamic_shapes=dynamic_shapes,
73
  )
74
 
 
 
75
  compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
 
 
76
  compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
77
+ compiled_portrait.weights = compiled_landscape.weights # Avoid weights duplication when serializing back to main process
 
 
 
78
 
79
  return compiled_landscape, compiled_portrait
80
 
81
  compiled_landscape, compiled_portrait = compile_transformer()
 
 
 
82
 
83
  def combined_transformer(*args, **kwargs):
84
  hidden_states: torch.Tensor = kwargs['hidden_states']