# Importing necessary libraries
import streamlit as st

st.set_page_config(
    page_title="AI Model Transformations",
    page_icon="⚖️",
    layout="wide",
    initial_sidebar_state="collapsed",
)

import sys
import pickle
import traceback
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from post_gres_cred import db_cred
from log_application import log_message
from utilities import (
    set_header,
    load_local_css,
    update_db,
    project_selection,
    delete_entries,
    retrieve_pkl_object,
)
from constants import (
    predefined_defaults,
    lead_min_value,
    lead_max_value,
    lead_step,
    lag_min_value,
    lag_max_value,
    lag_step,
    moving_average_min_value,
    moving_average_max_value,
    moving_average_step,
    saturation_min_value,
    saturation_max_value,
    saturation_step,
    power_min_value,
    power_max_value,
    power_step,
    adstock_min_value,
    adstock_max_value,
    adstock_step,
    display_max_col,
)


schema = db_cred["schema"]
load_local_css("styles.css")
set_header()


# Initialize project name session state
if "project_name" not in st.session_state:
    st.session_state["project_name"] = None

# Fetch project dictionary
if "project_dct" not in st.session_state:
    project_selection()
    st.stop()

# Display Username and Project Name
if "username" in st.session_state and st.session_state["username"] is not None:

    cols1 = st.columns([2, 1])

    with cols1[0]:
        st.markdown(f"**Welcome {st.session_state['username']}**")
    with cols1[1]:
        st.markdown(f"**Current Project: {st.session_state['project_name']}**")

# Load saved data from project dictionary
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
    st.warning(
        "The data import is incomplete. Please go back to the Data Import page and complete the save.",
        icon="🔙",
    )

    # Log message
    log_message(
        "warning",
        "The data import is incomplete. Please go back to the Data Import page and complete the save.",
        "Transformations",
    )

    st.stop()
else:
    final_df_loaded = st.session_state["project_dct"]["data_import"][
        "imputed_tool_df"
    ].copy()
    bin_dict_loaded = st.session_state["project_dct"]["data_import"][
        "category_dict"
    ].copy()
    unique_panels = st.session_state["project_dct"]["data_import"][
        "unique_panels"
    ].copy()

# Initialize project dictionary data
if st.session_state["project_dct"]["transformations"]["final_df"] is None:
    st.session_state["project_dct"]["transformations"][
        "final_df"
    ] = final_df_loaded  # Default as original dataframe

# Extract original columns for specified categories
original_columns = {
    category: bin_dict_loaded[category]
    for category in ["Media", "Internal", "Exogenous"]
    if category in bin_dict_loaded
}

# Retrive Panel columns
panel = ["panel"] if len(unique_panels) > 1 else []


# Function to clear model metadata
def clear_pages():
    # Reset Pages
    st.session_state["project_dct"]["model_build"] = {
        "sel_target_col": None,
        "all_iters_check": False,
        "iterations": 0,
        "build_button": False,
        "show_results_check": False,
        "session_state_saved": {},
    }
    st.session_state["project_dct"]["model_tuning"] = {
        "sel_target_col": None,
        "sel_model": {},
        "flag_expander": False,
        "start_date_default": None,
        "end_date_default": None,
        "repeat_default": "No",
        "flags": {},
        "select_all_flags_check": {},
        "selected_flags": {},
        "trend_check": False,
        "week_num_check": False,
        "sine_cosine_check": False,
        "session_state_saved": {},
    }
    st.session_state["project_dct"]["saved_model_results"] = {
        "selected_options": None,
        "model_grid_sel": [1],
    }
    if "model_results_df" in st.session_state:
        del st.session_state["model_results_df"]
    if "model_results_data" in st.session_state:
        del st.session_state["model_results_data"]
    if "coefficients_df" in st.session_state:
        del st.session_state["coefficients_df"]


# Function to update transformation change
def transformation_change(category, transformation, key):
    st.session_state["project_dct"]["transformations"][category][transformation] = (
        st.session_state[key]
    )


# Function to update specific transformation change
def transformation_specific_change(channel_name, transformation, key):
    st.session_state["project_dct"]["transformations"]["Specific"][transformation][
        channel_name
    ] = st.session_state[key]


