fix: apply learning rate offset only when requested
Browse files- tools/train/train.py +3 -1
tools/train/train.py
CHANGED
@@ -786,11 +786,13 @@ def main():
|
|
786 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
787 |
)
|
788 |
# offset step when resuming
|
|
|
789 |
if model_metadata.get("step", 0) and training_args.lr_resume_offset:
|
790 |
warmup_fn = optax.join_schedules(
|
791 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
792 |
boundaries=[model_metadata["step"]],
|
793 |
)
|
|
|
794 |
if training_args.lr_decay is None:
|
795 |
return warmup_fn
|
796 |
elif training_args.lr_decay == "linear":
|
@@ -811,7 +813,7 @@ def main():
|
|
811 |
)
|
812 |
schedule_fn = optax.join_schedules(
|
813 |
schedules=[warmup_fn, decay_fn],
|
814 |
-
boundaries=[
|
815 |
)
|
816 |
return schedule_fn
|
817 |
|
|
|
786 |
transition_steps=training_args.warmup_steps + 1, # ensure not 0
|
787 |
)
|
788 |
# offset step when resuming
|
789 |
+
last_boundary = training_args.warmup_steps
|
790 |
if model_metadata.get("step", 0) and training_args.lr_resume_offset:
|
791 |
warmup_fn = optax.join_schedules(
|
792 |
schedules=[optax.constant_schedule(0.0), warmup_fn],
|
793 |
boundaries=[model_metadata["step"]],
|
794 |
)
|
795 |
+
last_boundary += model_metadata["step"]
|
796 |
if training_args.lr_decay is None:
|
797 |
return warmup_fn
|
798 |
elif training_args.lr_decay == "linear":
|
|
|
813 |
)
|
814 |
schedule_fn = optax.join_schedules(
|
815 |
schedules=[warmup_fn, decay_fn],
|
816 |
+
boundaries=[last_boundary],
|
817 |
)
|
818 |
return schedule_fn
|
819 |
|