File size: 19,677 Bytes
6325697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
from .networks import ImplicitNet, RenderingNet
from .density import LaplaceDensity, AbsDensity
from .ray_sampler import ErrorBoundSampler
from .deformer import SMPLDeformer
from .smpl import SMPLServer

from .sampler import PointInSpace

from ..utils import utils

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import grad
import hydra
import kaolin
from kaolin.ops.mesh import index_vertices_by_faces
class V2A(nn.Module):
    def __init__(self, opt, betas_path, gender, num_training_frames):
        super().__init__()

        # Foreground networks
        self.implicit_network = ImplicitNet(opt.implicit_network)
        self.rendering_network = RenderingNet(opt.rendering_network)

        # Background networks
        self.bg_implicit_network = ImplicitNet(opt.bg_implicit_network)
        self.bg_rendering_network = RenderingNet(opt.bg_rendering_network)

        # Frame latent encoder
        self.frame_latent_encoder = nn.Embedding(num_training_frames, opt.bg_rendering_network.dim_frame_encoding)
        self.sampler = PointInSpace()

        betas = np.load(betas_path)
        self.use_smpl_deformer = opt.use_smpl_deformer
        self.gender = gender
        if self.use_smpl_deformer:
            self.deformer = SMPLDeformer(betas=betas, gender=self.gender)
        
        # pre-defined bounding sphere
        self.sdf_bounding_sphere = 3.0
        
        # threshold for the out-surface points
        self.threshold = 0.05
        
        self.density = LaplaceDensity(**opt.density)
        self.bg_density = AbsDensity()

        self.ray_sampler = ErrorBoundSampler(self.sdf_bounding_sphere, inverse_sphere_bg=True, **opt.ray_sampler)
        self.smpl_server = SMPLServer(gender=self.gender, betas=betas)

        if opt.smpl_init:
            smpl_model_state = torch.load(hydra.utils.to_absolute_path('../assets/smpl_init.pth'))
            self.implicit_network.load_state_dict(smpl_model_state["model_state_dict"])

        self.smpl_v_cano = self.smpl_server.verts_c
        self.smpl_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)

        self.mesh_v_cano = self.smpl_server.verts_c
        self.mesh_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)
        self.mesh_face_vertices = index_vertices_by_faces(self.mesh_v_cano, self.mesh_f_cano)

    def sdf_func_with_smpl_deformer(self, x, cond, smpl_tfs, smpl_verts):
        """ sdf_func_with_smpl_deformer method 
        Used to compute SDF values for input points using the SMPL deformer and the implicit network. 
        It handles the deforming of points, network inference, feature extraction, and handling of outlier points.
        """
        if hasattr(self, "deformer"):
            x_c, outlier_mask = self.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts)
            output = self.implicit_network(x_c, cond)[0]
            sdf = output[:, 0:1]
            feature = output[:, 1:]
            if not self.training:
                sdf[outlier_mask] = 4. # set a large SDF value for outlier points
            
        return sdf, x_c, feature
    
    def check_off_in_surface_points_cano_mesh(self, x_cano, N_samples, threshold=0.05):
        """check_off_in_surface_points_cano_mesh method 
        Used to check whether points are off the surface or within the surface of a canonical mesh. 
        It calculates distances, signs, and signed distances to determine the position of points with respect to the mesh surface. 
        The method plays a role in identifying points that might be considered outliers or outside the reconstructed avatar's surface.
        """

        distance, _, _ = kaolin.metrics.trianglemesh.point_to_mesh_distance(x_cano.unsqueeze(0).contiguous(), self.mesh_face_vertices)

        distance = torch.sqrt(distance) # kaolin outputs squared distance
        sign = kaolin.ops.mesh.check_sign(self.mesh_v_cano, self.mesh_f_cano, x_cano.unsqueeze(0)).float()
        sign = 1 - 2 * sign   # -1 for off-surface, 1 for in-surface
        signed_distance = sign * distance
        batch_size = x_cano.shape[0] // N_samples  
        signed_distance = signed_distance.reshape(batch_size, N_samples, 1)   # The distances are reshaped to match the batch size and the number of samples

        minimum = torch.min(signed_distance, 1)[0]
        index_off_surface = (minimum > threshold).squeeze(1)
        index_in_surface = (minimum <= 0.).squeeze(1)
        return index_off_surface, index_in_surface   # Indexes of off-surface points and in-surface points

    def forward(self, input):
        # Parse model input, prepares the necessary input data and SMPL parameters for subsequent calculations
        torch.set_grad_enabled(True)
        intrinsics = input["intrinsics"]
        pose = input["pose"]
        uv = input["uv"]

        scale = input['smpl_params'][:, 0]
        smpl_pose = input["smpl_pose"]
        smpl_shape = input["smpl_shape"]
        smpl_trans = input["smpl_trans"]
        smpl_output = self.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape)   #  invokes the SMPL model to obtain the transformations for pose and shape changes

        smpl_tfs = smpl_output['smpl_tfs']

        cond = {'smpl': smpl_pose[:, 3:]/np.pi}
        if self.training:
            if input['current_epoch'] < 20 or input['current_epoch'] % 20 == 0:   # set the pose to zero for the first 20 epochs
                cond = {'smpl': smpl_pose[:, 3:] * 0.}  
        ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics)   # get the ray directions and camera location
        batch_size, num_pixels, _ = ray_dirs.shape

        cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)   # reshape to match the batch size and the number of pixels
        ray_dirs = ray_dirs.reshape(-1, 3)    # reshape to match the batch size and the number of pixels

        z_vals, _ = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self, cond, smpl_tfs, eval_mode=True, smpl_verts=smpl_output['smpl_verts'])   # get the z values for each pixel  

        z_vals, z_vals_bg = z_vals   # unpack the z values for the foreground and the background
        z_max = z_vals[:,-1]   # get the maximum z value
        z_vals = z_vals[:,:-1]   # get the z values for the foreground
        N_samples = z_vals.shape[1]   # get the number of samples

        points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)   # 3D points along the rays are calculated by adding z_vals scaled by ray directions to the camera location. The result is stored in the points tensor of shape (batch_size * num_pixels, N_samples, 3)
        points_flat = points.reshape(-1, 3)   # The points tensor is reshaped into a flattened tensor points_flat of shape (batch_size * num_pixels * N_samples, 3)

        dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)   # The dirs tensor is created by repeating ray_dirs for each sample along the rays. The resulting tensor has shape (batch_size * num_pixels, N_samples, 3)
        sdf_output, canonical_points, feature_vectors = self.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_output['smpl_verts'])   # The sdf_func_with_smpl_deformer method is called to compute the signed distance functions (SDF) for the points

        sdf_output = sdf_output.unsqueeze(1)   # The sdf_output tensor is reshaped by unsqueezing along the first dimension

        if self.training:
            index_off_surface, index_in_surface = self.check_off_in_surface_points_cano_mesh(canonical_points, N_samples, threshold=self.threshold)
            canonical_points = canonical_points.reshape(num_pixels, N_samples, 3)   

            canonical_points = canonical_points.reshape(-1, 3)   # The canonical points tensor flattened to shape (-1, 3) 

            # sample canonical SMPL surface pnts for the eikonal loss
            smpl_verts_c = self.smpl_server.verts_c.repeat(batch_size, 1,1)   # The canonical SMPL surface vertices are repeated across the batch dimension
            
            indices = torch.randperm(smpl_verts_c.shape[1])[:num_pixels].cuda()   # Random indices are generated to select a subset of vertices for sampling. The number of selected vertices is num_pixels
            verts_c = torch.index_select(smpl_verts_c, 1, indices)   # The selected vertices are gathered from smpl_verts_c, resulting in the tensor verts_c.
            sample = self.sampler.get_points(verts_c, global_ratio=0.)   # The get_points method of the sampler class is called to sample points around the canonical SMPL surface points. The global_ratio is set to 0.0, indicating local sampling

            sample.requires_grad_()   # The sampled points are marked as requiring gradients 
            local_pred = self.implicit_network(sample, cond)[..., 0:1]   # The sampled points (sample) are passed through the implicit network along with the conditioning (cond). The local prediction (SDF) for each sampled point is extracted using [..., 0:1]
            grad_theta = gradient(sample, local_pred)   # compute gradients with respect to the sampled points and their local predictions (local_pred).

            differentiable_points = canonical_points   # The differentiable_points tensor is assigned the value of canonical_points

        else:
            differentiable_points = canonical_points.reshape(num_pixels, N_samples, 3).reshape(-1, 3)
            grad_theta = None

        sdf_output = sdf_output.reshape(num_pixels, N_samples, 1).reshape(-1, 1)   # flattened to shape (num_pixels * N_samples, )
        z_vals = z_vals
        view = -dirs.reshape(-1, 3)   # The view vector is calculated as the negation of the reshaped dirs, giving the view directions for points along the rays.

        if differentiable_points.shape[0] > 0:   # If there are differentiable points (indicating that gradient information is available)
            fg_rgb_flat, others = self.get_rbg_value(points_flat, differentiable_points, view,
                                                     cond, smpl_tfs, feature_vectors=feature_vectors, is_training=self.training)   # The returned values include fg_rgb_flat (foreground RGB values) and others (other calculated values, including normals)                 
            normal_values = others['normals']   # The normal values are extracted from the others dictionary

        if 'image_id' in input.keys():
            frame_latent_code = self.frame_latent_encoder(input['image_id'])
        else:
            frame_latent_code = self.frame_latent_encoder(input['idx'])

        fg_rgb = fg_rgb_flat.reshape(-1, N_samples, 3)
        normal_values = normal_values.reshape(-1, N_samples, 3)
        weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf_output)

        fg_rgb_values = torch.sum(weights.unsqueeze(-1) * fg_rgb, 1)

        # Background rendering
        if input['idx'] is not None:
            N_bg_samples = z_vals_bg.shape[1]
            z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ])  # 1--->0

            bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
            bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)

            bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg)  # [..., N_samples, 4]
            bg_points_flat = bg_points.reshape(-1, 4)
            bg_dirs_flat = bg_dirs.reshape(-1, 3)
            bg_output = self.bg_implicit_network(bg_points_flat, {'frame': frame_latent_code})[0]
            bg_sdf = bg_output[:, :1]
            bg_feature_vectors = bg_output[:, 1:]
            
            bg_rendering_output = self.bg_rendering_network(None, None, bg_dirs_flat, None, bg_feature_vectors, frame_latent_code)
            if bg_rendering_output.shape[-1] == 4:
                bg_rgb_flat = bg_rendering_output[..., :-1]
                shadow_r = bg_rendering_output[..., -1]
                bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
                shadow_r = shadow_r.reshape(-1, N_bg_samples, 1)
                bg_rgb = (1 - shadow_r) * bg_rgb
            else:
                bg_rgb_flat = bg_rendering_output
                bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
            bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)
            bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)
        else:
            bg_rgb_values = torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)

        # Composite foreground and background
        bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
        rgb_values = fg_rgb_values + bg_rgb_values

        normal_values = torch.sum(weights.unsqueeze(-1) * normal_values, 1)

        if self.training:
            output = {
                'points': points,
                'rgb_values': rgb_values,
                'normal_values': normal_values,
                'index_outside': input['index_outside'],
                'index_off_surface': index_off_surface,
                'index_in_surface': index_in_surface,
                'acc_map': torch.sum(weights, -1),
                'sdf_output': sdf_output,
                'grad_theta': grad_theta,
                'epoch': input['current_epoch'],
            }
        else:
            fg_output_rgb = fg_rgb_values + bg_transmittance.unsqueeze(-1) * torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)
            output = {
                'acc_map': torch.sum(weights, -1),
                'rgb_values': rgb_values,
                'fg_rgb_values': fg_output_rgb,
                'normal_values': normal_values,
                'sdf_output': sdf_output,
            }
        return output

    def get_rbg_value(self, x, points, view_dirs, cond, tfs, feature_vectors, is_training=True):
        pnts_c = points
        others = {}

        _, gradients, feature_vectors = self.forward_gradient(x, pnts_c, cond, tfs, create_graph=is_training, retain_graph=is_training)
        # ensure the gradient is normalized
        normals = nn.functional.normalize(gradients, dim=-1, eps=1e-6)
        fg_rendering_output = self.rendering_network(pnts_c, normals, view_dirs, cond['smpl'],
                                                     feature_vectors)
        
        rgb_vals = fg_rendering_output[:, :3]
        others['normals'] = normals
        return rgb_vals, others

    def forward_gradient(self, x, pnts_c, cond, tfs, create_graph=True, retain_graph=True):
        if pnts_c.shape[0] == 0:
            return pnts_c.detach()
        pnts_c.requires_grad_(True)

        pnts_d = self.deformer.forward_skinning(pnts_c.unsqueeze(0), None, tfs).squeeze(0)
        num_dim = pnts_d.shape[-1]
        grads = []
        for i in range(num_dim):
            d_out = torch.zeros_like(pnts_d, requires_grad=False, device=pnts_d.device)
            d_out[:, i] = 1
            grad = torch.autograd.grad(
                outputs=pnts_d,
                inputs=pnts_c,
                grad_outputs=d_out,
                create_graph=create_graph,
                retain_graph=True if i < num_dim - 1 else retain_graph,
                only_inputs=True)[0]
            grads.append(grad)
        grads = torch.stack(grads, dim=-2)
        grads_inv = grads.inverse()

        output = self.implicit_network(pnts_c, cond)[0]
        sdf = output[:, :1]
        
        feature = output[:, 1:]
        d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
        gradients = torch.autograd.grad(
            outputs=sdf,
            inputs=pnts_c,
            grad_outputs=d_output,
            create_graph=create_graph,
            retain_graph=retain_graph,
            only_inputs=True)[0]

        return grads.reshape(grads.shape[0], -1), torch.nn.functional.normalize(torch.einsum('bi,bij->bj', gradients, grads_inv), dim=1), feature

    def volume_rendering(self, z_vals, z_max, sdf):
        density_flat = self.density(sdf)
        density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples

        # included also the dist from the sphere intersection
        dists = z_vals[:, 1:] - z_vals[:, :-1]
        dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)

        # LOG SPACE
        free_energy = dists * density
        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1)  # add 0 for transperancy 1 at t_0
        alpha = 1 - torch.exp(-free_energy)  # probability of it is not empty here
        transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))  # probability of everything is empty up to now
        fg_transmittance = transmittance[:, :-1]
        weights = alpha * fg_transmittance  # probability of the ray hits something here
        bg_transmittance = transmittance[:, -1]  # factor to be multiplied with the bg volume rendering

        return weights, bg_transmittance

    def bg_volume_rendering(self, z_vals_bg, bg_sdf):
        bg_density_flat = self.bg_density(bg_sdf)
        bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples

        bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
        bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)

        # LOG SPACE
        bg_free_energy = bg_dists * bg_density
        bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1)  # shift one step
        bg_alpha = 1 - torch.exp(-bg_free_energy)  # probability of it is not empty here
        bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1))  # probability of everything is empty up to now
        bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here

        return bg_weights
    
    def depth2pts_outside(self, ray_o, ray_d, depth):

        '''
        ray_o, ray_d: [..., 3]
        depth: [...]; inverse of distance to sphere origin
        '''

        o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
        under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.sdf_bounding_sphere ** 2)
        d_sphere = torch.sqrt(under_sqrt) - o_dot_d
        p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
        p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
        p_mid_norm = torch.norm(p_mid, dim=-1)

        rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
        rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
        phi = torch.asin(p_mid_norm / self.sdf_bounding_sphere)
        theta = torch.asin(p_mid_norm * depth)  # depth is inside [0, 1]
        rot_angle = (phi - theta).unsqueeze(-1)  # [..., 1]

        # now rotate p_sphere
        # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
        p_sphere_new = p_sphere * torch.cos(rot_angle) + \
                       torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
                       rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
        p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
        pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)

        return pts

def gradient(inputs, outputs):

    d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
    points_grad = grad(
        outputs=outputs,
        inputs=inputs,
        grad_outputs=d_points,
        create_graph=True,
        retain_graph=True,
        only_inputs=True)[0][:, :, -3:]
    return points_grad