boris commited on
Commit
89b4c45
1 Parent(s): 23c1ef6

feat(train): arg to offset lr for resumed runs

Browse files
Files changed (1) hide show
  1. tools/train/train.py +7 -2
tools/train/train.py CHANGED
@@ -406,7 +406,12 @@ class TrainingArguments:
406
  "help": "Whether to use staircase or continuous learning rate when using exponential decay."
407
  },
408
  )
409
-
 
 
 
 
 
410
  logging_steps: int = field(
411
  default=40, metadata={"help": "Log every X updates steps."}
412
  )
@@ -781,7 +786,7 @@ def main():
781
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
782
  )
783
  # offset step when resuming
784
- if model_metadata.get("step", 0):
785
  warmup_fn = optax.join_schedules(
786
  schedules=[optax.constant_schedule(0.0), warmup_fn],
787
  boundaries=[model_metadata["step"]],
 
406
  "help": "Whether to use staircase or continuous learning rate when using exponential decay."
407
  },
408
  )
409
+ lr_resume_offset: bool = field(
410
+ default=False,
411
+ metadata={
412
+ "help": "Whether to offset the learning rate function with current step when resuming a run."
413
+ },
414
+ )
415
  logging_steps: int = field(
416
  default=40, metadata={"help": "Log every X updates steps."}
417
  )
 
786
  transition_steps=training_args.warmup_steps + 1, # ensure not 0
787
  )
788
  # offset step when resuming
789
+ if model_metadata.get("step", 0) and training_args.lr_resume_offset:
790
  warmup_fn = optax.join_schedules(
791
  schedules=[optax.constant_schedule(0.0), warmup_fn],
792
  boundaries=[model_metadata["step"]],