# https://www.mixedbread.ai/blog/mxbai-embed-large-v1
# https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1

import os
import time
import pandas as pd
import numpy as np
from typing import Dict

import torch
from transformers import AutoModel, AutoTokenizer
from sentence_transformers.util import cos_sim
from accelerate import Accelerator  # Import from accelerate
from scipy.stats import zscore

# Set up environment variables for Hugging Face caching
os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"

# Initialize the Accelerator
accelerator = Accelerator()

# Use the device managed by Accelerator
device = accelerator.device
print("Using accelerator device =", device)


# 1. Load the model and tokenizer
model_id_Retriever = 'mixedbread-ai/mxbai-embed-large-v1'
tokenizer_Retriever = AutoTokenizer.from_pretrained(model_id_Retriever)
modelRetriever = AutoModel.from_pretrained(model_id_Retriever)

# Accelerate prepares the model (e.g., moves to the appropriate device)
modelRetriever = accelerator.prepare(modelRetriever)




# Define the transform_query function
def transform_query(queryText: str) -> str:
    """For retrieval, add the prompt for queryText (not for documents)."""
    return f'Represent this sentence for searching relevant passages: {queryText}'

# Define the pooling function
def pooling(outputs: torch.Tensor, inputs: Dict, strategy: str = 'cls') -> np.ndarray:
    if strategy == 'cls':
        outputs = outputs[:, 0]
    elif strategy == 'mean':
        outputs = torch.sum(
            outputs * inputs["attention_mask"][:, :, None], dim=1
        ) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
    else:
        raise NotImplementedError
    return outputs.detach().cpu().numpy()


def retrievePassageSimilarities(queryText, passages):
    # Create the docs list by adding the transformed queryText and then the passages
    docs = [transform_query(queryText)] + passages

    # 2. Encode the inputs
    inputs = tokenizer_Retriever(docs, padding=True, return_tensors='pt')

    # Move inputs to the right device using accelerator
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = modelRetriever(**inputs).last_hidden_state
    embeddings = pooling(outputs, inputs, 'cls')

    similarities = cos_sim(embeddings[0], embeddings[1:])

    #print('similarities:', similarities)

    return similarities



def RAG_retrieval_Base(queryText,passages, min_threshold=0.0, max_num_passages=None):

    try:
        similarities=retrievePassageSimilarities(queryText, passages)

        #Create a DataFrame
        df = pd.DataFrame({
            'Passage': passages,
            'Similarity': similarities.flatten()  # Flatten the similarity tensor/array to ensure compatibility
        })

        # Filter the DataFrame based on the similarity threshold
        df_filtered = df[df['Similarity'] >= min_threshold]

        # If max_num_passages is specified, limit the number of passages returned
        if max_num_passages is not None:
            df_filtered = df_filtered.nlargest(max_num_passages, 'Similarity')

        df_filtered = df_filtered.sort_values(by='Similarity', ascending=False)

        # Return the filtered DataFrame
        return df_filtered

    except Exception as e:
        # Log the exception message or handle it as needed
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error



def RAG_retrieval_Z_scores(queryText, passages, z_threshold=1.0, max_num_passages=None, min_threshold=0.5):
    try:
        # Encoding and similarity computation remains the same

        similarities = retrievePassageSimilarities(queryText, passages)

        # Calculate z-scores for similarities
        z_scores = zscore(similarities.flatten())

        # Create a DataFrame with passages, similarities, and z-scores
        df = pd.DataFrame({
            'Passage': passages,
            'Similarity': similarities.flatten(),
            'Z-Score': z_scores
        })

        # Filter passages based on z-score threshold
        df_filtered = df[df['Z-Score'] >= z_threshold]

        if min_threshold:
            # Filter the DataFrame also on min similarity threshold
            df_filtered = df[df['Similarity'] >= min_threshold]

        # If max_num_passages is specified, limit the number of passages returned
        if max_num_passages is not None:
            df_filtered = df_filtered.nlargest(max_num_passages, 'Similarity')

        # Sort by similarity (or z-score if preferred)
        df_filtered = df_filtered.sort_values(by='Similarity', ascending=False)

        return df_filtered

    except Exception as e:
        # Log the exception message or handle it as needed
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error




def RAG_retrieval_Percentile(queryText, passages, percentile=90,max_num_passages=None, min_threshold=0.5):
    try:
        # Encoding and similarity computation remains the same

        similarities = retrievePassageSimilarities(queryText, passages)

        # Determine threshold based on percentile
        threshold = np.percentile(similarities.flatten(), percentile)

        # Create a DataFrame
        df = pd.DataFrame({
            'Passage': passages,
            'Similarity': similarities.flatten()
        })

        # Filter using percentile threshold
        df_filtered = df[df['Similarity'] >= threshold]

        if min_threshold:
            # Filter the DataFrame also on min similarity threshold
            df_filtered = df[df['Similarity'] >= min_threshold]

        # If max_num_passages is specified, limit the number of passages returned
        if max_num_passages is not None:
            df_filtered = df_filtered.nlargest(max_num_passages, 'Similarity')

        # Sort by similarity
        df_filtered = df_filtered.sort_values(by='Similarity', ascending=False)

        return df_filtered

    except Exception as e:
        # Log the exception message or handle it as needed
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error



