from functools import partial

import pandas as pd
import streamlit as st
import torch
from datasets import Dataset, DatasetDict, load_dataset  # type: ignore
from torch.nn.functional import cross_entropy
from transformers import DataCollatorForTokenClassification  # type: ignore

from src.utils import device, tokenizer_hash_funcs


@st.cache(allow_output_mutation=True)
def get_data(
    ds_name: str, config_name: str, split_name: str, split_sample_size: int, randomize_sample: bool
) -> Dataset:
    """Loads a Dataset from the HuggingFace hub (if not already loaded).

    Uses `datasets.load_dataset` to load the dataset (see its documentation for additional details).

    Args:
        ds_name (str): Path or name of the dataset.
        config_name (str): Name of the dataset configuration.
        split_name (str): Which split of the data to load.
        split_sample_size (int): The number of examples to load from the split.

    Returns:
        Dataset: A Dataset object.
    """
    ds: DatasetDict = load_dataset(ds_name, name=config_name, use_auth_token=True).shuffle(
        seed=0 if randomize_sample else None
    )  # type: ignore
    split = ds[split_name].select(range(split_sample_size))
    return split


@st.cache(
    allow_output_mutation=True,
    hash_funcs=tokenizer_hash_funcs,
)
def get_collator(tokenizer) -> DataCollatorForTokenClassification:
    """Returns a DataCollator that will dynamically pad the inputs received, as well as the labels.

    Args:
        tokenizer ([PreTrainedTokenizer] or [PreTrainedTokenizerFast]): The tokenizer used for encoding the data.

    Returns:
        DataCollatorForTokenClassification: The DataCollatorForTokenClassification object.
    """
    return DataCollatorForTokenClassification(tokenizer)


def create_word_ids_from_input_ids(tokenizer, input_ids: list[int]) -> list[int]:
    """Takes a list of input_ids and return corresponding word_ids

    Args:
        tokenizer: The tokenizer that was used to obtain the input ids.
        input_ids (list[int]): List of token ids.

    Returns:
        list[int]: Word ids corresponding to the input ids.
    """
    word_ids = []
    wid = -1
    tokens = [tokenizer.convert_ids_to_tokens(i) for i in input_ids]

    for i, tok in enumerate(tokens):
        if tok in tokenizer.all_special_tokens:
            word_ids.append(-1)
            continue

        if not tokens[i - 1].endswith("@@") and tokens[i - 1] != "<unk>":
            wid += 1

        word_ids.append(wid)

    assert len(word_ids) == len(input_ids)
    return word_ids


def tokenize(batch, tokenizer) -> dict:
    """Tokenizes a batch of examples.

    Args:
        batch: The examples to tokenize
        tokenizer: The tokenizer to use

    Returns:
        dict: The tokenized batch
    """
    tokenized_inputs = tokenizer(batch["tokens"], truncation=True, is_split_into_words=True)
    labels = []
    wids = []

    for idx, label in enumerate(batch["ner_tags"]):
        try:
            word_ids = tokenized_inputs.word_ids(batch_index=idx)
        except ValueError:
            word_ids = create_word_ids_from_input_ids(
                tokenizer, tokenized_inputs["input_ids"][idx]
            )
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx == -1 or word_idx is None or word_idx == previous_word_idx:
                label_ids.append(-100)
            else:
                label_ids.append(label[word_idx])
            previous_word_idx = word_idx
        wids.append(word_ids)
        labels.append(label_ids)
    tokenized_inputs["word_ids"] = wids
    tokenized_inputs["labels"] = labels
    return tokenized_inputs


def stringify_ner_tags(batch: dict, tags) -> dict:
    """Stringifies a dataset batch's NER tags."""
    return {"ner_tags_str": [tags.int2str(idx) for idx in batch["ner_tags"]]}


def encode_dataset(split: Dataset, tokenizer):
    """Encodes a dataset split.

    Args:
        split (Dataset): A Dataset object.
        tokenizer: A PreTrainedTokenizer object.

    Returns:
        Dataset: A Dataset object with the encoded inputs.
    """

    tags = split.features["ner_tags"].feature
    split = split.map(partial(stringify_ner_tags, tags=tags), batched=True)
    remove_columns = split.column_names
    ids = split["id"]
    split = split.map(
        partial(tokenize, tokenizer=tokenizer),
        batched=True,
        remove_columns=remove_columns,
    )
    word_ids = [[id if id is not None else -1 for id in wids] for wids in split["word_ids"]]
    return split.remove_columns(["word_ids"]), word_ids, ids


def forward_pass_with_label(batch, model, collator, num_classes: int) -> dict:
    """Runs the forward pass for a batch of examples.

    Args:
        batch: The batch to process
        model: The model to process the batch with
        collator: A data collator
        num_classes (int): Number of classes

    Returns:
        dict: a dictionary containing `losses`, `preds` and `hidden_states`
    """

    # Convert dict of lists to list of dicts suitable for data collator
    features = [dict(zip(batch, t)) for t in zip(*batch.values())]

    # Pad inputs and labels and put all tensors on device
    batch = collator(features)
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    labels = batch["labels"].to(device)

    with torch.no_grad():
        # Pass data through model
        output = model(input_ids, attention_mask, output_hidden_states=True)
        # logit.size: [batch_size, sequence_length, classes]

        # Predict class with largest logit value on classes axis
        preds = torch.argmax(output.logits, axis=-1).cpu().numpy()  # type: ignore

        # Calculate loss per token after flattening batch dimension with view
        loss = cross_entropy(
            output.logits.view(-1, num_classes), labels.view(-1), reduction="none"
        )

        # Unflatten batch dimension and convert to numpy array
        loss = loss.view(len(input_ids), -1).cpu().numpy()
        hidden_states = output.hidden_states[-1].cpu().numpy()

        # logits = output.logits.view(len(input_ids), -1).cpu().numpy()

    return {"losses": loss, "preds": preds, "hidden_states": hidden_states}


def predict(split_encoded: Dataset, model, tokenizer, collator, tags) -> pd.DataFrame:
    """Generates predictions for a given dataset split and returns the results as a dataframe.

    Args:
        split_encoded (Dataset): The dataset to process
        model: The model to process the dataset with
        tokenizer: The tokenizer to process the dataset with
        collator: The data collator to use
        tags: The tags used in the dataset

    Returns:
        pd.DataFrame: A dataframe containing token-level predictions.
    """

    split_encoded = split_encoded.map(
        partial(
            forward_pass_with_label,
            model=model,
            collator=collator,
            num_classes=tags.num_classes,
        ),
        batched=True,
        batch_size=8,
    )
    df: pd.DataFrame = split_encoded.to_pandas()  # type: ignore

    df["tokens"] = df["input_ids"].apply(
        lambda x: tokenizer.convert_ids_to_tokens(x)  # type: ignore
    )
    df["labels"] = df["labels"].apply(
        lambda x: ["IGN" if i == -100 else tags.int2str(int(i)) for i in x]
    )
    df["preds"] = df["preds"].apply(lambda x: [model.config.id2label[i] for i in x])
    df["preds"] = df.apply(lambda x: x["preds"][: len(x["input_ids"])], axis=1)
    df["losses"] = df.apply(lambda x: x["losses"][: len(x["input_ids"])], axis=1)
    df["hidden_states"] = df.apply(lambda x: x["hidden_states"][: len(x["input_ids"])], axis=1)
    df["total_loss"] = df["losses"].apply(sum)

    return df