from .networks import ImplicitNet, RenderingNet from .density import LaplaceDensity, AbsDensity from .ray_sampler import ErrorBoundSampler from .deformer import SMPLDeformer from .smpl import SMPLServer from .sampler import PointInSpace from ..utils import utils import numpy as np import torch import torch.nn as nn from torch.autograd import grad import hydra import kaolin from kaolin.ops.mesh import index_vertices_by_faces class V2A(nn.Module): def __init__(self, opt, betas_path, gender, num_training_frames): super().__init__() # Foreground networks self.implicit_network = ImplicitNet(opt.implicit_network) self.rendering_network = RenderingNet(opt.rendering_network) # Background networks self.bg_implicit_network = ImplicitNet(opt.bg_implicit_network) self.bg_rendering_network = RenderingNet(opt.bg_rendering_network) # Frame latent encoder self.frame_latent_encoder = nn.Embedding(num_training_frames, opt.bg_rendering_network.dim_frame_encoding) self.sampler = PointInSpace() betas = np.load(betas_path) self.use_smpl_deformer = opt.use_smpl_deformer self.gender = gender if self.use_smpl_deformer: self.deformer = SMPLDeformer(betas=betas, gender=self.gender) # pre-defined bounding sphere self.sdf_bounding_sphere = 3.0 # threshold for the out-surface points self.threshold = 0.05 self.density = LaplaceDensity(**opt.density) self.bg_density = AbsDensity() self.ray_sampler = ErrorBoundSampler(self.sdf_bounding_sphere, inverse_sphere_bg=True, **opt.ray_sampler) self.smpl_server = SMPLServer(gender=self.gender, betas=betas) if opt.smpl_init: smpl_model_state = torch.load(hydra.utils.to_absolute_path('../assets/smpl_init.pth')) self.implicit_network.load_state_dict(smpl_model_state["model_state_dict"]) self.smpl_v_cano = self.smpl_server.verts_c self.smpl_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device) self.mesh_v_cano = self.smpl_server.verts_c self.mesh_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device) self.mesh_face_vertices = index_vertices_by_faces(self.mesh_v_cano, self.mesh_f_cano) def sdf_func_with_smpl_deformer(self, x, cond, smpl_tfs, smpl_verts): """ sdf_func_with_smpl_deformer method Used to compute SDF values for input points using the SMPL deformer and the implicit network. It handles the deforming of points, network inference, feature extraction, and handling of outlier points. """ if hasattr(self, "deformer"): x_c, outlier_mask = self.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts) output = self.implicit_network(x_c, cond)[0] sdf = output[:, 0:1] feature = output[:, 1:] if not self.training: sdf[outlier_mask] = 4. # set a large SDF value for outlier points return sdf, x_c, feature def check_off_in_surface_points_cano_mesh(self, x_cano, N_samples, threshold=0.05): """check_off_in_surface_points_cano_mesh method Used to check whether points are off the surface or within the surface of a canonical mesh. It calculates distances, signs, and signed distances to determine the position of points with respect to the mesh surface. The method plays a role in identifying points that might be considered outliers or outside the reconstructed avatar's surface. """ distance, _, _ = kaolin.metrics.trianglemesh.point_to_mesh_distance(x_cano.unsqueeze(0).contiguous(), self.mesh_face_vertices) distance = torch.sqrt(distance) # kaolin outputs squared distance sign = kaolin.ops.mesh.check_sign(self.mesh_v_cano, self.mesh_f_cano, x_cano.unsqueeze(0)).float() sign = 1 - 2 * sign # -1 for off-surface, 1 for in-surface signed_distance = sign * distance batch_size = x_cano.shape[0] // N_samples signed_distance = signed_distance.reshape(batch_size, N_samples, 1) # The distances are reshaped to match the batch size and the number of samples minimum = torch.min(signed_distance, 1)[0] index_off_surface = (minimum > threshold).squeeze(1) index_in_surface = (minimum <= 0.).squeeze(1) return index_off_surface, index_in_surface # Indexes of off-surface points and in-surface points def forward(self, input): # Parse model input, prepares the necessary input data and SMPL parameters for subsequent calculations torch.set_grad_enabled(True) intrinsics = input["intrinsics"] pose = input["pose"] uv = input["uv"] scale = input['smpl_params'][:, 0] smpl_pose = input["smpl_pose"] smpl_shape = input["smpl_shape"] smpl_trans = input["smpl_trans"] smpl_output = self.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) # invokes the SMPL model to obtain the transformations for pose and shape changes smpl_tfs = smpl_output['smpl_tfs'] cond = {'smpl': smpl_pose[:, 3:]/np.pi} if self.training: if input['current_epoch'] < 20 or input['current_epoch'] % 20 == 0: # set the pose to zero for the first 20 epochs cond = {'smpl': smpl_pose[:, 3:] * 0.} ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics) # get the ray directions and camera location batch_size, num_pixels, _ = ray_dirs.shape cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) # reshape to match the batch size and the number of pixels ray_dirs = ray_dirs.reshape(-1, 3) # reshape to match the batch size and the number of pixels z_vals, _ = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self, cond, smpl_tfs, eval_mode=True, smpl_verts=smpl_output['smpl_verts']) # get the z values for each pixel z_vals, z_vals_bg = z_vals # unpack the z values for the foreground and the background z_max = z_vals[:,-1] # get the maximum z value z_vals = z_vals[:,:-1] # get the z values for the foreground N_samples = z_vals.shape[1] # get the number of samples points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # 3D points along the rays are calculated by adding z_vals scaled by ray directions to the camera location. The result is stored in the points tensor of shape (batch_size * num_pixels, N_samples, 3) points_flat = points.reshape(-1, 3) # The points tensor is reshaped into a flattened tensor points_flat of shape (batch_size * num_pixels * N_samples, 3) dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) # The dirs tensor is created by repeating ray_dirs for each sample along the rays. The resulting tensor has shape (batch_size * num_pixels, N_samples, 3) sdf_output, canonical_points, feature_vectors = self.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_output['smpl_verts']) # The sdf_func_with_smpl_deformer method is called to compute the signed distance functions (SDF) for the points sdf_output = sdf_output.unsqueeze(1) # The sdf_output tensor is reshaped by unsqueezing along the first dimension if self.training: index_off_surface, index_in_surface = self.check_off_in_surface_points_cano_mesh(canonical_points, N_samples, threshold=self.threshold) canonical_points = canonical_points.reshape(num_pixels, N_samples, 3) canonical_points = canonical_points.reshape(-1, 3) # The canonical points tensor flattened to shape (-1, 3) # sample canonical SMPL surface pnts for the eikonal loss smpl_verts_c = self.smpl_server.verts_c.repeat(batch_size, 1,1) # The canonical SMPL surface vertices are repeated across the batch dimension indices = torch.randperm(smpl_verts_c.shape[1])[:num_pixels].cuda() # Random indices are generated to select a subset of vertices for sampling. The number of selected vertices is num_pixels verts_c = torch.index_select(smpl_verts_c, 1, indices) # The selected vertices are gathered from smpl_verts_c, resulting in the tensor verts_c. sample = self.sampler.get_points(verts_c, global_ratio=0.) # The get_points method of the sampler class is called to sample points around the canonical SMPL surface points. The global_ratio is set to 0.0, indicating local sampling sample.requires_grad_() # The sampled points are marked as requiring gradients local_pred = self.implicit_network(sample, cond)[..., 0:1] # The sampled points (sample) are passed through the implicit network along with the conditioning (cond). The local prediction (SDF) for each sampled point is extracted using [..., 0:1] grad_theta = gradient(sample, local_pred) # compute gradients with respect to the sampled points and their local predictions (local_pred). differentiable_points = canonical_points # The differentiable_points tensor is assigned the value of canonical_points else: differentiable_points = canonical_points.reshape(num_pixels, N_samples, 3).reshape(-1, 3) grad_theta = None sdf_output = sdf_output.reshape(num_pixels, N_samples, 1).reshape(-1, 1) # flattened to shape (num_pixels * N_samples, ) z_vals = z_vals view = -dirs.reshape(-1, 3) # The view vector is calculated as the negation of the reshaped dirs, giving the view directions for points along the rays. if differentiable_points.shape[0] > 0: # If there are differentiable points (indicating that gradient information is available) fg_rgb_flat, others = self.get_rbg_value(points_flat, differentiable_points, view, cond, smpl_tfs, feature_vectors=feature_vectors, is_training=self.training) # The returned values include fg_rgb_flat (foreground RGB values) and others (other calculated values, including normals) normal_values = others['normals'] # The normal values are extracted from the others dictionary if 'image_id' in input.keys(): frame_latent_code = self.frame_latent_encoder(input['image_id']) else: frame_latent_code = self.frame_latent_encoder(input['idx']) fg_rgb = fg_rgb_flat.reshape(-1, N_samples, 3) normal_values = normal_values.reshape(-1, N_samples, 3) weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf_output) fg_rgb_values = torch.sum(weights.unsqueeze(-1) * fg_rgb, 1) # Background rendering if input['idx'] is not None: N_bg_samples = z_vals_bg.shape[1] z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0 bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1) bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1) bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4] bg_points_flat = bg_points.reshape(-1, 4) bg_dirs_flat = bg_dirs.reshape(-1, 3) bg_output = self.bg_implicit_network(bg_points_flat, {'frame': frame_latent_code})[0] bg_sdf = bg_output[:, :1] bg_feature_vectors = bg_output[:, 1:] bg_rendering_output = self.bg_rendering_network(None, None, bg_dirs_flat, None, bg_feature_vectors, frame_latent_code) if bg_rendering_output.shape[-1] == 4: bg_rgb_flat = bg_rendering_output[..., :-1] shadow_r = bg_rendering_output[..., -1] bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3) shadow_r = shadow_r.reshape(-1, N_bg_samples, 1) bg_rgb = (1 - shadow_r) * bg_rgb else: bg_rgb_flat = bg_rendering_output bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3) bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf) bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1) else: bg_rgb_values = torch.ones_like(fg_rgb_values, device=fg_rgb_values.device) # Composite foreground and background bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values rgb_values = fg_rgb_values + bg_rgb_values normal_values = torch.sum(weights.unsqueeze(-1) * normal_values, 1) if self.training: output = { 'points': points, 'rgb_values': rgb_values, 'normal_values': normal_values, 'index_outside': input['index_outside'], 'index_off_surface': index_off_surface, 'index_in_surface': index_in_surface, 'acc_map': torch.sum(weights, -1), 'sdf_output': sdf_output, 'grad_theta': grad_theta, 'epoch': input['current_epoch'], } else: fg_output_rgb = fg_rgb_values + bg_transmittance.unsqueeze(-1) * torch.ones_like(fg_rgb_values, device=fg_rgb_values.device) output = { 'acc_map': torch.sum(weights, -1), 'rgb_values': rgb_values, 'fg_rgb_values': fg_output_rgb, 'normal_values': normal_values, 'sdf_output': sdf_output, } return output def get_rbg_value(self, x, points, view_dirs, cond, tfs, feature_vectors, is_training=True): pnts_c = points others = {} _, gradients, feature_vectors = self.forward_gradient(x, pnts_c, cond, tfs, create_graph=is_training, retain_graph=is_training) # ensure the gradient is normalized normals = nn.functional.normalize(gradients, dim=-1, eps=1e-6) fg_rendering_output = self.rendering_network(pnts_c, normals, view_dirs, cond['smpl'], feature_vectors) rgb_vals = fg_rendering_output[:, :3] others['normals'] = normals return rgb_vals, others def forward_gradient(self, x, pnts_c, cond, tfs, create_graph=True, retain_graph=True): if pnts_c.shape[0] == 0: return pnts_c.detach() pnts_c.requires_grad_(True) pnts_d = self.deformer.forward_skinning(pnts_c.unsqueeze(0), None, tfs).squeeze(0) num_dim = pnts_d.shape[-1] grads = [] for i in range(num_dim): d_out = torch.zeros_like(pnts_d, requires_grad=False, device=pnts_d.device) d_out[:, i] = 1 grad = torch.autograd.grad( outputs=pnts_d, inputs=pnts_c, grad_outputs=d_out, create_graph=create_graph, retain_graph=True if i < num_dim - 1 else retain_graph, only_inputs=True)[0] grads.append(grad) grads = torch.stack(grads, dim=-2) grads_inv = grads.inverse() output = self.implicit_network(pnts_c, cond)[0] sdf = output[:, :1] feature = output[:, 1:] d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device) gradients = torch.autograd.grad( outputs=sdf, inputs=pnts_c, grad_outputs=d_output, create_graph=create_graph, retain_graph=retain_graph, only_inputs=True)[0] return grads.reshape(grads.shape[0], -1), torch.nn.functional.normalize(torch.einsum('bi,bij->bj', gradients, grads_inv), dim=1), feature def volume_rendering(self, z_vals, z_max, sdf): density_flat = self.density(sdf) density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples # included also the dist from the sphere intersection dists = z_vals[:, 1:] - z_vals[:, :-1] dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1) # LOG SPACE free_energy = dists * density shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1) # add 0 for transperancy 1 at t_0 alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now fg_transmittance = transmittance[:, :-1] weights = alpha * fg_transmittance # probability of the ray hits something here bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering return weights, bg_transmittance def bg_volume_rendering(self, z_vals_bg, bg_sdf): bg_density_flat = self.bg_density(bg_sdf) bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:] bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1) # LOG SPACE bg_free_energy = bg_dists * bg_density bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1) # shift one step bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here return bg_weights def depth2pts_outside(self, ray_o, ray_d, depth): ''' ray_o, ray_d: [..., 3] depth: [...]; inverse of distance to sphere origin ''' o_dot_d = torch.sum(ray_d * ray_o, dim=-1) under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.sdf_bounding_sphere ** 2) d_sphere = torch.sqrt(under_sqrt) - o_dot_d p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d p_mid_norm = torch.norm(p_mid, dim=-1) rot_axis = torch.cross(ray_o, p_sphere, dim=-1) rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True) phi = torch.asin(p_mid_norm / self.sdf_bounding_sphere) theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1] rot_angle = (phi - theta).unsqueeze(-1) # [..., 1] # now rotate p_sphere # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula p_sphere_new = p_sphere * torch.cos(rot_angle) + \ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \ rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle)) p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True) pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1) return pts def gradient(inputs, outputs): d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device) points_grad = grad( outputs=outputs, inputs=inputs, grad_outputs=d_points, create_graph=True, retain_graph=True, only_inputs=True)[0][:, :, -3:] return points_grad