File size: 4,347 Bytes
2af6ba4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import dataclasses
import json
import warnings
from dataclasses import dataclass, MISSING
from functools import partial
from typing import Optional, Any


@partial(dataclass, frozen=True, kw_only=True)
class JsonComparable:
    def to_json(self) -> str:
        return json.dumps(dataclasses.asdict(self))

    def __eq__(self, other: "JsonComparable") -> bool:
        return self.to_json() == other.to_json()

    def __hash__(self) -> int:
        return hash(self.to_json())

    def __lt__(self, other: "JsonComparable") -> bool:
        return self.to_json() < other.to_json()


@partial(dataclass, frozen=True, kw_only=True)
class SubblockConfig(JsonComparable):
    no_op: bool = False
    replace_with_linear: bool = False
    sparsify: Optional[list[str]] = None

    def __post_init__(self):
        assert not (self.no_op and self.replace_with_linear)

    def _force_setattr(self, name: str, value: Any) -> None:
        """
        Set an attribute even in frozen dataclasses.
        Use only inside __post_init__!
        """
        object.__setattr__(self, name, value)


@partial(dataclass, frozen=True, kw_only=True)
class AttentionConfig(SubblockConfig):
    n_heads_in_group: Optional[int] = None
    window_length: Optional[int] = None
    num_sink_tokens: Optional[int] = None
    use_prefill_window_in_sink_attention: bool = False
    unshifted_sink: bool = False

    def __post_init__(self):
        super().__post_init__()
        assert not (self.no_op and self.replace_with_linear)

        if self.no_op or self.replace_with_linear:
            for irrelevant_att in ["n_heads_in_group", "window_length", "num_sink_tokens"]:
                self._force_setattr(irrelevant_att, None)
        else:
            assert self.n_heads_in_group is not None

        if self.is_sink:
            assert not (self.unshifted_sink and self.use_prefill_window_in_sink_attention), \
                ("Unshifted sink uses its own kind of explicit masking, not standard window. "
                 "Set use_prefill_window_in_sink_attention to False.")
            assert not (self.num_sink_tokens == 0 and not self.unshifted_sink), \
                "Fake sink attention with 0 sink tokens is only supported with unshifted_sink=True"

    @property
    def prefill_sliding_window(self) -> Optional[int]:
        if self.window_length is not None:
            if not self.is_sink or self.use_prefill_window_in_sink_attention:
                return self.window_length
        return None

    @property
    def is_sliding(self) -> bool:
        return self.prefill_sliding_window is not None

    @property
    def is_sink(self) -> bool:
        return (
                (self.window_length is not None)
                and
                (self.num_sink_tokens is not None)
        )


@partial(dataclass, frozen=True, kw_only=True)
class FFNConfig(SubblockConfig):
    ffn_mult: Optional[float] = None

    def __post_init__(self):
        super().__post_init__()
        if self.no_op or self.replace_with_linear:
            self._force_setattr("ffn_mult", None)
        else:
            assert self.ffn_mult is not None
            self._force_setattr("ffn_mult", round(self.ffn_mult, 6))


@partial(dataclass, frozen=True, kw_only=True)
class BlockConfig(JsonComparable):
    attention: AttentionConfig = MISSING
    ffn: FFNConfig = MISSING

    def __post_init__(self):
        """
        Init subblock dataclasses from dicts
        """
        for subblock_name in dataclasses.fields(self):
            subblock_config = getattr(self, subblock_name.name)
            if isinstance(subblock_config, dict):
                subblock_fields = [field.name for field in dataclasses.fields(subblock_name.type)]
                unsupported_fields = [field_name for field_name in subblock_config.keys()
                                      if field_name not in subblock_fields]
                if len(unsupported_fields) > 0:
                    warnings.warn(f"Removed unsupported fields {unsupported_fields} from {subblock_name.type.__name__}")
                subblock_config = {k: v for k, v in subblock_config.items() if k not in unsupported_fields}
                object.__setattr__(self, subblock_name.name,
                                   subblock_name.type(**subblock_config))  # __setattr__ to overcome frozen=True