import plotly.express as px
import numpy as np
import plotly.graph_objects as go
import streamlit as st
import pandas as pd
import statsmodels.api as sm

# from sklearn.metrics import mean_absolute_percentage_error
import sys
import os
from utilities import set_header, load_local_css
import seaborn as sns
import matplotlib.pyplot as plt
import tempfile
from sklearn.preprocessing import MinMaxScaler

# from st_aggrid import AgGrid
# from st_aggrid import GridOptionsBuilder, GridUpdateMode
# from st_aggrid import GridOptionsBuilder
import sys
import re
import pickle
from sklearn.metrics import r2_score
from data_prep import plot_actual_vs_predicted
import sqlite3
from utilities import (
    set_header,
    load_local_css,
    update_db,
    project_selection,
    retrieve_pkl_object,
)
from post_gres_cred import db_cred
from log_application import log_message
import sys, traceback

schema = db_cred["schema"]

sys.setrecursionlimit(10**6)

original_stdout = sys.stdout
sys.stdout = open("temp_stdout.txt", "w")
sys.stdout.close()
sys.stdout = original_stdout

st.set_page_config(layout="wide")
load_local_css("styles.css")
set_header()


## DEFINE ALL FUCNTIONS
def plot_residual_predicted(actual, predicted, df_):
    df_["Residuals"] = actual - pd.Series(predicted)
    df_["StdResidual"] = (df_["Residuals"] - df_["Residuals"].mean()) / df_[
        "Residuals"
    ].std()

    # Create a Plotly scatter plot
    fig = px.scatter(
        df_,
        x=predicted,
        y="StdResidual",
        opacity=0.5,
        color_discrete_sequence=["#11B6BD"],
    )

    # Add horizontal lines
    fig.add_hline(y=0, line_dash="dash", line_color="darkorange")
    fig.add_hline(y=2, line_color="red")
    fig.add_hline(y=-2, line_color="red")

    fig.update_xaxes(title="Predicted")
    fig.update_yaxes(title="Standardized Residuals (Actual - Predicted)")

    # Set the same width and height for both figures
    fig.update_layout(
        title="Residuals over Predicted Values",
        autosize=False,
        width=600,
        height=400,
    )

    return fig


def residual_distribution(actual, predicted):
    Residuals = actual - pd.Series(predicted)

    # Create a Seaborn distribution plot
    sns.set(style="whitegrid")
    plt.figure(figsize=(6, 4))
    sns.histplot(Residuals, kde=True, color="#11B6BD")

    plt.title(" Distribution of Residuals")
    plt.xlabel("Residuals")
    plt.ylabel("Probability Density")

    return plt


def qqplot(actual, predicted):
    Residuals = actual - pd.Series(predicted)
    Residuals = pd.Series(Residuals)
    Resud_std = (Residuals - Residuals.mean()) / Residuals.std()

    # Create a QQ plot using Plotly with custom colors
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=sm.ProbPlot(Resud_std).theoretical_quantiles,
            y=sm.ProbPlot(Resud_std).sample_quantiles,
            mode="markers",
            marker=dict(size=5, color="#11B6BD"),
            name="QQ Plot",
        )
    )

    # Add the 45-degree reference line
    diagonal_line = go.Scatter(
        x=[
            -2,
            2,
        ],  # Adjust the x values as needed to fit the range of your data
        y=[-2, 2],  # Adjust the y values accordingly
        mode="lines",
        line=dict(color="red"),  # Customize the line color and style
        name=" ",
    )
    fig.add_trace(diagonal_line)

    # Customize the layout
    fig.update_layout(
        title="QQ Plot of Residuals",
        title_x=0.5,
        autosize=False,
        width=600,
        height=400,
        xaxis_title="Theoretical Quantiles",
        yaxis_title="Sample Quantiles",
    )

    return fig


def get_random_effects(media_data, panel_col, mdf):
    random_eff_df = pd.DataFrame(columns=[panel_col, "random_effect"])
    for i, market in enumerate(media_data[panel_col].unique()):
        print(i, end="\r")
        intercept = mdf.random_effects[market].values[0]
        random_eff_df.loc[i, "random_effect"] = intercept
        random_eff_df.loc[i, panel_col] = market

    return random_eff_df


