boris commited on
Commit
80d791a
1 Parent(s): b6f5026

feat(train): allow editing dropout during training

Browse files
Files changed (1) hide show
  1. tools/train/train.py +18 -0
tools/train/train.py CHANGED
@@ -106,6 +106,18 @@ class ModelArguments:
106
  "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107
  },
108
  )
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  def __post_init__(self):
111
  if self.tokenizer_name is None:
@@ -674,6 +686,9 @@ def main():
674
  if model_args.config_name:
675
  config = DalleBartConfig.from_pretrained(model_args.config_name)
676
  config.gradient_checkpointing = training_args.gradient_checkpointing
 
 
 
677
  else:
678
  config = None
679
 
@@ -686,6 +701,9 @@ def main():
686
  dtype=getattr(jnp, model_args.dtype),
687
  _do_init=False, # we overwrite them with loaded checkpoint
688
  gradient_checkpointing=training_args.gradient_checkpointing,
 
 
 
689
  )
690
  else:
691
  model = DalleBart(
 
106
  "help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
107
  },
108
  )
109
+ dropout: Optional[float] = field(
110
+ default=None,
111
+ metadata={"help": "Dropout rate. Overwrites config."},
112
+ )
113
+ activation_dropout: Optional[float] = field(
114
+ default=None,
115
+ metadata={"help": "Activation dropout rate. Overwrites config."},
116
+ )
117
+ attention_dropout: Optional[float] = field(
118
+ default=None,
119
+ metadata={"help": "Attention dropout rate. Overwrites config."},
120
+ )
121
 
122
  def __post_init__(self):
123
  if self.tokenizer_name is None:
 
686
  if model_args.config_name:
687
  config = DalleBartConfig.from_pretrained(model_args.config_name)
688
  config.gradient_checkpointing = training_args.gradient_checkpointing
689
+ config.dropout = model_args.dropout
690
+ config.activation_dropout = model_args.activation_dropout
691
+ config.attention_dropout = model_args.attention_dropout
692
  else:
693
  config = None
694
 
 
701
  dtype=getattr(jnp, model_args.dtype),
702
  _do_init=False, # we overwrite them with loaded checkpoint
703
  gradient_checkpointing=training_args.gradient_checkpointing,
704
+ dropout=model_args.dropout,
705
+ activation_dropout=model_args.activation_dropout,
706
+ attention_dropout=model_args.attention_dropout,
707
  )
708
  else:
709
  model = DalleBart(