cbensimon HF Staff commited on
Commit
dfac6b3
·
1 Parent(s): 1d06ec0
Files changed (2) hide show
  1. optimization.py +2 -1
  2. optimization_utils.py +54 -0
optimization.py CHANGED
@@ -13,6 +13,7 @@ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
 
14
  from optimization_utils import capture_component_call
15
  from optimization_utils import aoti_compile
 
16
 
17
 
18
  P = ParamSpec('P')
@@ -57,7 +58,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
57
  dynamic_shapes=dynamic_shapes,
58
  )
59
 
60
- return aoti_compile(exported, INDUCTOR_CONFIGS)
61
 
62
  transformer_config = pipeline.transformer.config
63
  pipeline.transformer = compile_transformer()
 
13
 
14
  from optimization_utils import capture_component_call
15
  from optimization_utils import aoti_compile
16
+ from optimization_utils import cudagraph
17
 
18
 
19
  P = ParamSpec('P')
 
58
  dynamic_shapes=dynamic_shapes,
59
  )
60
 
61
+ return cudagraph(aoti_compile(exported, INDUCTOR_CONFIGS))
62
 
63
  transformer_config = pipeline.transformer.config
64
  pipeline.transformer = compile_transformer()
optimization_utils.py CHANGED
@@ -4,16 +4,24 @@ import contextlib
4
  from contextvars import ContextVar
5
  from io import BytesIO
6
  from typing import Any
 
 
 
7
  from typing import cast
8
  from unittest.mock import patch
9
 
10
  import torch
 
11
  from torch._inductor.package.package import package_aoti
12
  from torch.export.pt2_archive._package import AOTICompiledModel
13
  from torch.export.pt2_archive._package_weights import TensorProperties
14
  from torch.export.pt2_archive._package_weights import Weights
15
 
16
 
 
 
 
 
17
  INDUCTOR_CONFIGS_OVERRIDES = {
18
  'aot_inductor.package_constants_in_so': False,
19
  'aot_inductor.package_constants_on_disk': True,
@@ -64,6 +72,48 @@ def aoti_compile(
64
  return ZeroGPUCompiledModel(archive_file, weights)
65
 
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  @contextlib.contextmanager
68
  def capture_component_call(
69
  pipeline: Any,
@@ -94,3 +144,7 @@ def capture_component_call(
94
  except CapturedCallException as e:
95
  captured_call.args = e.args
96
  captured_call.kwargs = e.kwargs
 
 
 
 
 
4
  from contextvars import ContextVar
5
  from io import BytesIO
6
  from typing import Any
7
+ from typing import Callable
8
+ from typing import ParamSpec
9
+ from typing import TypeVar
10
  from typing import cast
11
  from unittest.mock import patch
12
 
13
  import torch
14
+ from torch.utils._pytree import tree_map_only
15
  from torch._inductor.package.package import package_aoti
16
  from torch.export.pt2_archive._package import AOTICompiledModel
17
  from torch.export.pt2_archive._package_weights import TensorProperties
18
  from torch.export.pt2_archive._package_weights import Weights
19
 
20
 
21
+ P = ParamSpec('P')
22
+ T = TypeVar('T')
23
+
24
+
25
  INDUCTOR_CONFIGS_OVERRIDES = {
26
  'aot_inductor.package_constants_in_so': False,
27
  'aot_inductor.package_constants_on_disk': True,
 
72
  return ZeroGPUCompiledModel(archive_file, weights)
73
 
74
 
75
+ def cudagraph(fn: Callable[P, list[torch.Tensor]]):
76
+
77
+ graphs = {}
78
+
79
+ def fn_(*args: P.args, **kwargs: P.kwargs):
80
+
81
+ key = hash(tuple(
82
+ tuple(kwarg.shape)
83
+ for a in sorted(kwargs.keys())
84
+ if isinstance((kwarg := kwargs[a]), torch.Tensor)
85
+ ))
86
+
87
+ if key in graphs:
88
+ wrapped, *_ = graphs[key]
89
+ return wrapped(*args, **kwargs)
90
+
91
+ graph = torch.cuda.CUDAGraph()
92
+ in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
93
+ in_args, in_kwargs = _cast_as((args, kwargs), (in_args, in_kwargs))
94
+
95
+ fn(*in_args, **in_kwargs)
96
+ with torch.cuda.graph(graph):
97
+ out_tensors = fn(*in_args, **in_kwargs)
98
+
99
+ def wrapped(*args: P.args, **kwargs: P.kwargs):
100
+ for a, b in zip(in_args, args):
101
+ if isinstance(a, torch.Tensor):
102
+ assert isinstance(b, torch.Tensor)
103
+ a.copy_(b)
104
+ for key in kwargs:
105
+ if isinstance((kwarg := kwargs[key]), torch.Tensor):
106
+ assert isinstance((in_kwarg := in_kwargs[key]), torch.Tensor)
107
+ in_kwarg.copy_(kwarg)
108
+ graph.replay()
109
+ return [tensor.clone() for tensor in out_tensors]
110
+
111
+ graphs[key] = (wrapped, graph, in_args, in_kwargs, out_tensors)
112
+ return wrapped(*args, **kwargs)
113
+
114
+ return fn_
115
+
116
+
117
  @contextlib.contextmanager
118
  def capture_component_call(
119
  pipeline: Any,
 
144
  except CapturedCallException as e:
145
  captured_call.args = e.args
146
  captured_call.kwargs = e.kwargs
147
+
148
+
149
+ def _cast_as(type_from: T, value: Any) -> T:
150
+ return value