feat(train): allow editing dropout during training
Browse files- tools/train/train.py +18 -0
tools/train/train.py
CHANGED
@@ -106,6 +106,18 @@ class ModelArguments:
|
|
106 |
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
|
107 |
},
|
108 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
def __post_init__(self):
|
111 |
if self.tokenizer_name is None:
|
@@ -674,6 +686,9 @@ def main():
|
|
674 |
if model_args.config_name:
|
675 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
676 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
|
|
|
|
|
|
677 |
else:
|
678 |
config = None
|
679 |
|
@@ -686,6 +701,9 @@ def main():
|
|
686 |
dtype=getattr(jnp, model_args.dtype),
|
687 |
_do_init=False, # we overwrite them with loaded checkpoint
|
688 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
|
|
|
|
|
|
689 |
)
|
690 |
else:
|
691 |
model = DalleBart(
|
|
|
106 |
"help": "Restore optimizer and training state. Can be True (will retrieve associated wandb artifact), a local directory or a Google bucket path."
|
107 |
},
|
108 |
)
|
109 |
+
dropout: Optional[float] = field(
|
110 |
+
default=None,
|
111 |
+
metadata={"help": "Dropout rate. Overwrites config."},
|
112 |
+
)
|
113 |
+
activation_dropout: Optional[float] = field(
|
114 |
+
default=None,
|
115 |
+
metadata={"help": "Activation dropout rate. Overwrites config."},
|
116 |
+
)
|
117 |
+
attention_dropout: Optional[float] = field(
|
118 |
+
default=None,
|
119 |
+
metadata={"help": "Attention dropout rate. Overwrites config."},
|
120 |
+
)
|
121 |
|
122 |
def __post_init__(self):
|
123 |
if self.tokenizer_name is None:
|
|
|
686 |
if model_args.config_name:
|
687 |
config = DalleBartConfig.from_pretrained(model_args.config_name)
|
688 |
config.gradient_checkpointing = training_args.gradient_checkpointing
|
689 |
+
config.dropout = model_args.dropout
|
690 |
+
config.activation_dropout = model_args.activation_dropout
|
691 |
+
config.attention_dropout = model_args.attention_dropout
|
692 |
else:
|
693 |
config = None
|
694 |
|
|
|
701 |
dtype=getattr(jnp, model_args.dtype),
|
702 |
_do_init=False, # we overwrite them with loaded checkpoint
|
703 |
gradient_checkpointing=training_args.gradient_checkpointing,
|
704 |
+
dropout=model_args.dropout,
|
705 |
+
activation_dropout=model_args.activation_dropout,
|
706 |
+
attention_dropout=model_args.attention_dropout,
|
707 |
)
|
708 |
else:
|
709 |
model = DalleBart(
|