multimodalart's picture
Upload 83 files
38e20ed verified
import torch
import torch.nn.functional as F
from torch import nn
import timm
def create_backbone(backbone_name, pretrained=True):
backbone = timm.create_model(backbone_name,
pretrained=pretrained,
features_only=True)
feature_dim = backbone.feature_info[-1]['num_chs']
return backbone, feature_dim
class PoseEncoder(nn.Module):
def __init__(self) -> None:
super().__init__()
self.encoder, feature_dim = create_backbone('tf_mobilenetv3_small_minimal_100')
self.pose_cam_layers = nn.Sequential(
nn.Linear(feature_dim, 6)
)
self.init_weights()
def init_weights(self):
self.pose_cam_layers[-1].weight.data *= 0.001
self.pose_cam_layers[-1].bias.data *= 0.001
self.pose_cam_layers[-1].weight.data[3] = 0
self.pose_cam_layers[-1].bias.data[3] = 7
def forward(self, img):
features = self.encoder(img)[-1]
features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
outputs = {}
pose_cam = self.pose_cam_layers(features).reshape(img.size(0), -1)
outputs['pose_params'] = pose_cam[...,:3]
# import pdb;pdb.set_trace()
outputs['cam'] = pose_cam[...,3:]
return outputs
class ShapeEncoder(nn.Module):
def __init__(self, n_shape=300) -> None:
super().__init__()
self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
self.shape_layers = nn.Sequential(
nn.Linear(feature_dim, n_shape)
)
self.init_weights()
def init_weights(self):
self.shape_layers[-1].weight.data *= 0
self.shape_layers[-1].bias.data *= 0
def forward(self, img):
features = self.encoder(img)[-1]
features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
parameters = self.shape_layers(features).reshape(img.size(0), -1)
return {'shape_params': parameters}
class ExpressionEncoder(nn.Module):
def __init__(self, n_exp=50) -> None:
super().__init__()
self.encoder, feature_dim = create_backbone('tf_mobilenetv3_large_minimal_100')
self.expression_layers = nn.Sequential(
nn.Linear(feature_dim, n_exp+2+3) # num expressions + jaw + eyelid
)
self.n_exp = n_exp
self.init_weights()
def init_weights(self):
self.expression_layers[-1].weight.data *= 0.1
self.expression_layers[-1].bias.data *= 0.1
def forward(self, img):
features = self.encoder(img)[-1]
features = F.adaptive_avg_pool2d(features, (1, 1)).squeeze(-1).squeeze(-1)
parameters = self.expression_layers(features).reshape(img.size(0), -1)
outputs = {}
outputs['expression_params'] = parameters[...,:self.n_exp]
outputs['eyelid_params'] = torch.clamp(parameters[...,self.n_exp:self.n_exp+2], 0, 1)
outputs['jaw_params'] = torch.cat([F.relu(parameters[...,self.n_exp+2].unsqueeze(-1)),
torch.clamp(parameters[...,self.n_exp+3:self.n_exp+5], -.2, .2)], dim=-1)
return outputs
class SmirkEncoder(nn.Module):
def __init__(self, n_exp=50, n_shape=300) -> None:
super().__init__()
self.pose_encoder = PoseEncoder()
self.shape_encoder = ShapeEncoder(n_shape=n_shape)
self.expression_encoder = ExpressionEncoder(n_exp=n_exp)
def forward(self, img):
pose_outputs = self.pose_encoder(img)
shape_outputs = self.shape_encoder(img)
expression_outputs = self.expression_encoder(img)
outputs = {}
outputs.update(pose_outputs)
outputs.update(shape_outputs)
outputs.update(expression_outputs)
return outputs