IF3D / code /lib /model /body_model_params.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
1.89 kB
import torch
import torch.nn as nn
class BodyModelParams(nn.Module):
def __init__(self, num_frames, model_type='smpl'):
super(BodyModelParams, self).__init__()
self.num_frames = num_frames
self.model_type = model_type
self.params_dim = {
'betas': 10,
'global_orient': 3,
'transl': 3,
}
if model_type == 'smpl':
self.params_dim.update({
'body_pose': 69,
})
else:
assert ValueError(f'Unknown model type {model_type}, exiting!')
self.param_names = self.params_dim.keys()
for param_name in self.param_names:
if param_name == 'betas':
param = nn.Embedding(1, self.params_dim[param_name])
param.weight.data.fill_(0)
param.weight.requires_grad = False
setattr(self, param_name, param)
else:
param = nn.Embedding(num_frames, self.params_dim[param_name])
param.weight.data.fill_(0)
param.weight.requires_grad = False
setattr(self, param_name, param)
def init_parameters(self, param_name, data, requires_grad=False):
getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]]
getattr(self, param_name).weight.requires_grad = requires_grad
def set_requires_grad(self, param_name, requires_grad=True):
getattr(self, param_name).weight.requires_grad = requires_grad
def forward(self, frame_ids):
params = {}
for param_name in self.param_names:
if param_name == 'betas':
params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids))
else:
params[param_name] = getattr(self, param_name)(frame_ids)
return params