|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
import os |
|
import pickle |
|
import warnings |
|
|
|
import hydra |
|
import numpy as np |
|
import torch |
|
from nerf.dataset import get_nerf_datasets, trivial_collate |
|
from nerf.nerf_renderer import RadianceFieldRenderer, visualize_nerf_outputs |
|
from nerf.stats import Stats |
|
from omegaconf import DictConfig |
|
from visdom import Visdom |
|
|
|
|
|
CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs") |
|
|
|
|
|
@hydra.main(config_path=CONFIG_DIR, config_name="lego") |
|
def main(cfg: DictConfig): |
|
|
|
|
|
np.random.seed(cfg.seed) |
|
torch.manual_seed(cfg.seed) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
else: |
|
warnings.warn( |
|
"Please note that although executing on CPU is supported," |
|
+ "the training is unlikely to finish in reasonable time." |
|
) |
|
device = "cpu" |
|
|
|
|
|
model = RadianceFieldRenderer( |
|
image_size=cfg.data.image_size, |
|
n_pts_per_ray=cfg.raysampler.n_pts_per_ray, |
|
n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray, |
|
n_rays_per_image=cfg.raysampler.n_rays_per_image, |
|
min_depth=cfg.raysampler.min_depth, |
|
max_depth=cfg.raysampler.max_depth, |
|
stratified=cfg.raysampler.stratified, |
|
stratified_test=cfg.raysampler.stratified_test, |
|
chunk_size_test=cfg.raysampler.chunk_size_test, |
|
n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz, |
|
n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir, |
|
n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz, |
|
n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir, |
|
n_layers_xyz=cfg.implicit_function.n_layers_xyz, |
|
density_noise_std=cfg.implicit_function.density_noise_std, |
|
visualization=cfg.visualization.visdom, |
|
) |
|
|
|
|
|
model.to(device) |
|
|
|
|
|
stats = None |
|
optimizer_state_dict = None |
|
start_epoch = 0 |
|
|
|
checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) |
|
if len(cfg.checkpoint_path) > 0: |
|
|
|
checkpoint_dir = os.path.split(checkpoint_path)[0] |
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
if cfg.resume and os.path.isfile(checkpoint_path): |
|
print(f"Resuming from checkpoint {checkpoint_path}.") |
|
loaded_data = torch.load(checkpoint_path) |
|
model.load_state_dict(loaded_data["model"]) |
|
stats = pickle.loads(loaded_data["stats"]) |
|
print(f" => resuming from epoch {stats.epoch}.") |
|
optimizer_state_dict = loaded_data["optimizer"] |
|
start_epoch = stats.epoch |
|
|
|
|
|
optimizer = torch.optim.Adam( |
|
model.parameters(), |
|
lr=cfg.optimizer.lr, |
|
) |
|
|
|
|
|
if optimizer_state_dict is not None: |
|
optimizer.load_state_dict(optimizer_state_dict) |
|
optimizer.last_epoch = start_epoch |
|
|
|
|
|
if stats is None: |
|
stats = Stats( |
|
["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def lr_lambda(epoch): |
|
return cfg.optimizer.lr_scheduler_gamma ** ( |
|
epoch / cfg.optimizer.lr_scheduler_step_size |
|
) |
|
|
|
|
|
lr_scheduler = torch.optim.lr_scheduler.LambdaLR( |
|
optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False |
|
) |
|
|
|
|
|
visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) |
|
|
|
|
|
if cfg.visualization.visdom: |
|
viz = Visdom( |
|
server=cfg.visualization.visdom_server, |
|
port=cfg.visualization.visdom_port, |
|
use_incoming_socket=False, |
|
) |
|
else: |
|
viz = None |
|
|
|
|
|
train_dataset, val_dataset, _ = get_nerf_datasets( |
|
dataset_name=cfg.data.dataset_name, |
|
image_size=cfg.data.image_size, |
|
) |
|
|
|
if cfg.data.precache_rays: |
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for dataset in (train_dataset, val_dataset): |
|
cache_cameras = [e["camera"].to(device) for e in dataset] |
|
cache_camera_hashes = [e["camera_idx"] for e in dataset] |
|
model.precache_rays(cache_cameras, cache_camera_hashes) |
|
|
|
train_dataloader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=1, |
|
shuffle=True, |
|
num_workers=0, |
|
collate_fn=trivial_collate, |
|
) |
|
|
|
|
|
val_dataloader = torch.utils.data.DataLoader( |
|
val_dataset, |
|
batch_size=1, |
|
num_workers=0, |
|
collate_fn=trivial_collate, |
|
sampler=torch.utils.data.RandomSampler( |
|
val_dataset, |
|
replacement=True, |
|
num_samples=cfg.optimizer.max_epochs, |
|
), |
|
) |
|
|
|
|
|
model.train() |
|
|
|
|
|
for epoch in range(start_epoch, cfg.optimizer.max_epochs): |
|
stats.new_epoch() |
|
for iteration, batch in enumerate(train_dataloader): |
|
image, camera, camera_idx = batch[0].values() |
|
image = image.to(device) |
|
camera = camera.to(device) |
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
nerf_out, metrics = model( |
|
camera_idx if cfg.data.precache_rays else None, |
|
camera, |
|
image, |
|
) |
|
|
|
|
|
loss = metrics["mse_coarse"] + metrics["mse_fine"] |
|
|
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
stats.update( |
|
{"loss": float(loss), **metrics}, |
|
stat_set="train", |
|
) |
|
|
|
if iteration % cfg.stats_print_interval == 0: |
|
stats.print(stat_set="train") |
|
|
|
|
|
if viz is not None: |
|
visuals_cache.append( |
|
{ |
|
"camera": camera.cpu(), |
|
"camera_idx": camera_idx, |
|
"image": image.cpu().detach(), |
|
"rgb_fine": nerf_out["rgb_fine"].cpu().detach(), |
|
"rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(), |
|
"rgb_gt": nerf_out["rgb_gt"].cpu().detach(), |
|
"coarse_ray_bundle": nerf_out["coarse_ray_bundle"], |
|
} |
|
) |
|
|
|
|
|
lr_scheduler.step() |
|
|
|
|
|
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: |
|
|
|
|
|
val_batch = next(val_dataloader.__iter__()) |
|
val_image, val_camera, camera_idx = val_batch[0].values() |
|
val_image = val_image.to(device) |
|
val_camera = val_camera.to(device) |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
val_nerf_out, val_metrics = model( |
|
camera_idx if cfg.data.precache_rays else None, |
|
val_camera, |
|
val_image, |
|
) |
|
|
|
|
|
stats.update(val_metrics, stat_set="val") |
|
stats.print(stat_set="val") |
|
|
|
if viz is not None: |
|
|
|
stats.plot_stats( |
|
viz=viz, |
|
visdom_env=cfg.visualization.visdom_env, |
|
plot_file=None, |
|
) |
|
|
|
visualize_nerf_outputs( |
|
val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env |
|
) |
|
|
|
|
|
model.train() |
|
|
|
|
|
if ( |
|
epoch % cfg.checkpoint_epoch_interval == 0 |
|
and len(cfg.checkpoint_path) > 0 |
|
and epoch > 0 |
|
): |
|
print(f"Storing checkpoint {checkpoint_path}.") |
|
data_to_store = { |
|
"model": model.state_dict(), |
|
"optimizer": optimizer.state_dict(), |
|
"stats": pickle.dumps(stats), |
|
} |
|
torch.save(data_to_store, checkpoint_path) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|