import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader
import pickle

from ..bert import BERT
from ..seq_model import BERTSM
from ..classifier_model import BERTForClassification
from ..optim_schedule import ScheduledOptim

import tqdm
import sys
import time

import numpy as np
# import visualization

from sklearn.metrics import precision_score, recall_score, f1_score

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict
import os

class ECE(nn.Module):

    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(ECE, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        labels = torch.argmax(labels,1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

def accurate_nb(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = np.argmax(labels, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat)

class BERTTrainer:
    """
    BERTTrainer pretrains BERT model on input sequence of strategies.
    BERTTrainer make the pretrained BERT model with one training method objective.
        1. Masked Strategy Modelling : 3.3.1 Task #1: Masked SM
    """

    def __init__(self, bert: BERT, vocab_size: int,
                 train_dataloader: DataLoader, val_dataloader: DataLoader = None, test_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=5000,
                 with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, same_student_prediction = False,
                workspace_name=None, code=None):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        print(cuda_condition, " Device used = ", self.device)
        
        available_gpus = list(range(torch.cuda.device_count()))

        # This BERT model will be saved every epoch
        self.bert = bert.to(self.device)
        # Initialize the BERT Language Model, with BERT model
        self.model = BERTSM(bert, vocab_size).to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model, device_ids=available_gpus)

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.val_data = val_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
        self.optim_schedule = ScheduledOptim(self.optim, self.bert.hidden, n_warmup_steps=warmup_steps)

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)
    
        self.log_freq = log_freq
        self.same_student_prediction = same_student_prediction
        self.workspace_name = workspace_name
        self.save_model = False
        self.code = code
        self.avg_loss = 10000
        self.start_time = time.time()

        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def val(self, epoch):
        self.iteration(epoch, self.val_data, phase="val")
        
    def test(self, epoch):
        self.iteration(epoch, self.test_data, phase="test")

    def iteration(self, epoch, data_loader, phase="train"):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every peoch

        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        # str_code = "train" if train else "test"
        # code = "masked_prediction" if self.same_student_prediction else "masked"
        
        self.log_file = f"{self.workspace_name}/logs/{self.code}/log_{phase}_pretrained.txt"
        # bert_hidden_representations = []
        if epoch == 0:
            f = open(self.log_file, 'w')
            f.close()
            if phase == "val":
                self.avg_loss = 10000
        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (phase, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss_mask = 0.0
        total_correct_mask = 0
        total_element_mask = 0
        
        avg_loss_pred = 0.0
        total_correct_pred = 0
        total_element_pred = 0
        
        avg_loss = 0.0
        
        if phase == "train":
            self.model.train()
        else:
            self.model.eval()
        with open(self.log_file, 'a') as f:
            sys.stdout = f
            for i, data in data_iter:
                # 0. batch_data will be sent into the device(GPU or cpu)
                data = {key: value.to(self.device) for key, value in data.items()}
                # if i == 0:
                #     print(f"data : {data[0]}")
                # 1. forward the next_sentence_prediction and masked_lm model
                # next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"])
                if self.same_student_prediction:
                    bert_hidden_rep, mask_lm_output, same_student_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)
                else:
                    bert_hidden_rep, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"], self.same_student_prediction)

                # embeddings = [h for h in bert_hidden_rep.cpu().detach().numpy()]
                # bert_hidden_representations.extend(embeddings)


                # 2-2. NLLLoss of predicting masked token word
                mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"])

                # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
                if self.same_student_prediction:
                    # 2-1. NLL(negative log likelihood) loss of is_next classification result
                    same_student_loss = self.criterion(same_student_output, data["is_same_student"])
                    loss = same_student_loss + mask_loss
                else:
                    loss = mask_loss

                # 3. backward and optimization only in train
                if phase == "train":
                    self.optim_schedule.zero_grad()
                    loss.backward()
                    self.optim_schedule.step_and_update_lr()


                # print(f"mask_lm_output : {mask_lm_output}")
                # non_zero_mask = (data["bert_label"] != 0).float()
                # print(f"bert_label : {data['bert_label']}")
                non_zero_mask = (data["bert_label"] != 0).float()
                predictions = torch.argmax(mask_lm_output, dim=-1)
                # print(f"predictions : {predictions}")
                predicted_masked = predictions*non_zero_mask
                # print(f"predicted_masked : {predicted_masked}")
                mask_correct = ((data["bert_label"] == predicted_masked)*non_zero_mask).sum().item()
                # print(f"mask_correct : {mask_correct}")
                # print(f"non_zero_mask.sum().item() : {non_zero_mask.sum().item()}")

                avg_loss_mask += loss.item()
                total_correct_mask += mask_correct
                total_element_mask += non_zero_mask.sum().item()
                # total_element_mask += data["bert_label"].sum().item()

                torch.cuda.empty_cache()
                post_fix = {
                    "epoch": epoch,
                    "iter": i,
                    "avg_loss": avg_loss_mask / (i + 1),
                    "avg_acc_mask": (total_correct_mask / total_element_mask * 100) if total_element_mask != 0 else 0,
                    "loss": loss.item()
                }

                # next sentence prediction accuracy
                if self.same_student_prediction:
                    correct = same_student_output.argmax(dim=-1).eq(data["is_same_student"]).sum().item()
                    avg_loss_pred += loss.item()
                    total_correct_pred += correct
                    total_element_pred += data["is_same_student"].nelement()
                # correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item()
                    post_fix["avg_loss"] = avg_loss_pred / (i + 1)
                    post_fix["avg_acc_pred"] = total_correct_pred / total_element_pred * 100
                    post_fix["loss"] = loss.item()

                avg_loss +=loss.item()

                if i % self.log_freq == 0:
                    data_iter.write(str(post_fix))
                # if not train and epoch > 20 :
                #     pickle.dump(mask_lm_output.cpu().detach().numpy(), open(f"logs/mask/mask_out_e{epoch}_{i}.pkl","wb"))
                #     pickle.dump(data["bert_label"].cpu().detach().numpy(), open(f"logs/mask/label_e{epoch}_{i}.pkl","wb"))
            end_time = time.time()
            final_msg = {
                "epoch": f"EP{epoch}_{phase}",
                "avg_loss": avg_loss / len(data_iter),
                "total_masked_acc": total_correct_mask * 100.0 / total_element_mask if total_element_mask != 0 else 0,
                "time_taken_from_start": end_time - self.start_time
            }

            if self.same_student_prediction:
                final_msg["total_prediction_acc"] = total_correct_pred * 100.0 / total_element_pred

            print(final_msg)
            
            f.close()
        sys.stdout = sys.__stdout__
        
        if phase == "val":
            self.save_model = False
            if self.avg_loss > (avg_loss / len(data_iter)):
                self.save_model = True
                self.avg_loss = (avg_loss / len(data_iter))

        # pickle.dump(bert_hidden_representations, open(f"embeddings/{code}/{str_code}_embeddings_{epoch}.pkl","wb"))

        

    def save(self, epoch, file_path="output/bert_trained.model"):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
#         if self.code:
#             fpath = file_path.split("/")
#             # output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[2] + ".ep%d" % epoch
#             output_path = "/",join(fpath[0]+ "/"+ fpath[1]+f"/{self.code}/" + fpath[-1] + ".ep%d" % epoch

#         else:
        output_path = file_path + ".ep%d" % epoch
            
        torch.save(self.bert.cpu(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path


class BERTFineTuneTrainer:
    
    def __init__(self, bert: BERT, vocab_size: int,
                 train_dataloader: DataLoader, test_dataloader: DataLoader = None,
                 lr: float = 1e-4, betas=(0.9, 0.999), weight_decay: float = 0.01, warmup_steps=10000,
                 with_cuda: bool = True, cuda_devices=None, log_freq: int = 10, workspace_name=None, 
                 num_labels=2, finetune_task=""):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        print(with_cuda, cuda_condition, " Device used = ", self.device)

        # This BERT model will be saved every epoch
        self.bert = bert
        for param in self.bert.parameters():
            param.requires_grad = False
        # Initialize the BERT Language Model, with BERT model
        self.model = BERTForClassification(self.bert, vocab_size, num_labels).to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            self.model = nn.DataParallel(self.model, device_ids=cuda_devices)

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader
    
        self.optim = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay) #, eps=1e-9
        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2, factor=0.1)
        
        if num_labels == 1:
            self.criterion = nn.MSELoss()
        elif num_labels == 2:
            self.criterion = nn.BCEWithLogitsLoss()
            # self.criterion = nn.CrossEntropyLoss()
        elif num_labels > 2:
            self.criterion = nn.CrossEntropyLoss()
            # self.criterion = nn.BCEWithLogitsLoss()
        
        # self.ece_criterion = ECE().to(self.device)
        
        self.log_freq = log_freq
        self.workspace_name = workspace_name
        self.finetune_task = finetune_task
        self.save_model = False
        self.avg_loss = 10000
        self.start_time = time.time()
        self.probability_list = []
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every peoch

        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        str_code = "train" if train else "test"
        
        self.log_file = f"{self.workspace_name}/logs/{self.finetune_task}/log_{str_code}_finetuned.txt"

        if epoch == 0:
            f = open(self.log_file, 'w')
            f.close()
            if not train:
                self.avg_loss = 10000
            
        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0
        plabels = []
        tlabels = []
        
        eval_accurate_nb = 0
        nb_eval_examples = 0
        logits_list = []
        labels_list = []
        
        if train:
            self.model.train()
        else:
            self.model.eval()
        self.probability_list = []
        with open(self.log_file, 'a') as f:
            sys.stdout = f
            
            for i, data in data_iter:
                # 0. batch_data will be sent into the device(GPU or cpu)
                data = {key: value.to(self.device) for key, value in data.items()}
                if train:
                    h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
                else:
                    with torch.no_grad():
                        h_rep, logits = self.model.forward(data["bert_input"], data["segment_label"])
                    # print(logits, logits.shape)
                    logits_list.append(logits.cpu())
                    labels_list.append(data["progress_status"].cpu())
                # print(">>>>>>>>>>>>", progress_output)
                # print(f"{epoch}---nelement--- {data['progress_status'].nelement()}")
                # print(data["progress_status"].shape, logits.shape)
                progress_loss = self.criterion(logits, data["progress_status"])
                loss = progress_loss
                
                if torch.cuda.device_count() > 1:
                    loss = loss.mean()

                # 3. backward and optimization only in train
                if train:
                    self.optim.zero_grad()
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.optim.step()

                # progress prediction accuracy
                # correct = progress_output.argmax(dim=-1).eq(data["progress_status"]).sum().item()
                probs = nn.LogSoftmax(dim=-1)(logits)
                self.probability_list.append(probs)
                predicted_labels = torch.argmax(probs, dim=-1)
                true_labels = torch.argmax(data["progress_status"], dim=-1)
                plabels.extend(predicted_labels.cpu().numpy())
                tlabels.extend(true_labels.cpu().numpy())

                # Compare predicted labels to true labels and calculate accuracy
                correct = (predicted_labels == true_labels).sum().item()
                avg_loss += loss.item()
                total_correct += correct
                # total_element += true_labels.nelement()
                total_element += data["progress_status"].nelement()
                # print(">>>>>>>>>>>>>>", predicted_labels, true_labels, correct, total_correct, total_element)
                
                # if train: 
                post_fix = {
                    "epoch": epoch,
                    "iter": i,
                    "avg_loss": avg_loss / (i + 1),
                    "avg_acc": total_correct / total_element * 100,
                    "loss": loss.item()
                }
#                 else:
#                     logits = logits.detach().cpu().numpy()
#                     label_ids = data["progress_status"].to('cpu').numpy()
#                     tmp_eval_nb = accurate_nb(logits, label_ids)

#                     eval_accurate_nb += tmp_eval_nb
#                     nb_eval_examples += label_ids.shape[0]

#                     # total_element += data["progress_status"].nelement()
#                     # avg_loss += loss.item()

#                     post_fix = {
#                         "epoch": epoch,
#                         "iter": i,
#                         "avg_loss": avg_loss / (i + 1),
#                         "avg_acc": tmp_eval_nb / total_element * 100,
#                         "loss": loss.item()
#                     }


                if i % self.log_freq == 0:
                    data_iter.write(str(post_fix))
            
            # precisions = precision_score(plabels, tlabels, average="weighted")
            # recalls = recall_score(plabels, tlabels, average="weighted")
            f1_scores = f1_score(plabels, tlabels, average="weighted")
            # if train:
            end_time = time.time()
            final_msg = {
                "epoch": f"EP{epoch}_{str_code}",
                "avg_loss": avg_loss / len(data_iter),
                "total_acc": total_correct * 100.0 / total_element,
                # "precisions": precisions,
                # "recalls": recalls,
                "f1_scores": f1_scores,
                "time_taken_from_start": end_time - self.start_time
            }
#             else:
#                 eval_accuracy = eval_accurate_nb/nb_eval_examples

#                 logits_ece = torch.cat(logits_list)
#                 labels_ece = torch.cat(labels_list)
#                 ece = self.ece_criterion(logits_ece, labels_ece).item()
#                 end_time = time.time()
#                 final_msg = {
#                     "epoch": f"EP{epoch}_{str_code}",
#                     "eval_accuracy": eval_accuracy,
#                     "ece": ece,
#                     "avg_loss": avg_loss / len(data_iter),
#                     "precisions": precisions,
#                     "recalls": recalls,
#                     "f1_scores": f1_scores,
#                     "time_taken_from_start": end_time - self.start_time
#                 }
#                 if self.save_model:
#                     conf_hist = visualization.ConfidenceHistogram()
#                     plt_test = conf_hist.plot(np.array(logits_ece), np.array(labels_ece), title= f"Confidence Histogram {epoch}")
#                     plt_test.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/conf_histogram_test_{epoch}.png",bbox_inches='tight')
#                     plt_test.close()

#                     rel_diagram = visualization.ReliabilityDiagram()
#                     plt_test_2 = rel_diagram.plot(np.array(logits_ece), np.array(labels_ece),title=f"Reliability Diagram {epoch}")
#                     plt_test_2.savefig(f"{self.workspace_name}/plots/confidence_histogram/{self.finetune_task}/rel_diagram_test_{epoch}.png",bbox_inches='tight')
#                     plt_test_2.close()
            print(final_msg)

            # print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter), "total_acc=", total_correct * 100.0 / total_element)
            f.close()
        sys.stdout = sys.__stdout__
        self.save_model = False
        if self.avg_loss > (avg_loss / len(data_iter)):
            self.save_model = True
            self.avg_loss = (avg_loss / len(data_iter))
            
    def iteration_1(self, epoch_idx, data):
        try:
            data = {key: value.to(self.device) for key, value in data.items()}
            logits = self.model(data['input_ids'], data['segment_label'])
            # Ensure logits is a tensor, not a tuple
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, data['labels'])

            # Backpropagation and optimization
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            if self.log_freq > 0 and epoch_idx % self.log_freq == 0:
                print(f"Epoch {epoch_idx}: Loss = {loss.item()}")

            return loss

        except Exception as e:
            print(f"Error during iteration: {e}")
            raise




        
                # plt_test.show()
        # print("EP%d_%s, " % (epoch, str_code))

    def save(self, epoch, file_path="output/bert_fine_tuned_trained.model"):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        if self.finetune_task:
            fpath = file_path.split("/")
            output_path = fpath[0]+ "/"+ fpath[1]+f"/{self.finetune_task}/" + fpath[2] + ".ep%d" % epoch
        else:
            output_path = file_path + ".ep%d" % epoch
        torch.save(self.model.cpu(), output_path)
        self.model.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path        


