cbensimon HF Staff commited on
Commit
288103a
·
1 Parent(s): 3df4fd5

Unclutter a bit

Browse files
Files changed (1) hide show
  1. optimization.py +28 -30
optimization.py CHANGED
@@ -10,43 +10,41 @@ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
10
  from zerogpu import aoti_compile
11
 
12
 
 
 
 
 
13
  def optimize_pipeline_(pipeline: FluxPipeline):
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @spaces.GPU(duration=1500)
16
  def compile_transformer():
17
-
18
  pipeline.transformer.fuse_qkv_projections()
19
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
20
-
21
- def _example_tensor(*shape):
22
- return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
23
-
24
- is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
25
- seq_length = 256 if is_timestep_distilled else 512
26
-
27
- transformer_kwargs = {
28
- 'hidden_states': _example_tensor(1, 4096, 64),
29
- 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
30
- 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
31
- 'pooled_projections': _example_tensor(1, 768),
32
- 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
33
- 'txt_ids': _example_tensor(seq_length, 3),
34
- 'img_ids': _example_tensor(4096, 3),
35
- 'joint_attention_kwargs': {},
36
- 'return_dict': False,
37
- }
38
-
39
- inductor_configs = {
40
- 'conv_1x1_as_mm': True,
41
- 'epilogue_fusion': False,
42
- 'coordinate_descent_tuning': True,
43
- 'coordinate_descent_check_all_directions': True,
44
- 'max_autotune': True,
45
- 'triton.cudagraphs': True,
46
- }
47
-
48
  exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
49
-
50
  return aoti_compile(exported, inductor_configs)
51
 
52
  transformer_config = pipeline.transformer.config
 
10
  from zerogpu import aoti_compile
11
 
12
 
13
+ def _example_tensor(*shape):
14
+ return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
15
+
16
+
17
  def optimize_pipeline_(pipeline: FluxPipeline):
18
 
19
+ is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
20
+ seq_length = 256 if is_timestep_distilled else 512
21
+
22
+ transformer_kwargs = {
23
+ 'hidden_states': _example_tensor(1, 4096, 64),
24
+ 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
25
+ 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
26
+ 'pooled_projections': _example_tensor(1, 768),
27
+ 'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
28
+ 'txt_ids': _example_tensor(seq_length, 3),
29
+ 'img_ids': _example_tensor(4096, 3),
30
+ 'joint_attention_kwargs': {},
31
+ 'return_dict': False,
32
+ }
33
+
34
+ inductor_configs = {
35
+ 'conv_1x1_as_mm': True,
36
+ 'epilogue_fusion': False,
37
+ 'coordinate_descent_tuning': True,
38
+ 'coordinate_descent_check_all_directions': True,
39
+ 'max_autotune': True,
40
+ 'triton.cudagraphs': True,
41
+ }
42
+
43
  @spaces.GPU(duration=1500)
44
  def compile_transformer():
 
45
  pipeline.transformer.fuse_qkv_projections()
46
  quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
 
48
  return aoti_compile(exported, inductor_configs)
49
 
50
  transformer_config = pipeline.transformer.config