boris commited on
Commit
8ae9176
unverified
1 Parent(s): 3500e67

fix: allow non-scanned models (#168)

Browse files
tools/train/config/medium/config.json DELETED
@@ -1,31 +0,0 @@
1
- {
2
- "activation_dropout": 0.0,
3
- "activation_function": "gelu",
4
- "attention_dropout": 0.0,
5
- "bos_token_id": 16385,
6
- "d_model": 1408,
7
- "decoder_attention_heads": 16,
8
- "decoder_ffn_dim": 4096,
9
- "decoder_layerdrop": 0.0,
10
- "decoder_layers": 14,
11
- "decoder_start_token_id": 16384,
12
- "dropout": 0.0,
13
- "encoder_attention_heads": 16,
14
- "encoder_ffn_dim": 4096,
15
- "encoder_layerdrop": 0.0,
16
- "encoder_layers": 14,
17
- "encoder_vocab_size": 50264,
18
- "eos_token_id": 16385,
19
- "gradient_checkpointing": false,
20
- "image_length": 256,
21
- "image_vocab_size": 16384,
22
- "init_std": 0.01,
23
- "is_encoder_decoder": true,
24
- "max_text_length": 64,
25
- "model_type": "dallebart",
26
- "normalize_text": true,
27
- "pad_token_id": 16385,
28
- "scale_embedding": false,
29
- "tie_word_embeddings": false,
30
- "use_cache": true
31
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tools/train/train.py CHANGED
@@ -536,6 +536,8 @@ def split_params(data):
536
  split["scanned_decoder"][k] = v
537
  else:
538
  split["standard"][k] = v
 
 
539
  for k, v in split.items():
540
  split[k] = freeze(traverse_util.unflatten_dict(v))
541
  return split
@@ -544,7 +546,8 @@ def split_params(data):
544
  def unsplit_params(data):
545
  flat = {}
546
  for k in ["standard", "scanned_encoder", "scanned_decoder"]:
547
- flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
 
548
  return freeze(traverse_util.unflatten_dict(flat))
549
 
550
 
@@ -1483,7 +1486,7 @@ def main():
1483
  logger.info(" Ready to start training")
1484
  with mesh:
1485
  for epoch in epochs:
1486
- state.replace(epoch=epoch)
1487
  local_state["epoch"] = epoch
1488
  # ======================== Training ================================
1489
  metrics_logger.update_state_metrics(local_state)
 
536
  split["scanned_decoder"][k] = v
537
  else:
538
  split["standard"][k] = v
539
+ # remove empty keys
540
+ split = {k: v for k, v in split.items() if v}
541
  for k, v in split.items():
542
  split[k] = freeze(traverse_util.unflatten_dict(v))
543
  return split
 
546
  def unsplit_params(data):
547
  flat = {}
548
  for k in ["standard", "scanned_encoder", "scanned_decoder"]:
549
+ if k in data:
550
+ flat.update(traverse_util.flatten_dict(unfreeze(data[k])))
551
  return freeze(traverse_util.unflatten_dict(flat))
552
 
553
 
 
1486
  logger.info(" Ready to start training")
1487
  with mesh:
1488
  for epoch in epochs:
1489
+ state = state.replace(epoch=epoch)
1490
  local_state["epoch"] = epoch
1491
  # ======================== Training ================================
1492
  metrics_logger.update_state_metrics(local_state)