fix(train): overwrite dropout only when specified
Browse files- tools/train/train.py +9 -7
tools/train/train.py
CHANGED
@@ -131,7 +131,7 @@ class ModelArguments:
|
|
131 |
), "Restoring state only available with W&B artifact reference"
|
132 |
|
133 |
def get_metadata(self):
|
134 |
-
if ":" in self.model_name_or_path:
|
135 |
if jax.process_index() == 0:
|
136 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
137 |
else:
|
@@ -685,12 +685,16 @@ def main():
|
|
685 |
)
|
686 |
|
687 |
# Set up our new model config
|
|
|
|
|
|
|
|
|
|
|
688 |
if model_args.config_name:
|
689 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
690 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
691 |
-
|
692 |
-
|
693 |
-
config.attention_dropout = model_args.attention_dropout
|
694 |
else:
|
695 |
config = None
|
696 |
|
@@ -703,9 +707,7 @@ def main():
|
|
703 |
dtype=getattr(jnp, model_args.dtype),
|
704 |
_do_init=False, # we overwrite them with loaded checkpoint
|
705 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
706 |
-
|
707 |
-
activation_dropout=model_args.activation_dropout,
|
708 |
-
attention_dropout=model_args.attention_dropout,
|
709 |
)
|
710 |
else:
|
711 |
model = DalleBart(
|
|
|
131 |
), "Restoring state only available with W&B artifact reference"
|
132 |
|
133 |
def get_metadata(self):
|
134 |
+
if self.model_name_or_path is not None and ":" in self.model_name_or_path:
|
135 |
if jax.process_index() == 0:
|
136 |
artifact = wandb.run.use_artifact(self.model_name_or_path)
|
137 |
else:
|
|
|
685 |
)
|
686 |
|
687 |
# Set up our new model config
|
688 |
+
config_args = {
|
689 |
+
k: getattr(model_args, k)
|
690 |
+
for k in ["dropout", "activation_dropout", "attention_dropout"]
|
691 |
+
if getattr(model_args, k) is not None
|
692 |
+
}
|
693 |
if model_args.config_name:
|
694 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
695 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
696 |
+
for k, v in config_args.items():
|
697 |
+
setattr(config, k, v)
|
|
|
698 |
else:
|
699 |
config = None
|
700 |
|
|
|
707 |
dtype=getattr(jnp, model_args.dtype),
|
708 |
_do_init=False, # we overwrite them with loaded checkpoint
|
709 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
710 |
+
**config_args,
|
|
|
|
|
711 |
)
|
712 |
else:
|
713 |
model = DalleBart(
|