|
import torch.nn as nn |
|
import torch |
|
import numpy as np |
|
from .embedders import get_embedder |
|
|
|
class ImplicitNet(nn.Module): |
|
def __init__(self, opt): |
|
super().__init__() |
|
|
|
dims = [opt.d_in] + list( |
|
opt.dims) + [opt.d_out + opt.feature_vector_size] |
|
self.num_layers = len(dims) |
|
self.skip_in = opt.skip_in |
|
self.embed_fn = None |
|
self.opt = opt |
|
|
|
if opt.multires > 0: |
|
embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode) |
|
self.embed_fn = embed_fn |
|
dims[0] = input_ch |
|
self.cond = opt.cond |
|
if self.cond == 'smpl': |
|
self.cond_layer = [0] |
|
self.cond_dim = 69 |
|
elif self.cond == 'frame': |
|
self.cond_layer = [0] |
|
self.cond_dim = opt.dim_frame_encoding |
|
self.dim_pose_embed = 0 |
|
if self.dim_pose_embed > 0: |
|
self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed) |
|
self.cond_dim = self.dim_pose_embed |
|
for l in range(0, self.num_layers - 1): |
|
if l + 1 in self.skip_in: |
|
out_dim = dims[l + 1] - dims[0] |
|
else: |
|
out_dim = dims[l + 1] |
|
|
|
if self.cond != 'none' and l in self.cond_layer: |
|
lin = nn.Linear(dims[l] + self.cond_dim, out_dim) |
|
else: |
|
lin = nn.Linear(dims[l], out_dim) |
|
if opt.init == 'geometry': |
|
if l == self.num_layers - 2: |
|
torch.nn.init.normal_(lin.weight, |
|
mean=np.sqrt(np.pi) / |
|
np.sqrt(dims[l]), |
|
std=0.0001) |
|
torch.nn.init.constant_(lin.bias, -opt.bias) |
|
elif opt.multires > 0 and l == 0: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) |
|
torch.nn.init.normal_(lin.weight[:, :3], 0.0, |
|
np.sqrt(2) / np.sqrt(out_dim)) |
|
elif opt.multires > 0 and l in self.skip_in: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.normal_(lin.weight, 0.0, |
|
np.sqrt(2) / np.sqrt(out_dim)) |
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], |
|
0.0) |
|
else: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.normal_(lin.weight, 0.0, |
|
np.sqrt(2) / np.sqrt(out_dim)) |
|
if opt.init == 'zero': |
|
init_val = 1e-5 |
|
if l == self.num_layers - 2: |
|
torch.nn.init.constant_(lin.bias, 0.0) |
|
torch.nn.init.uniform_(lin.weight, -init_val, init_val) |
|
if opt.weight_norm: |
|
lin = nn.utils.weight_norm(lin) |
|
setattr(self, "lin" + str(l), lin) |
|
self.softplus = nn.Softplus(beta=100) |
|
|
|
def forward(self, input, cond, current_epoch=None): |
|
if input.ndim == 2: input = input.unsqueeze(0) |
|
|
|
num_batch, num_point, num_dim = input.shape |
|
|
|
if num_batch * num_point == 0: return input |
|
|
|
input = input.reshape(num_batch * num_point, num_dim) |
|
|
|
if self.cond != 'none': |
|
num_batch, num_cond = cond[self.cond].shape |
|
|
|
input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond) |
|
|
|
input_cond = input_cond.reshape(num_batch * num_point, num_cond) |
|
|
|
if self.dim_pose_embed: |
|
input_cond = self.lin_p0(input_cond) |
|
|
|
if self.embed_fn is not None: |
|
input = self.embed_fn(input) |
|
|
|
x = input |
|
|
|
for l in range(0, self.num_layers - 1): |
|
lin = getattr(self, "lin" + str(l)) |
|
if self.cond != 'none' and l in self.cond_layer: |
|
x = torch.cat([x, input_cond], dim=-1) |
|
if l in self.skip_in: |
|
x = torch.cat([x, input], 1) / np.sqrt(2) |
|
x = lin(x) |
|
if l < self.num_layers - 2: |
|
x = self.softplus(x) |
|
|
|
x = x.reshape(num_batch, num_point, -1) |
|
|
|
return x |
|
|
|
def gradient(self, x, cond): |
|
x.requires_grad_(True) |
|
y = self.forward(x, cond)[:, :1] |
|
d_output = torch.ones_like(y, requires_grad=False, device=y.device) |
|
gradients = torch.autograd.grad(outputs=y, |
|
inputs=x, |
|
grad_outputs=d_output, |
|
create_graph=True, |
|
retain_graph=True, |
|
only_inputs=True)[0] |
|
return gradients.unsqueeze(1) |
|
|
|
|
|
class RenderingNet(nn.Module): |
|
def __init__(self, opt): |
|
super().__init__() |
|
|
|
self.mode = opt.mode |
|
dims = [opt.d_in + opt.feature_vector_size] + list( |
|
opt.dims) + [opt.d_out] |
|
|
|
self.embedview_fn = None |
|
if opt.multires_view > 0: |
|
embedview_fn, input_ch = get_embedder(opt.multires_view) |
|
self.embedview_fn = embedview_fn |
|
dims[0] += (input_ch - 3) |
|
if self.mode == 'nerf_frame_encoding': |
|
dims[0] += opt.dim_frame_encoding |
|
if self.mode == 'pose': |
|
self.dim_cond_embed = 8 |
|
self.cond_dim = 69 |
|
|
|
self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed) |
|
self.num_layers = len(dims) |
|
for l in range(0, self.num_layers - 1): |
|
out_dim = dims[l + 1] |
|
lin = nn.Linear(dims[l], out_dim) |
|
if opt.weight_norm: |
|
lin = nn.utils.weight_norm(lin) |
|
setattr(self, "lin" + str(l), lin) |
|
self.relu = nn.ReLU() |
|
self.sigmoid = nn.Sigmoid() |
|
|
|
def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None): |
|
if self.embedview_fn is not None: |
|
if self.mode == 'nerf_frame_encoding': |
|
view_dirs = self.embedview_fn(view_dirs) |
|
|
|
if self.mode == 'nerf_frame_encoding': |
|
frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1) |
|
rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1) |
|
elif self.mode == 'pose': |
|
num_points = points.shape[0] |
|
body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1) |
|
body_pose = self.lin_pose(body_pose) |
|
rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1) |
|
else: |
|
raise NotImplementedError |
|
|
|
x = rendering_input |
|
for l in range(0, self.num_layers - 1): |
|
lin = getattr(self, "lin" + str(l)) |
|
x = lin(x) |
|
if l < self.num_layers - 2: |
|
x = self.relu(x) |
|
x = self.sigmoid(x) |
|
return x |
|
|