import logging
import os
from argparse import ArgumentParser
from ast import literal_eval
from types import SimpleNamespace
from typing import List

from robustnessgym import Dataset, Spacy, CachedOperation
from robustnessgym.core.constants import CACHEDOPS
from robustnessgym.core.tools import strings_as_json
from robustnessgym.logging.utils import set_logging_level
from spacy import load
from spacy.attrs import DEP, IS_ALPHA, IS_PUNCT, IS_STOP, LEMMA, LOWER, TAG, SENT_END, \
    SENT_START, ORTH, POS, ENT_IOB
from spacy.tokens import Doc

from align import BertscoreAligner, NGramAligner, StaticEmbeddingAligner
from utils import preprocess_text

set_logging_level('critical')
logger = logging.getLogger(__name__)
logger.setLevel(logging.CRITICAL)


def _spacy_encode(self, x):
    arr = x.to_array(
        [DEP, IS_ALPHA, IS_PUNCT, IS_STOP, LEMMA, LOWER, TAG, SENT_END, SENT_START,
         ORTH, POS, ENT_IOB])
    return {
        'arr': arr.flatten(),
        'shape': list(arr.shape),
        'words': [t.text for t in x]
    }


def _spacy_decode(self, x):
    doc = Doc(self.nlp.vocab, words=x['words'])
    return doc.from_array(
        [DEP, IS_ALPHA, IS_PUNCT, IS_STOP, LEMMA, LOWER,
         TAG, SENT_END, SENT_START, ORTH, POS, ENT_IOB],
        x['arr'].reshape(x['shape'])
    )


Spacy.encode = _spacy_encode
Spacy.decode = _spacy_decode


class AlignerCap(CachedOperation):
    def __init__(
            self,
            aligner,
            spacy,
            **kwargs,
    ):
        super(AlignerCap, self).__init__(**kwargs)
        self.spacy = spacy
        self.aligner = aligner

    @classmethod
    def encode(cls, x):
        # Convert to built-in types from np.int / np.float
        return super(AlignerCap, cls).encode([
            {str(k): [(int(t[0]), float(t[1])) for t in v] for k, v in d.items()}
            for d in x
        ])

    @classmethod
    def decode(cls, x):
        x = super(AlignerCap, cls).decode(x)
        x = [{literal_eval(k): v for k, v in d.items()} for d in x]
        return x

    def apply(self, batch, columns, *args, **kwargs):
        # Run the aligner on the first example of the batch
        return [
            self.aligner.align(
                self.spacy.retrieve(batch, columns[0])[0],
                [self.spacy.retrieve(batch, col)[0] for col in columns[1:]]
                if len(columns) > 2 else
                [self.spacy.retrieve(batch, columns[1])[0]],
            )
        ]


class BertscoreAlignerCap(AlignerCap):
    def __init__(
            self,
            threshold: float,
            top_k: int,
            spacy,
    ):
        super(BertscoreAlignerCap, self).__init__(
            aligner=BertscoreAligner(threshold=threshold, top_k=top_k),
            spacy=spacy,
            threshold=threshold,
            top_k=top_k,
        )


class NGramAlignerCap(AlignerCap):
    def __init__(
            self,
            spacy,
    ):
        super(NGramAlignerCap, self).__init__(
            aligner=NGramAligner(),
            spacy=spacy
        )


class StaticEmbeddingAlignerCap(AlignerCap):
    def __init__(
            self,
            threshold: float,
            top_k: int,
            spacy,
    ):
        super(StaticEmbeddingAlignerCap, self).__init__(
            aligner=StaticEmbeddingAligner(threshold=threshold, top_k=top_k),
            spacy=spacy,
            threshold=threshold,
            top_k=top_k,
        )


