import re
import pandas
import seaborn
import gradio
import pathlib

import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy
from sklearn.metrics.pairwise import cosine_distances

from transformers import (
    AutoConfig,
    AutoTokenizer,
    AutoModelForSequenceClassification, AutoModelForMaskedLM
)

## Rollout Helper Function
def compute_joint_attention(att_mat, res=True):
    if res:
        residual_att = numpy.eye(att_mat.shape[1])[None,...]
        att_mat = att_mat + residual_att
        att_mat = att_mat / att_mat.sum(axis=-1)[...,None]

    joint_attentions = numpy.zeros(att_mat.shape)
    layers = joint_attentions.shape[0]
    joint_attentions[0] = att_mat[0]
    for i in numpy.arange(1,layers):
        joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])

    return joint_attentions

def create_plot(all_tokens, score_data):
    LAYERS = list(range(12))
    fig, axs = plt.subplots(6, 2, figsize=(8, 24))
    plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0.5, wspace=0.5)
    for layer in LAYERS:
        a = (layer)//2
        b = layer%2
        seaborn.heatmap(
                ax=axs[a, b],
                data=pandas.DataFrame(score_data[layer], index= all_tokens, columns=all_tokens),
                cmap="Blues",
                annot=False,
                cbar=False
            )
        axs[a, b].set_title(f"Layer: {layer+1}")
    return fig

matplotlib.use('agg')

DISTANCE_FUNC = {
    'cosine': cosine_distances
}
MODEL_PATH = {
    'bert': 'bert-base-uncased',
    'roberta': 'roberta-base',
}

MODEL_NAME = 'bert'
#MODEL_NAME = 'roberta'
METRIC = 'cosine'

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)


def run(mname, sent):
    global MODEL_NAME, config, model, tokenizer
    if mname != MODEL_NAME:
        MODEL_NAME = mname
        config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
        tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
        model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
    sent = re.sub(r".MASK.", tokenizer.mask_token, sent)
    inputs = tokenizer(sent, return_token_type_ids=True, return_tensors="pt")

    ## Cpmpute: layerwise value zeroing
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        token_type_ids=inputs['token_type_ids'],
                        output_hidden_states=True, output_attentions=False)

    org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
    input_shape = inputs['input_ids'].size()
    batch_size, seq_length = input_shape

    score_matrix = numpy.zeros((config.num_hidden_layers, seq_length, seq_length))
    for l, layer_module in enumerate(getattr(model, MODEL_NAME).encoder.layer):
        for t in range(seq_length):
            extended_blanking_attention_mask: torch.Tensor = getattr(model, MODEL_NAME).get_extended_attention_mask(inputs['attention_mask'], input_shape, device)
            with torch.no_grad():
                layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output
                                            attention_mask=extended_blanking_attention_mask,
                                            output_attentions=False,
                                            zero_value_index=t,
                                            )
            hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
            # compute similarity between original and new outputs
            # cosine
            x = hidden_states
            y = org_hidden_states[l+1].detach().cpu().numpy()

            distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
            score_matrix[l, :, t] = distances

    valuezeroing_scores = score_matrix / numpy.sum(score_matrix, axis=-1, keepdims=True)
    rollout_valuezeroing_scores = compute_joint_attention(valuezeroing_scores, res=False)


    # Plot:
    cmap = "Blues"
    all_tokens = [tokenizer.convert_ids_to_tokens(t) for t in inputs['input_ids']]
    rollout_fig = create_plot(all_tokens, rollout_valuezeroing_scores)
    value_fig = create_plot(all_tokens, valuezeroing_scores)

    return rollout_fig, value_fig

examples = pandas.read_csv("examples.csv").to_numpy().tolist()

with gradio.Blocks(
        title="Differences with/without zero-valuing",
        css= ".output-image > img {height: 2000px !important; max-height: none !important;} "
) as iface:
    gradio.Markdown(pathlib.Path("description.md").read_text)
    with gradio.Row(equal_height=True):
        with gradio.Column(scale=4):
            sent = gradio.Textbox(label="Input sentence")
        with gradio.Column(scale=1):
            model_choice = gradio.Dropdown(choices=['bert', 'roberta'], value="bert")
            but = gradio.Button("Submit")
    gradio.Examples(examples, [sent])
    with gradio.Row(equal_height=True):
        with gradio.Column():
            gradio.Markdown("### With Rollout")
            rollout_result = gradio.Plot()
        with gradio.Column():
            gradio.Markdown("### Without Rollout")
            value_result = gradio.Plot()
    with gradio.Accordion("Some more details"):
        gradio.Markdown(pathlib.Path("notice.md").read_text)

    but.click(run,
            inputs=[model_choice, sent],
            outputs=[rollout_result, value_result]
        )


iface.launch()