import torch.nn as nn import torch import numpy as np from .embedders import get_embedder class ImplicitNet(nn.Module): def __init__(self, opt): super().__init__() dims = [opt.d_in] + list( opt.dims) + [opt.d_out + opt.feature_vector_size] self.num_layers = len(dims) self.skip_in = opt.skip_in self.embed_fn = None self.opt = opt if opt.multires > 0: embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode) self.embed_fn = embed_fn dims[0] = input_ch self.cond = opt.cond if self.cond == 'smpl': self.cond_layer = [0] self.cond_dim = 69 elif self.cond == 'frame': self.cond_layer = [0] self.cond_dim = opt.dim_frame_encoding self.dim_pose_embed = 0 if self.dim_pose_embed > 0: self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed) self.cond_dim = self.dim_pose_embed for l in range(0, self.num_layers - 1): if l + 1 in self.skip_in: out_dim = dims[l + 1] - dims[0] else: out_dim = dims[l + 1] if self.cond != 'none' and l in self.cond_layer: lin = nn.Linear(dims[l] + self.cond_dim, out_dim) else: lin = nn.Linear(dims[l], out_dim) if opt.init == 'geometry': if l == self.num_layers - 2: torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) torch.nn.init.constant_(lin.bias, -opt.bias) elif opt.multires > 0 and l == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) elif opt.multires > 0 and l in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) else: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) if opt.init == 'zero': init_val = 1e-5 if l == self.num_layers - 2: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.uniform_(lin.weight, -init_val, init_val) if opt.weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.softplus = nn.Softplus(beta=100) def forward(self, input, cond, current_epoch=None): if input.ndim == 2: input = input.unsqueeze(0) num_batch, num_point, num_dim = input.shape if num_batch * num_point == 0: return input input = input.reshape(num_batch * num_point, num_dim) if self.cond != 'none': num_batch, num_cond = cond[self.cond].shape input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond) input_cond = input_cond.reshape(num_batch * num_point, num_cond) if self.dim_pose_embed: input_cond = self.lin_p0(input_cond) if self.embed_fn is not None: input = self.embed_fn(input) x = input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) if self.cond != 'none' and l in self.cond_layer: x = torch.cat([x, input_cond], dim=-1) if l in self.skip_in: x = torch.cat([x, input], 1) / np.sqrt(2) x = lin(x) if l < self.num_layers - 2: x = self.softplus(x) x = x.reshape(num_batch, num_point, -1) return x def gradient(self, x, cond): x.requires_grad_(True) y = self.forward(x, cond)[:, :1] d_output = torch.ones_like(y, requires_grad=False, device=y.device) gradients = torch.autograd.grad(outputs=y, inputs=x, grad_outputs=d_output, create_graph=True, retain_graph=True, only_inputs=True)[0] return gradients.unsqueeze(1) class RenderingNet(nn.Module): def __init__(self, opt): super().__init__() self.mode = opt.mode dims = [opt.d_in + opt.feature_vector_size] + list( opt.dims) + [opt.d_out] self.embedview_fn = None if opt.multires_view > 0: embedview_fn, input_ch = get_embedder(opt.multires_view) self.embedview_fn = embedview_fn dims[0] += (input_ch - 3) if self.mode == 'nerf_frame_encoding': dims[0] += opt.dim_frame_encoding if self.mode == 'pose': self.dim_cond_embed = 8 self.cond_dim = 69 # dimension of the body pose, global orientation excluded. # lower the condition dimension self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed) self.num_layers = len(dims) for l in range(0, self.num_layers - 1): out_dim = dims[l + 1] lin = nn.Linear(dims[l], out_dim) if opt.weight_norm: lin = nn.utils.weight_norm(lin) setattr(self, "lin" + str(l), lin) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None): if self.embedview_fn is not None: if self.mode == 'nerf_frame_encoding': view_dirs = self.embedview_fn(view_dirs) if self.mode == 'nerf_frame_encoding': frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1) rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1) elif self.mode == 'pose': num_points = points.shape[0] body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1) body_pose = self.lin_pose(body_pose) rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1) else: raise NotImplementedError x = rendering_input for l in range(0, self.num_layers - 1): lin = getattr(self, "lin" + str(l)) x = lin(x) if l < self.num_layers - 2: x = self.relu(x) x = self.sigmoid(x) return x