import os
import torch
import spacy
import spaces
import numpy as np
import pandas as pd
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import gradio as gr
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import PercentFormatter
import matplotlib.colors as mcolors
import plotly.express as px
import seaborn as sns
from tqdm import tqdm

PATH = "/data/"  # at least 150GB storage needs to be attached
os.environ["TRANSFORMERS_CACHE"] = PATH
os.environ["HF_HOME"] = PATH
os.environ["HF_DATASETS_CACHE"] = PATH
os.environ["TORCH_HOME"] = PATH

css = """
.info {font-size: 3em; !important}
.title_ {text-align: center;}
"""

HF_TOKEN = os.environ["hf_read"]

SENTIMENT_LABEL_NAMES = {
    0: "Negative",
    1: "No sentiment or Neutral sentiment",
    2: "Positive",
}
LANGUAGES = ["Czech", "English", "French", "German", "Hungarian", "Polish", "Slovakian"]

id2label = {
    0: "Anger",
    1: "Fear",
    2: "Disgust",
    3: "Sadness",
    4: "Joy",
    5: "None of Them",
}

emotion_colors = {
    "Anger": "#D96459",
    "Fear": "#6A8EAE",
    "Disgust": "#A4C639",
    "Sadness": "#9DBCD4",
    "Joy": "#F3E9A8",
    "None of Them": "#C0C0C0",
}


def load_spacy_model(model_name="xx_sent_ud_sm"):
    try:
        model = spacy.load(model_name)
    except OSError:
        spacy.cli.download(model_name)
        model = spacy.load(model_name)
    return model


def split_sentences(text, model):
    # disable pipeline components not necessary for splitting
    model.disable_pipes(model.pipe_names)  # first disable all the pipes
    model.enable_pipe("senter")  # then enable the sentence splitter only

    doc = model(text)
    sentences = [sent.text for sent in doc.sents]

    return sentences


def build_huggingface_path(language: str):
    if language == "Czech" or language == "Slovakian":
        return "visegradmedia-emotion/Emotion_RoBERTa_pooled_V4"
    return "poltextlab/xlm-roberta-large-pooled-emotions6"


@spaces.GPU
def predict(text, model_id, tokenizer_id):
    model = AutoModelForSequenceClassification.from_pretrained(
        model_id,
        low_cpu_mem_usage=True,
        device_map="auto",
        offload_folder="offload",
        token=HF_TOKEN,
    )
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)

    inputs = tokenizer(
        text, max_length=64, truncation=True, padding="do_not_pad", return_tensors="pt"
    )
    model.eval()

    with torch.no_grad():
        logits = model(**inputs).logits

    probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
    return probs


def get_most_probable_label(probs, idx=1):
    sorted_indices = probs.argsort()[::-1]
    selected_idx = sorted_indices[idx - 1]
    label = id2label[selected_idx]
    probability = f"{round(100 * probs[selected_idx], 2)}%"
    return label, probability


def prepare_heatmap_data(data):
    heatmap_data = pd.DataFrame(0.0, index=id2label.values(), columns=range(len(data)))

    for idx, row in enumerate(data):
        confidences = row["emotions"].tolist()
        for idy, confidence in enumerate(confidences):
            emotion = id2label[idy]
            heatmap_data.at[emotion, idx] = round(confidence, 4)

    heatmap_data.columns = [item["sentence"][:18] + "..." for item in data]
    return heatmap_data


def plot_emotion_heatmap(heatmap_data):
    # Transpose: now rows = sentences, columns = emotions
    heatmap_data = heatmap_data.T

    # Normalize each row (sentence-wise)
    normalized_data = heatmap_data.copy()
    for row in normalized_data.index:
        max_val = normalized_data.loc[row].max()
        normalized_data.loc[row] = (
            normalized_data.loc[row] / max_val if max_val > 0 else 0
        )

    # Create color matrix
    color_matrix = np.empty(
        (len(normalized_data.index), len(normalized_data.columns), 3)
    )
    for i, sentence in enumerate(normalized_data.index):
        for j, emotion in enumerate(normalized_data.columns):
            val = normalized_data.loc[sentence, emotion]
            base_rgb = mcolors.to_rgb(emotion_colors[emotion])
            # Blend from white to base color
            blended = tuple(1 - val * (1 - c) for c in base_rgb)
            color_matrix[i, j] = blended

    fig, ax = plt.subplots(
        figsize=(
            len(normalized_data.columns) * 0.8 + 2,
            len(normalized_data.index) * 0.5 + 2,
        )
    )
    ax.imshow(color_matrix, aspect="auto")

    # Set ticks and labels
    ax.set_xticks(np.arange(len(normalized_data.columns)))
    ax.set_xticklabels(normalized_data.columns, rotation=45, ha="right", fontsize=10)

    ax.set_yticks(np.arange(len(normalized_data.index)))
    ax.set_yticklabels(normalized_data.index, rotation=0, fontsize=10)

    ax.set_xlabel("Emotions")
    ax.set_ylabel("Sentences")

    plt.tight_layout()
    return fig