# Function to update transformations to apply change
def transformations_to_apply_change(category, key):
    st.session_state["project_dct"]["transformations"][category][key] = (
        st.session_state[key]
    )


# Function to update channel select specific change
def channel_select_specific_change():
    st.session_state["project_dct"]["transformations"]["Specific"][
        "channel_select_specific"
    ] = st.session_state["channel_select_specific"]


# Function to update specific transformation change
def specific_transformation_change(specific_transformation_key):
    st.session_state["project_dct"]["transformations"]["Specific"][
        specific_transformation_key
    ] = st.session_state[specific_transformation_key]


# Function to build transformation widgets
def transformation_widgets(category, transform_params, date_granularity):
    # Transformation Options
    transformation_options = {
        "Media": [
            "Lag",
            "Moving Average",
            "Saturation",
            "Power",
            "Adstock",
        ],
        "Internal": ["Lead", "Lag", "Moving Average"],
        "Exogenous": ["Lead", "Lag", "Moving Average"],
    }

    # Define a helper function to create widgets for each transformation
    def create_transformation_widgets(column, transformations):
        with column:
            for transformation in transformations:
                transformation_key = f"{transformation}_{category}"

                slider_value = st.session_state["project_dct"]["transformations"][
                    category
                ].get(transformation, predefined_defaults[transformation])

                # Conditionally create widgets for selected transformations
                if transformation == "Lead":
                    st.markdown(f"**{transformation} ({date_granularity})**")

                    lead = st.slider(
                        label="Lead periods",
                        min_value=lead_min_value,
                        max_value=lead_max_value,
                        step=lead_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = lead[0]
                    end = lead[1]
                    step = lead_step
                    transform_params[category][transformation] = np.arange(
                        start, end + step, step
                    )

                if transformation == "Lag":
                    st.markdown(f"**{transformation} ({date_granularity})**")

                    lag = st.slider(
                        label="Lag periods",
                        min_value=lag_min_value,
                        max_value=lag_max_value,
                        step=lag_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = lag[0]
                    end = lag[1]
                    step = lag_step
                    transform_params[category][transformation] = np.arange(
                        start, end + step, step
                    )

                if transformation == "Moving Average":
                    st.markdown(f"**{transformation} ({date_granularity})**")

                    window = st.slider(
                        label="Window size for Moving Average",
                        min_value=moving_average_min_value,
                        max_value=moving_average_max_value,
                        step=moving_average_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = window[0]
                    end = window[1]
                    step = moving_average_step
                    transform_params[category][transformation] = np.arange(
                        start, end + step, step
                    )

                if transformation == "Saturation":
                    st.markdown(f"**{transformation} (%)**")

                    saturation_point = st.slider(
                        label="Saturation Percentage",
                        min_value=saturation_min_value,
                        max_value=saturation_max_value,
                        step=saturation_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = saturation_point[0]
                    end = saturation_point[1]
                    step = saturation_step
                    transform_params[category][transformation] = np.arange(
                        start, end + step, step
                    )

                if transformation == "Power":
                    st.markdown(f"**{transformation}**")

                    power = st.slider(
                        label="Power",
                        min_value=power_min_value,
                        max_value=power_max_value,
                        step=power_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = power[0]
                    end = power[1]
                    step = power_step
                    transform_params[category][transformation] = np.arange(
                        start, end + step, step
                    )

                if transformation == "Adstock":
                    st.markdown(f"**{transformation}**")

                    rate = st.slider(
                        label="Decay Factor",
                        min_value=adstock_min_value,
                        max_value=adstock_max_value,
                        step=adstock_step,
                        value=slider_value,
                        key=transformation_key,
                        label_visibility="collapsed",
                        on_change=transformation_change,
                        args=(
                            category,
                            transformation,
                            transformation_key,
                        ),
                    )

                    start = rate[0]
                    end = rate[1]
                    step = adstock_step
                    adstock_range = [
                        round(a, 3) for a in np.arange(start, end + step, step)
                    ]
                    transform_params[category][transformation] = np.array(adstock_range)

    with st.expander(f"All {category} Transformations", expanded=True):

        transformation_key = f"transformation_{category}"

        # Select which transformations to apply
        sel_transformations = st.session_state["project_dct"]["transformations"][
            category
        ].get(transformation_key, [])

        # Reset default selected channels list if options are changed
        for channel in sel_transformations:
            if channel not in transformation_options[category]:
                (
                    st.session_state["project_dct"]["transformations"][category][
                        transformation_key
                    ],
                    sel_transformations,
                ) = ([], [])

        transformations_to_apply = st.multiselect(
            label="Select transformations to apply",
            options=transformation_options[category],
            default=sel_transformations,
            key=transformation_key,
            on_change=transformations_to_apply_change,
            args=(
                category,
                transformation_key,
            ),
        )

        # Determine the number of transformations to put in each column
        transformations_per_column = (
            len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
        )

        # Create two columns
        col1, col2 = st.columns(2)

        # Assign transformations to each column
        transformations_col1 = transformations_to_apply[:transformations_per_column]
        transformations_col2 = transformations_to_apply[transformations_per_column:]

        # Create widgets in each column
        create_transformation_widgets(col1, transformations_col1)
        create_transformation_widgets(col2, transformations_col2)


# Define a helper function to create widgets for each specific transformation
def create_specific_transformation_widgets(
    column,
    transformations,
    channel_name,
    date_granularity,
    specific_transform_params,
):
    with column:
        for transformation in transformations:
            transformation_key = f"{transformation}_{channel_name}_specific"

            if (
                transformation
                not in st.session_state["project_dct"]["transformations"]["Specific"]
            ):
                st.session_state["project_dct"]["transformations"]["Specific"][
                    transformation
                ] = {}

            slider_value = st.session_state["project_dct"]["transformations"][
                "Specific"
            ][transformation].get(channel_name, predefined_defaults[transformation])

            # Conditionally create widgets for selected transformations
            if transformation == "Lead":
                st.markdown(f"**Lead ({date_granularity})**")

                lead = st.slider(
                    label="Lead periods",
                    min_value=lead_min_value,
                    max_value=lead_max_value,
                    step=lead_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = lead[0]
                end = lead[1]
                step = lead_step
                specific_transform_params[channel_name]["Lead"] = np.arange(
                    start, end + step, step
                )

            if transformation == "Lag":
                st.markdown(f"**Lag ({date_granularity})**")

                lag = st.slider(
                    label="Lag periods",
                    min_value=lag_min_value,
                    max_value=lag_max_value,
                    step=lag_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = lag[0]
                end = lag[1]
                step = lag_step
                specific_transform_params[channel_name]["Lag"] = np.arange(
                    start, end + step, step
                )

            if transformation == "Moving Average":
                st.markdown(f"**Moving Average ({date_granularity})**")

                window = st.slider(
                    label="Window size for Moving Average",
                    min_value=moving_average_min_value,
                    max_value=moving_average_max_value,
                    step=moving_average_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = window[0]
                end = window[1]
                step = moving_average_step
                specific_transform_params[channel_name]["Moving Average"] = np.arange(
                    start, end + step, step
                )

            if transformation == "Saturation":
                st.markdown("**Saturation (%)**")

                saturation_point = st.slider(
                    label="Saturation Percentage",
                    min_value=saturation_min_value,
                    max_value=saturation_max_value,
                    step=saturation_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = saturation_point[0]
                end = saturation_point[1]
                step = saturation_step
                specific_transform_params[channel_name]["Saturation"] = np.arange(
                    start, end + step, step
                )

            if transformation == "Power":
                st.markdown("**Power**")

                power = st.slider(
                    label="Power",
                    min_value=power_min_value,
                    max_value=power_max_value,
                    step=power_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = power[0]
                end = power[1]
                step = power_step
                specific_transform_params[channel_name]["Power"] = np.arange(
                    start, end + step, step
                )

            if transformation == "Adstock":
                st.markdown("**Adstock**")
                rate = st.slider(
                    label="Decay Factor",
                    min_value=adstock_min_value,
                    max_value=adstock_max_value,
                    step=adstock_step,
                    value=slider_value,
                    key=transformation_key,
                    label_visibility="collapsed",
                    on_change=transformation_specific_change,
                    args=(
                        channel_name,
                        transformation,
                        transformation_key,
                    ),
                )

                start = rate[0]
                end = rate[1]
                step = adstock_step
                adstock_range = [
                    round(a, 3) for a in np.arange(start, end + step, step)
                ]
                specific_transform_params[channel_name]["Adstock"] = np.array(
                    adstock_range
                )


# Function to apply Lag transformation
def apply_lag(df, lag):
    return df.shift(lag)


# Function to apply Lead transformation
def apply_lead(df, lead):
    return df.shift(-lead)


# Function to apply Moving Average transformation
def apply_moving_average(df, window_size):
    return df.rolling(window=window_size).mean()


# Function to apply Saturation transformation
def apply_saturation(df, saturation_percent_100):
    # Convert percentage to fraction
    saturation_percent = min(max(saturation_percent_100, 0.01), 99.99) / 100.0

    # Get the maximum and minimum values
    column_max = df.max()
    column_min = df.min()
    
    # If the data is constant, scale it directly
    if column_min == column_max:
        return df.apply(lambda x: x * saturation_percent)
    
    # Compute the saturation point from the data range
    saturation_point = (column_min + saturation_percent * column_max) / 2

    # Calculate steepness for the saturation curve
    numerator = np.log((1 / saturation_percent) - 1)
    denominator = np.log(saturation_point / column_max)
    steepness = numerator / denominator

    # Apply the saturation transformation
    transformed_series = df.apply(
        lambda x: (1 / (1 + (saturation_point / (x if x != 0 else 1e-9)) ** steepness)) * x
    )

    return transformed_series


# Function to apply Power transformation
def apply_power(df, power):
    return df**power


# Function to apply Adstock transformation
def apply_adstock(df, factor):
    x = 0
    # Use the walrus operator to update x iteratively with the Adstock formula
    adstock_var = [x := x * factor + v for v in df]
    ans = pd.Series(adstock_var, index=df.index)
    return ans


# Function to generate transformed columns names
@st.cache_resource(show_spinner=False)
def generate_transformed_columns(
    original_columns, transform_params, specific_transform_params
):
    transformed_columns, summary = {}, {}

    for category, columns in original_columns.items():
        for column in columns:
            transformed_columns[column] = []
            summary_details = (
                []
            )  # List to hold transformation details for the current column

            if (
                column in specific_transform_params.keys()
                and len(specific_transform_params[column]) > 0
            ):
                for transformation, values in specific_transform_params[column].items():
                    # Generate transformed column names for each value
                    for value in values:
                        transformed_name = f"{column}@{transformation}_{value}"
                        transformed_columns[column].append(transformed_name)

                    # Format the values list as a string with commas and "and" before the last item
                    if len(values) > 1:
                        formatted_values = (
                            ", ".join(map(str, values[:-1])) + " and " + str(values[-1])
                        )
                    else:
                        formatted_values = str(values[0])

                    # Add transformation details
                    summary_details.append(f"{transformation} ({formatted_values})")

            else:
                if category in transform_params:
                    for transformation, values in transform_params[category].items():
                        # Generate transformed column names for each value
                        if column not in specific_transform_params.keys():
                            for value in values:
                                transformed_name = f"{column}@{transformation}_{value}"
                                transformed_columns[column].append(transformed_name)

                            # Format the values list as a string with commas and "and" before the last item
                            if len(values) > 1:
                                formatted_values = (
                                    ", ".join(map(str, values[:-1]))
                                    + " and "
                                    + str(values[-1])
                                )
                            else:
                                formatted_values = str(values[0])

                            # Add transformation details
                            summary_details.append(
                                f"{transformation} ({formatted_values})"
                            )

                        else:
                            summary_details = ["No transformation selected"]

            # Only add to summary if there are transformation details for the column
            if summary_details:
                formatted_summary = "⮕ ".join(summary_details)
                # Use <strong> tags to make the column name bold
                summary[column] = f"<strong>{column}</strong>: {formatted_summary}"

    # Generate a comprehensive summary string for all columns
    summary_items = [
        f"{idx + 1}. {details}" for idx, details in enumerate(summary.values())
    ]

    summary_string = "\n".join(summary_items)

    return transformed_columns, summary_string


# Function to transform Dataframe slice
def transform_slice(
    transform_params,
    transformation_functions,
    panel,
    df,
    df_slice,
    category,
    category_df,
):
    # Iterate through each transformation and its parameters for the current category
    for transformation, parameters in transform_params[category].items():
        transformation_function = transformation_functions[transformation]

        # Check if there is panel data to group by
        if len(panel) > 0:
            # Apply the transformation to each group
            category_df = pd.concat(
                [
                    df_slice.groupby(panel)
                    .transform(transformation_function, p)
                    .add_suffix(f"@{transformation}_{p}")
                    for p in parameters
                ],
                axis=1,
            )

            # Replace all NaN or null values in category_df with 0
            category_df.fillna(0, inplace=True)

            # Update df_slice
            df_slice = pd.concat(
                [df[panel], category_df],
                axis=1,
            )

        else:
            for p in parameters:
                # Apply the transformation function to each column
                temp_df = df_slice.apply(
                    lambda x: transformation_function(x, p), axis=0
                ).rename(
                    lambda x: f"{x}@{transformation}_{p}",
                    axis="columns",
                )
                # Concatenate the transformed DataFrame slice to the category DataFrame
                category_df = pd.concat([category_df, temp_df], axis=1)

            # Replace all NaN or null values in category_df with 0
            category_df.fillna(0, inplace=True)

            # Update df_slice
            df_slice = pd.concat(
                [df[panel], category_df],
                axis=1,
            )

    return category_df, df, df_slice


# Function to apply transformations to DataFrame slices based on specified categories and parameters
@st.cache_resource(show_spinner=False)
def apply_category_transformations(
    df_main, bin_dict, transform_params, panel, specific_transform_params
):
    # Dictionary for function mapping
    transformation_functions = {
        "Lead": apply_lead,
        "Lag": apply_lag,
        "Moving Average": apply_moving_average,
        "Saturation": apply_saturation,
        "Power": apply_power,
        "Adstock": apply_adstock,
    }

    # List to collect all transformed DataFrames
    transformed_dfs = []

    # Iterate through each category specified in transform_params
    for category in ["Media", "Exogenous", "Internal"]:
        if (
            category not in transform_params
            or category not in bin_dict
            or not transform_params[category]
        ):
            continue  # Skip categories without transformations

        # Initialize category_df as an empty DataFrame
        category_df = pd.DataFrame()

        # Slice the DataFrame based on the columns specified in bin_dict for the current category
        df_slice = df_main[bin_dict[category] + panel].copy()

        # Drop the column from df_slice to skip specific transformations
        df_slice = df_slice.drop(
            columns=list(specific_transform_params.keys()), errors="ignore"
        ).copy()

        category_df, df, df_slice_updated = transform_slice(
            transform_params.copy(),
            transformation_functions.copy(),
            panel,
            df_main.copy(),
            df_slice.copy(),
            category,
            category_df.copy(),
        )

        # Append the transformed category DataFrame to the list if it's not empty
        if not category_df.empty:
            transformed_dfs.append(category_df)

    # Apply channel specific transforms
    for channel_specific in specific_transform_params:
        # Initialize category_df as an empty DataFrame
        category_df = pd.DataFrame()

        df_slice_specific = df_main[[channel_specific] + panel].copy()
        transform_params_specific = {
            "Media": specific_transform_params[channel_specific]
        }

        category_df, df, df_slice_specific_updated = transform_slice(
            transform_params_specific.copy(),
            transformation_functions.copy(),
            panel,
            df_main.copy(),
            df_slice_specific.copy(),
            "Media",
            category_df.copy(),
        )

        # Append the transformed category DataFrame to the list if it's not empty
        if not category_df.empty:
            transformed_dfs.append(category_df)

    # If category_df has been modified, concatenate it with the panel and response metrics from the original DataFrame
    if len(transformed_dfs) > 0:
        final_df = pd.concat([df_main] + transformed_dfs, axis=1)
    else:
        # If no transformations were applied, use the original DataFrame
        final_df = df_main

    # Find columns with '@' in their names
    columns_with_at = [col for col in final_df.columns if "@" in col]

    # Create a set of columns to drop
    columns_to_drop = set()

    # Iterate through columns with '@' to find shorter names to drop
    for col in columns_with_at:
        base_name = col.split("@")[0]
        for other_col in columns_with_at:
            if other_col.startswith(base_name) and len(other_col.split("@")) > len(
                col.split("@")
            ):
                columns_to_drop.add(col)
                break

    # Drop the identified columns from the DataFrame
    final_df.drop(columns=list(columns_to_drop), inplace=True)

    return final_df


# Function to infers the granularity of the date column in a DataFrame
@st.cache_resource(show_spinner=False)
def infer_date_granularity(df):
    # Find the most common difference
    common_freq = pd.Series(df["date"].unique()).diff().dt.days.dropna().mode()[0]

    # Map the most common difference to a granularity
    if common_freq == 1:
        return "daily"
    elif common_freq == 7:
        return "weekly"
    elif 28 <= common_freq <= 31:
        return "monthly"
    else:
        return "irregular"


# Function to clean display DataFrame
@st.cache_data(show_spinner=False)
def clean_display_df(df, display_max_col=500):
    # Sort by 'panel' and 'date'
    sort_columns = ["panel", "date"]
    sorted_df = df.sort_values(by=sort_columns, ascending=True, na_position="first")

    # Drop duplicate columns
    sorted_df = sorted_df.loc[:, ~sorted_df.columns.duplicated()]

    # Check if the DataFrame has more than display_max_col columns
    exceeds_max_col = sorted_df.shape[1] > display_max_col

    if exceeds_max_col:
        # Create a new DataFrame with 'date' and 'panel' at the start
        display_df = sorted_df[["date", "panel"]]

        # Add the next display_max_col - 2 columns (as 'date' and 'panel' already occupy 2 columns)
        additional_columns = sorted_df.columns.difference(["date", "panel"]).tolist()[
            : display_max_col - 2
        ]
        display_df = pd.concat([display_df, sorted_df[additional_columns]], axis=1)
    else:
        # Ensure 'date' and 'panel' are the first two columns in the final display DataFrame
        column_order = ["date", "panel"] + sorted_df.columns.difference(
            ["date", "panel"]
        ).tolist()
        display_df = sorted_df[column_order]

    # Return the display DataFrame and whether it exceeds 500 columns
    return display_df, exceeds_max_col


#########################################################################################################################################################
# User input for transformations
#########################################################################################################################################################

try:
    # Page Title
    st.title("AI Model Transformations")

    # Infer date granularity
    date_granularity = infer_date_granularity(final_df_loaded)

    # Initialize the main dictionary to store the transformation parameters for each category
    transform_params = {"Media": {}, "Internal": {}, "Exogenous": {}}

    st.markdown("### Select Transformations to Apply")

    with st.expander("Specific Media Transformations"):
        # Select which transformations to apply
        sel_channel_specific = st.session_state["project_dct"]["transformations"][
            "Specific"
        ].get("channel_select_specific", [])

        # Reset default selected channels list if options are changed
        for channel in sel_channel_specific:
            if channel not in bin_dict_loaded["Media"]:
                (
                    st.session_state["project_dct"]["transformations"]["Specific"][
                        "channel_select_specific"
                    ],
                    sel_channel_specific,
                ) = ([], [])

        select_specific_channels = st.multiselect(
            label="Select channel variable",
            default=sel_channel_specific,
            options=bin_dict_loaded["Media"],
            key="channel_select_specific",
            on_change=channel_select_specific_change,
            max_selections=30,
        )

        specific_transform_params = {}
        for select_specific_channel in select_specific_channels:
            specific_transform_params[select_specific_channel] = {}

            st.divider()
            channel_name = str(select_specific_channel).replace("_", " ").title()
            st.markdown(f"###### {channel_name}")

            specific_transformation_key = (
                f"specific_transformation_{select_specific_channel}_Media"
            )

            transformations_options = [
                "Lag",
                "Moving Average",
                "Saturation",
                "Power",
                "Adstock",
            ]

            # Select which transformations to apply
            sel_transformations = st.session_state["project_dct"]["transformations"][
                "Specific"
            ].get(specific_transformation_key, [])

            # Reset default selected channels list if options are changed
            for channel in sel_transformations:
                if channel not in transformations_options:
                    (
                        st.session_state["project_dct"]["transformations"]["Specific"][
                            specific_transformation_key
                        ],
                        sel_channel_specific,
                    ) = ([], [])

            transformations_to_apply = st.multiselect(
                label="Select transformations to apply",
                options=transformations_options,
                default=sel_transformations,
                key=specific_transformation_key,
                on_change=specific_transformation_change,
                args=(specific_transformation_key,),
            )

            # Determine the number of transformations to put in each column
            transformations_per_column = (
                len(transformations_to_apply) // 2 + len(transformations_to_apply) % 2
            )

            # Create two columns
            col1, col2 = st.columns(2)

            # Assign transformations to each column
            transformations_col1 = transformations_to_apply[:transformations_per_column]
            transformations_col2 = transformations_to_apply[transformations_per_column:]

            # Create widgets in each column
            create_specific_transformation_widgets(
                col1,
                transformations_col1,
                select_specific_channel,
                date_granularity,
                specific_transform_params,
            )
            create_specific_transformation_widgets(
                col2,
                transformations_col2,
                select_specific_channel,
                date_granularity,
                specific_transform_params,
            )

    # Create Widgets
    for category in ["Media", "Internal", "Exogenous"]:
        # Skip Internal
        if category == "Internal":
            continue

        # Skip category if no column available
        elif (
            category not in bin_dict_loaded.keys()
            or len(bin_dict_loaded[category]) == 0
        ):
            st.info(
                f"{str(category).title()} category has no column associated with it. Skipping transformation step for this category.",
                icon="💬",
            )
            continue

        transformation_widgets(category, transform_params, date_granularity)

    #########################################################################################################################################################
    # Apply transformations
    #########################################################################################################################################################

    # Reset transformation selection to default
    button_col = st.columns(2)
    with button_col[1]:
        if st.button("Reset to Default", use_container_width=True):
            st.session_state["project_dct"]["transformations"]["Media"] = {}
            st.session_state["project_dct"]["transformations"]["Exogenous"] = {}
            st.session_state["project_dct"]["transformations"]["Internal"] = {}
            st.session_state["project_dct"]["transformations"]["Specific"] = {}

            # Log message
            log_message(
                "info",
                "All persistent selections have been reset to their default settings and cleared.",
                "Transformations",
            )

            st.rerun()

    # Apply category-based transformations to the DataFrame
    with button_col[0]:
        if st.button("Accept and Proceed", use_container_width=True):
            with st.spinner("Applying transformations ..."):
                final_df = apply_category_transformations(
                    final_df_loaded.copy(),
                    bin_dict_loaded.copy(),
                    transform_params.copy(),
                    panel.copy(),
                    specific_transform_params.copy(),
                )

                # Generate a dictionary mapping original column names to lists of transformed column names
                transformed_columns_dict, summary_string = generate_transformed_columns(
                    original_columns, transform_params, specific_transform_params
                )

                # Store into transformed dataframe and summary session state
                st.session_state["project_dct"]["transformations"][
                    "final_df"
                ] = final_df
                st.session_state["project_dct"]["transformations"][
                    "summary_string"
                ] = summary_string

                # Display success message
                st.success("Transformation of the DataFrame is successful!", icon="✅")

                # Log message
                log_message(
                    "info",
                    "Transformation of the DataFrame is successful!",
                    "Transformations",
                )

    #########################################################################################################################################################
    # Display the transformed DataFrame and summary
    #########################################################################################################################################################

    # Display the transformed DataFrame in the Streamlit app
    st.markdown("### Transformed DataFrame")
    with st.spinner("Please wait while the transformed DataFrame is loading ..."):
        final_df = st.session_state["project_dct"]["transformations"]["final_df"].copy()

        # Clean display DataFrame
        display_df, exceeds_max_col = clean_display_df(final_df, display_max_col)

        # Check the number of columns and show only the first display_max_col if there are more
        if exceeds_max_col:
            # Display a info if the DataFrame has more than display_max_col columns
            st.info(
                f"The transformed DataFrame has more than {display_max_col} columns. Displaying only the first {display_max_col} columns.",
                icon="💬",
            )

        # Display Final DataFrame
        st.dataframe(
            display_df,
            hide_index=True,
            column_config={
                "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
            },
        )

        # Total rows and columns
        total_rows, total_columns = st.session_state["project_dct"]["transformations"][
            "final_df"
        ].shape
        st.markdown(
            f"<p style='text-align: justify;'>The transformed DataFrame contains <strong>{total_rows}</strong> rows and <strong>{total_columns}</strong> columns.</p>",
            unsafe_allow_html=True,
        )

        # Display the summary of transformations as markdown
        if (
            "summary_string" in st.session_state["project_dct"]["transformations"]
            and st.session_state["project_dct"]["transformations"]["summary_string"]
        ):
            with st.expander("Summary of Transformations"):
                st.markdown("### Summary of Transformations")
                st.markdown(
                    st.session_state["project_dct"]["transformations"][
                        "summary_string"
                    ],
                    unsafe_allow_html=True,
                )

    #########################################################################################################################################################
    # Correlation Plot
    #########################################################################################################################################################

    # Filter out the 'date' column
    variables = [
        col for col in final_df.columns if col.lower() not in ["date", "panel"]
    ]

    with st.expander("Transformed Variable Correlation Plot"):
        selected_vars = st.multiselect(
            label="Choose variables for correlation plot:",
            options=variables,
            max_selections=30,
            default=st.session_state["project_dct"]["transformations"][
                "correlation_plot_selection"
            ],
            key="correlation_plot_key",
        )

        # Calculate correlation
        if selected_vars:
            corr_df = final_df[selected_vars].corr()

            # Prepare text annotations with 2 decimal places
            annotations = []
            for i in range(len(corr_df)):
                for j in range(len(corr_df.columns)):
                    annotations.append(
                        go.layout.Annotation(
                            text=f"{corr_df.iloc[i, j]:.2f}",
                            x=corr_df.columns[j],
                            y=corr_df.index[i],
                            showarrow=False,
                            font=dict(color="black"),
                        )
                    )

            # Plotly correlation plot using go
            heatmap = go.Heatmap(
                z=corr_df.values,
                x=corr_df.columns,
                y=corr_df.index,
                colorscale="RdBu",
                zmin=-1,
                zmax=1,
            )

            layout = go.Layout(
                title="Transformed Variable Correlation Plot",
                xaxis=dict(title="Variables"),
                yaxis=dict(title="Variables"),
                width=1000,
                height=1000,
                annotations=annotations,
            )

            fig = go.Figure(data=[heatmap], layout=layout)

            st.plotly_chart(fig)
        else:
            st.write("Please select at least one variable to plot.")

    #########################################################################################################################################################
    # Accept and Save
    #########################################################################################################################################################

    # Check for saved model
    if (
        retrieve_pkl_object(
            st.session_state["project_number"], "Model_Build", "best_models", schema
        )
        is not None
    ):  # db
        st.warning(
            "Saving transformations will overwrite existing ones and delete all saved models. To keep previous models, please start a new project.",
            icon="⚠️",
        )

    if st.button("Accept and Save", use_container_width=True):

        with st.spinner("Saving Changes"):
            # Update correlation plot selection
            st.session_state["project_dct"]["transformations"][
                "correlation_plot_selection"
            ] = st.session_state["correlation_plot_key"]

            # Clear model metadata
            clear_pages()

            # Update DB
            update_db(
                prj_id=st.session_state["project_number"],
                page_nam="Transformations",
                file_nam="project_dct",
                pkl_obj=pickle.dumps(st.session_state["project_dct"]),
                schema=schema,
            )

            # Clear data from DB
            delete_entries(
                st.session_state["project_number"],
                ["Model_Build", "Model_Tuning"],
                db_cred,
                schema,
            )

            # Success message
            st.success("Saved Successfully!", icon="💾")
            st.toast("Saved Successfully!", icon="💾")

            # Log message
            log_message("info", "Saved Successfully!", "Transformations")

except Exception as e:
    # Capture the error details
    exc_type, exc_value, exc_traceback = sys.exc_info()
    error_message = "".join(
        traceback.format_exception(exc_type, exc_value, exc_traceback)
    )

    # Log message
    log_message("error", f"An error occurred: {error_message}.", "Transformations")

    # Display a warning message
    st.warning(
        "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
        icon="⚠️",
    )