|
from v2a_model import V2AModel |
|
from lib.datasets import create_dataset |
|
import hydra |
|
import pytorch_lightning as pl |
|
from pytorch_lightning.loggers import WandbLogger |
|
import os |
|
import glob |
|
|
|
@hydra.main(config_path="confs", config_name="base") |
|
def main(opt): |
|
pl.seed_everything(42) |
|
print("Working dir:", os.getcwd()) |
|
|
|
checkpoint_callback = pl.callbacks.ModelCheckpoint( |
|
dirpath="checkpoints/", |
|
filename="{epoch:04d}-{loss}", |
|
save_on_train_epoch_end=True, |
|
save_last=True) |
|
logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}") |
|
|
|
trainer = pl.Trainer( |
|
gpus=1, |
|
accelerator="gpu", |
|
callbacks=[checkpoint_callback], |
|
max_epochs=8000, |
|
check_val_every_n_epoch=50, |
|
logger=logger, |
|
log_every_n_steps=1, |
|
num_sanity_val_steps=0 |
|
) |
|
|
|
model = V2AModel(opt) |
|
checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1] |
|
testset = create_dataset(opt.dataset.metainfo, opt.dataset.test) |
|
|
|
trainer.test(model, testset, ckpt_path=checkpoint) |
|
|
|
if __name__ == '__main__': |
|
main() |