class BERTAttention:
    def __init__(self, bert: BERT, vocab_obj, train_dataloader: DataLoader, workspace_name=None, code=None, finetune_task=None, with_cuda=True):
        
        # available_gpus = list(range(torch.cuda.device_count()))

        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")
        print(with_cuda, cuda_condition, " Device used = ", self.device)
        self.bert = bert.to(self.device)
        
        # if with_cuda and torch.cuda.device_count() > 1:
        #     print("Using %d GPUS for BERT" % torch.cuda.device_count())
        #     self.bert = nn.DataParallel(self.bert, device_ids=available_gpus)
            
        self.train_dataloader = train_dataloader
        self.workspace_name = workspace_name
        self.code = code
        self.finetune_task = finetune_task
        self.vocab_obj = vocab_obj

    def getAttention(self):        
        # self.log_file = f"{self.workspace_name}/logs/{self.code}/log_attention.txt"

                
        labels = ['PercentChange', 'NumeratorQuantity2', 'NumeratorQuantity1', 'DenominatorQuantity1',
                  'OptionalTask_1', 'EquationAnswer', 'NumeratorFactor', 'DenominatorFactor',
                  'OptionalTask_2', 'FirstRow1:1', 'FirstRow1:2', 'FirstRow2:1', 'FirstRow2:2', 'SecondRow',
                  'ThirdRow', 'FinalAnswer','FinalAnswerDirection']
        df_all = pd.DataFrame(0.0, index=labels, columns=labels)
        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(self.train_dataloader),
                              desc="attention",
                              total=len(self.train_dataloader),
                              bar_format="{l_bar}{r_bar}")
        count = 0
        for i, data in data_iter:
            data = {key: value.to(self.device) for key, value in data.items()}
            a = self.bert.forward(data["bert_input"], data["segment_label"])
            non_zero = np.sum(data["segment_label"].cpu().detach().numpy())

            # Last Transformer Layer
            last_layer = self.bert.attention_values[-1].transpose(1,0,2,3)
            # print(last_layer.shape)
            head, d_model, s, s = last_layer.shape

            for d in range(d_model):
                seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
                # df_all = pd.DataFrame(0.0, index=seq_labels, columns=seq_labels)
                indices_to_choose = defaultdict(int)

                for k,s in enumerate(seq_labels):
                    if s in labels:
                        indices_to_choose[s] = k
                indices_chosen = list(indices_to_choose.values())
                selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
                # print(len(seq_labels), len(selected_seq_labels))
                for h in range(head):
                    # fig, ax = plt.subplots(figsize=(12, 12)) 
                    # seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])#[1:non_zero-1]
                    # seq_labels = self.vocab_obj.to_sentence(data["bert_input"].cpu().detach().numpy().tolist()[d])[1:non_zero-1]