def mdf_predict(X_df, mdf, random_eff_df):
    X = X_df.copy()
    X = pd.merge(
        X,
        random_eff_df[[panel_col, "random_effect"]],
        on=panel_col,
        how="left",
    )
    X["pred_fixed_effect"] = mdf.predict(X)

    X["pred"] = X["pred_fixed_effect"] + X["random_effect"]
    X.drop(columns=["pred_fixed_effect", "random_effect"], inplace=True)
    return X


def metrics_df_panel(model_dict, is_panel):
    def wmape(actual, forecast):
        # Weighted MAPE (WMAPE) eliminates the following shortcomings of MAPE & SMAPE
        ## 1. MAPE becomes insanely high when actual is close to 0
        ## 2. MAPE is more favourable to underforecast than overforecast
        return np.sum(np.abs(actual - forecast)) / np.sum(np.abs(actual))

    metrics_df = pd.DataFrame(
        columns=[
            "Model",
            "R2",
            "ADJR2",
            "Train Mape",
            "Test Mape",
            "Summary",
            "Model_object",
        ]
    )
    i = 0
    for key in model_dict.keys():
        target = key.split("__")[1]
        metrics_df.at[i, "Model"] = target
        y = model_dict[key]["X_train_tuned"][target]

        feature_set = model_dict[key]["feature_set"]

        if is_panel:
            random_df = get_random_effects(
                media_data, panel_col, model_dict[key]["Model_object"]
            )
            pred = mdf_predict(
                model_dict[key]["X_train_tuned"],
                model_dict[key]["Model_object"],
                random_df,
            )["pred"]
        else:
            pred = model_dict[key]["Model_object"].predict(
                model_dict[key]["X_train_tuned"][feature_set]
            )

        ytest = model_dict[key]["X_test_tuned"][target]
        if is_panel:

            predtest = mdf_predict(
                model_dict[key]["X_test_tuned"],
                model_dict[key]["Model_object"],
                random_df,
            )["pred"]

        else:
            predtest = model_dict[key]["Model_object"].predict(
                model_dict[key]["X_test_tuned"][feature_set]
            )

        metrics_df.at[i, "R2"] = r2_score(y, pred)
        metrics_df.at[i, "ADJR2"] = 1 - (1 - metrics_df.loc[i, "R2"]) * (len(y) - 1) / (
            len(y) - len(model_dict[key]["feature_set"]) - 1
        )
        # metrics_df.at[i, "Train Mape"] = mean_absolute_percentage_error(y, pred)
        # metrics_df.at[i, "Test Mape"] = mean_absolute_percentage_error(
        #     ytest, predtest
        # )
        metrics_df.at[i, "Train Mape"] = wmape(y, pred)
        metrics_df.at[i, "Test Mape"] = wmape(ytest, predtest)
        metrics_df.at[i, "Summary"] = model_dict[key]["Model_object"].summary()
        metrics_df.at[i, "Model_object"] = model_dict[key]["Model_object"]
        i += 1
    metrics_df = np.round(metrics_df, 2)

    metrics_df.rename(
        columns={"R2": "R-squared", "ADJR2": "Adj. R-squared"}, inplace=True
    )
    return metrics_df


def map_channel(transformed_var, channel_dict):
    for key, value_list in channel_dict.items():
        if any(raw_var in transformed_var for raw_var in value_list):
            return key
    return transformed_var  # Return the original value if no match is found


