import streamlit as st
from scenario import numerize
import pandas as pd
from utilities import (
    format_numbers,
    load_local_css,
    set_header,
    name_formating,
    project_selection,
)
import pickle
import yaml
from yaml import SafeLoader
from scenario import class_from_dict
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots
import sqlite3
from utilities import update_db
from collections import OrderedDict
import os

st.set_page_config(layout="wide")
load_local_css("styles.css")
set_header()

st.empty()
st.header("AI Model Media Recommendation")

# def get_saved_scenarios_dict():
#     # Path to the saved scenarios file
#     saved_scenarios_dict_path = os.path.join(
#         st.session_state["project_path"], "saved_scenarios.pkl"
#     )

#     # Load existing scenarios if the file exists
#     if os.path.exists(saved_scenarios_dict_path):
#         with open(saved_scenarios_dict_path, "rb") as f:
#             saved_scenarios_dict = pickle.load(f)
#     else:
#         saved_scenarios_dict = OrderedDict()

#     return saved_scenarios_dict


# # Function to format values based on their size
# def format_value(value):
#     return round(value, 4) if value < 1 else round(value, 1)


# # Function to recursively convert non-serializable types to serializable ones
# def convert_to_serializable(obj):
#     if isinstance(obj, np.ndarray):
#         return obj.tolist()
#     elif isinstance(obj, dict):
#         return {key: convert_to_serializable(value) for key, value in obj.items()}
#     elif isinstance(obj, list):
#         return [convert_to_serializable(element) for element in obj]
#     elif isinstance(obj, (int, float, str, bool, type(None))):
#         return obj
#     else:
#         # Fallback: convert the object to a string
#         return str(obj)


if "username" not in st.session_state:
    st.session_state["username"] = None

if "project_name" not in st.session_state:
    st.session_state["project_name"] = None

if "project_dct" not in st.session_state:
    project_selection()
    st.stop()
# if   "project_path" not in st.session_state:
#     st.stop()
# if 'username' in st.session_state and st.session_state['username'] is not None:

#     data_path = os.path.join(st.session_state["project_path"], "data_import.pkl")

#     try:
#         with open(data_path, "rb") as f:
#             data = pickle.load(f)
#     except Exception as e:
#         st.error(f"Please import data from the Data Import Page")
#         st.stop()
# # Get saved scenarios dictionary and scenario name list
# saved_scenarios_dict = get_saved_scenarios_dict()
# scenarios_list = list(saved_scenarios_dict.keys())

# #st.write(saved_scenarios_dict)
# # Check if the list of saved scenarios is empty
# if len(scenarios_list) == 0:
#     # Display a warning message if no scenarios are saved
#     st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
#     st.stop()

# # Display a dropdown saved scenario list
# selected_scenario = st.selectbox(
#     "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
# )
# selected_scenario_data = saved_scenarios_dict[selected_scenario]

# # Scenarios Name
# metrics_name = selected_scenario_data["metrics_selected"]
# panel_name = selected_scenario_data["panel_selected"]
# optimization_name = selected_scenario_data["optimization"]

# # Display the scenario details with bold "Metric," "Panel," and "Optimization"

# # Create columns for download and delete buttons
# download_col, delete_col = st.columns(2)


# channels_list = list(selected_scenario_data["channels"].keys())

# # List to hold data for all channels
# channels_data = []

# # Iterate through each channel and gather required data
# for channel in channels_list:
#     channel_conversion_rate = selected_scenario_data["channels"][channel][
#         "conversion_rate"
#     ]
#     channel_actual_spends = (
#         selected_scenario_data["channels"][channel]["actual_total_spends"]
#         * channel_conversion_rate
#     )
#     channel_optimized_spends = (
#         selected_scenario_data["channels"][channel]["modified_total_spends"]
#         * channel_conversion_rate
#     )

#     channel_actual_metrics = selected_scenario_data["channels"][channel][
#         "actual_total_sales"
#     ]
#     channel_optimized_metrics = selected_scenario_data["channels"][channel][
#         "modified_total_sales"
#     ]

#     channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]

#     # Extract the ROI and MROI data
#     actual_roi = channel_roi_mroi_data["actual_roi"]
#     optimized_roi = channel_roi_mroi_data["optimized_roi"]
#     actual_mroi = channel_roi_mroi_data["actual_mroi"]
#     optimized_mroi = channel_roi_mroi_data["optimized_mroi"]