def _run_aligners(
        dataset: Dataset,
        aligners: List[CachedOperation],
        doc_column: str,
        reference_column: str,
        summary_columns: List[str] = None,
):
    if not summary_columns:
        summary_columns = []

    to_columns = []
    if reference_column is not None:
        to_columns.append(reference_column)
    to_columns.extend(summary_columns)

    for aligner in aligners:

        # Run the aligner on (document, summary) pairs

        dataset = aligner(
            dataset,
            [doc_column] + to_columns,
            # Must use `batch_size = 1`
            batch_size=1,
        )

        if reference_column is not None and len(summary_columns):
            # Run the aligner on (reference, summary) pairs
            dataset = aligner(
                dataset,
                [reference_column] + summary_columns,
                # Must use `batch_size = 1`
                batch_size=1,
            )

        if len(to_columns) > 1:
            # Instead of having one column for (document, summary) comparisons, split
            # off into (1 + |summary_columns|) total columns, one for each comparison

            # Retrieve the (document, summary) column
            doc_summary_column = aligner.retrieve(
                dataset[:],
                [doc_column] + to_columns,
            )[tuple([doc_column] + to_columns)]

            for i, col in enumerate(to_columns):
                # Add as a new column after encoding with the aligner's `encode` method
                dataset.add_column(
                    column=str(aligner.identifier(columns=[doc_column, col])),
                    values=[aligner.encode([row[i]]) for row in doc_summary_column],
                )

            # Remove the (document, summary) column
            dataset.remove_column(
                str(
                    aligner.identifier(
                        columns=[doc_column] + to_columns
                    )
                )
            )
            del dataset.interactions[CACHEDOPS].history[
                (
                    aligner.identifier,
                    strings_as_json(
                        strings=[doc_column] + to_columns
                    )
                )
            ]

        if reference_column is not None and len(summary_columns) > 1:
            # Instead of having one column for (reference, summary) comparisons, split
            # off into (|summary_columns|) total columns, one for each comparison

            # Retrieve the (reference, summary) column
            reference_summary_column = aligner.retrieve(
                dataset[:],
                [reference_column] + summary_columns,
            )[tuple([reference_column] + summary_columns)]

            for i, col in enumerate(summary_columns):
                # Add as a new column
                dataset.add_column(
                    column=str(aligner.identifier(columns=[reference_column, col])),
                    values=[
                        aligner.encode([row[i]]) for row in reference_summary_column
                    ]
                )

            # Remove the (reference, summary) column
            dataset.remove_column(
                str(
                    aligner.identifier(
                        columns=[reference_column] + summary_columns
                    )
                )
            )
            del dataset.interactions[CACHEDOPS].history[
                (
                    aligner.identifier,
                    strings_as_json(
                        strings=[reference_column] + summary_columns
                    )
                )
            ]

    return dataset


def deanonymize_dataset(
        rg_path: str,
        standardized_dataset: Dataset,
        processed_dataset_path: str = None,
        n_samples: int = None,

):
    """Take an anonymized dataset and add back the original dataset columns."""
    assert processed_dataset_path is not None, \
        "Please specify a path to save the dataset."

    # Load the dataset
    dataset = Dataset.load_from_disk(rg_path)

    if n_samples:
        dataset.set_visible_rows(list(range(n_samples)))
        standardized_dataset.set_visible_rows(list(range(n_samples)))

    text_columns = []

    # Add columns from the standardized dataset
    dataset.add_column('document', standardized_dataset['document'])
    text_columns.append('document')

    if 'summary:reference' in standardized_dataset.column_names:
        dataset.add_column('summary:reference', standardized_dataset['summary:reference'])
        text_columns.append('summary:reference')

    # Preprocessing all the text columns
    dataset = dataset.update(
        lambda x: {f'preprocessed_{k}': preprocess_text(x[k]) for k in text_columns}
    )

    # Run the Spacy pipeline on all preprocessed text columns
    try:
        nlp = load('en_core_web_lg')
    except OSError:
        nlp = load('en_core_web_sm')

    nlp.add_pipe('sentencizer', before="parser")
    spacy = Spacy(nlp=nlp)
    dataset = spacy(
        dataset,
        [f'preprocessed_{col}' for col in text_columns],
        batch_size=100,
    )

    # Directly save to disk
    dataset.save_to_disk(processed_dataset_path)

    return dataset