def contributions_nonpanel(model_dict):
    # with open(os.path.join(st.session_state["project_path"], "channel_groups.pkl"), "rb") as f:
    #     channels = pickle.load(f)

    channels = st.session_state["project_dct"]["data_import"]["group_dict"]  # db
    media_data = st.session_state["media_data"]
    contribution_df = pd.DataFrame(columns=["Channel"])

    for key in model_dict.keys():

        best_feature_set = model_dict[key]["feature_set"]
        model = model_dict[key]["Model_object"]
        target = key.split("__")[1]
        X_train = model_dict[key]["X_train_tuned"]
        contri_df = pd.DataFrame()
        y = []
        y_pred = []

        coef_df = pd.DataFrame(model.params)
        coef_df.reset_index(inplace=True)
        coef_df.columns = ["feature", "coef"]
        x_train_contribution = X_train.copy()
        x_train_contribution["pred"] = model.predict(X_train[best_feature_set])

        for i in range(len(coef_df)):

            coef = coef_df.loc[i, "coef"]
            col = coef_df.loc[i, "feature"]
            if col != "const":
                x_train_contribution[str(col) + "_contr"] = (
                    coef * x_train_contribution[col]
                )
            else:
                x_train_contribution["const"] = coef

        tuning_cols = [
            c
            for c in x_train_contribution.filter(regex="contr").columns
            if c
            in [
                "day_of_week_contr",
                "Trend_contr",
                "sine_wave_contr",
                "cosine_wave_contr",
            ]
        ]
        flag_cols = [
            c
            for c in x_train_contribution.filter(regex="contr").columns
            if "_flag" in c
        ]

        # add exogenous contribution to base
        all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
        all_exog_vars = [
            var.lower()
            .replace(".", "_")
            .replace("@", "_")
            .replace(" ", "_")
            .replace("-", "")
            .replace(":", "")
            .replace("__", "_")
            for var in all_exog_vars
        ]
        exog_cols = []
        if len(all_exog_vars) > 0:
            for col in x_train_contribution.filter(regex="contr").columns:
                if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
                    exog_cols.append(col)

        base_cols = ["const"] + flag_cols + tuning_cols + exog_cols

        x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
        x_train_contribution.drop(columns=base_cols, inplace=True)

        contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))

        contri_df.reset_index(inplace=True)
        contri_df.columns = ["Channel", target]
        contri_df["Channel"] = contri_df["Channel"].apply(
            lambda x: map_channel(x, channels)
        )
        contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
        contri_df["Channel"].replace("base_contr", "base", inplace=True)
        contribution_df = pd.merge(
            contribution_df, contri_df, on="Channel", how="outer"
        )

    return contribution_df


def contributions_panel(model_dict):
    channels = st.session_state["project_dct"]["data_import"]["group_dict"]  # db
    media_data = st.session_state["media_data"]
    contribution_df = pd.DataFrame(columns=["Channel"])
    for key in model_dict.keys():
        best_feature_set = model_dict[key]["feature_set"]
        model = model_dict[key]["Model_object"]
        target = key.split("__")[1]
        X_train = model_dict[key]["X_train_tuned"]
        contri_df = pd.DataFrame()

        y = []
        y_pred = []

        random_eff_df = get_random_effects(media_data, panel_col, model)
        random_eff_df["fixed_effect"] = model.fe_params["Intercept"]
        random_eff_df["panel_effect"] = (
            random_eff_df["random_effect"] + random_eff_df["fixed_effect"]
        )

        coef_df = pd.DataFrame(model.fe_params)
        coef_df.reset_index(inplace=True)
        coef_df.columns = ["feature", "coef"]

        x_train_contribution = X_train.copy()
        x_train_contribution = mdf_predict(x_train_contribution, model, random_eff_df)

        x_train_contribution = pd.merge(
            x_train_contribution,
            random_eff_df[[panel_col, "panel_effect"]],
            on=panel_col,
            how="left",
        )
        for i in range(len(coef_df)):
            coef = coef_df.loc[i, "coef"]
            col = coef_df.loc[i, "feature"]
            if col.lower() != "intercept":
                x_train_contribution[str(col) + "_contr"] = (
                    coef * x_train_contribution[col]
                )

        # x_train_contribution['sum_contributions'] = x_train_contribution.filter(regex="contr").sum(axis=1)
        # x_train_contribution['sum_contributions'] = x_train_contribution['sum_contributions'] + x_train_contribution[
        #     'panel_effect']

        # base_cols = ["panel_effect"] + [
        #     c
        #     for c in x_train_contribution.filter(regex="contr").columns
        #     if c
        #     in [
        #         "day_of_week_contr",
        #         "Trend_contr",
        #         "sine_wave_contr",
        #         "cosine_wave_contr",
        #     ]
        # ]
        tuning_cols = [
            c
            for c in x_train_contribution.filter(regex="contr").columns
            if c
            in [
                "day_of_week_contr",
                "Trend_contr",
                "sine_wave_contr",
                "cosine_wave_contr",
            ]
        ]
        flag_cols = [
            c
            for c in x_train_contribution.filter(regex="contr").columns
            if "_flag" in c
        ]

        # add exogenous contribution to base
        all_exog_vars = st.session_state["bin_dict"]["Exogenous"]
        all_exog_vars = [
            var.lower()
            .replace(".", "_")
            .replace("@", "_")
            .replace(" ", "_")
            .replace("-", "")
            .replace(":", "")
            .replace("__", "_")
            for var in all_exog_vars
        ]
        exog_cols = []
        if len(all_exog_vars) > 0:
            for col in x_train_contribution.filter(regex="contr").columns:
                if len([exog_var for exog_var in all_exog_vars if exog_var in col]) > 0:
                    exog_cols.append(col)

        base_cols = ["panel_effect"] + flag_cols + tuning_cols + exog_cols

        x_train_contribution["base_contr"] = x_train_contribution[base_cols].sum(axis=1)
        x_train_contribution.drop(columns=base_cols, inplace=True)

        contri_df = pd.DataFrame(x_train_contribution.filter(regex="contr").sum(axis=0))
        contri_df.reset_index(inplace=True)
        contri_df.columns = ["Channel", target]

        contri_df[target] = 100 * contri_df[target] / contri_df[target].sum()
        contri_df["Channel"] = contri_df["Channel"].apply(
            lambda x: map_channel(x, channels)
        )

        contri_df["Channel"].replace("base_contr", "base", inplace=True)
        contribution_df = pd.merge(
            contribution_df, contri_df, on="Channel", how="outer"
        )
    # st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)
    return contribution_df


