import streamlit as st
import gc
from collections import defaultdict
import torch
from transformers import pipeline
from lingua import Language,LanguageDetectorBuilder
#__version__ = "0.1.0"

#if torch.cuda.is_available():
#    device_tag = 0  # first gpu
#else:
    #device_tag = -1  # cpu

default_models = {
    Language.ENGLISH: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.JAPANESE: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.ARABIC: "Ammar-alhaj-ali/arabic-MARBERT-sentiment",
    Language.GERMAN: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.SPANISH: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.FRENCH: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.CHINESE: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.INDONESIAN: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.HINDI: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.ITALIAN: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.MALAY: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.PORTUGUESE: "lxyuan/distilbert-base-multilingual-cased-sentiments-student",
    Language.SWEDISH: "KBLab/robust-swedish-sentiment-multiclass",
    Language.FINNISH: "fergusq/finbert-finnsentiment",
}
language_detector = LanguageDetectorBuilder.from_all_languages().build()


def split_message(message, max_length):
    """ Split a message into a list of chunks of given maximum size. """
    return [message[i: i + max_length] for i in range(0, len(message), max_length)]


def process_messages_in_batches(messages_with_languages, models=None, max_length=512):
    """
    Process messages in batches, creating only one pipeline at a time, and maintain the original order.
    Params:
    messages_with_languages: list of tuples, each containing a message and its detected language
    models: dict, model paths indexed by Language
    Returns:
    OrderedDict: containing the index as keys and tuple of (message, sentiment result) as values
    """

    if models is None:
        models = default_models
    else:
        models = default_models.copy().update(models)

    results = {}

    # Group messages by model, preserving original order.
    # If language is no detected or a model for that language is not
    # provided, add None to results
    messages_by_model = defaultdict(list)
    for index, (message, language) in enumerate(messages_with_languages):
        model_name = models.get(language)
        if model_name:
            messages_by_model[model_name].append((index, message))
        else:
            results[index] = {"label": "none", "score": 0}

    # Process messages and maintain original order
    for model_name, batch in messages_by_model.items():
        sentiment_pipeline = pipeline(model=model_name, device=device_tag)

        chunks = []
        message_map = {}
        for idx, message in batch:
            message_chunks = split_message(message, max_length)
            for chunk in message_chunks:
                chunks.append(chunk)
                if idx in message_map:
                    message_map[idx].append(len(chunks) - 1)
                else:
                    message_map[idx] = [len(chunks) - 1]

        chunk_sentiments = sentiment_pipeline(chunks)

        for idx, chunk_indices in message_map.items():
            sum_scores = {"neutral": 0}
            for chunk_idx in chunk_indices:
                label = chunk_sentiments[chunk_idx]["label"]
                score = chunk_sentiments[chunk_idx]["score"]
                if label in sum_scores:
                    sum_scores[label] += score
                else:
                    sum_scores[label] = score
            best_sentiment = max(sum_scores, key=sum_scores.get)
            score = sum_scores[best_sentiment] / len(chunk_indices)
            results[idx] = {"label": best_sentiment, "score": score}

        # Force garbage collections to remove the model from memory
        del sentiment_pipeline
        gc.collect()

    # Unify common spellings of the labels
    for i in range(len(results)):
        results[i]["label"] = results[i]["label"].lower()

    results = [results[i] for i in range(len(results))]

    return results


def sentiment(messages, models=None):
    """
    Estimate the sentiment of a list of messages (strings of text). The
    sentences may be in different languages from each other.
    We maintain a list of default models for some languages. In addition,
    the user can provide a model for a given language in the models
    dictionary. The keys for this dictionary are lingua.Language objects
    and items HuggingFace model paths.
    Params:
    messages: list of message strings
    models: dict, huggingface model paths indexed by lingua.Language
    Returns:
    OrderedDict: containing the index as keys and tuple of (message, sentiment result) as values
    """
    messages_with_languages = [
        (message, language_detector.detect_language_of(message)) for message in messages
    ]

    results = process_messages_in_batches(messages_with_languages, models)
    return results


def main():
    st.title("Sentiment Analysis Pipeline")
    messages_input = st.text_area("Enter your messages (one per line):", height=200)
    messages = [message.strip() for message in messages_input.split('\n') if message.strip()]
    
    if st.button("Analyze Sentiments"):
        results = sentiment(messages)
        st.write("## Results:")
        for idx, result in enumerate(results):
            message = messages[idx]
            sentiment_label = result["label"]
            sentiment_score = result["score"]
            st.write(f"**Message:** {message}")
            st.write(f"**Sentiment:** {sentiment_label.capitalize()} (Score: {sentiment_score:.2f})")


if __name__ == "__main__":
    main()