|
import re |
|
|
|
from flax.core.frozen_dict import freeze |
|
from flax.traverse_util import flatten_dict, unflatten_dict |
|
from jax.experimental import PartitionSpec as P |
|
|
|
|
|
|
|
_unmatched = object() |
|
|
|
|
|
empty_dict = object() |
|
|
|
|
|
def _match(qs, ks): |
|
"""Return True if regexes in qs match any window of strings in tuple ks.""" |
|
|
|
qts = tuple(map(lambda x: re.compile(x + "$"), qs)) |
|
for i in range(len(ks) - len(qs) + 1): |
|
matches = [x.match(y) for x, y in zip(qts, ks[i:])] |
|
if matches and all(matches): |
|
return True |
|
return False |
|
|
|
|
|
def _replacement_rules(rules): |
|
def replace(key, val): |
|
for rule, replacement in rules: |
|
if _match(rule, key): |
|
return replacement |
|
return val |
|
|
|
return replace |
|
|
|
|
|
def _get_partition_rules(): |
|
return [ |
|
|
|
(("embed_positions", "embedding"), P("mp", None)), |
|
(("embed_tokens", "embedding"), P("mp", None)), |
|
(("rel_bias", "embedding"), P(None, "mp")), |
|
|
|
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), |
|
(("out_proj", "kernel"), P("mp", None)), |
|
|
|
(("Dense_0", "kernel"), P(None, "mp")), |
|
(("GLU.*", "Dense_1", "kernel"), P(None, "mp")), |
|
(("GLU.*", "Dense_2", "kernel"), P("mp", None)), |
|
(("FFN.*", "Dense_1", "kernel"), P("mp", None)), |
|
|
|
(("(bias|scale)",), None), |
|
(("lm_head", "kernel"), P(None, "mp")), |
|
|
|
(("(head_scale|tau)",), None), |
|
] |
|
|
|
|
|
def set_partitions(in_dict, use_scan): |
|
rules = _get_partition_rules() |
|
replace = _replacement_rules(rules) |
|
initd = {k: _unmatched for k in flatten_dict(in_dict)} |
|
result = {k: replace(k, v) for k, v in initd.items()} |
|
for k, v in result.items(): |
|
if v == _unmatched: |
|
print(f"Unmatched -> {k}") |
|
l = list(result.keys()) |
|
if use_scan: |
|
|
|
result = { |
|
k: (P(*(None,) + v) if v is not None else None) |
|
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"]) |
|
else v |
|
for k, v in result.items() |
|
} |
|
assert _unmatched not in result.values(), "Incomplete partition spec." |
|
return freeze(unflatten_dict(result)) |
|
|