def RAG_retrieval_TopK(queryText, passages, top_fraction=0.1, max_num_passages=None, min_threshold=0.5):
    try:
        # Encoding and similarity computation (assuming retrievePassageSimilarities is defined elsewhere)
        similarities = retrievePassageSimilarities(queryText, passages)

        # Calculate the number of passages to select based on top fraction
        num_passages_TopFraction = max(1, int(top_fraction * len(passages)))

        # Create a DataFrame
        df = pd.DataFrame({
            'Passage': passages,
            'Similarity': similarities.flatten()
        })

        # Select the top passages dynamically
        df_filtered = df.nlargest(num_passages_TopFraction, 'Similarity')

        if min_threshold:
            # Filter the DataFrame also on min similarity threshold
            df_filtered = df_filtered[df_filtered['Similarity'] >= min_threshold]

        # If max_num_passages is specified, limit the number of passages returned
        if max_num_passages is not None:
            df_filtered = df_filtered.nlargest(max_num_passages, 'Similarity')

        # Sort by similarity
        df_filtered = df_filtered.sort_values(by='Similarity', ascending=False)

        return df_filtered

    except Exception as e:
        # Log the exception message or handle it as needed
        print(f"An error occurred: {e}")
        return pd.DataFrame()  # Return an empty DataFrame in case of error



if __name__ == '__main__':

    queryText = 'A man is eating a piece of bread'

    # Define the passages list
    passages = [
        "A man is eating food.",
        "A man is eating pasta.",
        "The girl is carrying a baby.",
        "A man is riding a horse.",
    ]

    #df_retrieved = RAG_retrieval_Base(queryText, passages)
    #df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0.5)
    #df_retrieved = RAG_retrieval_Base(queryText, passages, max_num_passages=3)
    df_retrieved = RAG_retrieval_Base(queryText, passages, min_threshold=0.5, max_num_passages=3)

    #df_retrieved = RAG_retrieval_Z_scores(queryText, passages, z_threshold=1.0)
    #df_retrieved = RAG_retrieval_Z_scores(queryText, passages, z_threshold=1.0,max_num_passages=3)

    #df_retrieved = RAG_retrieval_Percentile(queryText, passages, percentile=80)
    # df_retrieved = RAG_retrieval_Percentile(queryText, passages, percentile=80, max_num_passages=3)

    ##df_retrieved = RAG_retrieval_TopK(queryText, passages, top_fraction=0.2)
    #df_retrieved = RAG_retrieval_TopK(queryText, passages, top_fraction=0.2, max_num_passages=3)


    print(df_retrieved)

    #labelTriplesLIST_RAGGED = df_retrieved['Passage'].apply(lambda x: (x,)).tolist()


    print("end of computations")

# VERSION WITHOUT ACCELERATE
#
# #https://www.mixedbread.ai/blog/mxbai-embed-large-v1
# #https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
#
# import os
#
# os.environ["HF_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
# os.environ["HUGGINGFACE_HUB_CACHE"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
# os.environ["HF_HOME"] = "/eos/jeodpp/home/users/consose/cache/huggingface/hub"
#
# import time
# import pandas as pd
# import numpy as np
#
#
#
# from typing import Dict
#
# import torch
# import numpy as np
# from transformers import AutoModel, AutoTokenizer
# from sentence_transformers.util import cos_sim
#
# # For retrieval you need to pass this prompt. Please find our more in our blog post.
# def transform_queryText(queryText: str) -> str:
#     """ For retrieval, add the prompt for queryText (not for documents).
#     """
#     return f'Represent this sentence for searching relevant passages: {queryText}'
#
# # The model works really well with cls pooling (default) but also with mean pooling.
# def pooling(outputs: torch.Tensor, inputs: Dict,  strategy: str = 'cls') -> np.ndarray:
#     if strategy == 'cls':
#         outputs = outputs[:, 0]
#     elif strategy == 'mean':
#         outputs = torch.sum(
#             outputs * inputs["attention_mask"][:, :, None], dim=1) / torch.sum(inputs["attention_mask"], dim=1, keepdim=True)
#     else:
#         raise NotImplementedError
#     return outputs.detach().cpu().numpy()
#
# # 1. load model
# model_id = 'mixedbread-ai/mxbai-embed-large-v1'
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModel.from_pretrained(model_id).cuda()
#
# queryText = 'A man is eating a piece of bread'
#
# # Define the passages list
# passages = [
#     "A man is eating food.",
#     "A man is eating pasta.",
#     "The girl is carrying a baby.",
#     "A man is riding a horse.",
# ]
#
# # Create the docs list by adding the transformed queryText and then the passages
# docs = [transform_queryText(queryText)] + passages
#
# # 2. encode
# inputs = tokenizer(docs, padding=True, return_tensors='pt')
# for k, v in inputs.items():
#     inputs[k] = v.cuda()
# outputs = model(**inputs).last_hidden_state
# embeddings = pooling(outputs, inputs, 'cls')
#
# similarities = cos_sim(embeddings[0], embeddings[1:])
#
# print('similarities:', similarities)
#
#
# # Create a DataFrame
# df = pd.DataFrame({
#     'Passage': passages,
#     'Similarity': similarities.flatten()  # Flatten the similarity tensor/array to ensure compatibility
# })
#
# # Display the DataFrame
# print(df)
#
#
# print("end of computations")