import torch import torch.nn as nn class BodyModelParams(nn.Module): def __init__(self, num_frames, model_type='smpl'): super(BodyModelParams, self).__init__() self.num_frames = num_frames self.model_type = model_type self.params_dim = { 'betas': 10, 'global_orient': 3, 'transl': 3, } if model_type == 'smpl': self.params_dim.update({ 'body_pose': 69, }) else: assert ValueError(f'Unknown model type {model_type}, exiting!') self.param_names = self.params_dim.keys() for param_name in self.param_names: if param_name == 'betas': param = nn.Embedding(1, self.params_dim[param_name]) param.weight.data.fill_(0) param.weight.requires_grad = False setattr(self, param_name, param) else: param = nn.Embedding(num_frames, self.params_dim[param_name]) param.weight.data.fill_(0) param.weight.requires_grad = False setattr(self, param_name, param) def init_parameters(self, param_name, data, requires_grad=False): getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]] getattr(self, param_name).weight.requires_grad = requires_grad def set_requires_grad(self, param_name, requires_grad=True): getattr(self, param_name).weight.requires_grad = requires_grad def forward(self, frame_ids): params = {} for param_name in self.param_names: if param_name == 'betas': params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids)) else: params[param_name] = getattr(self, param_name)(frame_ids) return params