def run_workflow(
        jsonl_path: str = None,
        dataset: Dataset = None,
        doc_column: str = None,
        reference_column: str = None,
        summary_columns: List[str] = None,
        bert_aligner_threshold: float = 0.5,
        bert_aligner_top_k: int = 3,
        embedding_aligner_threshold: float = 0.5,
        embedding_aligner_top_k: int = 3,
        processed_dataset_path: str = None,
        n_samples: int = None,
        anonymize: bool = False,
):
    assert (jsonl_path is None) != (dataset is None), \
        "One of `jsonl_path` and `dataset` must be specified."
    assert processed_dataset_path is not None, \
        "Please specify a path to save the dataset."

    # Load the dataset
    if jsonl_path is not None:
        dataset = Dataset.from_jsonl(jsonl_path)

    if doc_column is None:
        # Assume `doc_column` is called "document"
        doc_column = 'document'
        assert doc_column in dataset.column_names, \
            f"`doc_column={doc_column}` is not a column in dataset."
        print("Assuming `doc_column` is called 'document'.")

    if reference_column is None:
        # Assume `reference_column` is called "summary:reference"
        reference_column = 'summary:reference'
        print("Assuming `reference_column` is called 'summary:reference'.")
        if reference_column not in dataset.column_names:
            print("No reference summary loaded")
            reference_column = None

    if summary_columns is None or len(summary_columns) == 0:
        # Assume `summary_columns` are prefixed by "summary:"
        summary_columns = []
        for col in dataset.column_names:
            if col.startswith("summary:") and col != "summary:reference":
                summary_columns.append(col)
        print(f"Reading summary columns from dataset. Found {summary_columns}.")

    if len(summary_columns) == 0 and reference_column is None:
        raise ValueError("At least one summary is required")

    # Set visible rows to restrict to the first `n_samples`
    if n_samples:
        dataset.set_visible_rows(list(range(n_samples)))

    # Combine the text columns into one list
    text_columns = [doc_column] + ([reference_column] if reference_column else []) + summary_columns

    # Preprocessing all the text columns
    dataset = dataset.update(
        lambda x: {f'preprocessed_{k}': preprocess_text(x[k]) for k in text_columns}
    )

    # Run the Spacy pipeline on all preprocessed text columns
    nlp = load('en_core_web_lg')
    nlp.add_pipe('sentencizer', before="parser")
    spacy = Spacy(nlp=nlp)
    dataset = spacy(
        dataset,
        [f'preprocessed_{col}' for col in text_columns],
        batch_size=100,
    )

    # Run the 3 align pipelines
    bert_aligner = BertscoreAlignerCap(
        threshold=bert_aligner_threshold,
        top_k=bert_aligner_top_k,
        spacy=spacy,
    )

    embedding_aligner = StaticEmbeddingAlignerCap(
        threshold=embedding_aligner_threshold,
        top_k=embedding_aligner_top_k,
        spacy=spacy,
    )

    ngram_aligner = NGramAlignerCap(
        spacy=spacy,
    )

    dataset = _run_aligners(
        dataset=dataset,
        aligners=[bert_aligner, embedding_aligner, ngram_aligner],
        doc_column=f'preprocessed_{doc_column}',
        reference_column=f'preprocessed_{reference_column}' if reference_column else None,
        summary_columns=[f'preprocessed_{col}' for col in summary_columns],
    )

    # Save the dataset
    if anonymize:
        # Remove certain columns to anonymize and save to disk
        for col in [doc_column, reference_column]:
            if col is not None:
                dataset.remove_column(col)
                dataset.remove_column(f'preprocessed_{col}')
                dataset.remove_column(
                    str(spacy.identifier(columns=[f'preprocessed_{col}']))
                )
                del dataset.interactions[CACHEDOPS].history[
                    (spacy.identifier, f'preprocessed_{col}')
                ]
        dataset.save_to_disk(f'{processed_dataset_path}.anonymized')
    else:
        # Directly save to disk
        dataset.save_to_disk(processed_dataset_path)

    return dataset


def parse_prediction_jsonl_name(prediction_jsonl: str):
    """Parse the name of the prediction_jsonl to extract useful information."""
    # Analyze the name of the prediction_jsonl
    filename = prediction_jsonl.split("/")[-1]

    # Check that the filename ends with `.results.anonymized`
    if filename.endswith(".results.anonymized"):
        # Fmt: <model>-<training dataset>.<eval dataset>.<eval split>.results.anonymized

        # Split using a period
        model_train_dataset, eval_dataset, eval_split = filename.split(".")[:-2]
        model, train_dataset = model_train_dataset.split("-")

        return SimpleNamespace(
            model_train_dataset=model_train_dataset,
            model=model,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            eval_split=eval_split,
        )

    raise NotImplementedError(
        "Prediction files must be named "
        "<model>-<training dataset>.<eval dataset>.<eval split>.results.anonymized. "
        f"Please rename the prediction file {filename} and run again."
    )