def create_grouped_bar_plot(contribution_df, contribution_selections):
    # Extract the 'Channel' names
    channel_names = contribution_df["Channel"].tolist()

    # Dictionary to store all contributions except 'const' and 'base'
    all_contributions = {
        name: [] for name in channel_names if name not in ["const", "base"]
    }

    # Dictionary to store base sales for each selection
    base_sales_dict = {}

    # Accumulate contributions for each channel from each selection
    for selection in contribution_selections:
        contributions = contribution_df[selection].values.astype(float)
        base_sales = 0  # Initialize base sales for the current selection

        for channel_name, contribution in zip(channel_names, contributions):
            if channel_name in all_contributions:
                all_contributions[channel_name].append(contribution)
            elif channel_name == "base":
                base_sales = (
                    contribution  # Capture base sales for the current selection
                )

        # Store base sales for each selection
        base_sales_dict[selection] = base_sales

    # Calculate the average of contributions and sort by this average
    sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
    sorted_channel_names = [name for name, _ in sorted_channels]
    sorted_channel_names = [
        "Base Sales"
    ] + sorted_channel_names  # Adding 'Base Sales' at the start

    trace_data = []
    max_value = 0  # Initialize max_value to find the highest bar for y-axis adjustment

    # Create traces for the grouped bar chart
    for i, selection in enumerate(contribution_selections):
        display_name = sorted_channel_names
        display_contribution = [base_sales_dict[selection]] + [
            all_contributions[name][i] for name in sorted_channel_names[1:]
        ]  # Start with base sales for the current selection

        # Generating text labels for each bar
        text_values = [
            f"{val}%" for val in np.round(display_contribution, 0).astype(int)
        ]

        # Find the max value for y-axis calculation
        max_contribution = max(display_contribution)
        if max_contribution > max_value:
            max_value = max_contribution

        # Create a bar trace for each selection
        trace = go.Bar(
            x=display_name,
            y=display_contribution,
            name=selection,
            text=text_values,
            textposition="outside",
        )
        trace_data.append(trace)

    # Define layout for the bar chart
    layout = go.Layout(
        title="Metrics Contribution by Channel (Train)",
        xaxis=dict(title="Channel Name"),
        yaxis=dict(
            title="Metrics Contribution", range=[0, max_value * 1.2]
        ),  # Set y-axis 20% higher than the max bar
        barmode="group",
        plot_bgcolor="white",
    )

    # Create the figure with trace data and layout
    fig = go.Figure(data=trace_data, layout=layout)

    return fig


