feat(train): arg to offset lr for resumed runs
Browse files- tools/train/train.py +7 -2
tools/train/train.py
CHANGED
@@ -406,7 +406,12 @@ class TrainingArguments:
|
|
406 |
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
407 |
},
|
408 |
)
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
410 |
logging_steps: int = field(
|
411 |
default=40, metadata={"help": "Log every X updates steps."}
|
412 |
)
|
@@ -781,7 +786,7 @@ def main():
|
|
781 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
782 |
)
|
783 |
# offset step when resuming
|
784 |
-
if model_metadata.get("step", 0):
|
785 |
warmup_fn = optax.join_schedules(
|
786 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
787 |
boundaries=[model_metadata["step"]],
|
|
|
406 |
"help": "Whether to use staircase or continuous learning rate when using exponential decay."
|
407 |
},
|
408 |
)
|
409 |
+
lr_resume_offset: bool = field(
|
410 |
+
default=False,
|
411 |
+
metadata={
|
412 |
+
"help": "Whether to offset the learning rate function with current step when resuming a run."
|
413 |
+
},
|
414 |
+
)
|
415 |
logging_steps: int = field(
|
416 |
default=40, metadata={"help": "Log every X updates steps."}
|
417 |
)
|
|
|
786 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
787 |
)
|
788 |
# offset step when resuming
|
789 |
+
if model_metadata.get("step", 0) and training_args.lr_resume_offset:
|
790 |
warmup_fn = optax.join_schedules(
|
791 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
792 |
boundaries=[model_metadata["step"]],
|