|
import torch |
|
import hydra |
|
import numpy as np |
|
from ..smpl.body_models import SMPL |
|
|
|
class SMPLServer(torch.nn.Module): |
|
|
|
def __init__(self, gender='neutral', betas=None, v_template=None): |
|
super().__init__() |
|
|
|
|
|
self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'), |
|
gender=gender, |
|
batch_size=1, |
|
use_hands=False, |
|
use_feet_keypoints=False, |
|
dtype=torch.float32).cuda() |
|
|
|
self.bone_parents = self.smpl.bone_parents.astype(int) |
|
self.bone_parents[0] = -1 |
|
self.bone_ids = [] |
|
self.faces = self.smpl.faces |
|
for i in range(24): self.bone_ids.append([self.bone_parents[i], i]) |
|
|
|
if v_template is not None: |
|
self.v_template = torch.tensor(v_template).float().cuda() |
|
else: |
|
self.v_template = None |
|
|
|
if betas is not None: |
|
self.betas = torch.tensor(betas).float().cuda() |
|
else: |
|
self.betas = None |
|
|
|
|
|
param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda() |
|
param_canonical[0, 0] = 1 |
|
param_canonical[0, 9] = np.pi / 6 |
|
param_canonical[0, 12] = -np.pi / 6 |
|
if self.betas is not None and self.v_template is None: |
|
param_canonical[0,-10:] = self.betas |
|
self.param_canonical = param_canonical |
|
|
|
output = self.forward(*torch.split(self.param_canonical, [1, 3, 72, 10], dim=1), absolute=True) |
|
self.verts_c = output['smpl_verts'] |
|
self.joints_c = output['smpl_jnts'] |
|
self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse() |
|
|
|
|
|
def forward(self, scale, transl, thetas, betas, absolute=False): |
|
"""return SMPL output from params |
|
Args: |
|
scale : scale factor. shape: [B, 1] |
|
transl: translation. shape: [B, 3] |
|
thetas: pose. shape: [B, 72] |
|
betas: shape. shape: [B, 10] |
|
absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical. |
|
Returns: |
|
smpl_verts: vertices. shape: [B, 6893. 3] |
|
smpl_tfs: bone transformations. shape: [B, 24, 4, 4] |
|
smpl_jnts: joint positions. shape: [B, 25, 3] |
|
""" |
|
|
|
output = {} |
|
|
|
|
|
if self.v_template is not None: |
|
betas = torch.zeros_like(betas) |
|
|
|
|
|
smpl_output = self.smpl.forward(betas=betas, |
|
transl=torch.zeros_like(transl), |
|
body_pose=thetas[:, 3:], |
|
global_orient=thetas[:, :3], |
|
return_verts=True, |
|
return_full_pose=True, |
|
v_template=self.v_template) |
|
|
|
verts = smpl_output.vertices.clone() |
|
output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
|
joints = smpl_output.joints.clone() |
|
output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
|
tf_mats = smpl_output.T.clone() |
|
tf_mats[:, :, :3, :] = tf_mats[:, :, :3, :] * scale.unsqueeze(1).unsqueeze(1) |
|
tf_mats[:, :, :3, 3] = tf_mats[:, :, :3, 3] + transl.unsqueeze(1) * scale.unsqueeze(1) |
|
|
|
if not absolute: |
|
tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv) |
|
|
|
output['smpl_tfs'] = tf_mats |
|
output['smpl_weights'] = smpl_output.weights |
|
return output |