def preprocess_and_plot(contribution_df, contribution_selections):
    # Extract the 'Channel' names
    channel_names = contribution_df["Channel"].tolist()

    # Dictionary to store all contributions except 'const' and 'base'
    all_contributions = {
        name: [] for name in channel_names if name not in ["const", "base"]
    }

    # Dictionary to store base sales for each selection
    base_sales_dict = {}

    # Accumulate contributions for each channel from each selection
    for selection in contribution_selections:
        contributions = contribution_df[selection].values.astype(float)
        base_sales = 0  # Initialize base sales for the current selection

        for channel_name, contribution in zip(channel_names, contributions):
            if channel_name in all_contributions:
                all_contributions[channel_name].append(contribution)
            elif channel_name == "base":
                base_sales = (
                    contribution  # Capture base sales for the current selection
                )

        # Store base sales for each selection
        base_sales_dict[selection] = base_sales

    # Calculate the average of contributions and sort by this average
    sorted_channels = sorted(all_contributions.items(), key=lambda x: -np.mean(x[1]))
    sorted_channel_names = [name for name, _ in sorted_channels]
    sorted_channel_names = [
        "Base Sales"
    ] + sorted_channel_names  # Adding 'Base Sales' at the start

    # Initialize a Plotly figure
    fig = go.Figure()

    for i, selection in enumerate(contribution_selections):
        display_name = ["Base Sales"] + sorted_channel_names[
            1:
        ]  # Channel names for the plot
        display_contribution = [
            base_sales_dict[selection]
        ]  # Start with base sales for the current selection

        # Append average contributions for other channels
        for name in sorted_channel_names[1:]:
            display_contribution.append(all_contributions[name][i])

        # Generating text labels for each bar
        text_values = [
            f"{val}%" for val in np.round(display_contribution, 0).astype(int)
        ]

        # Add a waterfall trace for each selection
        fig.add_trace(
            go.Waterfall(
                orientation="v",
                measure=["relative"] * len(display_contribution),
                x=display_name,
                text=text_values,
                textposition="outside",
                y=display_contribution,
                increasing={"marker": {"color": "green"}},
                decreasing={"marker": {"color": "red"}},
                totals={"marker": {"color": "blue"}},
                name=selection,
            )
        )

    # Update layout of the figure
    fig.update_layout(
        title="Metrics Contribution by Channel (Train)",
        xaxis={"title": "Channel Name"},
        yaxis=dict(title="Metrics Contribution", range=[0, 100 * 1.2]),
    )

    return fig


def selection_change():
    edited_rows: dict = st.session_state.project_selection["edited_rows"]
    st.session_state["selected_row_index_gd_table"] = next(iter(edited_rows))
    st.session_state["gd_table"] = st.session_state["gd_table"].assign(selected=False)

    update_dict = {idx: values for idx, values in edited_rows.items()}

    st.session_state["gd_table"].update(
        pd.DataFrame.from_dict(update_dict, orient="index")
    )


if "username" not in st.session_state:
    st.session_state["username"] = None

if "project_name" not in st.session_state:
    st.session_state["project_name"] = None

if "project_dct" not in st.session_state:
    project_selection()
    st.stop()

try:
    st.session_state["bin_dict"] = st.session_state["project_dct"]["data_import"][
        "category_dict"
    ]  # db

except Exception as e:
    st.warning("Save atleast one tuned model to proceed")
    log_message("warning", "No tuned models available", "AI Model Results")
    st.stop()


if "gd_table" not in st.session_state:
    st.session_state["gd_table"] = pd.DataFrame()

