import json
import torch
from torch import nn
from safetensors.torch import load_file
from transformers import AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the model state_dict from safetensors
def load_model_safetensors(model, load_path="model.safetensors"):
    # Load the safetensors file
    state_dict = load_file(load_path)
    # Load the state dict into the model
    model.load_state_dict(state_dict)
    return model

###################
# JINA EMBEDDINGS
###################

# Jina Configs
JINA_CONTEXT_LEN = 1024

# Adapter for embeddings
class Adapter(nn.Module):
    def __init__(self, hidden_size):
        super(Adapter, self).__init__()
        self.down_project = nn.Linear(hidden_size, hidden_size // 2)
        self.activation = nn.ReLU()
        self.up_project = nn.Linear(hidden_size // 2, hidden_size)

    def forward(self, x):
        down = self.down_project(x)
        activated = self.activation(down)
        up = self.up_project(activated)
        return up + x  # Residual connection

# Pool by attention score
class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super(AttentionPooling, self).__init__()
        self.attention_weights = nn.Parameter(torch.randn(hidden_size))

    def forward(self, hidden_states):
        # hidden_states: [seq_len, batch_size, hidden_size]
        scores = torch.matmul(hidden_states, self.attention_weights)
        attention_weights = torch.softmax(scores, dim=0)
        weighted_sum = torch.sum(attention_weights.unsqueeze(-1) * hidden_states, dim=0)
        return weighted_sum

# Custom bi-encoder model with MLP layers for interaction
class CrossEncoderWithSharedBase(nn.Module):
    def __init__(self, base_model, num_labels=2, num_heads=8):
        super(CrossEncoderWithSharedBase, self).__init__()
        # Shared pre-trained model
        self.shared_encoder = base_model
        hidden_size = self.shared_encoder.config.hidden_size
        # Sentence-specific adapters
        self.adapter1 = Adapter(hidden_size)
        self.adapter2 = Adapter(hidden_size)
        # Cross-attention layers
        self.cross_attention_1_to_2 = nn.MultiheadAttention(hidden_size, num_heads)
        self.cross_attention_2_to_1 = nn.MultiheadAttention(hidden_size, num_heads)
        # Attention pooling layers
        self.attn_pooling_1_to_2 = AttentionPooling(hidden_size)
        self.attn_pooling_2_to_1 = AttentionPooling(hidden_size)
        # Projection layer with non-linearity
        self.projection_layer = nn.Sequential(
            nn.Linear(hidden_size * 2, hidden_size),
            nn.ReLU()
        )
        # Classifier with three hidden layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 2, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size // 4, num_labels)
        )
    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        # Encode sentences
        outputs1 = self.shared_encoder(input_ids1, attention_mask=attention_mask1)
        outputs2 = self.shared_encoder(input_ids2, attention_mask=attention_mask2)
        # Apply sentence-specific adapters
        embeds1 = self.adapter1(outputs1.last_hidden_state)
        embeds2 = self.adapter2(outputs2.last_hidden_state)
        # Transpose for attention layers
        embeds1 = embeds1.transpose(0, 1)
        embeds2 = embeds2.transpose(0, 1)
        # Cross-attention
        cross_attn_1_to_2, _ = self.cross_attention_1_to_2(embeds1, embeds2, embeds2)
        cross_attn_2_to_1, _ = self.cross_attention_2_to_1(embeds2, embeds1, embeds1)
        # Attention pooling
        pooled_1_to_2 = self.attn_pooling_1_to_2(cross_attn_1_to_2)
        pooled_2_to_1 = self.attn_pooling_2_to_1(cross_attn_2_to_1)
        # Concatenate and project
        combined = torch.cat((pooled_1_to_2, pooled_2_to_1), dim=1)
        projected = self.projection_layer(combined)
        # Classification
        logits = self.classifier(projected)
        return logits

# Prediction function for embeddings relevance
def embeddings_predict_relevance(sentence1, sentence2, model, tokenizer, device):
    model.eval()
    inputs1 = tokenizer(sentence1, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
    inputs2 = tokenizer(sentence2, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
    input_ids1 = inputs1['input_ids'].to(device)
    attention_mask1 = inputs1['attention_mask'].to(device)
    input_ids2 = inputs2['input_ids'].to(device)
    attention_mask2 = inputs2['attention_mask'].to(device)
    with torch.no_grad():
        outputs = model(input_ids1=input_ids1, attention_mask1=attention_mask1,
                        input_ids2=input_ids2, attention_mask2=attention_mask2)
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()
    return predicted_label, probabilities.cpu().numpy()

# Load configuration file
jina_repo_path = "govtech/jina-embeddings-v2-small-en-off-topic"
jina_config_path = hf_hub_download(repo_id=jina_repo_path, filename="config.json")
with open(jina_config_path, 'r') as f:
    jina_config = json.load(f)

# Load Jina model configuration
JINA_MODEL_NAME = jina_config['classifier']['embedding']['model_name']
jina_model_weights_fp = jina_config['classifier']['embedding']['model_weights_fp']

# Load tokenizer and model
jina_tokenizer = AutoTokenizer.from_pretrained(JINA_MODEL_NAME)
jina_base_model = AutoModel.from_pretrained(JINA_MODEL_NAME)
jina_model = CrossEncoderWithSharedBase(jina_base_model, num_labels=2)

# Load model weights from safetensors
jina_model_weights_path = hf_hub_download(repo_id=jina_repo_path, filename=jina_model_weights_fp)
jina_model = load_model_safetensors(jina_model, jina_model_weights_path)

#################
# CROSS-ENCODER
#################

# STSB Configuration
STSB_CONTEXT_LEN = 512

class CrossEncoderWithMLP(nn.Module):
    def __init__(self, base_model, num_labels=2):
        super(CrossEncoderWithMLP, self).__init__()

        # Existing cross-encoder model
        self.base_model = base_model
        # Hidden size of the base model
        hidden_size = base_model.config.hidden_size
        # MLP layers after combining the cross-encoders
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),  # Input: a single sentence
            nn.ReLU(),
            nn.Linear(hidden_size // 2, hidden_size // 4),  # Reduce the size of the layer
            nn.ReLU()
        )
        # Classifier head
        self.classifier = nn.Linear(hidden_size // 4, num_labels)

    def forward(self, input_ids, attention_mask):
        # Encode the pair of sentences in one pass
        outputs = self.base_model(input_ids, attention_mask)
        pooled_output = outputs.pooler_output
        # Pass the pooled output through mlp layers
        mlp_output = self.mlp(pooled_output)
        # Pass the final MLP output through the classifier
        logits = self.classifier(mlp_output)
        return logits

# Prediction function for cross-encoder
def cross_encoder_predict_relevance(sentence1, sentence2, model, tokenizer, device):
    model.eval()
    # Tokenize the pair of sentences
    encoding = tokenizer(
        sentence1, sentence2,  # Takes in a two sentences as a pair
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=512,
        return_token_type_ids=False
    )
    # Extract the input_ids and attention mask
    input_ids = encoding["input_ids"].to(device)
    attention_mask = encoding["attention_mask"].to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask
        )  # Returns logits
        # Convert raw logits into probabilities for each class and get the predicted label
        probabilities = torch.softmax(outputs, dim=1)
        predicted_label = torch.argmax(probabilities, dim=1).item()
    return predicted_label, probabilities.cpu().numpy()

# Load STSB model configuration
stsb_repo_path = "govtech/stsb-roberta-base-off-topic"
stsb_config_path = hf_hub_download(repo_id=stsb_repo_path, filename="config.json")
with open(stsb_config_path, 'r') as f:
    stsb_config = json.load(f)

STSB_MODEL_NAME = stsb_config['classifier']['embedding']['model_name']
stsb_model_weights_fp = stsb_config['classifier']['embedding']['model_weights_fp']

# Load STSB tokenizer and model
stsb_tokenizer = AutoTokenizer.from_pretrained(STSB_MODEL_NAME, use_fast=False)
stsb_base_model = AutoModel.from_pretrained(STSB_MODEL_NAME)
stsb_model = CrossEncoderWithMLP(stsb_base_model, num_labels=2)

# Load model weights from safetensors for STSB
stsb_model_weights_path = hf_hub_download(repo_id=stsb_repo_path, filename=stsb_model_weights_fp)
stsb_model = load_model_safetensors(stsb_model, stsb_model_weights_path)