File size: 1,889 Bytes
6325697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import torch.nn as nn
class BodyModelParams(nn.Module):
def __init__(self, num_frames, model_type='smpl'):
super(BodyModelParams, self).__init__()
self.num_frames = num_frames
self.model_type = model_type
self.params_dim = {
'betas': 10,
'global_orient': 3,
'transl': 3,
}
if model_type == 'smpl':
self.params_dim.update({
'body_pose': 69,
})
else:
assert ValueError(f'Unknown model type {model_type}, exiting!')
self.param_names = self.params_dim.keys()
for param_name in self.param_names:
if param_name == 'betas':
param = nn.Embedding(1, self.params_dim[param_name])
param.weight.data.fill_(0)
param.weight.requires_grad = False
setattr(self, param_name, param)
else:
param = nn.Embedding(num_frames, self.params_dim[param_name])
param.weight.data.fill_(0)
param.weight.requires_grad = False
setattr(self, param_name, param)
def init_parameters(self, param_name, data, requires_grad=False):
getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]]
getattr(self, param_name).weight.requires_grad = requires_grad
def set_requires_grad(self, param_name, requires_grad=True):
getattr(self, param_name).weight.requires_grad = requires_grad
def forward(self, frame_ids):
params = {}
for param_name in self.param_names:
if param_name == 'betas':
params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids))
else:
params[param_name] = getattr(self, param_name)(frame_ids)
return params |