import torch.nn as nn

from .bert import BERT


class BERTSM(nn.Module):
    """
    BERT Sequence Model
    Masked Sequence Model
    """

    def __init__(self, bert: BERT, vocab_size):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self.bert = bert
        self.mask_lm = MaskedSequenceModel(self.bert.hidden, vocab_size)
        
    def forward(self, x, segment_label):
        x = self.bert(x, segment_label)
        return self.mask_lm(x), x[:, 0]

    
class MaskedSequenceModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self.softmax(self.linear(x))