try:
    if "username" in st.session_state and st.session_state["username"] is not None:

        if (
            retrieve_pkl_object(
                st.session_state["project_number"],
                "Model_Tuning",
                "tuned_model",
                schema,
            )
            is None
        ):

            st.error("Please save a tuned model")
            st.stop()

        if (
            "session_state_saved"
            in st.session_state["project_dct"]["model_tuning"].keys()
            and st.session_state["project_dct"]["model_tuning"]["session_state_saved"]
            != []
        ):
            for key in ["used_response_metrics", "media_data", "bin_dict"]:
                if key not in st.session_state:
                    st.session_state[key] = st.session_state["project_dct"][
                        "model_tuning"
                    ]["session_state_saved"][key]
                # st.session_state["bin_dict"] = st.session_state["project_dct"][
                #     "model_build"
                # ]["session_state_saved"]["bin_dict"]

        media_data = st.session_state["media_data"]

        # st.write(media_data.columns)

        # set the panel column
        panel_col = "panel"
        is_panel = (
            True if st.session_state["media_data"][panel_col].nunique() > 1 else False
        )
        # st.write(is_panel)

        date_col = "date"

        transformed_data = st.session_state["project_dct"]["transformations"][
            "final_df"
        ]  # db
        tuned_model_dict = retrieve_pkl_object(
            st.session_state["project_number"], "Model_Tuning", "tuned_model", schema
        )  # db

        feature_set_dct = {
            key.split("__")[1]: key_dict["feature_set"]
            for key, key_dict in tuned_model_dict.items()
        }

        # """ the above part should be modified so that we are fetching features set from the saved model"""

        if "contribution_df" not in st.session_state:
            st.session_state["contribution_df"] = None

        metrics_table = metrics_df_panel(tuned_model_dict, is_panel)

        cols1 = st.columns([2, 1])
        with cols1[0]:
            st.markdown(f"**Welcome {st.session_state['username']}**")
        with cols1[1]:
            st.markdown(f"**Current Project: {st.session_state['project_name']}**")

        st.title("AI Model Validation")

        st.header("Contribution Overview")

        # Get list of response metrics
        st.session_state["used_response_metrics"] = list(
            set([model.split("__")[1] for model in tuned_model_dict.keys()])
        )
        options = st.session_state["used_response_metrics"]

        if len(options) == 0:
            st.error("Please save and tune a model")
            st.stop()
        options = [
            opt.lower()
            .replace(" ", "_")
            .replace("-", "")
            .replace(":", "")
            .replace("__", "_")
            for opt in options
        ]

        default_options = (
            st.session_state["project_dct"]["saved_model_results"].get(
                "selected_options"
            )
            if st.session_state["project_dct"]["saved_model_results"].get(
                "selected_options"
            )
            is not None
            else [options[-1]]
        )
        for i in default_options:
            if i not in options:
                # st.write(i)
                default_options.remove(i)

        def remove_response_metric(name):
            # Convert the name to a lowercase string and remove any leading or trailing spaces
            name_str = str(name).lower().strip()

            # Check if the name starts with "response metric" or "response_metric"
            if name_str.startswith("response metric"):
                return name[len("response metric") :].replace("_", " ").strip().title()
            elif name_str.startswith("response_metric"):
                return name[len("response_metric") :].replace("_", " ").strip().title()
            else:
                return name

        contribution_selections = st.multiselect(
            "Select the Response Metrics to compare contributions",
            options,
            default=default_options,
            format_func=remove_response_metric,
        )
        trace_data = []

        if is_panel:
            st.session_state["contribution_df"] = contributions_panel(tuned_model_dict)

        else:
            st.session_state["contribution_df"] = contributions_nonpanel(
                tuned_model_dict
            )

        # st.write(st.session_state["contribution_df"].columns)
        # for selection in contribution_selections:

        #     trace = go.Bar(
        #         x=st.session_state["contribution_df"]["Channel"],
        #         y=st.session_state["contribution_df"][selection],
        #         name=selection,
        #         text=np.round(st.session_state["contribution_df"][selection], 0)
        #         .astype(int)
        #         .astype(str)
        #         + "%",
        #         textposition="outside",
        #     )
        #     trace_data.append(trace)

        # layout = go.Layout(
        #     title="Metrics Contribution by Channel",
        #     xaxis=dict(title="Channel Name"),
        #     yaxis=dict(title="Metrics Contribution"),
        #     barmode="group",
        # )
        # fig = go.Figure(data=trace_data, layout=layout)
        # st.plotly_chart(fig, use_container_width=True)

        # Display the chart in Streamlit
        st.plotly_chart(
            create_grouped_bar_plot(
                st.session_state["contribution_df"], contribution_selections
            ),
            use_container_width=True,
        )

        ############################################ Waterfall Chart ############################################

        import plotly.graph_objects as go

        st.plotly_chart(
            preprocess_and_plot(
                st.session_state["contribution_df"], contribution_selections
            ),
            use_container_width=True,
        )

        ############################################ Waterfall Chart ############################################
        st.header("Analysis of Models Result")
        gd_table = metrics_table.iloc[:, :-2]
        target_column = gd_table.at[0, "Model"]  # sprint8
        st.session_state["gd_table"] = gd_table

        with st.container():
            table = st.data_editor(
                st.session_state["gd_table"],
                hide_index=True,
                # on_change=selection_change,
                key="project_selection",
                use_container_width=True,
            )

        target_column = st.selectbox(
            "Select a Model to analyse its results",
            options=st.session_state.used_response_metrics,
            placeholder=options[0],
        )
        feature_set = feature_set_dct[target_column]

        model = metrics_table[metrics_table["Model"] == target_column][
            "Model_object"
        ].iloc[0]
        target = metrics_table[metrics_table["Model"] == target_column]["Model"].iloc[0]
        st.header("Model Summary")
        st.write(model.summary())

        sel_dict = tuned_model_dict[
            [k for k in tuned_model_dict.keys() if k.split("__")[1] == target][0]
        ]

        feature_set = sel_dict["feature_set"]
        X_train = sel_dict["X_train_tuned"]
        y_train = X_train[target]

        if is_panel:
            random_effects = get_random_effects(media_data, panel_col, model)
            pred = mdf_predict(X_train, model, random_effects)["pred"]
        else:
            pred = model.predict(X_train[feature_set])

        X_test = sel_dict["X_test_tuned"]
        y_test = X_test[target]
        if is_panel:
            predtest = mdf_predict(X_test, model, random_effects)["pred"]
        else:
            predtest = model.predict(X_test[feature_set])

        metrics_table_train, _, fig_train = plot_actual_vs_predicted(
            X_train[date_col],
            y_train,
            pred,
            model,
            target_column=target,
            flag=None,
            repeat_all_years=False,
            is_panel=is_panel,
        )

        metrics_table_test, _, fig_test = plot_actual_vs_predicted(
            X_test[date_col],
            y_test,
            predtest,
            model,
            target_column=target,
            flag=None,
            repeat_all_years=False,
            is_panel=is_panel,
        )

        metrics_table_train = metrics_table_train.set_index("Metric").transpose()
        metrics_table_train.index = ["Train"]
        metrics_table_test = metrics_table_test.set_index("Metric").transpose()
        metrics_table_test.index = ["Test"]
        metrics_table = np.round(
            pd.concat([metrics_table_train, metrics_table_test]), 2
        )

        st.markdown("Result Overview")
        st.dataframe(np.round(metrics_table, 2), use_container_width=True)

        st.header("Model Accuracy")
        st.subheader("Actual vs Predicted Plot (Train)")

        st.plotly_chart(fig_train, use_container_width=True)
        st.subheader("Actual vs Predicted Plot (Test)")
        st.plotly_chart(fig_test, use_container_width=True)

        st.markdown("## Residual Analysis (Train)")
        columns = st.columns(2)

        Xtrain1 = X_train.copy()
        with columns[0]:
            fig = plot_residual_predicted(y_train, pred, Xtrain1)
            st.plotly_chart(fig)

        with columns[1]:
            st.empty()
            fig = qqplot(y_train, pred)
            st.plotly_chart(fig)

        with columns[0]:
            fig = residual_distribution(y_train, pred)
            st.pyplot(fig)

        if st.button("Save this session", use_container_width=True):
            project_dct_pkl = pickle.dumps(st.session_state["project_dct"])

            update_db(
                st.session_state["project_number"],
                "AI_Model_Results",
                "project_dct",
                project_dct_pkl,
                schema,
                # resp_mtrc=None,
            )  # db

            log_message("info", "Session saved!", "AI Model Results")
            st.success("Session Saved!")
except:
    exc_type, exc_value, exc_traceback = sys.exc_info()
    error_message = "".join(
        traceback.format_exception(exc_type, exc_value, exc_traceback)
    )
    log_message("error", f"Error: {error_message}", "AI Model Results")
    st.warning("An error occured, please try again", icon="⚠️")