cbensimon HF Staff commited on
Commit
1d06ec0
·
1 Parent(s): e020d00

Float8DynamicActivation quantization

Browse files
Files changed (2) hide show
  1. optimization.py +4 -0
  2. requirements.txt +1 -0
optimization.py CHANGED
@@ -8,6 +8,8 @@ from typing import ParamSpec
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
 
 
11
 
12
  from optimization_utils import capture_component_call
13
  from optimization_utils import aoti_compile
@@ -46,6 +48,8 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
46
 
47
  pipeline.transformer.fuse_qkv_projections()
48
 
 
 
49
  exported = torch.export.export(
50
  mod=pipeline.transformer,
51
  args=call.args,
 
8
  import spaces
9
  import torch
10
  from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
 
14
  from optimization_utils import capture_component_call
15
  from optimization_utils import aoti_compile
 
48
 
49
  pipeline.transformer.fuse_qkv_projections()
50
 
51
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
52
+
53
  exported = torch.export.export(
54
  mod=pipeline.transformer,
55
  args=call.args,
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  transformers
2
  git+https://github.com/huggingface/diffusers.git
3
  accelerate
 
1
+ torchao
2
  transformers
3
  git+https://github.com/huggingface/diffusers.git
4
  accelerate