# -*- coding: utf-8 -*-
# file: app.py
# time: 17:08 2023/3/6
# author: YANG, HENG <hy345@exeter.ac.uk> (杨恒)
# github: https://github.com/yangheng95
# huggingface: https://huggingface.co/yangheng
# google scholar: https://scholar.google.com/citations?user=NPq5a_0AAAAJ&hl=en
# Copyright (C) 2023. All Rights Reserved.

import random

import autocuda
import gradio as gr
import pandas as pd
from pyabsa import (
    download_all_available_datasets,
    TaskCodeOption,
    available_checkpoints,
)
from pyabsa import ABSAInstruction
from pyabsa.utils.data_utils.dataset_manager import detect_infer_dataset
import requests

download_all_available_datasets()


def get_atepc_example(dataset):
    task = TaskCodeOption.Aspect_Polarity_Classification
    dataset_file = detect_infer_dataset(atepc_dataset_items[dataset], task)

    for fname in dataset_file:
        lines = []
        if isinstance(fname, str):
            fname = [fname]

        for f in fname:
            print("loading: {}".format(f))
            fin = open(f, "r", encoding="utf-8")
            lines.extend(fin.readlines())
            fin.close()
        for i in range(len(lines)):
            lines[i] = (
                lines[i][: lines[i].find("$LABEL$")]
                .replace("[B-ASP]", "")
                .replace("[E-ASP]", "")
                .strip()
            )
        return sorted(set(lines), key=lines.index)


def get_aste_example(dataset):
    task = TaskCodeOption.Aspect_Sentiment_Triplet_Extraction
    dataset_file = detect_infer_dataset(aste_dataset_items[dataset], task)

    for fname in dataset_file:
        lines = []
        if isinstance(fname, str):
            fname = [fname]

        for f in fname:
            print("loading: {}".format(f))
            fin = open(f, "r", encoding="utf-8")
            lines.extend(fin.readlines())
            fin.close()
        return sorted(set(lines), key=lines.index)


def get_acos_example(dataset):
    task = "ACOS"
    dataset_file = detect_infer_dataset(acos_dataset_items[dataset], task)

    for fname in dataset_file:
        lines = []
        if isinstance(fname, str):
            fname = [fname]

        for f in fname:
            print("loading: {}".format(f))
            fin = open(f, "r", encoding="utf-8")
            lines.extend(fin.readlines())
            fin.close()
        lines = [line.split("####")[0] for line in lines]
        return sorted(set(lines), key=lines.index)


try:
    from pyabsa import AspectTermExtraction as ATEPC

    atepc_dataset_items = {
        dataset.name: dataset for dataset in ATEPC.ATEPCDatasetList()
    }
    atepc_dataset_dict = {
        dataset.name: get_atepc_example(dataset.name)
        for dataset in ATEPC.ATEPCDatasetList()
    }
    aspect_extractor = ATEPC.AspectExtractor(checkpoint="multilingual")
except Exception as e:
    print(e)
    atepc_dataset_items = {}
    atepc_dataset_dict = {}
    aspect_extractor = None

try:
    from pyabsa import AspectSentimentTripletExtraction as ASTE

    aste_dataset_items = {dataset.name: dataset for dataset in ASTE.ASTEDatasetList()}
    aste_dataset_dict = {
        dataset.name: get_aste_example(dataset.name)
        for dataset in ASTE.ASTEDatasetList()[:-1]
    }
    triplet_extractor = ASTE.AspectSentimentTripletExtractor(checkpoint="multilingual")
except Exception as e:
    print(e)
    aste_dataset_items = {}
    aste_dataset_dict = {}
    triplet_extractor = None

try:
    from pyabsa import ABSAInstruction

    acos_dataset_items = {
        dataset.name: dataset for dataset in ABSAInstruction.ACOSDatasetList()
    }
    acos_dataset_dict = {
        dataset.name: get_acos_example(dataset.name)
        for dataset in ABSAInstruction.ACOSDatasetList()
    }
    quadruple_extractor = ABSAInstruction.ABSAGenerator("multilingual")
except Exception as e:
    print(e)
    acos_dataset_items = {}
    acos_dataset_dict = {}
    quadruple_extractor = None


def perform_atepc_inference(text, dataset):
    if not text:
        text = atepc_dataset_dict[dataset][
            random.randint(0, len(atepc_dataset_dict[dataset]) - 1)
        ]

    result = aspect_extractor.predict(text, pred_sentiment=True)

    result = pd.DataFrame(
        {
            "aspect": result["aspect"],
            "sentiment": result["sentiment"],
            # 'probability': result[0]['probs'],
            "confidence": [round(x, 4) for x in result["confidence"]],
            "position": result["position"],
        }
    )
    return result, "{}".format(text)


def perform_aste_inference(text, dataset):
    if not text:
        text = aste_dataset_dict[dataset][
            random.randint(0, len(aste_dataset_dict[dataset]) - 1)
        ]

    result = triplet_extractor.predict(text)

    pred_triplets = pd.DataFrame(result["Triplets"])
    true_triplets = pd.DataFrame(result["True Triplets"]) if result["True Triplets"] else None
    return pred_triplets, true_triplets, "{}".format(text.split("####")[0])


def perform_acos_inference(text, dataset):
    if not text:
        text = acos_dataset_dict[dataset][
            random.randint(0, len(acos_dataset_dict[dataset]) - 1)
        ]

    raw_output = quadruple_extractor.predict(text.split("####")[0], max_length=128)

    result = raw_output["Quadruples"]
    result = pd.DataFrame(result)
    return result, text