def join_predictions(
        dataset_jsonl: str = None,
        prediction_jsonls: str = None,
        save_jsonl_path: str = None,
):
    """Join predictions with a dataset."""
    assert prediction_jsonls is not None, "Must have prediction jsonl files."

    print(
        "> Warning: please inspect the prediction .jsonl file to make sure that "
        "predictions are aligned with the examples in the dataset. "
        "Use `get_dataset` to inspect the dataset."
    )

    # Load the dataset
    dataset = get_dataset(dataset_jsonl=dataset_jsonl)

    # Parse names of all prediction files to get metadata
    metadata = [
        parse_prediction_jsonl_name(prediction_jsonl)
        for prediction_jsonl in prediction_jsonls
    ]

    # Load the predictions
    predictions = [
        Dataset.from_jsonl(json_path=prediction_jsonl)
        for prediction_jsonl in prediction_jsonls
    ]

    # Predictions for a model
    for i, prediction_data in enumerate(predictions):
        # Get metadata for i_th prediction file
        metadata_i = metadata[i]

        # Construct a prefix for columns added to the dataset for this prediction file
        prefix = metadata_i.model_train_dataset

        # Add the predictions column to the dataset
        for col in prediction_data.column_names:
            # Don't add the indexing information since the dataset has it already
            if col not in {'index', 'ix', 'id'}:
                # `add_column` will automatically ensure that column lengths match
                if col == 'decoded':  # rename decoded to summary
                    dataset.add_column(f'summary:{prefix}', prediction_data[col])
                else:
                    dataset.add_column(f'{prefix}:{col}', prediction_data[col])

    # Save the dataset back to disk
    if save_jsonl_path:
        dataset.to_jsonl(save_jsonl_path)
    else:
        print("Dataset with predictions was not saved since `save_jsonl_path` "
              "was not specified.")

    return dataset


def standardize_dataset(
        dataset_name: str = None,
        dataset_version: str = None,
        dataset_split: str = 'test',
        dataset_jsonl: str = None,
        doc_column: str = None,
        reference_column: str = None,
        save_jsonl_path: str = None,
        no_save: bool = False,
):
    """Load a dataset from Huggingface and dump it to disk."""
    # Load the dataset from Huggingface
    dataset = get_dataset(
        dataset_name=dataset_name,
        dataset_version=dataset_version,
        dataset_split=dataset_split,
        dataset_jsonl=dataset_jsonl,
    )

    if doc_column is None:
        if reference_column is not None:
            raise ValueError("You must specify `doc_column` if you specify `reference_column`")
        try:
            doc_column, reference_column = {
                'cnn_dailymail': ('article', 'highlights'),
                'xsum': ('document', 'summary')
            }[dataset_name]
        except:
            raise NotImplementedError(
                "Please specify `doc_column`."
            )

    # Rename the columns
    if doc_column != 'document':
        dataset.add_column('document', dataset[doc_column])
        dataset.remove_column(doc_column)
    dataset.add_column('summary:reference', dataset[reference_column])
    dataset.remove_column(reference_column)

    # Save the dataset back to disk
    if save_jsonl_path:
        dataset.to_jsonl(save_jsonl_path)

    elif (save_jsonl_path is None) and not no_save:
        # Auto-create a path to save the standardized dataset
        os.makedirs('preprocessing', exist_ok=True)
        if not dataset_jsonl:
            dataset.to_jsonl(
                f'preprocessing/'
                f'standardized_{dataset_name}_{dataset_version}_{dataset_split}.jsonl'
            )
        else:
            dataset.to_jsonl(
                f'preprocessing/'
                f'standardized_{dataset_jsonl.split("/")[-1]}'
            )

    return dataset


def get_dataset(
        dataset_name: str = None,
        dataset_version: str = None,
        dataset_split: str = 'test',
        dataset_jsonl: str = None,
):
    """Load a dataset."""
    assert (dataset_name is not None) != (dataset_jsonl is not None), \
        "Specify one of `dataset_name` or `dataset_jsonl`."

    # Load the dataset
    if dataset_name is not None:
        return get_hf_dataset(dataset_name, dataset_version, dataset_split)

    return Dataset.from_jsonl(json_path=dataset_jsonl)


def get_hf_dataset(name: str, version: str = None, split: str = 'test'):
    """Get dataset from Huggingface."""
    if version:
        return Dataset.load_dataset(name, version, split=split)
    return Dataset.load_dataset(name, split=split)


