IF3D / code /lib /model /networks.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
7.21 kB
import torch.nn as nn
import torch
import numpy as np
from .embedders import get_embedder
class ImplicitNet(nn.Module):
def __init__(self, opt):
super().__init__()
dims = [opt.d_in] + list(
opt.dims) + [opt.d_out + opt.feature_vector_size]
self.num_layers = len(dims)
self.skip_in = opt.skip_in
self.embed_fn = None
self.opt = opt
if opt.multires > 0:
embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode)
self.embed_fn = embed_fn
dims[0] = input_ch
self.cond = opt.cond
if self.cond == 'smpl':
self.cond_layer = [0]
self.cond_dim = 69
elif self.cond == 'frame':
self.cond_layer = [0]
self.cond_dim = opt.dim_frame_encoding
self.dim_pose_embed = 0
if self.dim_pose_embed > 0:
self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed)
self.cond_dim = self.dim_pose_embed
for l in range(0, self.num_layers - 1):
if l + 1 in self.skip_in:
out_dim = dims[l + 1] - dims[0]
else:
out_dim = dims[l + 1]
if self.cond != 'none' and l in self.cond_layer:
lin = nn.Linear(dims[l] + self.cond_dim, out_dim)
else:
lin = nn.Linear(dims[l], out_dim)
if opt.init == 'geometry':
if l == self.num_layers - 2:
torch.nn.init.normal_(lin.weight,
mean=np.sqrt(np.pi) /
np.sqrt(dims[l]),
std=0.0001)
torch.nn.init.constant_(lin.bias, -opt.bias)
elif opt.multires > 0 and l == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(lin.weight[:, :3], 0.0,
np.sqrt(2) / np.sqrt(out_dim))
elif opt.multires > 0 and l in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0,
np.sqrt(2) / np.sqrt(out_dim))
if opt.init == 'zero':
init_val = 1e-5
if l == self.num_layers - 2:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.uniform_(lin.weight, -init_val, init_val)
if opt.weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.softplus = nn.Softplus(beta=100)
def forward(self, input, cond, current_epoch=None):
if input.ndim == 2: input = input.unsqueeze(0)
num_batch, num_point, num_dim = input.shape
if num_batch * num_point == 0: return input
input = input.reshape(num_batch * num_point, num_dim)
if self.cond != 'none':
num_batch, num_cond = cond[self.cond].shape
input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond)
input_cond = input_cond.reshape(num_batch * num_point, num_cond)
if self.dim_pose_embed:
input_cond = self.lin_p0(input_cond)
if self.embed_fn is not None:
input = self.embed_fn(input)
x = input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
if self.cond != 'none' and l in self.cond_layer:
x = torch.cat([x, input_cond], dim=-1)
if l in self.skip_in:
x = torch.cat([x, input], 1) / np.sqrt(2)
x = lin(x)
if l < self.num_layers - 2:
x = self.softplus(x)
x = x.reshape(num_batch, num_point, -1)
return x
def gradient(self, x, cond):
x.requires_grad_(True)
y = self.forward(x, cond)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True)[0]
return gradients.unsqueeze(1)
class RenderingNet(nn.Module):
def __init__(self, opt):
super().__init__()
self.mode = opt.mode
dims = [opt.d_in + opt.feature_vector_size] + list(
opt.dims) + [opt.d_out]
self.embedview_fn = None
if opt.multires_view > 0:
embedview_fn, input_ch = get_embedder(opt.multires_view)
self.embedview_fn = embedview_fn
dims[0] += (input_ch - 3)
if self.mode == 'nerf_frame_encoding':
dims[0] += opt.dim_frame_encoding
if self.mode == 'pose':
self.dim_cond_embed = 8
self.cond_dim = 69 # dimension of the body pose, global orientation excluded.
# lower the condition dimension
self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed)
self.num_layers = len(dims)
for l in range(0, self.num_layers - 1):
out_dim = dims[l + 1]
lin = nn.Linear(dims[l], out_dim)
if opt.weight_norm:
lin = nn.utils.weight_norm(lin)
setattr(self, "lin" + str(l), lin)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None):
if self.embedview_fn is not None:
if self.mode == 'nerf_frame_encoding':
view_dirs = self.embedview_fn(view_dirs)
if self.mode == 'nerf_frame_encoding':
frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1)
rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1)
elif self.mode == 'pose':
num_points = points.shape[0]
body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1)
body_pose = self.lin_pose(body_pose)
rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1)
else:
raise NotImplementedError
x = rendering_input
for l in range(0, self.num_layers - 1):
lin = getattr(self, "lin" + str(l))
x = lin(x)
if l < self.num_layers - 2:
x = self.relu(x)
x = self.sigmoid(x)
return x