from typing import Dict, List, Union

import numpy as np
from datasets import Dataset, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate

LANGAUGE_TO_PREFIX = {
    "chinese_simplified": "zh-CN",
    "arabic": "ar",
    "hindi": "hi",
    "indonesian": "id",
    "amharic": "am",
    "bengali": "bn",
    "burmese": "my",
    "uzbek": "uz",
    "nepali": "ne",
    "japanese": "ja",
    "spanish": "es",
    "turkish": "tr",
    "persian": "fa",
    "azerbaijani": "az",
    "korean": "ko",
    "hebrew": "he",
    "telugu": "te",
    "german": "de",
    "greek": "el",
    "tamil": "ta",
    "assamese": "as",
    "russian": "ru",
    "romanian": "ro",
    "malayalam": "ml",
    "swahili": "sw",
    "bulgarian": "bg",
    "thai": "th",
    "urdu": "ur",
    "polish": "pl",
    "dutch": "nl",
    "danish": "da",
    "norwegian": "no",
    "finnish": "fi",
    "hungarian": "hu",
    "czech": "cs",
    "ukrainian": "uk",
    "bambara": "bam",
    "ewe": "ewe",
    "fon": "fon",
    "hausa": "hau",
    "igbo": "ibo",
    "kinyarwanda": "kin",
    "chichewa": "nya",
    "twi": "twi",
    "yoruba": "yor",
    "slovak": "sk",
    "serbian": "sr",
    "swedish": "sv",
    "vietnamese": "vi",
    "italian": "it",
    "portuguese": "pt",
    "chinese": "zh",
    "english": "en",
    "french": "fr",
}

def _translate_instruction(basic_instruction: str, target_language: str) -> str:
    translator = EasyGoogleTranslate(
        source_language="en",
        target_language=LANGAUGE_TO_PREFIX[target_language],
        timeout=10,
    )
    return translator.translate(basic_instruction)


def create_instruction(lang: str, instruction_language: str, expected_output: str):
    basic_instruction = f"""You are an NLP assistant whose
                            purpose is to perform Named Entity Recognition
                            (NER). You will need to give each entity a tag, from the following:
                            PER means a person, ORG means organization.
                            LOC means a location entity.
                            The output should be a list of tuples of the format:
                            ['Tag: Entity', 'Tag: Entity'] for each entity in the sentence. 
                            The entities should be in {expected_output} language"""

    return (
        instruction_language
        if lang == "english"
        else _translate_instruction(basic_instruction, target_language=lang)
    )


def load_wikiann_dataset(lang, split, limit):
    """Loads the xlsum dataset"""
    dataset = load_dataset("wikiann", LANGAUGE_TO_PREFIX[lang])[split]
    return dataset.select(np.arange(limit))


def _translate_example(
    example: Dict[str, str], src_language: str, target_language: str
):
    translator = EasyGoogleTranslate(
        source_language=LANGAUGE_TO_PREFIX[src_language],
        target_language=LANGAUGE_TO_PREFIX[target_language],
        timeout=30,
    )

    return {
        "tokens": translator.translate(str(example["tokens"])),
        "ner_tags": translator.translate(str(example["ner_tags"])),
    }


def choose_few_shot_examples(
    train_dataset: Dataset,
    few_shot_size: int,
    context: List[str],
    selection_criteria: str,
    lang: str,
) -> List[Dict[str, Union[str, int]]]:
    """Selects few-shot examples from training datasets

    Args:
        train_dataset (Dataset): Training Dataset
        few_shot_size (int): Number of few-shot examples
        selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]

    Returns:
        List[Dict[str, Union[str, int]]]: Selected examples
    """
    selected_examples = []

    example_idxs = []
    if selection_criteria == "first_k":
        example_idxs = list(range(few_shot_size))
    elif selection_criteria == "random":
        example_idxs = (
            np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
            .astype(int)
            .tolist()
        )

    ic_examples = [train_dataset[idx] for idx in example_idxs]

    ic_examples = [
        {"tokens": " ".join(example["tokens"]), "ner_tags": example["spans"]}
        for example in ic_examples
    ]

    for idx, ic_language in enumerate(context):
        (
            selected_examples.append(ic_examples[idx])
            if ic_language == lang
            else (
                selected_examples.append(
                    _translate_example(
                        example=ic_examples[idx],
                        src_language=lang,
                        target_language=ic_language,
                    )
                )
            )
        )

    return selected_examples


def construct_prompt(
    instruction: str,
    test_example: dict,
    zero_shot: bool,
    dataset: str,
    num_examples: int,
    lang: str,
    config: Dict[str, str],
):
    if not instruction:
        instruction = create_instruction(lang, config["prefix"], config["output"])

    example_prompt = PromptTemplate(
        input_variables=["tokens", "ner_tags"],
        template="Sentence: {tokens}\nNer Tags: {ner_tags}",
    )

    zero_shot_template = f"""{instruction}""" + "\n Sentence: {text} " ""

    try:
        test_data = load_wikiann_dataset(lang=lang, split="test", limit=500)
    except Exception as e:
        raise KeyError(
            f"{lang} is not supported in 'wikiAnn' dataset, choose supported language in few-shot"
        )

    ic_examples = []
    if not zero_shot:

        ic_examples = choose_few_shot_examples(
            train_dataset=test_data,
            few_shot_size=num_examples,
            context=[config["context"]] * num_examples,
            selection_criteria="random",
            lang=lang,
        )

    prompt = (
        FewShotPromptTemplate(
            examples=ic_examples,
            prefix=instruction,
            example_prompt=example_prompt,
            suffix="<Text>: {text}",
            input_variables=["text"],
        )
        if not zero_shot
        else PromptTemplate(input_variables=["text"], template=zero_shot_template)
    )

    if config["input"] != lang:
        test_example = _translate_example(
            example=test_example, src_language=lang, target_language=config["input"]
        )

    print(test_example)
    return prompt.format(text=test_example["tokens"])