|
from preprocessing_utils import GMoF |
|
import torch |
|
num_joints = 25 |
|
joints_to_ign = [1,9,12] |
|
joint_weights = torch.ones(num_joints) |
|
joint_weights[joints_to_ign] = 0 |
|
joint_weights = joint_weights.reshape((-1,1)).cuda() |
|
|
|
robustifier = GMoF(rho=100) |
|
|
|
def get_loss_weights(): |
|
loss_weight = {'J2D_Loss': lambda cst, it: 1e-2 * cst, |
|
'Temporal_Loss': lambda cst, it: 6e0 * cst, |
|
} |
|
return loss_weight |
|
|
|
def joints_2d_loss(gt_joints_2d=None, joints_2d=None, joint_confidence=None): |
|
|
|
joint_diff = robustifier(gt_joints_2d - joints_2d) |
|
joints_2dloss = torch.mean((joint_confidence*joint_weights[:, 0]).unsqueeze(-1) ** 2 * joint_diff) |
|
return joints_2dloss |
|
|
|
def pose_temporal_loss(last_pose, param_pose): |
|
temporal_loss = torch.mean(torch.square(last_pose - param_pose)) |
|
return temporal_loss |