File size: 4,006 Bytes
22ac777
 
 
356294b
22ac777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356294b
 
22ac777
 
06c883e
 
356294b
425ba26
22ac777
 
 
356294b
 
22ac777
 
 
 
 
356294b
 
22ac777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356294b
 
22ac777
 
 
 
 
 
 
356294b
 
 
 
 
1017ac0
356294b
1017ac0
 
 
 
356294b
22ac777
 
28eb8a8
 
 
22ac777
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
"""

from datetime import datetime
from typing import Any
from typing import Callable
from typing import ParamSpec

import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig

from optimization_utils import capture_component_call
from optimization_utils import aoti_compile


P = ParamSpec('P')


TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)

TRANSFORMER_DYNAMIC_SHAPES = {
    'hidden_states': {
        2: TRANSFORMER_NUM_FRAMES_DIM,
    },
}

INDUCTOR_CONFIGS = {
    'conv_1x1_as_mm': True,
    'epilogue_fusion': False,
    'coordinate_descent_tuning': True,
    'coordinate_descent_check_all_directions': True,
    'max_autotune': True,
    'triton.cudagraphs': True,
}


def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):

    t0 = datetime.now()

    @spaces.GPU(duration=1500)
    def compile_transformer():
        nonlocal t0
        
        print('compile_transformer', -(t0 - (t0 := datetime.now())))

        with capture_component_call(pipeline, 'transformer') as call:
            pipeline(*args, **kwargs)

        print('capture_component_call', -(t0 - (t0 := datetime.now())))
        
        dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())

        print('quantize_', -(t0 - (t0 := datetime.now())))
        
        hidden_states: torch.Tensor = call.kwargs['hidden_states']
        hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            hidden_states_landscape = hidden_states
            hidden_states_portrait = hidden_states_transposed
        else:
            hidden_states_landscape = hidden_states_transposed
            hidden_states_portrait = hidden_states

        exported_landscape = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
            dynamic_shapes=dynamic_shapes,
        )

        print('exported_landscape', -(t0 - (t0 := datetime.now())))
        
        exported_portrait = torch.export.export(
            mod=pipeline.transformer,
            args=call.args,
            kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
            dynamic_shapes=dynamic_shapes,
        )

        print('exported_portrait', -(t0 - (t0 := datetime.now())))

        compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
        print('compiled_landscape', -(t0 - (t0 := datetime.now())))

        compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
        print('compiled_portrait', -(t0 - (t0 := datetime.now())))

        # Avoid weights duplication when serializing back to main process
        compiled_portrait.weights = compiled_landscape.weights

        return compiled_landscape, compiled_portrait

    compiled_landscape, compiled_portrait = compile_transformer()
    print('compiled', -(t0 - (t0 := datetime.now())))

    compiled_portrait.weights = compiled_landscape.weights

    def combined_transformer(*args, **kwargs):
        hidden_states: torch.Tensor = kwargs['hidden_states']
        if hidden_states.shape[-1] > hidden_states.shape[-2]:
            return compiled_landscape(*args, **kwargs)
        else:
            return compiled_portrait(*args, **kwargs)

    transformer_config = pipeline.transformer.config
    transformer_dtype = pipeline.transformer.dtype
    pipeline.transformer = combined_transformer
    pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
    pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]