feat: allow eval only
Browse files- tools/train/train.py +87 -70
tools/train/train.py
CHANGED
@@ -525,6 +525,8 @@ class TrainingArguments:
|
|
525 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
526 |
if self.log_norm_steps is True:
|
527 |
self.log_norm_steps = self.logging_steps
|
|
|
|
|
528 |
if (
|
529 |
os.path.exists(self.output_dir)
|
530 |
and os.listdir(self.output_dir)
|
@@ -1354,6 +1356,8 @@ def main():
|
|
1354 |
# init variables
|
1355 |
start_time = time.perf_counter() - local_state["train_time"]
|
1356 |
train_metrics = None
|
|
|
|
|
1357 |
metrics_logger = MetricsLogger(local_state["step"])
|
1358 |
epochs = tqdm(
|
1359 |
range(local_state["epoch"], num_epochs),
|
@@ -1532,85 +1536,98 @@ def main():
|
|
1532 |
metrics_logger.update_state_metrics(local_state)
|
1533 |
metrics_logger.log({})
|
1534 |
|
1535 |
-
|
1536 |
-
|
1537 |
-
|
1538 |
-
|
1539 |
-
"train",
|
1540 |
-
loader_bs,
|
1541 |
-
epoch,
|
1542 |
-
)
|
1543 |
-
# train
|
1544 |
-
for batch in tqdm(
|
1545 |
-
train_loader,
|
1546 |
-
desc="Training...",
|
1547 |
-
position=1,
|
1548 |
-
leave=False,
|
1549 |
-
total=steps_per_epoch,
|
1550 |
-
disable=jax.process_index() > 0,
|
1551 |
-
):
|
1552 |
-
# calculate delta time (we have a lag of one step but it's ok)
|
1553 |
-
train_time = time.perf_counter() - start_time
|
1554 |
-
|
1555 |
-
# set correct shape to batch
|
1556 |
-
# - add grad_step dim if gradient_accumulation_steps > 1
|
1557 |
-
bs_shape = (
|
1558 |
-
(batch_size_per_node_per_grad_step * node_groups,)
|
1559 |
-
if not use_vmap_trick
|
1560 |
-
else (
|
1561 |
-
jax.local_device_count()
|
1562 |
-
* node_groups
|
1563 |
-
// training_args.mp_devices, # local dp devices
|
1564 |
-
training_args.per_device_train_batch_size,
|
1565 |
-
)
|
1566 |
)
|
1567 |
-
|
1568 |
-
|
1569 |
-
|
1570 |
-
|
1571 |
-
|
1572 |
-
# reshape batch
|
1573 |
-
batch = jax.tree_map(
|
1574 |
-
lambda x: x.reshape(bs_shape + x.shape[1:]),
|
1575 |
-
batch,
|
1576 |
)
|
1577 |
-
#
|
1578 |
-
batch
|
1579 |
-
|
1580 |
-
|
1581 |
-
|
1582 |
-
|
1583 |
-
|
1584 |
-
|
1585 |
-
|
1586 |
-
if (
|
1587 |
-
local_state["step"] % training_args.logging_steps == 0
|
1588 |
-
and jax.process_index() == 0
|
1589 |
):
|
1590 |
-
|
1591 |
-
|
1592 |
-
|
1593 |
-
eval_metrics = None
|
1594 |
-
if local_state["step"] % training_args.eval_steps == 0:
|
1595 |
-
eval_metrics = run_evaluation()
|
1596 |
|
1597 |
-
|
1598 |
-
|
|
|
1599 |
|
1600 |
-
|
1601 |
-
|
1602 |
-
|
1603 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1604 |
|
1605 |
-
|
1606 |
-
|
1607 |
-
|
1608 |
|
1609 |
-
# Final evaluation
|
1610 |
-
|
|
|
1611 |
|
1612 |
# save checkpoint after each epoch
|
1613 |
-
|
|
|
1614 |
|
1615 |
|
1616 |
if __name__ == "__main__":
|
|
|
525 |
self.per_device_eval_batch_size = self.per_device_train_batch_size
|
526 |
if self.log_norm_steps is True:
|
527 |
self.log_norm_steps = self.logging_steps
|
528 |
+
if not self.do_train:
|
529 |
+
self.num_train_epochs = 1
|
530 |
if (
|
531 |
os.path.exists(self.output_dir)
|
532 |
and os.listdir(self.output_dir)
|
|
|
1356 |
# init variables
|
1357 |
start_time = time.perf_counter() - local_state["train_time"]
|
1358 |
train_metrics = None
|
1359 |
+
evaluation_ran = False
|
1360 |
+
save_model_ran = False
|
1361 |
metrics_logger = MetricsLogger(local_state["step"])
|
1362 |
epochs = tqdm(
|
1363 |
range(local_state["epoch"], num_epochs),
|
|
|
1536 |
metrics_logger.update_state_metrics(local_state)
|
1537 |
metrics_logger.log({})
|
1538 |
|
1539 |
+
if training_args.do_train:
|
1540 |
+
# load data - may be replicated on multiple nodes
|
1541 |
+
node_groups = max(
|
1542 |
+
1, training_args.mp_devices // jax.local_device_count()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1543 |
)
|
1544 |
+
loader_bs = batch_size_per_node * node_groups
|
1545 |
+
train_loader = dataset.dataloader(
|
1546 |
+
"train",
|
1547 |
+
loader_bs,
|
1548 |
+
epoch,
|
|
|
|
|
|
|
|
|
1549 |
)
|
1550 |
+
# train
|
1551 |
+
for batch in tqdm(
|
1552 |
+
train_loader,
|
1553 |
+
desc="Training...",
|
1554 |
+
position=1,
|
1555 |
+
leave=False,
|
1556 |
+
total=steps_per_epoch,
|
1557 |
+
disable=jax.process_index() > 0,
|
|
|
|
|
|
|
|
|
1558 |
):
|
1559 |
+
# calculate delta time (we have a lag of one step but it's ok)
|
1560 |
+
train_time = time.perf_counter() - start_time
|
|
|
|
|
|
|
|
|
1561 |
|
1562 |
+
# reset control variables
|
1563 |
+
evaluation_ran = False
|
1564 |
+
save_model_ran = False
|
1565 |
|
1566 |
+
# set correct shape to batch
|
1567 |
+
# - add grad_step dim if gradient_accumulation_steps > 1
|
1568 |
+
bs_shape = (
|
1569 |
+
(batch_size_per_node_per_grad_step * node_groups,)
|
1570 |
+
if not use_vmap_trick
|
1571 |
+
else (
|
1572 |
+
jax.local_device_count()
|
1573 |
+
* node_groups
|
1574 |
+
// training_args.mp_devices, # local dp devices
|
1575 |
+
training_args.per_device_train_batch_size,
|
1576 |
+
)
|
1577 |
+
)
|
1578 |
+
if training_args.gradient_accumulation_steps > 1:
|
1579 |
+
# reshape data into (gradient_accumulation_steps, batch_per_node, ...)
|
1580 |
+
# to avoid any data redistribution when sharding
|
1581 |
+
bs_shape = (
|
1582 |
+
training_args.gradient_accumulation_steps,
|
1583 |
+
) + bs_shape
|
1584 |
+
|
1585 |
+
# reshape batch
|
1586 |
+
batch = jax.tree_map(
|
1587 |
+
lambda x: x.reshape(bs_shape + x.shape[1:]),
|
1588 |
+
batch,
|
1589 |
+
)
|
1590 |
+
# freeze batch to pass safely to jax transforms
|
1591 |
+
batch = freeze(batch)
|
1592 |
+
|
1593 |
+
# train step
|
1594 |
+
state, train_metrics = p_train_step(state, batch, train_time)
|
1595 |
+
local_state["step"] += 1
|
1596 |
+
local_state["train_time"] = train_time
|
1597 |
+
local_state["train_samples"] += batch_size_per_step
|
1598 |
+
|
1599 |
+
if (
|
1600 |
+
local_state["step"] % training_args.logging_steps == 0
|
1601 |
+
and jax.process_index() == 0
|
1602 |
+
):
|
1603 |
+
metrics_logger.update_state_metrics(local_state)
|
1604 |
+
metrics_logger.log(train_metrics, prefix="train")
|
1605 |
+
|
1606 |
+
eval_metrics = None
|
1607 |
+
if local_state["step"] % training_args.eval_steps == 0:
|
1608 |
+
eval_metrics = run_evaluation()
|
1609 |
+
evaluation_ran = True
|
1610 |
+
|
1611 |
+
if local_state["step"] % training_args.save_steps == 0:
|
1612 |
+
run_save_model(state, eval_metrics)
|
1613 |
+
save_model_ran = True
|
1614 |
+
|
1615 |
+
# log final train metrics
|
1616 |
+
if train_metrics is not None:
|
1617 |
+
metrics_logger.update_state_metrics(state)
|
1618 |
+
metrics_logger.log(train_metrics, prefix="train")
|
1619 |
|
1620 |
+
epochs.write(
|
1621 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metrics['loss']}, Learning Rate: {train_metrics['learning_rate']})"
|
1622 |
+
)
|
1623 |
|
1624 |
+
# Final evaluation at the end of each epoch
|
1625 |
+
if not evaluation_ran:
|
1626 |
+
eval_metrics = run_evaluation()
|
1627 |
|
1628 |
# save checkpoint after each epoch
|
1629 |
+
if not save_model_ran:
|
1630 |
+
run_save_model(state, eval_metrics)
|
1631 |
|
1632 |
|
1633 |
if __name__ == "__main__":
|