feat: update shampoo
Browse files
tools/train/scalable_shampoo/distributed_shampoo.py
CHANGED
@@ -98,15 +98,19 @@ class LocalShardedParameterStats:
|
|
98 |
def init_training_metrics(num_statistics):
|
99 |
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
100 |
# num_statistics=0.
|
101 |
-
|
102 |
-
|
|
|
|
|
103 |
|
104 |
|
105 |
def init_training_metrics_shapes(num_statistics):
|
106 |
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
107 |
# num_statistics=0.
|
108 |
-
|
109 |
-
|
|
|
|
|
110 |
|
111 |
|
112 |
def init_training_metrics_pspec():
|
@@ -297,7 +301,7 @@ def matrix_inverse_pth_root(
|
|
297 |
|
298 |
if matrix_size == 1:
|
299 |
resultant_mat_h = (matrix + ridge_epsilon) ** alpha
|
300 |
-
error = 0
|
301 |
else:
|
302 |
damped_matrix = matrix + ridge_epsilon * identity
|
303 |
|
@@ -688,9 +692,12 @@ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
|
|
688 |
"""Adds errors back into local statistics."""
|
689 |
new_local_stats = []
|
690 |
for local_stat in local_stats:
|
691 |
-
|
692 |
-
|
693 |
-
|
|
|
|
|
|
|
694 |
if local_stat.sizes:
|
695 |
per_stat_error = jnp.where(
|
696 |
jnp.logical_and(
|
@@ -1077,18 +1084,22 @@ def distributed_shampoo(
|
|
1077 |
|
1078 |
diagonal_statistics_pspec = []
|
1079 |
diagonal_statistics_scale_pspec = []
|
|
|
1080 |
if _graft_type_has_diagonal_statistics():
|
1081 |
# Identically shaped param.
|
1082 |
diagonal_statistics_pspec = param_pspec
|
|
|
1083 |
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
1084 |
diagonal_statistics_scale_pspec = (
|
1085 |
_remove_leading_sharding_annotation(param_pspec)
|
1086 |
)
|
1087 |
|
1088 |
m1_pspec = []
|
|
|
1089 |
m1_scale_pspec = []
|
1090 |
if _graft_type_has_diagonal_momentum_states():
|
1091 |
m1_pspec = param_pspec
|
|
|
1092 |
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1093 |
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1094 |
|
@@ -1105,7 +1116,7 @@ def distributed_shampoo(
|
|
1105 |
diagonal_statistics_scale_pspec,
|
1106 |
quantized_dtype_for_diagonal_statistics_buffers(),
|
1107 |
False,
|
1108 |
-
|
1109 |
),
|
1110 |
QuantizedValue(
|
1111 |
m1_pspec,
|
@@ -1113,7 +1124,7 @@ def distributed_shampoo(
|
|
1113 |
m1_scale_pspec,
|
1114 |
quantized_dtype_for_momentum_buffers(),
|
1115 |
False,
|
1116 |
-
|
1117 |
),
|
1118 |
QuantizedValue(
|
1119 |
m2_pspec,
|
@@ -1463,9 +1474,11 @@ def distributed_shampoo(
|
|
1463 |
# Partition the concatenated statistics matrix across all cores.
|
1464 |
pspec_for_partition = preconditioner_partition_spec
|
1465 |
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
|
1466 |
-
|
1467 |
-
|
1468 |
-
|
|
|
|
|
1469 |
# Run matrix inverse pth root on each shard.
|
1470 |
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1471 |
partitioned_xs, partitioned_ps
|
@@ -1581,7 +1594,7 @@ def distributed_shampoo(
|
|
1581 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1582 |
if num_statistics == 0:
|
1583 |
preconditioners_for_states.append([])
|
1584 |
-
errors_for_states.append(
|
1585 |
else:
|
1586 |
preconditioners_for_state = new_preconditioners_flat[
|
1587 |
idx : idx + num_statistics
|
@@ -1809,7 +1822,7 @@ def distributed_shampoo(
|
|
1809 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1810 |
if num_statistics == 0:
|
1811 |
preconditioners_for_states.append([])
|
1812 |
-
errors_for_states.append(
|
1813 |
else:
|
1814 |
quantized_preconditioners_for_state = (
|
1815 |
new_quantized_preconditioners_flat[idx : idx + num_statistics]
|
@@ -1962,7 +1975,7 @@ def distributed_shampoo(
|
|
1962 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1963 |
if num_statistics == 0:
|
1964 |
preconditioners_for_states.append([])
|
1965 |
-
errors_for_states.append(
|
1966 |
else:
|
1967 |
preconditioners_for_state = new_preconditioners_flat[
|
1968 |
idx : idx + num_statistics
|
|
|
98 |
def init_training_metrics(num_statistics):
|
99 |
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
100 |
# num_statistics=0.
|
101 |
+
if not num_statistics:
|
102 |
+
return TrainingMetrics(jnp.array(0, jnp.float32))
|
103 |
+
else:
|
104 |
+
return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
|
105 |
|
106 |
|
107 |
def init_training_metrics_shapes(num_statistics):
|
108 |
# Since the downstream apis expect a jnp.array - we create a dummy one if
|
109 |
# num_statistics=0.
|
110 |
+
if not num_statistics:
|
111 |
+
return TrainingMetrics([[], jnp.float32])
|
112 |
+
else:
|
113 |
+
return TrainingMetrics([[num_statistics], jnp.float32])
|
114 |
|
115 |
|
116 |
def init_training_metrics_pspec():
|
|
|
301 |
|
302 |
if matrix_size == 1:
|
303 |
resultant_mat_h = (matrix + ridge_epsilon) ** alpha
|
304 |
+
error = jnp.array(0, jnp.float32)
|
305 |
else:
|
306 |
damped_matrix = matrix + ridge_epsilon * identity
|
307 |
|
|
|
692 |
"""Adds errors back into local statistics."""
|
693 |
new_local_stats = []
|
694 |
for local_stat in local_stats:
|
695 |
+
if local_stat.sizes:
|
696 |
+
index_start = int(local_stat.index_start)
|
697 |
+
index_end = int(len(local_stat.sizes)) + index_start
|
698 |
+
per_stat_error = errors[index_start:index_end]
|
699 |
+
else:
|
700 |
+
per_stat_error = jnp.array(0, jnp.float32)
|
701 |
if local_stat.sizes:
|
702 |
per_stat_error = jnp.where(
|
703 |
jnp.logical_and(
|
|
|
1084 |
|
1085 |
diagonal_statistics_pspec = []
|
1086 |
diagonal_statistics_scale_pspec = []
|
1087 |
+
diagonal_statistics_shape = []
|
1088 |
if _graft_type_has_diagonal_statistics():
|
1089 |
# Identically shaped param.
|
1090 |
diagonal_statistics_pspec = param_pspec
|
1091 |
+
diagonal_statistics_shape = list(param.shape)
|
1092 |
if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
|
1093 |
diagonal_statistics_scale_pspec = (
|
1094 |
_remove_leading_sharding_annotation(param_pspec)
|
1095 |
)
|
1096 |
|
1097 |
m1_pspec = []
|
1098 |
+
m1_shape = []
|
1099 |
m1_scale_pspec = []
|
1100 |
if _graft_type_has_diagonal_momentum_states():
|
1101 |
m1_pspec = param_pspec
|
1102 |
+
m1_shape = list(param.shape)
|
1103 |
if quantized_dtype_for_momentum_buffers() != jnp.float32:
|
1104 |
m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
|
1105 |
|
|
|
1116 |
diagonal_statistics_scale_pspec,
|
1117 |
quantized_dtype_for_diagonal_statistics_buffers(),
|
1118 |
False,
|
1119 |
+
diagonal_statistics_shape,
|
1120 |
),
|
1121 |
QuantizedValue(
|
1122 |
m1_pspec,
|
|
|
1124 |
m1_scale_pspec,
|
1125 |
quantized_dtype_for_momentum_buffers(),
|
1126 |
False,
|
1127 |
+
m1_shape,
|
1128 |
),
|
1129 |
QuantizedValue(
|
1130 |
m2_pspec,
|
|
|
1474 |
# Partition the concatenated statistics matrix across all cores.
|
1475 |
pspec_for_partition = preconditioner_partition_spec
|
1476 |
partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
|
1477 |
+
if preconditioner_partition_spec:
|
1478 |
+
partitioned_ps_spec = pjit.PartitionSpec(preconditioner_partition_spec[0])
|
1479 |
+
else:
|
1480 |
+
partitioned_ps_spec = None
|
1481 |
+
partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
|
1482 |
# Run matrix inverse pth root on each shard.
|
1483 |
partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
|
1484 |
partitioned_xs, partitioned_ps
|
|
|
1594 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1595 |
if num_statistics == 0:
|
1596 |
preconditioners_for_states.append([])
|
1597 |
+
errors_for_states.append(jnp.array(0, jnp.float32))
|
1598 |
else:
|
1599 |
preconditioners_for_state = new_preconditioners_flat[
|
1600 |
idx : idx + num_statistics
|
|
|
1822 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1823 |
if num_statistics == 0:
|
1824 |
preconditioners_for_states.append([])
|
1825 |
+
errors_for_states.append(jnp.array(0, jnp.float32))
|
1826 |
else:
|
1827 |
quantized_preconditioners_for_state = (
|
1828 |
new_quantized_preconditioners_flat[idx : idx + num_statistics]
|
|
|
1975 |
for num_statistics, state in zip(num_statistics_per_state, states):
|
1976 |
if num_statistics == 0:
|
1977 |
preconditioners_for_states.append([])
|
1978 |
+
errors_for_states.append(jnp.array(0, jnp.float32))
|
1979 |
else:
|
1980 |
preconditioners_for_state = new_preconditioners_flat[
|
1981 |
idx : idx + num_statistics
|