fix: allow non-scanned models (#168)
Browse files
tools/train/config/medium/config.json
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"activation_dropout": 0.0,
|
3 |
-
"activation_function": "gelu",
|
4 |
-
"attention_dropout": 0.0,
|
5 |
-
"bos_token_id": 16385,
|
6 |
-
"d_model": 1408,
|
7 |
-
"decoder_attention_heads": 16,
|
8 |
-
"decoder_ffn_dim": 4096,
|
9 |
-
"decoder_layerdrop": 0.0,
|
10 |
-
"decoder_layers": 14,
|
11 |
-
"decoder_start_token_id": 16384,
|
12 |
-
"dropout": 0.0,
|
13 |
-
"encoder_attention_heads": 16,
|
14 |
-
"encoder_ffn_dim": 4096,
|
15 |
-
"encoder_layerdrop": 0.0,
|
16 |
-
"encoder_layers": 14,
|
17 |
-
"encoder_vocab_size": 50264,
|
18 |
-
"eos_token_id": 16385,
|
19 |
-
"gradient_checkpointing": false,
|
20 |
-
"image_length": 256,
|
21 |
-
"image_vocab_size": 16384,
|
22 |
-
"init_std": 0.01,
|
23 |
-
"is_encoder_decoder": true,
|
24 |
-
"max_text_length": 64,
|
25 |
-
"model_type": "dallebart",
|
26 |
-
"normalize_text": true,
|
27 |
-
"pad_token_id": 16385,
|
28 |
-
"scale_embedding": false,
|
29 |
-
"tie_word_embeddings": false,
|
30 |
-
"use_cache": true
|
31 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/train/train.py
CHANGED
@@ -536,6 +536,8 @@ def split_params(data):
|
|
536 |
split["scanned_decoder"][k] = v
|
537 |
else:
|
538 |
split["standard"][k] = v
|
|
|
|
|
539 |
for k, v in split.items():
|
540 |
split[k] = freeze(traverse_util.unflatten_dict(v))
|
541 |
return split
|
@@ -544,7 +546,8 @@ def split_params(data):
|
|
544 |
def unsplit_params(data):
|
545 |
flat = {}
|
546 |
for k in ["standard", "scanned_encoder", "scanned_decoder"]:
|
547 |
-
|
|
|
548 |
return freeze(traverse_util.unflatten_dict(flat))
|
549 |
|
550 |
|
@@ -1483,7 +1486,7 @@ def main():
|
|
1483 |
logger.info(" Ready to start training")
|
1484 |
with mesh:
|
1485 |
for epoch in epochs:
|
1486 |
-
state.replace(epoch=epoch)
|
1487 |
local_state["epoch"] = epoch
|
1488 |
# ======================== Training ================================
|
1489 |
metrics_logger.update_state_metrics(local_state)
|
|
|
536 |
split["scanned_decoder"][k] = v
|
537 |
else:
|
538 |
split["standard"][k] = v
|
539 |
+
# remove empty keys
|
540 |
+
split = {k: v for k, v in split.items() if v}
|
541 |
for k, v in split.items():
|
542 |
split[k] = freeze(traverse_util.unflatten_dict(v))
|
543 |
return split
|
|
|
546 |
def unsplit_params(data):
|
547 |
flat = {}
|
548 |
for k in ["standard", "scanned_encoder", "scanned_decoder"]:
|
549 |
+
if k in data:
|
550 |
+
flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
|
551 |
return freeze(traverse_util.unflatten_dict(flat))
|
552 |
|
553 |
|
|
|
1486 |
logger.info(" Ready to start training")
|
1487 |
with mesh:
|
1488 |
for epoch in epochs:
|
1489 |
+
state = state.replace(epoch=epoch)
|
1490 |
local_state["epoch"] = epoch
|
1491 |
# ======================== Training ================================
|
1492 |
metrics_logger.update_state_metrics(local_state)
|