#     # Calculate spends per metric
#     spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
#     spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics

#     # Append the collected data as a dictionary to the list
#     channels_data.append(
#         {
#             "Channel Name": channel,
#             "Spends Actual": channel_actual_spends,
#             "Spends Optimized": channel_optimized_spends,
#             f"{metrics_name} Actual": channel_actual_metrics,
#             f"{name_formating(metrics_name)} Optimized": numerize(
#                 channel_optimized_metrics
#             ),
#             "ROI Actual": format_value(actual_roi),
#             "ROI Optimized": format_value(optimized_roi),
#             "MROI Actual": format_value(actual_mroi),
#             "MROI Optimized": format_value(optimized_mroi),
#             f"Spends per {name_formating(metrics_name)} Actual": numerize(
#                 spends_per_metrics_actual
#             ),
#             f"Spends per {name_formating(metrics_name)} Optimized": numerize(
#                 spends_per_metrics_optimized
#             ),
#         }
#     )

# # Create a DataFrame from the collected data

##NEW CODE##########

scenarios_name_placeholder = st.empty()


# Function to get saved scenarios dictionary
def get_saved_scenarios_dict():
    return st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"]


# Function to format values based on their size
def format_value(value):
    return round(value, 4) if value < 1 else round(value, 1)


# Function to recursively convert non-serializable types to serializable ones
def convert_to_serializable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(element) for element in obj]
    elif isinstance(obj, (int, float, str, bool, type(None))):
        return obj
    else:
        # Fallback: convert the object to a string
        return str(obj)


# Get saved scenarios dictionary and scenario name list
saved_scenarios_dict = get_saved_scenarios_dict()
scenarios_list = list(saved_scenarios_dict.keys())

# Check if the list of saved scenarios is empty
if len(scenarios_list) == 0:
    # Display a warning message if no scenarios are saved
    st.warning("No scenarios saved. Please save a scenario to load.", icon="⚠️")
    st.stop()

# Display a dropdown saved scenario list
selected_scenario = st.selectbox(
    "Pick a Scenario", sorted(scenarios_list), key="selected_scenario"
)
selected_scenario_data = saved_scenarios_dict[selected_scenario]

# Scenarios Name
metrics_name = selected_scenario_data["metrics_selected"]
panel_name = selected_scenario_data["panel_selected"]
optimization_name = selected_scenario_data["optimization"]
multiplier = selected_scenario_data["multiplier"]
timeframe = selected_scenario_data["timeframe"]

# Display the scenario details with bold "Metric," "Panel," and "Optimization"
scenarios_name_placeholder.markdown(
    f"**Metric**: {name_formating(metrics_name)}; **Panel**: {name_formating(panel_name)}; **Fix**: {name_formating(optimization_name)}; **Timeframe**: {name_formating(timeframe)}"
)

# Create columns for download and delete buttons
download_col, delete_col = st.columns(2)

# Channel List
channels_list = list(selected_scenario_data["channels"].keys())

# List to hold data for all channels
channels_data = []

# Iterate through each channel and gather required data
for channel in channels_list:
    channel_conversion_rate = selected_scenario_data["channels"][channel][
        "conversion_rate"
    ]
    channel_actual_spends = (
        selected_scenario_data["channels"][channel]["actual_total_spends"]
        * channel_conversion_rate
    )
    channel_optimized_spends = (
        selected_scenario_data["channels"][channel]["modified_total_spends"]
        * channel_conversion_rate
    )

    channel_actual_metrics = selected_scenario_data["channels"][channel][
        "actual_total_sales"
    ]
    channel_optimized_metrics = selected_scenario_data["channels"][channel][
        "modified_total_sales"
    ]

    channel_roi_mroi_data = selected_scenario_data["channel_roi_mroi"][channel]

    # Extract the ROI and MROI data
    actual_roi = channel_roi_mroi_data["actual_roi"]
    optimized_roi = channel_roi_mroi_data["optimized_roi"]
    actual_mroi = channel_roi_mroi_data["actual_mroi"]
    optimized_mroi = channel_roi_mroi_data["optimized_mroi"]

    # Calculate spends per metric
    spends_per_metrics_actual = channel_actual_spends / channel_actual_metrics
    spends_per_metrics_optimized = channel_optimized_spends / channel_optimized_metrics

    # Append the collected data as a dictionary to the list
    channels_data.append(
        {
            "Channel Name": channel,
            "Spends Actual": (channel_actual_spends / multiplier),
            "Spends Optimized": (channel_optimized_spends / multiplier),
            f"{name_formating(metrics_name)} Actual": (
                channel_actual_metrics / multiplier
            ),
            f"{name_formating(metrics_name)} Optimized": (
                channel_optimized_metrics / multiplier
            ),
            "ROI Actual": format_value(actual_roi),
            "ROI Optimized": format_value(optimized_roi),
            "MROI Actual": format_value(actual_mroi),
            "MROI Optimized": format_value(optimized_mroi),
            f"Spends per {name_formating(metrics_name)} Actual": round(
                spends_per_metrics_actual, 2
            ),
            f"Spends per {name_formating(metrics_name)} Optimized": round(
                spends_per_metrics_optimized, 2
            ),
        }
    )

