IF3D / preprocessing /loss.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
842 Bytes
from preprocessing_utils import GMoF
import torch
num_joints = 25
joints_to_ign = [1,9,12]
joint_weights = torch.ones(num_joints)
joint_weights[joints_to_ign] = 0
joint_weights = joint_weights.reshape((-1,1)).cuda()
robustifier = GMoF(rho=100)
def get_loss_weights():
loss_weight = {'J2D_Loss': lambda cst, it: 1e-2 * cst,
'Temporal_Loss': lambda cst, it: 6e0 * cst,
}
return loss_weight
def joints_2d_loss(gt_joints_2d=None, joints_2d=None, joint_confidence=None):
joint_diff = robustifier(gt_joints_2d - joints_2d)
joints_2dloss = torch.mean((joint_confidence*joint_weights[:, 0]).unsqueeze(-1) ** 2 * joint_diff)
return joints_2dloss
def pose_temporal_loss(last_pose, param_pose):
temporal_loss = torch.mean(torch.square(last_pose - param_pose))
return temporal_loss