#####################################
# Packages & Dependencies
#####################################
import param
import panel as pn

import torch
import numpy as np
import plotly.graph_objects as go

from . import canvas
from app_utils import styles

import sys, os
APP_PATH = os.path.dirname(os.path.dirname(__file__)) # Path to the digit-classifier-app directory
sys.path.append(APP_PATH + '/model_training')

# Imports from model_training
import data_setup, model


#####################################
# Plotly Panels
#####################################
PLOTLY_CONFIGS = {
    'displayModeBar': True, 'displaylogo': False,
    'modeBarButtonsToRemove': ['autoScale', 'lasso', 'select', 
                               'toImage', 'pan', 'zoom', 'zoomIn', 'zoomOut']
}

class PlotPanels(param.Parameterized):
    '''
    Contains all Plotly pane objects for the application. 
    This includes the probability bar chart and the MNIST preprocessed image heat map.

    Args:
        canvas_info (param.ClassSelector): A Canvas class object to get the data URI of the drawn image.
        mod_path (str): The absolute path to the saved TinyVGG model.
        mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model.
                           This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes
    '''

    canvas_info = param.ClassSelector(class_ = canvas.Canvas)    # Canvas object to get the data URI 
    
    def __init__(self, mod_path: str, mod_kwargs: dict, **params):
        super().__init__(**params)
        self.class_labels = np.arange(0, 10)
        self.cnn_mod = model.TinyVGG(**mod_kwargs)
        self.cnn_mod.load_state_dict(torch.load(mod_path, map_location = 'cpu'))
        
        self.img_pane = pn.pane.Plotly(
            name = 'image_plot',
            config = PLOTLY_CONFIGS,
            sizing_mode = 'stretch_both',
            margin = 0,
        )

        self.prob_pane = pn.pane.Plotly(
            name = 'prob_plot',
            config = PLOTLY_CONFIGS,
            sizing_mode = 'stretch_both',
            margin = 0
        )
        
        self.pred_txt = pn.pane.HTML(
            styles = {'margin':'0rem', 'color':styles.CLRS['pred_txt'], 
                      'font-size':styles.FONTSIZES['pred_txt'],
                      'font-family':styles.FONTFAMILY}
        )

        # Initialize plotly figures
        self._update_prediction()

        # Set up watchers thta update based on data URI changes
        self.canvas_info.param.watch(self._update_prediction, 'uri')

    def _update_prediction(self, *event):
        '''
        Performs all prediction-related updates for the application.
        This function is connected to the URI parameter of canvas_info through a watcher.
        Any times the URI changes, a class prediction is immediately. 
        Following this, the probability bar chart and model input heatmap are updated as well.
        '''
        try:
            self._update_preprocessed_tensor()
            self._update_pred_txt()
            self._update_img_plot()
            self._update_prob_plot()
        except Exception as e:
            print(f'[Errored] {e}')
            return

    def _update_preprocessed_tensor(self):
        '''
        Transforms the data URI (string) from canvas_info into a preprocessed tensor.
        This is done by having it undergo the MNISt preprocessing pipeline (see mnist_preprocess in data_setup for details).
        Additionally, a prediction is made for the preprocessed tensor to get its class label. 
        The correpsonding set of prediction probabilities are stored.
        '''
        # Check if uri is non-empty
        if self.canvas_info.uri:
            self.input_img = data_setup.mnist_preprocess(self.canvas_info.uri)

            self.cnn_mod.eval() # Set CNN to eval & inference mode
            with torch.inference_mode():
                pred_logits = self.cnn_mod(self.input_img.unsqueeze(0))
                self.pred_probs = torch.softmax(pred_logits, dim = 1)[0].numpy()
                self.pred_label = np.argmax(self.pred_probs)
        else:
            self.input_img = torch.zeros((28, 28))
            self.pred_probs = np.zeros(10)
            self.pred_label = None

    def _update_pred_txt(self):
        '''
        Updates the prediction and probability HTML text to reflect the current data URI.
        '''
        if self.canvas_info.uri:
            pred, prob = self.pred_label, f'{self.pred_probs[self.pred_label]:.3f}'
        else:
            pred, prob = 'N/A', 'N/A'

        self.pred_txt.object = f'''
            <div style="text-align: left;">
                <b>Prediction:</b> {pred}
                </br>
                <b>Probability:</b> {prob}
            </div>
        '''

    def _update_prob_plot(self):
        '''
        Updates the probability bar chart to showcase the softmax output probability distribution
        obtained from the prediction in _update_preprocessed_tensor.
        '''
        # Marker fill and outline color for bar plot
        mkr_clrs = [styles.CLRS['base_bar']] * len(self.class_labels)
        mkr_line_clrs = [styles.CLRS['base_bar_line']] * len(self.class_labels)
        if self.pred_label is not None:
            mkr_clrs[self.pred_label] = styles.CLRS['pred_bar']
            mkr_line_clrs[self.pred_label] = styles.CLRS['pred_bar_line']
            
        fig = go.Figure()
        # Bar plot
        fig.add_trace(
            go.Bar(x = self.class_labels, y = self.pred_probs, 
                   marker_color = mkr_clrs, marker_line_color = mkr_line_clrs,
                   marker_line_width = 1.5, showlegend = False,
                   text = self.pred_probs, textposition = 'outside',
                   textfont = dict(color = styles.CLRS['plot_txt'],
                                   size = styles.FONTSIZES['plot_bar_txt'], family = styles.FONTFAMILY), 
                   texttemplate = '%{text:.3f}', 
                   customdata = self.pred_probs * 100,
                   hoverlabel_font = dict(family = styles.FONTFAMILY),
                   hovertemplate = '<b>Class Label:</b> %{x}' +
                                   '<br><b>Probability:</b> %{customdata:.2f} %' +
                                   '<extra></extra>'
            )
        )
        # Used to fix axis limits
        fig.add_trace(
            go.Scatter(
                x = [0.5, 0.5], y = [0.1, 1.01],
                marker = dict(color = 'rgba(0, 0, 0, 0)', size = 10),
                mode = 'markers', 
                hoverinfo = 'skip', 
                showlegend = False
            )
        )
        fig.update_yaxes(
            title = dict(text = 'Prediction Probability', standoff = 0,
                         font = dict(color = styles.CLRS['plot_txt'],
                                     size = styles.FONTSIZES['plot_labels'], 
                                     family = styles.FONTFAMILY)),
            tickfont = dict(size = styles.FONTSIZES['plot_ticks'], 
                            family = styles.FONTFAMILY),
            ticks = 'outside', ticklen = 0,
            tickvals = np.arange(0, 1.1, 0.1), 
            gridcolor = styles.CLRS['prob_plot_grid']
        )
        fig.update_xaxes(
            title = dict(text = 'Class Label', standoff = 6,
                         font = dict(color = styles.CLRS['plot_txt'],
                                     size = styles.FONTSIZES['plot_labels'], 
                                     family = styles.FONTFAMILY)),
            dtick = 1, tickfont = dict(size = styles.FONTSIZES['plot_ticks'], 
                                       family = styles.FONTFAMILY),
        )
        fig.update_layout(
            paper_bgcolor = styles.CLRS['prob_plot_bg'],
            plot_bgcolor = styles.CLRS['prob_plot_bg'],
            margin = dict(l = 60, r = 0, t = 5, b = 45),
        )

        self.prob_pane.object = fig
        
    def _update_img_plot(self):
        '''
        Updates the heat map to showcase the current model input, i.e. the preprocessed canvas drawing.
        '''
        img_np = self.input_img.squeeze().numpy()

        if self.pred_label is not None:
            zmin, zmax = np.min(img_np), np.max(img_np)
        else:
            zmin, zmax = 0, 1

        fig = go.Figure(
            data = go.Heatmap(
                z = img_np,
                colorscale = 'gray',
                showscale = False,
                zmin = zmin,
                zmax = zmax,
                hoverlabel_font = dict(family = styles.FONTFAMILY),
                hovertemplate = '<b>Pixel Position:</b> (%{x}, %{y})' +
                                '<br><b>Pixel Value:</b> %{z:.3f}' + 
                                '<extra></extra>'
            )
        )

        fig.update_yaxes(autorange = 'reversed') 
        fig.update_layout(
            plot_bgcolor = styles.CLRS['img_plot_bg'],
            margin = dict(l = 0, r = 0, t = 0, b = 0),
            xaxis = dict(showticklabels = False),
            yaxis = dict(showticklabels = False),
        )

        self.img_pane.object = fig