File size: 3,931 Bytes
6325697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torch.nn.functional as F
from .smpl import SMPLServer
from pytorch3d import ops
class SMPLDeformer():
def __init__(self, max_dist=0.1, K=1, gender='female', betas=None):
super().__init__()
self.max_dist = max_dist
self.K = K
self.smpl = SMPLServer(gender=gender)
smpl_params_canoical = self.smpl.param_canonical.clone()
smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device)
cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1)
smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas)
self.smpl_verts = smpl_output['smpl_verts']
self.smpl_weights = smpl_output['smpl_weights']
def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None):
if x.shape[0] == 0: return x
if smpl_verts is None:
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
else:
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights)
if return_weights:
return weights
x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse)
return x_transformed.squeeze(0), outlier_mask
def forward_skinning(self, xc, cond, smpl_tfs):
weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
x_transformed = skinning(xc, weights, smpl_tfs, inverse=False)
return x_transformed
def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):
distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0),
K=self.K, return_nn=True)
distance_batch = torch.clamp(distance_batch, max=4)
weights_conf = torch.exp(-distance_batch)
distance_batch = torch.sqrt(distance_batch)
weights_conf = weights_conf / weights_conf.sum(-1, keepdim=True)
index_batch = index_batch[0]
weights = smpl_weights[:, index_batch, :]
weights = torch.sum(weights * weights_conf.unsqueeze(-1), dim=-2).detach()
outlier_mask = (distance_batch[..., 0] > self.max_dist)[0]
return weights, outlier_mask
def query_weights(self, xc):
weights = self.forward(xc, None, return_weights=True, inverse=False)
return weights
def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False):
if normal.ndim == 2:
normal = normal.unsqueeze(0)
w = self.query_weights(xc[0], cond)
p_h = F.pad(normal, (0, 1), value=0)
if inverse:
# p:num_point, n:num_bone, i,j: num_dim+1
tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double())
p_h = torch.einsum('bpij,bpj->bpi', tf_w.inverse(), p_h.double()).float()
else:
p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float()
return p_h[:, :, :3]
def skinning(x, w, tfs, inverse=False):
"""Linear blend skinning
Args:
x (tensor): canonical points. shape: [B, N, D]
w (tensor): conditional input. [B, N, J]
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
Returns:
x (tensor): skinned points. shape: [B, N, D]
"""
x_h = F.pad(x, (0, 1), value=1.0)
if inverse:
# p:n_point, n:n_bone, i,k: n_dim+1
w_tf = torch.einsum("bpn,bnij->bpij", w, tfs)
x_h = torch.einsum("bpij,bpj->bpi", w_tf.inverse(), x_h)
else:
x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
return x_h[:, :, :3] |