File size: 1,889 Bytes
6325697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import torch.nn as nn

class BodyModelParams(nn.Module):
    def __init__(self, num_frames, model_type='smpl'):
        super(BodyModelParams, self).__init__()
        self.num_frames = num_frames
        self.model_type = model_type
        self.params_dim = {
            'betas': 10,
            'global_orient': 3,
            'transl': 3,
        }
        if model_type == 'smpl':
            self.params_dim.update({
                'body_pose': 69,
            })
        else:
            assert ValueError(f'Unknown model type {model_type}, exiting!')
        
        self.param_names = self.params_dim.keys()
        
        for param_name in self.param_names:
            if param_name == 'betas':
                param = nn.Embedding(1, self.params_dim[param_name])
                param.weight.data.fill_(0)
                param.weight.requires_grad = False
                setattr(self, param_name, param)
            else:
                param = nn.Embedding(num_frames, self.params_dim[param_name])
                param.weight.data.fill_(0)
                param.weight.requires_grad = False
                setattr(self, param_name, param)
    
    def init_parameters(self, param_name, data, requires_grad=False):
        getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]]
        getattr(self, param_name).weight.requires_grad = requires_grad

    def set_requires_grad(self, param_name, requires_grad=True):
        getattr(self, param_name).weight.requires_grad = requires_grad

    def forward(self, frame_ids):
        params = {}
        for param_name in self.param_names:
            if param_name == 'betas':
                params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids))
            else:
                params[param_name] = getattr(self, param_name)(frame_ids)
        return params