File size: 3,931 Bytes
6325697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn.functional as F
from .smpl import SMPLServer
from pytorch3d import ops

class SMPLDeformer():
    def __init__(self, max_dist=0.1, K=1, gender='female', betas=None):
        super().__init__()

        self.max_dist = max_dist
        self.K = K
        self.smpl = SMPLServer(gender=gender)
        smpl_params_canoical = self.smpl.param_canonical.clone()
        smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device)
        cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1)
        smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas)
        self.smpl_verts = smpl_output['smpl_verts']
        self.smpl_weights = smpl_output['smpl_weights']
    def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None):
        if x.shape[0] == 0: return x
        if smpl_verts is None:
            weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
        else:
            weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights)
        if return_weights:
            return weights

        x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse)

        return x_transformed.squeeze(0), outlier_mask
    def forward_skinning(self, xc, cond, smpl_tfs):
        weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
        x_transformed = skinning(xc, weights, smpl_tfs, inverse=False)

        return x_transformed

    def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):

        distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0),
                                                                      K=self.K, return_nn=True)
        distance_batch = torch.clamp(distance_batch, max=4)
        weights_conf = torch.exp(-distance_batch)
        distance_batch = torch.sqrt(distance_batch)
        weights_conf = weights_conf / weights_conf.sum(-1, keepdim=True)
        index_batch = index_batch[0]
        weights = smpl_weights[:, index_batch, :]
        weights = torch.sum(weights * weights_conf.unsqueeze(-1), dim=-2).detach()

        outlier_mask = (distance_batch[..., 0] > self.max_dist)[0]
        return weights, outlier_mask

    def query_weights(self, xc):
        weights = self.forward(xc, None, return_weights=True, inverse=False)
        return weights

    def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False):
        if normal.ndim == 2:
            normal = normal.unsqueeze(0)
        w = self.query_weights(xc[0], cond)

        p_h = F.pad(normal, (0, 1), value=0)

        if inverse:
            # p:num_point, n:num_bone, i,j: num_dim+1
            tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double())
            p_h = torch.einsum('bpij,bpj->bpi', tf_w.inverse(), p_h.double()).float()
        else:
            p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float()
        
        return p_h[:, :, :3]

def skinning(x, w, tfs, inverse=False):
    """Linear blend skinning
    Args:
        x (tensor): canonical points. shape: [B, N, D]
        w (tensor): conditional input. [B, N, J]
        tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
    Returns:
        x (tensor): skinned points. shape: [B, N, D]
    """
    x_h = F.pad(x, (0, 1), value=1.0)

    if inverse:
        # p:n_point, n:n_bone, i,k: n_dim+1
        w_tf = torch.einsum("bpn,bnij->bpij", w, tfs)
        x_h = torch.einsum("bpij,bpj->bpi", w_tf.inverse(), x_h)
    else:
        x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
    return x_h[:, :, :3]