boris commited on
Commit
79a3849
1 Parent(s): d08bf8d

feat: update shampoo

Browse files
tools/train/scalable_shampoo/distributed_shampoo.py CHANGED
@@ -98,15 +98,19 @@ class LocalShardedParameterStats:
98
  def init_training_metrics(num_statistics):
99
  # Since the downstream apis expect a jnp.array - we create a dummy one if
100
  # num_statistics=0.
101
- n = 1 if not num_statistics else num_statistics
102
- return TrainingMetrics(jnp.zeros([n], jnp.float32))
 
 
103
 
104
 
105
  def init_training_metrics_shapes(num_statistics):
106
  # Since the downstream apis expect a jnp.array - we create a dummy one if
107
  # num_statistics=0.
108
- n = 1 if not num_statistics else num_statistics
109
- return TrainingMetrics([[n], jnp.float32])
 
 
110
 
111
 
112
  def init_training_metrics_pspec():
@@ -297,7 +301,7 @@ def matrix_inverse_pth_root(
297
 
298
  if matrix_size == 1:
299
  resultant_mat_h = (matrix + ridge_epsilon) ** alpha
300
- error = 0
301
  else:
302
  damped_matrix = matrix + ridge_epsilon * identity
303
 
@@ -688,9 +692,12 @@ def _add_error_into_local_stats(local_stats, errors, inverse_failure_threshold):
688
  """Adds errors back into local statistics."""
689
  new_local_stats = []
690
  for local_stat in local_stats:
691
- index_start = int(local_stat.index_start)
692
- index_end = int(len(local_stat.sizes)) + index_start
693
- per_stat_error = errors[index_start:index_end]
 
 
 
694
  if local_stat.sizes:
695
  per_stat_error = jnp.where(
696
  jnp.logical_and(
@@ -1077,18 +1084,22 @@ def distributed_shampoo(
1077
 
1078
  diagonal_statistics_pspec = []
1079
  diagonal_statistics_scale_pspec = []
 
1080
  if _graft_type_has_diagonal_statistics():
1081
  # Identically shaped param.
1082
  diagonal_statistics_pspec = param_pspec
 
1083
  if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1084
  diagonal_statistics_scale_pspec = (
1085
  _remove_leading_sharding_annotation(param_pspec)
1086
  )
1087
 
1088
  m1_pspec = []
 
1089
  m1_scale_pspec = []
1090
  if _graft_type_has_diagonal_momentum_states():
1091
  m1_pspec = param_pspec
 
1092
  if quantized_dtype_for_momentum_buffers() != jnp.float32:
1093
  m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1094
 
@@ -1105,7 +1116,7 @@ def distributed_shampoo(
1105
  diagonal_statistics_scale_pspec,
1106
  quantized_dtype_for_diagonal_statistics_buffers(),
1107
  False,
1108
- list(param.shape),
1109
  ),
1110
  QuantizedValue(
1111
  m1_pspec,
@@ -1113,7 +1124,7 @@ def distributed_shampoo(
1113
  m1_scale_pspec,
1114
  quantized_dtype_for_momentum_buffers(),
1115
  False,
1116
- list(param.shape),
1117
  ),
1118
  QuantizedValue(
1119
  m2_pspec,
@@ -1463,9 +1474,11 @@ def distributed_shampoo(
1463
  # Partition the concatenated statistics matrix across all cores.
1464
  pspec_for_partition = preconditioner_partition_spec
1465
  partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1466
- partitioned_ps = pjit.with_sharding_constraint(
1467
- ps, pjit.PartitionSpec(preconditioner_partition_spec[0])
1468
- )
 
 
1469
  # Run matrix inverse pth root on each shard.
1470
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1471
  partitioned_xs, partitioned_ps
@@ -1581,7 +1594,7 @@ def distributed_shampoo(
1581
  for num_statistics, state in zip(num_statistics_per_state, states):
1582
  if num_statistics == 0:
1583
  preconditioners_for_states.append([])
1584
- errors_for_states.append([])
1585
  else:
1586
  preconditioners_for_state = new_preconditioners_flat[
1587
  idx : idx + num_statistics
@@ -1809,7 +1822,7 @@ def distributed_shampoo(
1809
  for num_statistics, state in zip(num_statistics_per_state, states):
1810
  if num_statistics == 0:
1811
  preconditioners_for_states.append([])
1812
- errors_for_states.append([])
1813
  else:
1814
  quantized_preconditioners_for_state = (
1815
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
@@ -1962,7 +1975,7 @@ def distributed_shampoo(
1962
  for num_statistics, state in zip(num_statistics_per_state, states):
1963
  if num_statistics == 0:
1964
  preconditioners_for_states.append([])
1965
- errors_for_states.append([])
1966
  else:
1967
  preconditioners_for_state = new_preconditioners_flat[
1968
  idx : idx + num_statistics
 
98
  def init_training_metrics(num_statistics):
99
  # Since the downstream apis expect a jnp.array - we create a dummy one if
100
  # num_statistics=0.
101
+ if not num_statistics:
102
+ return TrainingMetrics(jnp.array(0, jnp.float32))
103
+ else:
104
+ return TrainingMetrics(jnp.zeros([num_statistics], jnp.float32))
105
 
106
 
107
  def init_training_metrics_shapes(num_statistics):
108
  # Since the downstream apis expect a jnp.array - we create a dummy one if
109
  # num_statistics=0.
110
+ if not num_statistics:
111
+ return TrainingMetrics([[], jnp.float32])
112
+ else:
113
+ return TrainingMetrics([[num_statistics], jnp.float32])
114
 
115
 
116
  def init_training_metrics_pspec():
 
301
 
302
  if matrix_size == 1:
303
  resultant_mat_h = (matrix + ridge_epsilon) ** alpha
304
+ error = jnp.array(0, jnp.float32)
305
  else:
306
  damped_matrix = matrix + ridge_epsilon * identity
307
 
 
692
  """Adds errors back into local statistics."""
693
  new_local_stats = []
694
  for local_stat in local_stats:
695
+ if local_stat.sizes:
696
+ index_start = int(local_stat.index_start)
697
+ index_end = int(len(local_stat.sizes)) + index_start
698
+ per_stat_error = errors[index_start:index_end]
699
+ else:
700
+ per_stat_error = jnp.array(0, jnp.float32)
701
  if local_stat.sizes:
702
  per_stat_error = jnp.where(
703
  jnp.logical_and(
 
1084
 
1085
  diagonal_statistics_pspec = []
1086
  diagonal_statistics_scale_pspec = []
1087
+ diagonal_statistics_shape = []
1088
  if _graft_type_has_diagonal_statistics():
1089
  # Identically shaped param.
1090
  diagonal_statistics_pspec = param_pspec
1091
+ diagonal_statistics_shape = list(param.shape)
1092
  if quantized_dtype_for_diagonal_statistics_buffers() != jnp.float32:
1093
  diagonal_statistics_scale_pspec = (
1094
  _remove_leading_sharding_annotation(param_pspec)
1095
  )
1096
 
1097
  m1_pspec = []
1098
+ m1_shape = []
1099
  m1_scale_pspec = []
1100
  if _graft_type_has_diagonal_momentum_states():
1101
  m1_pspec = param_pspec
1102
+ m1_shape = list(param.shape)
1103
  if quantized_dtype_for_momentum_buffers() != jnp.float32:
1104
  m1_scale_pspec = _remove_leading_sharding_annotation(m1_pspec)
1105
 
 
1116
  diagonal_statistics_scale_pspec,
1117
  quantized_dtype_for_diagonal_statistics_buffers(),
1118
  False,
1119
+ diagonal_statistics_shape,
1120
  ),
1121
  QuantizedValue(
1122
  m1_pspec,
 
1124
  m1_scale_pspec,
1125
  quantized_dtype_for_momentum_buffers(),
1126
  False,
1127
+ m1_shape,
1128
  ),
1129
  QuantizedValue(
1130
  m2_pspec,
 
1474
  # Partition the concatenated statistics matrix across all cores.
1475
  pspec_for_partition = preconditioner_partition_spec
1476
  partitioned_xs = pjit.with_sharding_constraint(xs, pspec_for_partition)
1477
+ if preconditioner_partition_spec:
1478
+ partitioned_ps_spec = pjit.PartitionSpec(preconditioner_partition_spec[0])
1479
+ else:
1480
+ partitioned_ps_spec = None
1481
+ partitioned_ps = pjit.with_sharding_constraint(ps, partitioned_ps_spec)
1482
  # Run matrix inverse pth root on each shard.
1483
  partitioned_preconditioners, partitioned_errors = _matrix_inverse_pth_root_vmap(
1484
  partitioned_xs, partitioned_ps
 
1594
  for num_statistics, state in zip(num_statistics_per_state, states):
1595
  if num_statistics == 0:
1596
  preconditioners_for_states.append([])
1597
+ errors_for_states.append(jnp.array(0, jnp.float32))
1598
  else:
1599
  preconditioners_for_state = new_preconditioners_flat[
1600
  idx : idx + num_statistics
 
1822
  for num_statistics, state in zip(num_statistics_per_state, states):
1823
  if num_statistics == 0:
1824
  preconditioners_for_states.append([])
1825
+ errors_for_states.append(jnp.array(0, jnp.float32))
1826
  else:
1827
  quantized_preconditioners_for_state = (
1828
  new_quantized_preconditioners_flat[idx : idx + num_statistics]
 
1975
  for num_statistics, state in zip(num_statistics_per_state, states):
1976
  if num_statistics == 0:
1977
  preconditioners_for_states.append([])
1978
+ errors_for_states.append(jnp.array(0, jnp.float32))
1979
  else:
1980
  preconditioners_for_state = new_preconditioners_flat[
1981
  idx : idx + num_statistics