# Create a DataFrame from the collected data
# df = pd.DataFrame(channels_data)

# # Display the DataFrame
# st.dataframe(df, hide_index=True)

summary_df_sorted = pd.DataFrame(channels_data).sort_values(by=["Spends Optimized"])


summary_df_sorted["Delta"] = (
    summary_df_sorted["Spends Optimized"] - summary_df_sorted["Spends Actual"]
)


summary_df_sorted["Delta_percent"] = np.round(
    (summary_df_sorted["Delta"]) / summary_df_sorted["Spends Actual"] * 100, 2
)

# spends_data = pd.read_excel("Overview_data_test.xlsx")


st.header("Optimized Media Spend Overview")

channel_colors = px.colors.qualitative.Plotly

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Actual Spend", "Spends Optimized", "Delta"),
    horizontal_spacing=0.05,
)

for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
    channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
    channel_color = channel_colors[i % len(channel_colors)]

    fig.add_trace(
        go.Bar(
            x=channel_df["Spends Actual"],
            y=channel_df["Channel Name"],
            text=channel_df["Spends Actual"].apply(format_numbers),
            marker_color=channel_color,
            orientation="h",
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Bar(
            x=channel_df["Spends Optimized"],
            y=channel_df["Channel Name"],
            text=channel_df["Spends Optimized"].apply(format_numbers),
            marker_color=channel_color,
            orientation="h",
            showlegend=False,
        ),
        row=1,
        col=2,
    )

    fig.add_trace(
        go.Bar(
            x=channel_df["Delta_percent"],
            y=channel_df["Channel Name"],
            text=channel_df["Delta_percent"].apply(lambda x: f"{x:.0f}%"),
            marker_color=channel_color,
            orientation="h",
            showlegend=False,
        ),
        row=1,
        col=3,
    )
fig.update_layout(height=600, width=900, title="", showlegend=False)

fig.update_yaxes(showticklabels=False, row=1, col=2)
fig.update_yaxes(showticklabels=False, row=1, col=3)

fig.update_xaxes(showticklabels=False, row=1, col=1)
fig.update_xaxes(showticklabels=False, row=1, col=2)
fig.update_xaxes(showticklabels=False, row=1, col=3)


st.plotly_chart(fig, use_container_width=True)


summary_df_sorted["Perc_alloted"] = np.round(
    summary_df_sorted["Spends Optimized"] / summary_df_sorted["Spends Optimized"].sum(),
    2,
)
st.header("Optimized Media Spend Allocation")

fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Spends Optimized", "% Split"),
    horizontal_spacing=0.05,
)

for i, channel in enumerate(summary_df_sorted["Channel Name"].unique()):
    channel_df = summary_df_sorted[summary_df_sorted["Channel Name"] == channel]
    channel_color = channel_colors[i % len(channel_colors)]

    fig.add_trace(
        go.Bar(
            x=channel_df["Spends Optimized"],
            y=channel_df["Channel Name"],
            text=channel_df["Spends Optimized"].apply(format_numbers),
            marker_color=channel_color,
            orientation="h",
        ),
        row=1,
        col=1,
    )

    fig.add_trace(
        go.Bar(
            x=channel_df["Perc_alloted"],
            y=channel_df["Channel Name"],
            text=channel_df["Perc_alloted"].apply(lambda x: f"{100*x:.0f}%"),
            marker_color=channel_color,
            orientation="h",
            showlegend=False,
        ),
        row=1,
        col=2,
    )

