boris commited on
Commit
c6263f3
1 Parent(s): 209ade7

fix: apply learning rate offset only when requested

Browse files
Files changed (1) hide show
  1. tools/train/train.py +3 -1
tools/train/train.py CHANGED
@@ -786,11 +786,13 @@ def main():
786
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
787
  )
788
  # offset step when resuming
 
789
  if model_metadata.get("step", 0) and training_args.lr_resume_offset:
790
  warmup_fn = optax.join_schedules(
791
  schedules=[optax.constant_schedule(0.0), warmup_fn],
792
  boundaries=[model_metadata["step"]],
793
  )
 
794
  if training_args.lr_decay is None:
795
  return warmup_fn
796
  elif training_args.lr_decay == "linear":
@@ -811,7 +813,7 @@ def main():
811
  )
812
  schedule_fn = optax.join_schedules(
813
  schedules=[warmup_fn, decay_fn],
814
- boundaries=[model_metadata.get("step", 0) + training_args.warmup_steps],
815
  )
816
  return schedule_fn
817
 
 
786
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
787
  )
788
  # offset step when resuming
789
+ last_boundary = training_args.warmup_steps
790
  if model_metadata.get("step", 0) and training_args.lr_resume_offset:
791
  warmup_fn = optax.join_schedules(
792
  schedules=[optax.constant_schedule(0.0), warmup_fn],
793
  boundaries=[model_metadata["step"]],
794
  )
795
+ last_boundary += model_metadata["step"]
796
  if training_args.lr_decay is None:
797
  return warmup_fn
798
  elif training_args.lr_decay == "linear":
 
813
  )
814
  schedule_fn = optax.join_schedules(
815
  schedules=[warmup_fn, decay_fn],
816
+ boundaries=[last_boundary],
817
  )
818
  return schedule_fn
819