if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--dataset', type=str, choices=['cnn_dailymail', 'xsum'],
                        help="Huggingface dataset name.")
    parser.add_argument('--version', type=str,
                        help="Huggingface dataset version.")
    parser.add_argument('--split', type=str, default='test',
                        help="Huggingface dataset split.")
    parser.add_argument('--dataset_jsonl', type=str,
                        help="Path to a jsonl file for the dataset.")
    parser.add_argument('--dataset_rg', type=str,
                        help="Path to a dataset stored in the Robustness Gym format. "
                             "All processed datasets are stored in this format.")
    parser.add_argument('--prediction_jsonls', nargs='+', default=[],
                        help="Path to one or more jsonl files for the predictions.")
    parser.add_argument('--save_jsonl_path', type=str,
                        help="Path to save the processed jsonl dataset.")

    parser.add_argument('--doc_column', type=str,
                        help="Name of the document column in the dataset.")
    parser.add_argument('--reference_column', type=str,
                        help="Name of the reference summary column in the dataset.")
    parser.add_argument('--summary_columns', nargs='+', default=[],
                        help="Name of other summary columns in/added to the dataset.")

    parser.add_argument('--bert_aligner_threshold', type=float, default=0.1,
                        help="Minimum threshold for BERT alignment.")
    parser.add_argument('--bert_aligner_top_k', type=int, default=10,
                        help="Top-k for BERT alignment.")
    parser.add_argument('--embedding_aligner_threshold', type=float, default=0.1,
                        help="Minimum threshold for embedding alignment.")
    parser.add_argument('--embedding_aligner_top_k', type=int, default=10,
                        help="Top-k for embedding alignment.")
    parser.add_argument('--processed_dataset_path', type=str,
                        help="Path to store the final processed dataset.")
    parser.add_argument('--n_samples', type=int,
                        help="Number of dataset samples to process.")

    parser.add_argument('--workflow', action='store_true', default=False,
                        help="Whether to run the preprocessing workflow.")
    parser.add_argument('--standardize', action='store_true', default=False,
                        help="Whether to standardize the dataset and save to jsonl.")
    parser.add_argument('--join_predictions', action='store_true', default=False,
                        help="Whether to add predictions to the dataset and save to "
                             "jsonl.")
    parser.add_argument('--try_it', action='store_true', default=False,
                        help="`Try it` mode is faster and runs processing on 10 "
                             "examples.")
    parser.add_argument('--deanonymize', action='store_true', default=False,
                        help="Deanonymize the dataset provided by summvis.")
    parser.add_argument('--anonymize', action='store_true', default=False,
                        help="Anonymize by removing document and reference summary "
                             "columns of the original dataset.")

    args = parser.parse_args()

    if args.standardize:
        # Dump a dataset to jsonl on disk after standardizing it
        standardize_dataset(
            dataset_name=args.dataset,
            dataset_version=args.version,
            dataset_split=args.split,
            dataset_jsonl=args.dataset_jsonl,
            doc_column=args.doc_column,
            reference_column=args.reference_column,
            save_jsonl_path=args.save_jsonl_path,
        )

    if args.join_predictions:
        # Join the predictions with the dataset
        dataset = join_predictions(
            dataset_jsonl=args.dataset_jsonl,
            prediction_jsonls=args.prediction_jsonls,
            save_jsonl_path=args.save_jsonl_path,
        )

    if args.workflow:
        # Run the processing workflow
        dataset = None
        # Check if `args.dataset_rg` was passed in
        if args.dataset_rg:
            # Load the dataset directly
            dataset = Dataset.load_from_disk(args.dataset_rg)

        run_workflow(
            jsonl_path=args.dataset_jsonl,
            dataset=dataset,
            doc_column=args.doc_column,
            reference_column=args.reference_column,
            summary_columns=args.summary_columns,
            bert_aligner_threshold=args.bert_aligner_threshold,
            bert_aligner_top_k=args.bert_aligner_top_k,
            embedding_aligner_threshold=args.embedding_aligner_threshold,
            embedding_aligner_top_k=args.embedding_aligner_top_k,
            processed_dataset_path=args.processed_dataset_path,
            n_samples=args.n_samples if not args.try_it else 10,
            anonymize=args.anonymize,
        )

    if args.deanonymize:
        # Deanonymize an anonymized dataset
        # Check if `args.dataset_rg` was passed in
        assert args.dataset_rg is not None, \
            "Must specify `dataset_rg` path to be deanonymized."
        assert args.dataset_rg.endswith('anonymized'), \
            "`dataset_rg` must end in 'anonymized'."
        assert (args.dataset is None) != (args.dataset_jsonl is None), \
            "`dataset_rg` points to an anonymized dataset that will be " \
            "deanonymized. Please pass in relevant arguments: either " \
            "`dataset`, `version` and `split` OR `dataset_jsonl`."

        # Load the standardized dataset
        standardized_dataset = standardize_dataset(
            dataset_name=args.dataset,
            dataset_version=args.version,
            dataset_split=args.split,
            dataset_jsonl=args.dataset_jsonl,
            doc_column=args.doc_column,
            reference_column=args.reference_column,
            no_save=True,
        )
        # Use it to deanonymize
        dataset = deanonymize_dataset(
            rg_path=args.dataset_rg,
            standardized_dataset=standardized_dataset,
            processed_dataset_path=args.processed_dataset_path,
            n_samples=args.n_samples if not args.try_it else 10,
        )