fig.update_layout(height=600, width=900, title="", showlegend=False)

fig.update_yaxes(showticklabels=False, row=1, col=2)
fig.update_yaxes(showticklabels=False, row=1, col=3)

fig.update_xaxes(showticklabels=False, row=1, col=1)
fig.update_xaxes(showticklabels=False, row=1, col=2)
fig.update_xaxes(showticklabels=False, row=1, col=3)


st.plotly_chart(fig, use_container_width=True)


st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
    "imputed_tool_df"
]
st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
    "category_dict"
]

effectiveness_overall = pd.DataFrame()

response_metrics = list(
    *[
        st.session_state["category_dict"][key]
        for key in st.session_state["category_dict"].keys()
        if key == "Response Metrics"
    ]
)

effectiveness_overall = (
    st.session_state["cleaned_data"][response_metrics].sum().reset_index()
)

effectiveness_overall.columns = ["ResponseMetricName", "ResponseMetricValue"]


effectiveness_overall["Efficiency"] = effectiveness_overall["ResponseMetricValue"].map(
    lambda x: x / summary_df_sorted["Spends Optimized"].sum()
)


columns6 = st.columns(3)

effectiveness_overall.sort_values(
    by=["ResponseMetricValue"], ascending=False, inplace=True
)
effectiveness_overall = np.round(effectiveness_overall, 2)

columns4 = st.columns([0.55, 0.45])

# effectiveness_overall = effectiveness_overall.sort_values(by=["ResponseMetricValue"])

# with columns4[0]:
#     fig = px.funnel(
#         effectiveness_overall,
#         x="ResponseMetricValue",
#         y="ResponseMetricName",
#         color="ResponseMetricName",
#         title="Effectiveness",
#     )
#     fig.update_layout(
#         showlegend=False,
#         yaxis=dict(tickmode="array"),
#     )
#     fig.update_traces(
#         textinfo="value",
#         textposition="inside",
#         texttemplate="%{x:.2s} ",
#         hoverinfo="y+x+percent initial",
#     )
#     fig.update_traces(
#         marker=dict(line=dict(color="black", width=2)),
#         selector=dict(marker=dict(color="blue")),
#     )

#     st.plotly_chart(fig, use_container_width=True)

# with columns4[1]:
#     fig1 = px.bar(
#         effectiveness_overall.sort_values(by=["ResponseMetricValue"], ascending=False),
#         x="Efficiency",
#         y="ResponseMetricName",
#         color="ResponseMetricName",
#         text_auto=True,
#         title="Efficiency",
#     )

#     # Update layout and traces
#     fig1.update_traces(
#         customdata=effectiveness_overall["Efficiency"], textposition="auto"
#     )
#     fig1.update_layout(showlegend=False)
#     fig1.update_yaxes(title="", showticklabels=False)
#     fig1.update_xaxes(title="", showticklabels=False)
#     fig1.update_xaxes(tickfont=dict(size=20))
#     fig1.update_yaxes(tickfont=dict(size=20))
#     st.plotly_chart(fig1, use_container_width=True)

# Function to format metric names
def format_metric_name(metric_name):
    return str(metric_name).lower().replace("response_metric_", "").replace("_", " ").strip().title()

# Apply the formatting function to the 'ResponseMetricName' column
effectiveness_overall["FormattedMetricName"] = effectiveness_overall["ResponseMetricName"].apply(format_metric_name)

# Multiselect widget with all options as default, but using the formatted names for display
all_metrics = effectiveness_overall["FormattedMetricName"].unique()
selected_metrics = st.multiselect(
    "Select Metrics to Display",
    options=all_metrics,
    default=all_metrics
)

