boris commited on
Commit
89bc9d4
1 Parent(s): 65bb95f

fix(train): overwrite dropout only when specified

Browse files
Files changed (1) hide show
  1. tools/train/train.py +9 -7
tools/train/train.py CHANGED
@@ -131,7 +131,7 @@ class ModelArguments:
131
  ), "Restoring state only available with W&B artifact reference"
132
 
133
  def get_metadata(self):
134
- if ":" in self.model_name_or_path:
135
  if jax.process_index() == 0:
136
  artifact = wandb.run.use_artifact(self.model_name_or_path)
137
  else:
@@ -685,12 +685,16 @@ def main():
685
  )
686
 
687
  # Set up our new model config
 
 
 
 
 
688
  if model_args.config_name:
689
  config = DalleBartConfig.from_pretrained(model_args.config_name)
690
  config.gradient_checkpointing = training_args.gradient_checkpointing
691
- config.dropout = model_args.dropout
692
- config.activation_dropout = model_args.activation_dropout
693
- config.attention_dropout = model_args.attention_dropout
694
  else:
695
  config = None
696
 
@@ -703,9 +707,7 @@ def main():
703
  dtype=getattr(jnp, model_args.dtype),
704
  _do_init=False, # we overwrite them with loaded checkpoint
705
  gradient_checkpointing=training_args.gradient_checkpointing,
706
- dropout=model_args.dropout,
707
- activation_dropout=model_args.activation_dropout,
708
- attention_dropout=model_args.attention_dropout,
709
  )
710
  else:
711
  model = DalleBart(
 
131
  ), "Restoring state only available with W&B artifact reference"
132
 
133
  def get_metadata(self):
134
+ if self.model_name_or_path is not None and ":" in self.model_name_or_path:
135
  if jax.process_index() == 0:
136
  artifact = wandb.run.use_artifact(self.model_name_or_path)
137
  else:
 
685
  )
686
 
687
  # Set up our new model config
688
+ config_args = {
689
+ k: getattr(model_args, k)
690
+ for k in ["dropout", "activation_dropout", "attention_dropout"]
691
+ if getattr(model_args, k) is not None
692
+ }
693
  if model_args.config_name:
694
  config = DalleBartConfig.from_pretrained(model_args.config_name)
695
  config.gradient_checkpointing = training_args.gradient_checkpointing
696
+ for k, v in config_args.items():
697
+ setattr(config, k, v)
 
698
  else:
699
  config = None
700
 
 
707
  dtype=getattr(jnp, model_args.dtype),
708
  _do_init=False, # we overwrite them with loaded checkpoint
709
  gradient_checkpointing=training_args.gradient_checkpointing,
710
+ **config_args,
 
 
711
  )
712
  else:
713
  model = DalleBart(