import torch import torch.nn as nn class Model(nn.Module): def __init__(self, vit, roberta, tokenizer, device): super().__init__() self.bertmap = nn.Conv1d(768, 768, 1) self.vitmap = nn.Conv1d(768, 768, 1) self.conv1d = nn.Conv1d(1, 1, 1) self.add_module("vit", vit) self.add_module("roberta", roberta) self.tokenizer = tokenizer self.conv1d.weight = torch.nn.Parameter(torch.tensor([[[1.]]])) self.conv1d.bias = torch.nn.Parameter(torch.tensor([0.])) self.device = device def forward(self, image, cats): vit_out = self.vit(image) vit_out = vit_out[:,1:vit_out.shape[1],:] vit_out = torch.transpose(vit_out, 2,1) vit_out = self.vitmap(vit_out) vit_out = torch.transpose(vit_out, 2,1) token_out = self.tokenizer.encode_plus( cats, padding=True, add_special_tokens=True, return_token_type_ids=True, return_tensors='pt' ).to(self.device) bert_out = self.roberta(**token_out) hidden_state = bert_out.last_hidden_state hidden_state = torch.transpose(hidden_state, 2,1) hidden_state = self.bertmap(hidden_state) hidden_state = torch.transpose(hidden_state, 2,1) pooled_bert_out = hidden_state[:, 0] pooled_bert_out = torch.unsqueeze(pooled_bert_out, dim=2) out = torch.matmul(vit_out, pooled_bert_out) out = torch.transpose(out, 2,1) return torch.squeeze(self.conv1d(out), dim=1)