import torch import torch.nn.functional as F from torch import nn import timm def create_backbone(backbone_name, pretrained=True): backbone = timm.create_model(backbone_name, pretrained=pretrained, features_only=True) feature_dim = backbone.feature_info[-1]['num_chs'] return backbone, feature_dim class PoseEncoder(nn.Module): def __init__(self) -> None: super().__init__() self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100') self.pose_cam_layers = nn.Sequential( nn.Linear(feature_dim, 6) ) self.init_weights() def init_weights(self): self.pose_cam_layers[-1].weight.data *= 0.001 self.pose_cam_layers[-1].bias.data *= 0.001 self.pose_cam_layers[-1].weight.data[3] = 0 self.pose_cam_layers[-1].bias.data[3] = 7 def forward(self, img): features = self.encoder(img)[-1] features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) outputs = {} pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1) outputs['pose_params'] = pose_cam[...,:3] # import pdb;pdb.set_trace() outputs['cam'] = pose_cam[...,3:] return outputs class ShapeEncoder(nn.Module): def __init__(self, n_shape=300) -> None: super().__init__() self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') self.shape_layers = nn.Sequential( nn.Linear(feature_dim, n_shape) ) self.init_weights() def init_weights(self): self.shape_layers[-1].weight.data *= 0 self.shape_layers[-1].bias.data *= 0 def forward(self, img): features = self.encoder(img)[-1] features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) parameters = self.shape_layers(features).reshape(img.size(0), -1) return {'shape_params': parameters} class ExpressionEncoder(nn.Module): def __init__(self, n_exp=50) -> None: super().__init__() self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100') self.expression_layers = nn.Sequential( nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid ) self.n_exp = n_exp self.init_weights() def init_weights(self): self.expression_layers[-1].weight.data *= 0.1 self.expression_layers[-1].bias.data *= 0.1 def forward(self, img): features = self.encoder(img)[-1] features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1) parameters = self.expression_layers(features).reshape(img.size(0), -1) outputs = {} outputs['expression_params'] = parameters[...,:self.n_exp] outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1) outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)), torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1) return outputs class SmirkEncoder(nn.Module): def __init__(self, n_exp=50, n_shape=300) -> None: super().__init__() self.pose_encoder = PoseEncoder() self.shape_encoder = ShapeEncoder(n_shape=n_shape) self.expression_encoder = ExpressionEncoder(n_exp=n_exp) def forward(self, img): pose_outputs = self.pose_encoder(img) shape_outputs = self.shape_encoder(img) expression_outputs = self.expression_encoder(img) outputs = {} outputs.update(pose_outputs) outputs.update(shape_outputs) outputs.update(expression_outputs) return outputs