def run_demo(text, dataset, task):
    if len(text) > 3000:
        raise RuntimeError('Text is too long!')
    try:
        data = {
            "text": text,
            "dataset": dataset,
            "task": task,
        }
        response = requests.post("https://pyabsa.pagekite.me/api/inference", json=data)
        result = response.json()
        print(response.json())
        if task == "ATEPC":
            return (
                pd.DataFrame(
                    {
                        "aspect": result["aspect"],
                        "sentiment": result["sentiment"],
                        # 'probability': result[0]['probs'],
                        "confidence": [round(x, 4) for x in result["confidence"]],
                        "position": result["position"],
                    }
                ),
                result["text"],
            )
        elif task == "ASTE":
            return (
                        pd.DataFrame(result["pred_triplets"]),
                        pd.DataFrame(result["true_triplets"]),
                        result["text"],
                    )
        elif task == "ACOS":
            return pd.DataFrame(result["Quadruples"]), result["text"]

    except Exception as e:
        print(e)
        print("Failed to connect to the server, running locally...")
        return inference(text, dataset, task)


def inference(text, dataset, task):
    if task == "ATEPC":
        return perform_atepc_inference(text, dataset)
    elif task == "ASTE":
        return perform_aste_inference(text, dataset)
    elif task == "ACOS":
        return perform_acos_inference(text, dataset)
    else:
        raise Exception("No such task: {}".format(task))


if __name__ == "__main__":
    demo = gr.Blocks()

    with demo:
        with gr.Row():
            if quadruple_extractor:
                with gr.Row():
                    with gr.Column():
                        gr.Markdown(
                            "# <p align='center'> ABSA Quadruple Extraction (Experimental) </p>"
                        )

                        acos_input_sentence = gr.Textbox(
                            placeholder="Leave this box blank and choose a dataset will give you a random example...",
                            label="Example:",
                        )
                        acos_dataset_ids = gr.Radio(
                            choices=[
                                dataset.name
                                for dataset in ABSAInstruction.ACOSDatasetList()
                            ],
                            value="Laptop14",
                            label="Datasets",
                        )
                        acos_inference_button = gr.Button("Let's go!")

                        acos_output_text = gr.TextArea(label="Example:")
                        acos_output_pred_df = gr.DataFrame(label="Predicted Triplets:")

                        acos_inference_button.click(
                            fn=run_demo,
                            inputs=[
                                acos_input_sentence,
                                acos_dataset_ids,
                                gr.Text("ACOS", visible=False),
                            ],
                            outputs=[acos_output_pred_df, acos_output_text],
                        )
        with gr.Row():
            if triplet_extractor:
                with gr.Column():
                    gr.Markdown(
                        "# <p align='center'>Aspect Sentiment Triplet Extraction !</p>"
                    )

                    with gr.Row():
                        with gr.Column():
                            aste_input_sentence = gr.Textbox(
                                placeholder="Leave this box blank and choose a dataset will give you a random example...",
                                label="Example:",
                            )
                            gr.Markdown(
                                "You can find code and dataset at [ASTE examples](https://github.com/yangheng95/PyABSA/tree/v2/examples-v2/aspect_sentiment_triplet_extration)"
                            )
                            aste_dataset_ids = gr.Radio(
                                choices=[
                                    dataset.name
                                    for dataset in ASTE.ASTEDatasetList()[:-1]
                                ],
                                value="Restaurant14",
                                label="Datasets",
                            )
                            aste_inference_button = gr.Button("Let's go!")

                            aste_output_text = gr.TextArea(label="Example:")
                            aste_output_pred_df = gr.DataFrame(
                                label="Predicted Triplets:"
                            )
                            aste_output_true_df = gr.DataFrame(
                                label="Original Triplets:"
                            )

                            aste_inference_button.click(
                                fn=run_demo,
                                inputs=[
                                    aste_input_sentence,
                                    aste_dataset_ids,
                                    gr.Text("ASTE", visible=False),
                                ],
                                outputs=[
                                    aste_output_pred_df,
                                    aste_output_true_df,
                                    aste_output_text,
                                ],
                            )
            if aspect_extractor:
                with gr.Column():
                    gr.Markdown(
                        "# <p align='center'>Multilingual Aspect-based Sentiment Analysis !</p>"
                    )
                    with gr.Row():
                        with gr.Column():
                            atepc_input_sentence = gr.Textbox(
                                placeholder="Leave this box blank and choose a dataset will give you a random example...",
                                label="Example:",
                            )
                            gr.Markdown(
                                "You can find the datasets at [github.com/yangheng95/ABSADatasets](https://github.com/yangheng95/ABSADatasets/tree/v1.2/datasets/text_classification)"
                            )
                            atepc_dataset_ids = gr.Radio(
                                choices=[
                                    dataset.name
                                    for dataset in ATEPC.ATEPCDatasetList()[:-1]
                                ],
                                value="Laptop14",
                                label="Datasets",
                            )
                            atepc_inference_button = gr.Button("Let's go!")

                            atepc_output_text = gr.TextArea(label="Example:")
                            atepc_output_df = gr.DataFrame(label="Prediction Results:")

                            atepc_inference_button.click(
                                fn=run_demo,
                                inputs=[
                                    atepc_input_sentence,
                                    atepc_dataset_ids,
                                    gr.Text("ATEPC", visible=False),
                                ],
                                outputs=[atepc_output_df, atepc_output_text],
                            )

        gr.Markdown(
            """### GitHub Repo: [PyABSA V2](https://github.com/yangheng95/PyABSA)
            ### Author: [Heng Yang](https://github.com/yangheng95) (杨恒)
            [![Downloads](https://pepy.tech/badge/pyabsa)](https://pepy.tech/project/pyabsa) 
            [![Downloads](https://pepy.tech/badge/pyabsa/month)](https://pepy.tech/project/pyabsa)
            """
        )

    demo.launch()