def plot_average_emotion_barplot(heatmap_data):
    # Compute average emotion scores
    all_emotion_scores = np.array([item["emotions"] for item in heatmap_data])
    mean_scores = all_emotion_scores.mean(axis=0)

    labels = [id2label[i] for i in range(len(mean_scores))]
    scores = mean_scores

    colors = [emotion_colors[label] for label in labels]

    fig, ax = plt.subplots(figsize=(8, 6))
    bars = sns.barplot(x=list(scores), y=list(labels), palette=colors, ax=ax)

    ax.xaxis.set_major_formatter(PercentFormatter(xmax=1.0, decimals=0))

    # Add percentage labels
    for i, score in enumerate(scores):
        ax.text(score + 0.01, i, f"{score*100:.1f}%", va="center")

    ax.set_title("Which emotions showed up most in the text?", fontsize=14)
    ax.set_xlabel("Average Confidence")
    ax.set_ylabel("Emotions")
    plt.tight_layout()

    return fig


def predict_wrapper(text, language):
    model_id = build_huggingface_path(language)
    tokenizer_id = "xlm-roberta-large"

    spacy_model = load_spacy_model()
    sentences = split_sentences(text, spacy_model)

    results = []
    results_heatmap = []
    for sentence in tqdm(sentences):
        probs = predict(sentence, model_id, tokenizer_id)
        label1, probability1 = get_most_probable_label(probs, 1)
        label2, probability2 = get_most_probable_label(probs, 2)
        results.append([sentence, label1, probability1, label2, probability2])
        results_heatmap.append({"sentence": sentence, "emotions": probs})

    # let's see...
    print(results)
    print(results_heatmap)

    figure = plot_average_emotion_barplot(results_heatmap)
    heatmap = plot_emotion_heatmap(prepare_heatmap_data(results_heatmap))
    output_info = f'Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model. '
    funding_info = "The research was funded by European Union’s Horizon 2020 research and innovation program, “MORES” project (Grant No.: 101132601)"
    return results, figure, heatmap, output_info + funding_info


with gr.Blocks(css=css) as demo:
    placeholder = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
    introduction = """
    This application detects and visualises emotions in text. The model behind it operates using a 6-label codebook, including the following labels: `Anger`, `Fear`, `Disgust`, `Sadness`, `Joy`, and `None of Them`.
    The [model](https://huggingface.co/poltextlab/xlm-roberta-large-pooled-emotions6) is optimised for sentence-level analysis, and make predictions in the following languages: Czech, English, French, German, Hungarian, Polish, and Slovak.
    The text you enter in the input box is automatically divided into sentences, and the analysis is performed on each sentence. Depending on the length of the text, this process may take a few seconds, but for longer texts, it can take up to 2-3 minutes.
    Read our Q&A about Pulse [here](https://cms.mores-horizon.eu/uploads/MORES_Pulse_Q_and_A_33f61ea348.pdf).
    """

    gr.HTML("<h1>MORES Pulse</h1>", elem_classes="title_")
    gr.Markdown(introduction, elem_classes="info")
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(
                lines=6, label="Input", placeholder="Enter your text here..."
            )
        with gr.Column():
            with gr.Row():
                language_choice = gr.Dropdown(
                    choices=LANGUAGES, label="Language", value="English"
                )
            with gr.Row():
                predict_button = gr.Button("Submit")

    with gr.Row():
        with gr.Column(scale=7):
            plot = gr.Plot()
        with gr.Column(scale=3):
            gr.Markdown(
                "The chart gives an overview of the main emotions found in the text and how strongly each one is present.",
                elem_classes="info",
            )

    with gr.Row():
        with gr.Column(scale=7):
            result_table = gr.Dataframe(
                headers=["Sentence", "Prediction (1)", "Confidence (1)", "Prediction (2)", "Confidence (2)"],
                column_widths=["46%", "17%", "10%", "17%", "10%"],
                wrap=True,  # important
            )
        with gr.Column(scale=3):
            gr.Markdown(
                "This table shows the two most probable emotions detected in each sentence, along with how confident our predictions are. For all emotions check the heatmap below.",
                elem_classes="info",
            )

    with gr.Row():
        with gr.Column(scale=7):
            heatmap = gr.Plot()
        with gr.Column(scale=3):
            gr.Markdown(
                "This heatmap shows how strongly each emotion appears in every sentence. Darker colours mean stronger presence.",
                elem_classes="info",
            )

    with gr.Row():
        model_info = gr.Markdown()

    predict_button.click(
        fn=predict_wrapper,
        inputs=[input_text, language_choice],
        outputs=[result_table, plot, heatmap, model_info],
    )

if __name__ == "__main__":
    demo.launch()