boris commited on
Commit
65bb95f
1 Parent(s): a4d2af8

feat: allow eval only

Browse files
Files changed (1) hide show
  1. tools/train/train.py +87 -70
tools/train/train.py CHANGED
@@ -525,6 +525,8 @@ class TrainingArguments:
525
  self.per_device_eval_batch_size = self.per_device_train_batch_size
526
  if self.log_norm_steps is True:
527
  self.log_norm_steps = self.logging_steps
 
 
528
  if (
529
  os.path.exists(self.output_dir)
530
  and os.listdir(self.output_dir)
@@ -1354,6 +1356,8 @@ def main():
1354
  # init variables
1355
  start_time = time.perf_counter() - local_state["train_time"]
1356
  train_metrics = None
 
 
1357
  metrics_logger = MetricsLogger(local_state["step"])
1358
  epochs = tqdm(
1359
  range(local_state["epoch"], num_epochs),
@@ -1532,85 +1536,98 @@ def main():
1532
  metrics_logger.update_state_metrics(local_state)
1533
  metrics_logger.log({})
1534
 
1535
- # load data - may be replicated on multiple nodes
1536
- node_groups = max(1, training_args.mp_devices // jax.local_device_count())
1537
- loader_bs = batch_size_per_node * node_groups
1538
- train_loader = dataset.dataloader(
1539
- "train",
1540
- loader_bs,
1541
- epoch,
1542
- )
1543
- # train
1544
- for batch in tqdm(
1545
- train_loader,
1546
- desc="Training...",
1547
- position=1,
1548
- leave=False,
1549
- total=steps_per_epoch,
1550
- disable=jax.process_index() > 0,
1551
- ):
1552
- # calculate delta time (we have a lag of one step but it's ok)
1553
- train_time = time.perf_counter() - start_time
1554
-
1555
- # set correct shape to batch
1556
- # - add grad_step dim if gradient_accumulation_steps > 1
1557
- bs_shape = (
1558
- (batch_size_per_node_per_grad_step * node_groups,)
1559
- if not use_vmap_trick
1560
- else (
1561
- jax.local_device_count()
1562
- * node_groups
1563
- // training_args.mp_devices, # local dp devices
1564
- training_args.per_device_train_batch_size,
1565
- )
1566
  )
1567
- if training_args.gradient_accumulation_steps > 1:
1568
- # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1569
- # to avoid any data redistribution when sharding
1570
- bs_shape = (training_args.gradient_accumulation_steps,) + bs_shape
1571
-
1572
- # reshape batch
1573
- batch = jax.tree_map(
1574
- lambda x: x.reshape(bs_shape + x.shape[1:]),
1575
- batch,
1576
  )
1577
- # freeze batch to pass safely to jax transforms
1578
- batch = freeze(batch)
1579
-
1580
- # train step
1581
- state, train_metrics = p_train_step(state, batch, train_time)
1582
- local_state["step"] += 1
1583
- local_state["train_time"] = train_time
1584
- local_state["train_samples"] += batch_size_per_step
1585
-
1586
- if (
1587
- local_state["step"] % training_args.logging_steps == 0
1588
- and jax.process_index() == 0
1589
  ):
1590
- metrics_logger.update_state_metrics(local_state)
1591
- metrics_logger.log(train_metrics, prefix="train")
1592
-
1593
- eval_metrics = None
1594
- if local_state["step"] % training_args.eval_steps == 0:
1595
- eval_metrics = run_evaluation()
1596
 
1597
- if local_state["step"] % training_args.save_steps == 0:
1598
- run_save_model(state, eval_metrics)
 
1599
 
1600
- # log final train metrics
1601
- if train_metrics is not None:
1602
- metrics_logger.update_state_metrics(state)
1603
- metrics_logger.log(train_metrics, prefix="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1604
 
1605
- epochs.write(
1606
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1607
- )
1608
 
1609
- # Final evaluation
1610
- eval_metrics = run_evaluation()
 
1611
 
1612
  # save checkpoint after each epoch
1613
- run_save_model(state, eval_metrics)
 
1614
 
1615
 
1616
  if __name__ == "__main__":
 
525
  self.per_device_eval_batch_size = self.per_device_train_batch_size
526
  if self.log_norm_steps is True:
527
  self.log_norm_steps = self.logging_steps
528
+ if not self.do_train:
529
+ self.num_train_epochs = 1
530
  if (
531
  os.path.exists(self.output_dir)
532
  and os.listdir(self.output_dir)
 
1356
  # init variables
1357
  start_time = time.perf_counter() - local_state["train_time"]
1358
  train_metrics = None
1359
+ evaluation_ran = False
1360
+ save_model_ran = False
1361
  metrics_logger = MetricsLogger(local_state["step"])
1362
  epochs = tqdm(
1363
  range(local_state["epoch"], num_epochs),
 
1536
  metrics_logger.update_state_metrics(local_state)
1537
  metrics_logger.log({})
1538
 
1539
+ if training_args.do_train:
1540
+ # load data - may be replicated on multiple nodes
1541
+ node_groups = max(
1542
+ 1, training_args.mp_devices // jax.local_device_count()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1543
  )
1544
+ loader_bs = batch_size_per_node * node_groups
1545
+ train_loader = dataset.dataloader(
1546
+ "train",
1547
+ loader_bs,
1548
+ epoch,
 
 
 
 
1549
  )
1550
+ # train
1551
+ for batch in tqdm(
1552
+ train_loader,
1553
+ desc="Training...",
1554
+ position=1,
1555
+ leave=False,
1556
+ total=steps_per_epoch,
1557
+ disable=jax.process_index() > 0,
 
 
 
 
1558
  ):
1559
+ # calculate delta time (we have a lag of one step but it's ok)
1560
+ train_time = time.perf_counter() - start_time
 
 
 
 
1561
 
1562
+ # reset control variables
1563
+ evaluation_ran = False
1564
+ save_model_ran = False
1565
 
1566
+ # set correct shape to batch
1567
+ # - add grad_step dim if gradient_accumulation_steps > 1
1568
+ bs_shape = (
1569
+ (batch_size_per_node_per_grad_step * node_groups,)
1570
+ if not use_vmap_trick
1571
+ else (
1572
+ jax.local_device_count()
1573
+ * node_groups
1574
+ // training_args.mp_devices, # local dp devices
1575
+ training_args.per_device_train_batch_size,
1576
+ )
1577
+ )
1578
+ if training_args.gradient_accumulation_steps > 1:
1579
+ # reshape data into (gradient_accumulation_steps, batch_per_node, ...)
1580
+ # to avoid any data redistribution when sharding
1581
+ bs_shape = (
1582
+ training_args.gradient_accumulation_steps,
1583
+ ) + bs_shape
1584
+
1585
+ # reshape batch
1586
+ batch = jax.tree_map(
1587
+ lambda x: x.reshape(bs_shape + x.shape[1:]),
1588
+ batch,
1589
+ )
1590
+ # freeze batch to pass safely to jax transforms
1591
+ batch = freeze(batch)
1592
+
1593
+ # train step
1594
+ state, train_metrics = p_train_step(state, batch, train_time)
1595
+ local_state["step"] += 1
1596
+ local_state["train_time"] = train_time
1597
+ local_state["train_samples"] += batch_size_per_step
1598
+
1599
+ if (
1600
+ local_state["step"] % training_args.logging_steps == 0
1601
+ and jax.process_index() == 0
1602
+ ):
1603
+ metrics_logger.update_state_metrics(local_state)
1604
+ metrics_logger.log(train_metrics, prefix="train")
1605
+
1606
+ eval_metrics = None
1607
+ if local_state["step"] % training_args.eval_steps == 0:
1608
+ eval_metrics = run_evaluation()
1609
+ evaluation_ran = True
1610
+
1611
+ if local_state["step"] % training_args.save_steps == 0:
1612
+ run_save_model(state, eval_metrics)
1613
+ save_model_ran = True
1614
+
1615
+ # log final train metrics
1616
+ if train_metrics is not None:
1617
+ metrics_logger.update_state_metrics(state)
1618
+ metrics_logger.log(train_metrics, prefix="train")
1619
 
1620
+ epochs.write(
1621
+ f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
1622
+ )
1623
 
1624
+ # Final evaluation at the end of each epoch
1625
+ if not evaluation_ran:
1626
+ eval_metrics = run_evaluation()
1627
 
1628
  # save checkpoint after each epoch
1629
+ if not save_model_ran:
1630
+ run_save_model(state, eval_metrics)
1631
 
1632
 
1633
  if __name__ == "__main__":