Spaces:
Running
on
Zero
Running
on
Zero
cudagraph
Browse files- optimization.py +2 -1
- optimization_utils.py +54 -0
optimization.py
CHANGED
@@ -13,6 +13,7 @@ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
|
13 |
|
14 |
from optimization_utils import capture_component_call
|
15 |
from optimization_utils import aoti_compile
|
|
|
16 |
|
17 |
|
18 |
P = ParamSpec('P')
|
@@ -57,7 +58,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw
|
|
57 |
dynamic_shapes=dynamic_shapes,
|
58 |
)
|
59 |
|
60 |
-
return aoti_compile(exported, INDUCTOR_CONFIGS)
|
61 |
|
62 |
transformer_config = pipeline.transformer.config
|
63 |
pipeline.transformer = compile_transformer()
|
|
|
13 |
|
14 |
from optimization_utils import capture_component_call
|
15 |
from optimization_utils import aoti_compile
|
16 |
+
from optimization_utils import cudagraph
|
17 |
|
18 |
|
19 |
P = ParamSpec('P')
|
|
|
58 |
dynamic_shapes=dynamic_shapes,
|
59 |
)
|
60 |
|
61 |
+
return cudagraph(aoti_compile(exported, INDUCTOR_CONFIGS))
|
62 |
|
63 |
transformer_config = pipeline.transformer.config
|
64 |
pipeline.transformer = compile_transformer()
|
optimization_utils.py
CHANGED
@@ -4,16 +4,24 @@ import contextlib
|
|
4 |
from contextvars import ContextVar
|
5 |
from io import BytesIO
|
6 |
from typing import Any
|
|
|
|
|
|
|
7 |
from typing import cast
|
8 |
from unittest.mock import patch
|
9 |
|
10 |
import torch
|
|
|
11 |
from torch._inductor.package.package import package_aoti
|
12 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
13 |
from torch.export.pt2_archive._package_weights import TensorProperties
|
14 |
from torch.export.pt2_archive._package_weights import Weights
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
17 |
INDUCTOR_CONFIGS_OVERRIDES = {
|
18 |
'aot_inductor.package_constants_in_so': False,
|
19 |
'aot_inductor.package_constants_on_disk': True,
|
@@ -64,6 +72,48 @@ def aoti_compile(
|
|
64 |
return ZeroGPUCompiledModel(archive_file, weights)
|
65 |
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
@contextlib.contextmanager
|
68 |
def capture_component_call(
|
69 |
pipeline: Any,
|
@@ -94,3 +144,7 @@ def capture_component_call(
|
|
94 |
except CapturedCallException as e:
|
95 |
captured_call.args = e.args
|
96 |
captured_call.kwargs = e.kwargs
|
|
|
|
|
|
|
|
|
|
4 |
from contextvars import ContextVar
|
5 |
from io import BytesIO
|
6 |
from typing import Any
|
7 |
+
from typing import Callable
|
8 |
+
from typing import ParamSpec
|
9 |
+
from typing import TypeVar
|
10 |
from typing import cast
|
11 |
from unittest.mock import patch
|
12 |
|
13 |
import torch
|
14 |
+
from torch.utils._pytree import tree_map_only
|
15 |
from torch._inductor.package.package import package_aoti
|
16 |
from torch.export.pt2_archive._package import AOTICompiledModel
|
17 |
from torch.export.pt2_archive._package_weights import TensorProperties
|
18 |
from torch.export.pt2_archive._package_weights import Weights
|
19 |
|
20 |
|
21 |
+
P = ParamSpec('P')
|
22 |
+
T = TypeVar('T')
|
23 |
+
|
24 |
+
|
25 |
INDUCTOR_CONFIGS_OVERRIDES = {
|
26 |
'aot_inductor.package_constants_in_so': False,
|
27 |
'aot_inductor.package_constants_on_disk': True,
|
|
|
72 |
return ZeroGPUCompiledModel(archive_file, weights)
|
73 |
|
74 |
|
75 |
+
def cudagraph(fn: Callable[P, list[torch.Tensor]]):
|
76 |
+
|
77 |
+
graphs = {}
|
78 |
+
|
79 |
+
def fn_(*args: P.args, **kwargs: P.kwargs):
|
80 |
+
|
81 |
+
key = hash(tuple(
|
82 |
+
tuple(kwarg.shape)
|
83 |
+
for a in sorted(kwargs.keys())
|
84 |
+
if isinstance((kwarg := kwargs[a]), torch.Tensor)
|
85 |
+
))
|
86 |
+
|
87 |
+
if key in graphs:
|
88 |
+
wrapped, *_ = graphs[key]
|
89 |
+
return wrapped(*args, **kwargs)
|
90 |
+
|
91 |
+
graph = torch.cuda.CUDAGraph()
|
92 |
+
in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
|
93 |
+
in_args, in_kwargs = _cast_as((args, kwargs), (in_args, in_kwargs))
|
94 |
+
|
95 |
+
fn(*in_args, **in_kwargs)
|
96 |
+
with torch.cuda.graph(graph):
|
97 |
+
out_tensors = fn(*in_args, **in_kwargs)
|
98 |
+
|
99 |
+
def wrapped(*args: P.args, **kwargs: P.kwargs):
|
100 |
+
for a, b in zip(in_args, args):
|
101 |
+
if isinstance(a, torch.Tensor):
|
102 |
+
assert isinstance(b, torch.Tensor)
|
103 |
+
a.copy_(b)
|
104 |
+
for key in kwargs:
|
105 |
+
if isinstance((kwarg := kwargs[key]), torch.Tensor):
|
106 |
+
assert isinstance((in_kwarg := in_kwargs[key]), torch.Tensor)
|
107 |
+
in_kwarg.copy_(kwarg)
|
108 |
+
graph.replay()
|
109 |
+
return [tensor.clone() for tensor in out_tensors]
|
110 |
+
|
111 |
+
graphs[key] = (wrapped, graph, in_args, in_kwargs, out_tensors)
|
112 |
+
return wrapped(*args, **kwargs)
|
113 |
+
|
114 |
+
return fn_
|
115 |
+
|
116 |
+
|
117 |
@contextlib.contextmanager
|
118 |
def capture_component_call(
|
119 |
pipeline: Any,
|
|
|
144 |
except CapturedCallException as e:
|
145 |
captured_call.args = e.args
|
146 |
captured_call.kwargs = e.kwargs
|
147 |
+
|
148 |
+
|
149 |
+
def _cast_as(type_from: T, value: Any) -> T:
|
150 |
+
return value
|