|
import torch |
|
import torch.nn.functional as F |
|
from .smpl import SMPLServer |
|
from pytorch3d import ops |
|
|
|
class SMPLDeformer(): |
|
def __init__(self, max_dist=0.1, K=1, gender='female', betas=None): |
|
super().__init__() |
|
|
|
self.max_dist = max_dist |
|
self.K = K |
|
self.smpl = SMPLServer(gender=gender) |
|
smpl_params_canoical = self.smpl.param_canonical.clone() |
|
smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device) |
|
cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1) |
|
smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas) |
|
self.smpl_verts = smpl_output['smpl_verts'] |
|
self.smpl_weights = smpl_output['smpl_weights'] |
|
def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None): |
|
if x.shape[0] == 0: return x |
|
if smpl_verts is None: |
|
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights) |
|
else: |
|
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights) |
|
if return_weights: |
|
return weights |
|
|
|
x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse) |
|
|
|
return x_transformed.squeeze(0), outlier_mask |
|
def forward_skinning(self, xc, cond, smpl_tfs): |
|
weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights) |
|
x_transformed = skinning(xc, weights, smpl_tfs, inverse=False) |
|
|
|
return x_transformed |
|
|
|
def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights): |
|
|
|
distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0), |
|
K=self.K, return_nn=True) |
|
distance_batch = torch.clamp(distance_batch, max=4) |
|
weights_conf = torch.exp(-distance_batch) |
|
distance_batch = torch.sqrt(distance_batch) |
|
weights_conf = weights_conf / weights_conf.sum(-1, keepdim=True) |
|
index_batch = index_batch[0] |
|
weights = smpl_weights[:, index_batch, :] |
|
weights = torch.sum(weights * weights_conf.unsqueeze(-1), dim=-2).detach() |
|
|
|
outlier_mask = (distance_batch[..., 0] > self.max_dist)[0] |
|
return weights, outlier_mask |
|
|
|
def query_weights(self, xc): |
|
weights = self.forward(xc, None, return_weights=True, inverse=False) |
|
return weights |
|
|
|
def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False): |
|
if normal.ndim == 2: |
|
normal = normal.unsqueeze(0) |
|
w = self.query_weights(xc[0], cond) |
|
|
|
p_h = F.pad(normal, (0, 1), value=0) |
|
|
|
if inverse: |
|
|
|
tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double()) |
|
p_h = torch.einsum('bpij,bpj->bpi', tf_w.inverse(), p_h.double()).float() |
|
else: |
|
p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float() |
|
|
|
return p_h[:, :, :3] |
|
|
|
def skinning(x, w, tfs, inverse=False): |
|
"""Linear blend skinning |
|
Args: |
|
x (tensor): canonical points. shape: [B, N, D] |
|
w (tensor): conditional input. [B, N, J] |
|
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1] |
|
Returns: |
|
x (tensor): skinned points. shape: [B, N, D] |
|
""" |
|
x_h = F.pad(x, (0, 1), value=1.0) |
|
|
|
if inverse: |
|
|
|
w_tf = torch.einsum("bpn,bnij->bpij", w, tfs) |
|
x_h = torch.einsum("bpij,bpj->bpi", w_tf.inverse(), x_h) |
|
else: |
|
x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h) |
|
return x_h[:, :, :3] |