#                     indices_to_choose = defaultdict(int)

#                     for k,s in enumerate(seq_labels):
#                         if s in labels:
#                             indices_to_choose[s] = k
#                     indices_chosen = list(indices_to_choose.values())
#                     selected_seq_labels = [s for l,s in enumerate(seq_labels) if l in indices_chosen]
                    # print(f"Chosen index: {seq_labels, indices_to_choose, indices_chosen, selected_seq_labels}")

                    df_cm = pd.DataFrame(last_layer[h][d][indices_chosen,:][:,indices_chosen], index = selected_seq_labels, columns = selected_seq_labels)
                    df_all = df_all.add(df_cm, fill_value=0)
                    count += 1
                    
                    # df_cm = pd.DataFrame(last_layer[h][d][1:non_zero-1,:][:,1:non_zero-1], index=seq_labels, columns=seq_labels)
                    # df_all = df_all.add(df_cm, fill_value=0)
                
                # df_all = df_all.reindex(index=seq_labels, columns=seq_labels)
                # sns.heatmap(df_all, annot=False)
                # plt.title("Attentions") #Probabilities
                # plt.xlabel("Steps")
                # plt.ylabel("Steps")
                # plt.grid(True)
                # plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
                # plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores_over_[{h}]_head_n_data[{d}].png", bbox_inches='tight')
                # plt.show()
                # plt.close()



        print(f"Count of total : {count, head * self.train_dataloader.dataset.len}")    
        df_all = df_all.div(count) # head * self.train_dataloader.dataset.len
        df_all = df_all.reindex(index=labels, columns=labels)
        sns.heatmap(df_all, annot=False)
        plt.title("Attentions") #Probabilities
        plt.xlabel("Steps")
        plt.ylabel("Steps")
        plt.grid(True)
        plt.tick_params(axis='x', bottom=False, top=True, labelbottom=False, labeltop=True, labelrotation=90)
        plt.savefig(f"{self.workspace_name}/plots/{self.code}/{self.finetune_task}_attention_scores.png", bbox_inches='tight')
        plt.show()
        plt.close()