# Filter the data based on the selected metrics (using formatted names)
if selected_metrics:
    filtered_data = effectiveness_overall[
        effectiveness_overall["FormattedMetricName"].isin(selected_metrics)
    ]
    
    # Sort values for funnel plot
    filtered_data = filtered_data.sort_values(by=["ResponseMetricValue"])

    # Generate a consistent color mapping for all selected metrics
    color_map = {metric: px.colors.qualitative.Plotly[i % len(px.colors.qualitative.Plotly)] 
                 for i, metric in enumerate(filtered_data["FormattedMetricName"].unique())}

    # First plot: Funnel
    with columns4[0]:
        fig = px.funnel(
            filtered_data,
            x="ResponseMetricValue",
            y="FormattedMetricName",  # Use formatted names for y-axis
            color="FormattedMetricName",  # Use formatted names for color
            color_discrete_map=color_map,  # Ensure consistent colors
            title="Effectiveness",
        )
        fig.update_layout(
            showlegend=False,
            yaxis=dict(title="Response Metric", tickmode="array"),  # Set y-axis label to 'Response Metric'
        )
        fig.update_traces(
            textinfo="value",
            textposition="inside",
            texttemplate="%{x:.2s} ",
            hoverinfo="y+x+percent initial",
        )
        fig.update_traces(
            marker=dict(line=dict(color="black", width=2)),
            selector=dict(marker=dict(color="blue")),
        )

        st.plotly_chart(fig, use_container_width=True)

    # Second plot: Bar chart
    with columns4[1]:
        fig1 = px.bar(
            filtered_data.sort_values(by=["ResponseMetricValue"], ascending=False),
            x="Efficiency",
            y="FormattedMetricName",  # Use formatted names for y-axis
            color="FormattedMetricName",  # Use formatted names for color
            color_discrete_map=color_map,  # Ensure consistent colors
            text_auto=True,
            title="Efficiency",
        )

        # Update layout and traces
        fig1.update_traces(
            customdata=filtered_data["Efficiency"], textposition="auto"
        )
        fig1.update_layout(showlegend=False)
        fig1.update_yaxes(title="", showticklabels=False)
        fig1.update_xaxes(title="", showticklabels=False)
        fig1.update_xaxes(tickfont=dict(size=20))
        fig1.update_yaxes(tickfont=dict(size=20))
        st.plotly_chart(fig1, use_container_width=True)
else:
    st.info("Please select at least one response metric to display the charts.")

st.header("Return Forecast by Media Channel")

with st.expander("Return Forecast by Media Channel"):


    metric = metrics_name

    metric = metric.lower().replace("_", " ") + " " + "actual"
    summary_df_sorted.columns = [
        col.lower().replace("_", " ") for col in summary_df_sorted.columns
    ]

    effectiveness = summary_df_sorted[metric]

    summary_df_sorted["Efficiency"] = (
        summary_df_sorted[metric] / summary_df_sorted["spends optimized"]
    )

    channel_colors = px.colors.qualitative.Plotly

    fig = make_subplots(
        rows=1,
        cols=3,
        subplot_titles=("Optimized Spends", "Effectiveness", "Efficiency"),
        horizontal_spacing=0.05,
    )

    for i, channel in enumerate(summary_df_sorted["channel name"].unique()):
        channel_df = summary_df_sorted[summary_df_sorted["channel name"] == channel]
        channel_color = channel_colors[i % len(channel_colors)]

        fig.add_trace(
            go.Bar(
                x=channel_df["spends optimized"],
                y=channel_df["channel name"],
                text=channel_df["spends optimized"].apply(format_numbers),
                marker_color=channel_color,
                orientation="h",
            ),
            row=1,
            col=1,
        )

        fig.add_trace(
            go.Bar(
                x=channel_df[metric],
                y=channel_df["channel name"],
                text=channel_df[metric].apply(format_numbers),
                marker_color=channel_color,
                orientation="h",
                showlegend=False,
            ),
            row=1,
            col=2,
        )

        fig.add_trace(
            go.Bar(
                x=channel_df["Efficiency"],
                y=channel_df["channel name"],
                text=channel_df["Efficiency"].apply(lambda x: f"{x:.2f}"),
                marker_color=channel_color,
                orientation="h",
                showlegend=False,
            ),
            row=1,
            col=3,
        )

    fig.update_layout(
        height=600,
        width=900,
        title="Media Channel Performance",
        showlegend=False,
    )

    fig.update_yaxes(showticklabels=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, row=1, col=3)

    fig.update_xaxes(showticklabels=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, row=1, col=2)
    fig.update_xaxes(showticklabels=False, row=1, col=3)

    st.plotly_chart(fig, use_container_width=True)