import pandas as pd
import numpy as np
import pptx
from pptx import Presentation
from pptx.chart.data import CategoryChartData, ChartData
from pptx.enum.chart import XL_CHART_TYPE, XL_LEGEND_POSITION, XL_LABEL_POSITION
from pptx.enum.chart import XL_TICK_LABEL_POSITION
from pptx.util import Inches, Pt
import os
import pickle
from pathlib import Path
from sklearn.metrics import (
    mean_absolute_error,
    r2_score,
    mean_absolute_percentage_error,
)
import streamlit as st
from collections import OrderedDict
from utilities import get_metrics_names, initialize_data, retrieve_pkl_object_without_warning
from io import BytesIO
from pptx.dml.color import RGBColor
from post_gres_cred import db_cred
schema=db_cred['schema']

from constants import (
    TITLE_FONT_SIZE,
    AXIS_LABEL_FONT_SIZE,
    CHART_TITLE_FONT_SIZE,
    AXIS_TITLE_FONT_SIZE,
    DATA_LABEL_FONT_SIZE,
    LEGEND_FONT_SIZE,
    PIE_LEGEND_FONT_SIZE
)


def format_response_metric(target):
    if target.startswith('response_metric_'):
        target = target.replace('response_metric_', '')
    target = target.replace("_", " ").title()
    return target


def smape(actual, forecast):
    # Symmetric Mape (SMAPE) eliminates shortcomings of MAPE :
    ## 1. MAPE becomes insanely high when actual is close to 0
    ## 2. MAPE is more favourable to underforecast than overforecast
    return (1 / len(actual)) * np.sum(1 * np.abs(forecast - actual) / (np.abs(actual) + np.abs(forecast)))


def safe_num_to_per(num):
    try:
        return "{:.0%}".format(num)
    except:
        return num


# Function to convert numbers to abbreviated format
def convert_number_to_abbreviation(number):
    try:
        number = float(number)
        if number >= 1000000:
            return f'{number / 1000000:.1f} M'
        elif number >= 1000:
            return f'{number / 1000:.1f} K'
        else:
            return str(number)
    except:
        return number


def round_off(x, round_off_decimal=0):
    # round off
    try:
        x = float(x)
        if x < 1 and x > 0:
            round_off_decimal = int(np.floor(np.abs(np.log10(x)))) + max(round_off_decimal, 1)
            x = np.round(x, round_off_decimal)
        elif x < 0 and x > -1:
            round_off_decimal = int(np.floor(np.abs(np.log10(np.abs(x))))) + max(round_off_decimal, 1)
            x = -np.round(x, round_off_decimal)
        else:
            x = np.round(x, round_off_decimal)
        return x
    except:
        return x


def fill_table_placeholder(table_placeholder, slide, df, column_width=None, table_height=None):
    cols = len(df.columns)
    rows = len(df)

    if table_height is None:
        table_height = table_placeholder.height

    x, y, cx, cy = table_placeholder.left, table_placeholder.top, table_placeholder.width, table_height
    table = slide.shapes.add_table(rows + 1, cols, x, y, cx, cy).table

    # Populate the table with data from the DataFrame
    for row_idx, row in enumerate(df.values):
        for col_idx, value in enumerate(row):
            cell = table.cell(row_idx + 1, col_idx)
            cell.text = str(value)
    for col_idx, value in enumerate(df.columns):
        cell = table.cell(0, col_idx)
        cell.text = str(value)

    if column_width is not None:
        for col_idx, column_width in column_width.items():
            table.columns[col_idx].width = Inches(column_width)

    table_placeholder._element.getparent().remove(table_placeholder._element)


def bar_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None, type='V', legend=True,
              label_type=None, xaxis_pos=None):
    x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height
    if type == 'V':
        graphic_frame = slide.shapes.add_chart(
            XL_CHART_TYPE.COLUMN_CLUSTERED, x, y, cx, cy, chart_data
        )
    if type == 'H':
        graphic_frame = slide.shapes.add_chart(
            XL_CHART_TYPE.BAR_CLUSTERED, x, y, cx, cy, chart_data
        )
    chart = graphic_frame.chart

    category_axis = chart.category_axis
    value_axis = chart.value_axis

    # Add chart title
    if 'chart_title' in titles.keys():
        chart.has_title = True
        chart.chart_title.text_frame.text = titles['chart_title']
        chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
        chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)

    # Add axis titles
    if 'x_axis' in titles.keys():
        category_axis.has_title = True
        category_axis.axis_title.text_frame.text = titles['x_axis']
        category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
        category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)

    if 'y_axis' in titles.keys():
        value_axis.has_title = True
        value_axis.axis_title.text_frame.text = titles['y_axis']
        value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
        value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)

    if xaxis_pos == 'low':
        category_axis.tick_label_position = XL_TICK_LABEL_POSITION.LOW

    # Customize the chart
    if legend:
        chart.has_legend = True
        chart.legend.position = XL_LEGEND_POSITION.BOTTOM
        chart.legend.font.size = Pt(LEGEND_FONT_SIZE)
        chart.legend.include_in_layout = False

    # Adjust font size for axis labels
    category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
    value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)

    if min_y is not None:
        value_axis.minimum_scale = min_y  # Adjust this value as needed

    if max_y is not None:
        value_axis.maximum_scale = max_y  # Adjust this value as needed

    plot = chart.plots[0]
    plot.has_data_labels = True
    data_labels = plot.data_labels

    if label_type == 'per':
        data_labels.number_format = '0"%"'
    elif label_type == '$':
        data_labels.number_format = '$[>=1000000]#,##0.0,,"M";$[>=1000]#,##0.0,"K";$#,##0'
    elif label_type == '$1':
        data_labels.number_format = '$[>=1000000]#,##0,,"M";$[>=1000]#,##0,"K";$#,##0'
    elif label_type == 'M':
        data_labels.number_format = '#0.0,,"M"'
    elif label_type == 'M1':
        data_labels.number_format = '#0.00,,"M"'
    elif label_type == 'K':
        data_labels.number_format = '#0.0,"K"'

    data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)

    chart_placeholder._element.getparent().remove(chart_placeholder._element)


def line_chart(chart_placeholder, slide, chart_data, titles={}, min_y=None, max_y=None):
    # Add the chart to the slide
    x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height

    chart = slide.shapes.add_chart(
        XL_CHART_TYPE.LINE, x, y, cx, cy, chart_data
    ).chart

    chart.has_legend = True
    chart.legend.position = XL_LEGEND_POSITION.BOTTOM
    chart.legend.font.size = Pt(LEGEND_FONT_SIZE)

    category_axis = chart.category_axis
    value_axis = chart.value_axis

    if min_y is not None:
        value_axis.minimum_scale = min_y

    if max_y is not None:
        value_axis.maximum_scale = max_y

    if min_y is not None and max_y is not None:
        value_axis.major_unit = int((max_y - min_y) / 2)

    if 'chart_title' in titles.keys():
        chart.has_title = True
        chart.chart_title.text_frame.text = titles['chart_title']
        chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
        chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)

    if 'x_axis' in titles.keys():
        category_axis.has_title = True
        category_axis.axis_title.text_frame.text = titles['x_axis']
        category_title = category_axis.axis_title.text_frame.paragraphs[0].runs[0]
        category_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)

    if 'y_axis' in titles.keys():
        value_axis.has_title = True
        value_axis.axis_title.text_frame.text = titles['y_axis']
        value_title = value_axis.axis_title.text_frame.paragraphs[0].runs[0]
        value_title.font.size = Pt(AXIS_TITLE_FONT_SIZE)

    # Adjust font size for axis labels
    category_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)
    value_axis.tick_labels.font.size = Pt(AXIS_LABEL_FONT_SIZE)

    plot = chart.plots[0]
    series = plot.series[1]
    line = series.format.line
    line.color.rgb = RGBColor(141, 47, 0)

    chart_placeholder._element.getparent().remove(chart_placeholder._element)


def pie_chart(chart_placeholder, slide, chart_data, title):
    # Add the chart to the slide
    x, y, cx, cy = chart_placeholder.left, chart_placeholder.top, chart_placeholder.width, chart_placeholder.height

    chart = slide.shapes.add_chart(
        XL_CHART_TYPE.PIE, x, y, cx, cy, chart_data
    ).chart

    chart.has_legend = True
    chart.legend.position = XL_LEGEND_POSITION.RIGHT
    chart.legend.include_in_layout = False
    chart.legend.font.size = Pt(PIE_LEGEND_FONT_SIZE)

    chart.plots[0].has_data_labels = True
    data_labels = chart.plots[0].data_labels
    data_labels.number_format = '0%'
    data_labels.position = XL_LABEL_POSITION.OUTSIDE_END
    data_labels.font.size = Pt(DATA_LABEL_FONT_SIZE)

    chart.has_title = True
    chart.chart_title.text_frame.text = title
    chart_title = chart.chart_title.text_frame.paragraphs[0].runs[0]
    chart_title.font.size = Pt(CHART_TITLE_FONT_SIZE)

    chart_placeholder._element.getparent().remove(chart_placeholder._element)


def title_and_table(slide, title, df, column_width=None, custom_table_height=False):
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = title
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    table_placeholder = slide.placeholders[ph_idx[1]]

    table_height = None
    if custom_table_height:
        if len(df) < 4:
            table_height = int(np.ceil(table_placeholder.height / 2))

    fill_table_placeholder(table_placeholder, slide, df, column_width, table_height)

    # try:
    #     font_size = 18  # default for 3*3
    #     if cols < 3:
    #         row_diff = 3 - rows
    #         font_size = font_size + ((row_diff)*2) # 1 row less -> 2 pt font size increase & vice versa
    #     else:
    #         row_diff = 2 - rows
    #         font_size = font_size + ((row_diff)*2)
    #     for row in table.rows:
    #         for cell in row.cells:
    #             cell.text_frame.paragraphs[0].runs[0].font.size = Pt(font_size)
    # except Exception as e :
    #     print("**"*30)
    #     print(e)
    # else:
    # except Exception as e:
    #     print('table', e)
    return slide


def data_import(data, bin_dict):
    import_df = pd.DataFrame(columns=['Category', 'Value'])

    import_df.at[0, 'Category'] = 'Date Range'

    date_start = data['date'].min().date()
    date_end = data['date'].max().date()
    import_df.at[0, 'Value'] = str(date_start) + ' - ' + str(date_end)

    import_df.at[1, 'Category'] = 'Response Metrics'
    import_df.at[1, 'Value'] = ', '.join(bin_dict['Response Metrics'])

    import_df.at[2, 'Category'] = 'Media Variables'
    import_df.at[2, 'Value'] = ', '.join(bin_dict['Media'])

    import_df.at[3, 'Category'] = 'Spend Variables'
    import_df.at[3, 'Value'] = ', '.join(bin_dict['Spends'])

    if bin_dict['Exogenous'] != []:
        import_df.at[4, 'Category'] = 'Exogenous Variables'
        import_df.at[4, 'Value'] = ', '.join(bin_dict['Exogenous'])

    return import_df


def channel_groups_df(channel_groups_dct={}, bin_dict={}):
    df = pd.DataFrame(columns=['Channel', 'Media Variables', 'Spend Variables'])
    i = 0
    for channel, vars in channel_groups_dct.items():
        media_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Media"]))))
        spend_vars = ", ".join(list(set(vars).intersection(set(bin_dict["Spends"]))))
        df.at[i, "Channel"] = channel
        df.at[i, 'Media Variables'] = media_vars
        df.at[i, 'Spend Variables'] = spend_vars
        i += 1

    return df


def transformations(transform_dict):
    transform_df = pd.DataFrame(columns=['Category', 'Transformation', 'Value'])
    i = 0

    for category in ['Media', 'Exogenous']:
        transformations = f'transformation_{category}'
        category_dict = transform_dict[category]
        if transformations in category_dict.keys():
            for transformation in category_dict[transformations]:
                transform_df.at[i, 'Category'] = category
                transform_df.at[i, 'Transformation'] = transformation
                transform_df.at[i, 'Value'] = str(category_dict[transformation][0]) + ' - ' + str(
                    category_dict[transformation][1])
                i += 1
    return transform_df


def model_metrics(model_dict, is_panel):
    metrics_df = pd.DataFrame(
        columns=[
            "Response Metric",
            "Model",
            "R2",
            "ADJR2",
            "Train MAPE",
            "Test MAPE"
        ]
    )
    i = 0
    for key in model_dict.keys():
        target = key.split("__")[1]
        metrics_df.at[i, "Response Metric"] = format_response_metric(target)
        metrics_df.at[i, "Model"] = key.split("__")[0]

        y = model_dict[key]["X_train_tuned"][target]

        feature_set = model_dict[key]["feature_set"]

        if is_panel:
            random_df = get_random_effects(
                media_data, panel_col, model_dict[key]["Model_object"]
            )
            pred = mdf_predict(
                model_dict[key]["X_train_tuned"],
                model_dict[key]["Model_object"],
                random_df,
            )["pred"]
        else:
            pred = model_dict[key]["Model_object"].predict(model_dict[key]["X_train_tuned"][feature_set])

        ytest = model_dict[key]["X_test_tuned"][target]
        if is_panel:

            predtest = mdf_predict(
                model_dict[key]["X_test_tuned"],
                model_dict[key]["Model_object"],
                random_df,
            )["pred"]

        else:
            predtest = model_dict[key]["Model_object"].predict(model_dict[key]["X_test_tuned"][feature_set])

        metrics_df.at[i, "R2"] = np.round(r2_score(y, pred), 2)
        adjr2 = 1 - (1 - metrics_df.loc[i, "R2"]) * (
                len(y) - 1
        ) / (len(y) - len(model_dict[key]["feature_set"]) - 1)
        metrics_df.at[i, "ADJR2"] = np.round(adjr2, 2)
        # y = np.where(np.abs(y) < 0.00001, 0.00001, y)
        metrics_df.at[i, "Train MAPE"] = np.round(smape(y, pred), 2)
        metrics_df.at[i, "Test MAPE"] = np.round(smape(ytest, predtest), 2)
        i += 1
    metrics_df = np.round(metrics_df, 2)

    return metrics_df


def model_result(slide, model_key, model_dict, model_metrics_df, date_col):
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = model_key.split('__')[0]
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
    target = model_key.split('__')[1]

    metrics_table_placeholder = slide.placeholders[ph_idx[1]]
    metrics_df = model_metrics_df[model_metrics_df['Model'] == model_key.split('__')[0]].reset_index(drop=True)

    # Accuracy = 1-mape
    metrics_df['Accuracy'] = 100 * (1 - metrics_df['Train MAPE'])
    metrics_df['Accuracy'] = metrics_df['Accuracy'].apply(lambda x: f'{np.round(x, 0)}%')

    ## Removing metrics as requested by Ioannis

    metrics_df = metrics_df.drop(columns=['R2', 'ADJR2', 'Train MAPE', 'Test MAPE'])
    fill_table_placeholder(metrics_table_placeholder, slide, metrics_df)

    # coeff_table_placeholder = slide.placeholders[ph_idx[2]]
    # coeff_df = pd.DataFrame(model_dict['Model_object'].params)
    # coeff_df.reset_index(inplace=True)
    # coeff_df.columns = ['Feature', 'Coefficent']
    # fill_table_placeholder(coeff_table_placeholder, slide, coeff_df)

    chart_placeholder = slide.placeholders[ph_idx[2]]
    full_df = pd.concat([model_dict['X_train_tuned'], model_dict['X_test_tuned']])
    full_df['Predicted'] = model_dict['Model_object'].predict(full_df[model_dict['feature_set']])
    pred_df = full_df[[date_col, target, 'Predicted']]
    pred_df.rename(columns={target: 'Actual'}, inplace=True)

    # Create chart data
    chart_data = CategoryChartData()
    chart_data.categories = pred_df[date_col]
    chart_data.add_series('Actual', pred_df['Actual'])
    chart_data.add_series('Predicted', pred_df['Predicted'])

    # Set range for y axis
    min_y = np.floor(min(pred_df['Actual'].min(), pred_df['Predicted'].min()))
    max_y = np.ceil(max(pred_df['Actual'].max(), pred_df['Predicted'].max()))

    # Create the chart
    line_chart(chart_placeholder=chart_placeholder,
               slide=slide,
               chart_data=chart_data,
               titles={'chart_title': 'Actual VS Predicted',
                       'x_axis': 'Date',
                       'y_axis': target.title().replace('_', ' ')
                       },
               min_y=min_y,
               max_y=max_y
               )

    return slide


def metrics_contributions(slide, contributions_excels_dict, panel_col):
    # Create data for metrics contributions
    all_contribution_df = pd.DataFrame(columns=['Channel'])
    target_sum_dict = {}
    sort_support_dct = {}
    for target in contributions_excels_dict.keys():
        contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM'].copy()
        if 'Date' in contribution_df.columns:
            contribution_df.drop(columns=['Date'], inplace=True)
        if panel_col in contribution_df.columns:
            contribution_df.drop(columns=[panel_col], inplace=True)

        contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
        contribution_df.columns = ['Channel', target]
        target_sum = contribution_df[target].sum()
        target_sum_dict[target] = target_sum
        contribution_df[target] = 100 * contribution_df[target] / target_sum

        all_contribution_df = pd.merge(all_contribution_df, contribution_df, on='Channel', how='outer')

    sorted_target_sum_dict = sorted(target_sum_dict.items(), key=lambda kv: kv[1], reverse=True)
    sorted_target_sum_keys = [kv[0] for kv in sorted_target_sum_dict]
    if len([metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue']) == 1:
        rev_metric = [metric for metric in sorted_target_sum_keys if metric.lower() == 'revenue'][0]
        sorted_target_sum_keys.remove(rev_metric)
        sorted_target_sum_keys.append(rev_metric)
    all_contribution_df = all_contribution_df[['Channel'] + sorted_target_sum_keys]

    # for col in all_contribution_df.columns:
    #     all_contribution_df[col]=all_contribution_df[col].apply(lambda x: round_off(x,1))

    # Sort Data by Average contribution of the channels keeping base first <Removed>
    # all_contribution_df['avg'] = np.mean(all_contribution_df[list(contributions_excels_dict.keys())],axis=1)
    # all_contribution_df['rank'] = all_contribution_df['avg'].rank(ascending=False)

    # Sort data by contribution of bottom funnel metric
    bottom_funnel_metric = sorted_target_sum_keys[-1]
    all_contribution_df['rank'] = all_contribution_df[bottom_funnel_metric].rank(ascending=False)
    all_contribution_df.loc[all_contribution_df[all_contribution_df['Channel'] == 'base'].index, 'rank'] = 0
    all_contribution_df = all_contribution_df.sort_values(by='rank')
    all_contribution_df.drop(columns=['rank'], inplace=True)

    # Add title
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = "Response Metrics Contributions"
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    for target in contributions_excels_dict.keys():
        all_contribution_df[target] = all_contribution_df[target].astype(float)


    # Create chart data
    chart_data = CategoryChartData()
    chart_data.categories = all_contribution_df['Channel']
    for target in sorted_target_sum_keys:
        chart_data.add_series(format_response_metric(target), all_contribution_df[target])
    chart_placeholder = slide.placeholders[ph_idx[1]]

    if isinstance(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])), float):

        # Add the chart to the slide
        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Response Metrics Contributions',
                          # 'x_axis':'Channels',
                          'y_axis': 'Contributions'},
                  min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
                  max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime']))),
                  type='V',
                  label_type='per'
                  )
    else:

        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Response Metrics Contributions',
                          # 'x_axis':'Channels',
                          'y_axis': 'Contributions'},
                  min_y=np.floor(np.min(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
                  max_y=np.ceil(np.max(all_contribution_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
                  type='V',
                  label_type='per'
                  )

    return slide


def model_media_performance(slide, target, contributions_excels_dict, date_col='Date', is_panel=False,
                            panel_col='panel'):
    # Add title
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = "Media Performance - " + target.title().replace("_", " ")
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    # CONTRIBUTION CHART
    # Create contribution data
    contribution_df = contributions_excels_dict[target]['CONTRIBUTION MMM']
    if panel_col in contribution_df.columns:
        contribution_df.drop(columns=[panel_col], inplace=True)
    # contribution_df.drop(columns=[date_col], inplace=True)
    contribution_df = pd.DataFrame(np.sum(contribution_df, axis=0)).reset_index()
    contribution_df.columns = ['Channel', format_response_metric(target)]
    contribution_df['Channel'] = contribution_df['Channel'].apply(lambda x: x.title())
    target_sum = contribution_df[format_response_metric(target)].sum()
    contribution_df[format_response_metric(target)] = contribution_df[format_response_metric(target)] / target_sum
    contribution_df.sort_values(by=['Channel'], ascending=False, inplace=True)

    # for col in contribution_df.columns:
    #     contribution_df[col] = contribution_df[col].apply(lambda x : round_off(x))

    # Create Chart Data
    chart_data = ChartData()
    chart_data.categories = contribution_df['Channel']
    chart_data.add_series('Contribution', contribution_df[format_response_metric(target)])

    chart_placeholder = slide.placeholders[ph_idx[2]]
    pie_chart(chart_placeholder=chart_placeholder,
              slide=slide,
              chart_data=chart_data,
              title='Contribution')

    # SPENDS CHART

    initialize_data(panel='aggregated', metrics=target)
    scenario = st.session_state["scenario"]
    spends_values = {
        channel_name: round(
            scenario.channels[channel_name].actual_total_spends
            * scenario.channels[channel_name].conversion_rate,
            1,
        )
        for channel_name in st.session_state["channels_list"]
    }
    spends_df = pd.DataFrame(columns=['Channel', 'Media Spend'])
    spends_df['Channel'] = list(spends_values.keys())
    spends_df['Media Spend'] = list(spends_values.values())
    spends_sum = spends_df['Media Spend'].sum()
    spends_df['Media Spend'] = spends_df['Media Spend'] / spends_sum
    spends_df['Channel'] = spends_df['Channel'].apply(lambda x: x.title())
    spends_df.sort_values(by='Channel', ascending=False, inplace=True)
    # for col in spends_df.columns:
    #     spends_df[col] = spends_df[col].apply(lambda x : round_off(x))

    # Create Chart Data
    spends_chart_data = ChartData()
    spends_chart_data = ChartData()
    spends_chart_data.categories = spends_df['Channel']
    spends_chart_data.add_series('Media Spend', spends_df['Media Spend'])

    spends_chart_placeholder = slide.placeholders[ph_idx[1]]
    pie_chart(chart_placeholder=spends_chart_placeholder,
              slide=slide,
              chart_data=spends_chart_data,
              title='Media Spend')
    # spends_values.append(0)
    return contribution_df, spends_df


# def get_saved_scenarios_dict(project_path):
#     # Path to the saved scenarios file
#     saved_scenarios_dict_path = os.path.join(
#         project_path, "saved_scenarios.pkl"
#     )
#
#     # Load existing scenarios if the file exists
#     if os.path.exists(saved_scenarios_dict_path):
#         with open(saved_scenarios_dict_path, "rb") as f:
#             saved_scenarios_dict = pickle.load(f)
#     else:
#         saved_scenarios_dict = OrderedDict()
#
#     return saved_scenarios_dict

def optimization_summary(slide, scenario, scenario_name):
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = 'Optimization Summary'  # + ' (Scenario: ' + scenario_name + ')'
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    multiplier = 1 / float(scenario['multiplier'])
    # st.write(scenario['multiplier'], multiplier)
    ## Multiplier is an indicator of selected time fram
    ## Doesn't effect CPA

    opt_on = scenario['optimization']
    if opt_on.lower() == 'spends':
        opt_on = 'Media Spend'

    details_ph = slide.placeholders[ph_idx[3]]
    details_ph.text = 'Scenario Name: ' + scenario_name + \
                      '\nResponse Metric: ' + str(scenario['metrics_selected']).replace("_", " ").title() + \
                      '\nOptimized on: ' + str(opt_on).replace("_", " ").title()

    scenario_df = pd.DataFrame(columns=['Category', 'Actual', 'Simulated', 'Change'])
    scenario_df.at[0, 'Category'] = 'Media Spend'

    scenario_df.at[0, 'Actual'] = scenario['actual_total_spends'] * multiplier
    scenario_df.at[0, 'Simulated'] = scenario['modified_total_spends'] * multiplier
    scenario_df.at[0, 'Change'] = (scenario['modified_total_spends'] - scenario['actual_total_spends']) * multiplier

    scenario_df.at[1, 'Category'] = scenario['metrics_selected'].replace("_", " ").title()
    scenario_df.at[1, 'Actual'] = scenario['actual_total_sales'] * multiplier
    scenario_df.at[1, 'Simulated'] = (scenario['modified_total_sales']) * multiplier
    scenario_df.at[1, 'Change'] = (scenario['modified_total_sales'] - scenario['actual_total_sales']) * multiplier

    scenario_df.at[2, 'Category'] = 'CPA'
    actual_cpa = scenario['actual_total_spends'] / scenario['actual_total_sales']
    modified_cpa = scenario['modified_total_spends'] / scenario['modified_total_sales']
    scenario_df.at[2, 'Actual'] = actual_cpa
    scenario_df.at[2, 'Simulated'] = modified_cpa
    scenario_df.at[2, 'Change'] = modified_cpa - actual_cpa

    scenario_df.at[3, 'Category'] = 'ROI'
    act_roi = scenario['actual_total_sales'] / scenario['actual_total_spends']
    opt_roi = scenario['modified_total_sales'] / scenario['modified_total_spends']
    scenario_df.at[3, 'Actual'] = act_roi
    scenario_df.at[3, 'Simulated'] = opt_roi
    scenario_df.at[3, 'Change'] = opt_roi - act_roi

    for col in scenario_df.columns:
        scenario_df[col] = scenario_df[col].apply(lambda x: round_off(x, 1))
        scenario_df[col] = scenario_df[col].apply(lambda x: convert_number_to_abbreviation(x))

    table_placeholder = slide.placeholders[ph_idx[1]]
    fill_table_placeholder(table_placeholder, slide, scenario_df)

    channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
    for i, channel in enumerate(scenario['channels'].values()):
        channel_spends_df.at[i, 'Channel'] = channel['name']
        channel_conversion_rate = channel[
            "conversion_rate"
        ]
        channel_spends_df.at[i, 'Actual Spends'] = (
                                                           channel["actual_total_spends"]
                                                           * channel_conversion_rate
                                                   ) * multiplier
        channel_spends_df.at[i, 'Optimized Spends'] = (
                                                              channel["modified_total_spends"]
                                                              * channel_conversion_rate
                                                      ) * multiplier
    channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
    channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')

    for col in channel_spends_df.columns:
        channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))

    # Sort data on Actual Spends
    channel_spends_df.sort_values(by='Actual Spends', inplace=True, ascending=False)

    # Create chart data
    chart_data = CategoryChartData()
    chart_data.categories = channel_spends_df['Channel']
    for col in ['Actual Spends', 'Optimized Spends']:
        chart_data.add_series(col, channel_spends_df[col])

    chart_placeholder = slide.placeholders[ph_idx[2]]

    # Add the chart to the slide
    if isinstance(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])),float):
        bar_chart(chart_placeholder=chart_placeholder,
              slide=slide,
              chart_data=chart_data,
              titles={'chart_title': 'Channel Wise Spends',
                      # 'x_axis':'Channels',
                      'y_axis': 'Spends'},
              # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
              min_y=0,
              max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
              label_type='$'
              )
    else:
        # Add the chart to the slide
        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Channel Wise Spends',
                          # 'x_axis':'Channels',
                          'y_axis': 'Spends'},
                  # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
                  min_y=0,
                  max_y=np.ceil(np.max(channel_spends_df.select_dtypes(exclude=['object', 'datetime'])).values[0]),
                  label_type='$'
                  )


def channel_wise_spends(slide, scenario):
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = 'Channel Spends and Impact'
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)
    # print(scenario.keys())

    multiplier = 1 / float(scenario['multiplier'])
    channel_spends_df = pd.DataFrame(columns=['Channel', 'Actual Spends', 'Optimized Spends'])
    for i, channel in enumerate(scenario['channels'].values()):
        channel_spends_df.at[i, 'Channel'] = channel['name']
        channel_conversion_rate = channel["conversion_rate"]
        channel_spends_df.at[i, 'Actual Spends'] = (channel[
                                                        "actual_total_spends"] * channel_conversion_rate) * multiplier
        channel_spends_df.at[i, 'Optimized Spends'] = (channel[
                                                           "modified_total_spends"] * channel_conversion_rate) * multiplier
    channel_spends_df['Actual Spends'] = channel_spends_df['Actual Spends'].astype('float')
    channel_spends_df['Optimized Spends'] = channel_spends_df['Optimized Spends'].astype('float')

    actual_sum = channel_spends_df['Actual Spends'].sum()
    opt_sum = channel_spends_df['Optimized Spends'].sum()

    for col in channel_spends_df.columns:
        channel_spends_df[col] = channel_spends_df[col].apply(lambda x: round_off(x, 0))

    channel_spends_df['Actual Spends %'] = 100 * (channel_spends_df['Actual Spends'] / actual_sum)
    channel_spends_df['Optimized Spends %'] = 100 * (channel_spends_df['Optimized Spends'] / opt_sum)
    channel_spends_df['Actual Spends %'] = np.round(channel_spends_df['Actual Spends %'])
    channel_spends_df['Optimized Spends %'] = np.round(channel_spends_df['Optimized Spends %'])

    # Sort Data based on Actual Spends %
    channel_spends_df.sort_values(by='Actual Spends %', inplace=True)

    # Create chart data
    chart_data = CategoryChartData()
    chart_data.categories = channel_spends_df['Channel']
    for col in ['Actual Spends %', 'Optimized Spends %']:
        # for col in ['Actual Spends %']:
        chart_data.add_series(col, channel_spends_df[col])
    chart_placeholder = slide.placeholders[ph_idx[1]]

    # Add the chart to the slide
    if isinstance(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]), float):
        bar_chart(chart_placeholder=chart_placeholder,
              slide=slide,
              chart_data=chart_data,
              titles={'chart_title': 'Spend Split %',
                      # 'x_axis':'Channels',
                      'y_axis': 'Spend %'},
              min_y=0,
              max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
              type='H',
              legend=True,
              label_type='per',
              xaxis_pos='low'
              )
    else:
        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Spend Split %',
                          # 'x_axis':'Channels',
                          'y_axis': 'Spend %'},
                  min_y=0,
                  max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']]).values[0]),
                  type='H',
                  legend=True,
                  label_type='per',
                  xaxis_pos='low'
                  )
    #
    # # Create chart data
    # chart_data_1 = CategoryChartData()
    # chart_data_1.categories = channel_spends_df['Channel']
    # # for col in ['Actual Spends %', 'Optimized Spends %']:
    # for col in ['Optimized Spends %']:
    #     chart_data_1.add_series(col, channel_spends_df[col])
    # chart_placeholder_1 = slide.placeholders[ph_idx[3]]
    #
    # # Add the chart to the slide
    # bar_chart(chart_placeholder=chart_placeholder_1,
    #           slide=slide,
    #           chart_data=chart_data_1,
    #           titles={'chart_title': 'Optimized Spends Split %',
    #                   # 'x_axis':'Channels',
    #                   'y_axis': 'Spends %'},
    #           min_y=0,
    #           max_y=np.ceil(np.max(channel_spends_df[['Actual Spends %', 'Optimized Spends %']])),
    #           type='H',
    #           legend=False,
    #           label_type='per'
    #           )

    channel_spends_df['Delta %'] = 100 * (channel_spends_df['Optimized Spends'] - channel_spends_df['Actual Spends']) / \
                                   channel_spends_df['Actual Spends']
    channel_spends_df['Delta %'] = channel_spends_df['Delta %'].apply(lambda x: round_off(x, 0))

    # Create chart data
    delta_chart_data = CategoryChartData()
    delta_chart_data.categories = channel_spends_df['Channel']
    col = 'Delta %'
    delta_chart_data.add_series(col, channel_spends_df[col])
    delta_chart_placeholder = slide.placeholders[ph_idx[3]]

    # Add the chart to the slide
    if isinstance(np.min(channel_spends_df['Delta %']), float):
        bar_chart(chart_placeholder=delta_chart_placeholder,
                  slide=slide,
                  chart_data=delta_chart_data,
                  titles={'chart_title': 'Spend Delta %',
                          'y_axis': 'Spend Delta %'},
                  min_y=np.floor(np.min(channel_spends_df['Delta %'])),
                  max_y=np.ceil(np.max(channel_spends_df['Delta %'])),
                  type='H',
                  legend=False,
                  label_type='per',
                  xaxis_pos='low'

                  )
    else:
        bar_chart(chart_placeholder=delta_chart_placeholder,
                  slide=slide,
                  chart_data=delta_chart_data,
                  titles={'chart_title': 'Spend Delta %',
                          'y_axis': 'Spend Delta %'},
                  min_y=np.floor(np.min(channel_spends_df['Delta %']).values[0]),
                  max_y=np.ceil(np.max(channel_spends_df['Delta %']).values[0]),
                  type='H',
                  legend=False,
                  label_type='per',
                  xaxis_pos='low'

                  )

    # Incremental Impact
    channel_inc_df = pd.DataFrame(columns=['Channel', 'Increment'])
    for i, channel in enumerate(scenario['channels'].values()):
        channel_inc_df.at[i, 'Channel'] = channel['name']
        act_impact = channel['actual_total_sales']
        opt_impact = channel['modified_total_sales']
        impact = opt_impact - act_impact
        impact = round_off(impact, 0)
        impact = impact if abs(impact) > 0.0001 else 0
        channel_inc_df.at[i, 'Increment'] = impact

    channel_inc_df_1 = pd.merge(channel_spends_df, channel_inc_df, how='left', on='Channel')

    # Create chart data
    delta_chart_data = CategoryChartData()
    delta_chart_data.categories = channel_inc_df_1['Channel']
    col = 'Increment'
    delta_chart_data.add_series(col, channel_inc_df_1[col])
    delta_chart_placeholder = slide.placeholders[ph_idx[2]]

    label_req = True
    if min(np.abs(channel_inc_df_1[col])) > 100000:  # 0.1M
        label_type = 'M'
    elif min(np.abs(channel_inc_df_1[col])) > 10000 and max(np.abs(channel_inc_df_1[col])) > 1000000:
        label_type = 'M1'
    elif min(np.abs(channel_inc_df_1[col])) > 100 and max(np.abs(channel_inc_df_1[col])) > 1000:
        label_type = 'K'
    else:
        label_req = False
    # Add the chart to the slide
    if label_req:
        bar_chart(chart_placeholder=delta_chart_placeholder,
                  slide=slide,
                  chart_data=delta_chart_data,
                  titles={'chart_title': 'Incremental Impact',
                          'y_axis': format_response_metric(scenario['metrics_selected'])},
                  # min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
                  # max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
                  type='H',
                  label_type=label_type,
                  legend=False,
                  xaxis_pos='low'
                  )
    else:
        bar_chart(chart_placeholder=delta_chart_placeholder,
                  slide=slide,
                  chart_data=delta_chart_data,
                  titles={'chart_title': 'Increment',
                          'y_axis': scenario['metrics_selected']},
                  # min_y=np.floor(np.min(channel_inc_df_1['Delta %'])),
                  # max_y=np.ceil(np.max(channel_inc_df_1['Delta %'])),
                  type='H',
                  legend=False,
                  xaxis_pos='low'
                  )


def channel_wise_roi(slide, scenario):
    channel_roi_mroi = scenario['channel_roi_mroi']

    # Add title
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = 'Channel ROIs'
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    channel_roi_df = pd.DataFrame(columns=['Channel', 'Actual ROI', 'Optimized ROI'])
    for i, channel in enumerate(channel_roi_mroi.keys()):
        channel_roi_df.at[i, 'Channel'] = channel
        channel_roi_df.at[i, 'Actual ROI'] = channel_roi_mroi[channel]['actual_roi']
        channel_roi_df.at[i, 'Optimized ROI'] = channel_roi_mroi[channel]['optimized_roi']
    channel_roi_df['Actual ROI'] = channel_roi_df['Actual ROI'].astype('float')
    channel_roi_df['Optimized ROI'] = channel_roi_df['Optimized ROI'].astype('float')

    for col in channel_roi_df.columns:
        channel_roi_df[col] = channel_roi_df[col].apply(lambda x: round_off(x, 2))

    # Create chart data
    chart_data = CategoryChartData()
    chart_data.categories = channel_roi_df['Channel']
    for col in ['Actual ROI', 'Optimized ROI']:
        chart_data.add_series(col, channel_roi_df[col])

    chart_placeholder = slide.placeholders[ph_idx[1]]

    # Add the chart to the slide
    if isinstance(channel_roi_df.select_dtypes(exclude=['object', 'datetime']), float):
        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Channel Wise ROI',
                          # 'x_axis':'Channels',
                          'y_axis': 'ROI'},
                  # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
                  min_y=0,
                  max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime']))
                  )
    else:
        bar_chart(chart_placeholder=chart_placeholder,
                  slide=slide,
                  chart_data=chart_data,
                  titles={'chart_title': 'Channel Wise ROI',
                          # 'x_axis':'Channels',
                          'y_axis': 'ROI'},
                  # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
                  min_y=0,
                  max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])).values[0]
                  )
    # act_roi = scenario['actual_total_sales']/scenario['actual_total_spends']
    # opt_roi = scenario['modified_total_sales']/scenario['modified_total_spends']
    #
    # act_roi_ph = slide.placeholders[ph_idx[2]]
    # act_roi_ph.text = 'Actual ROI: ' + str(round_off(act_roi,2))
    # opt_roi_ph = slide.placeholders[ph_idx[3]]
    # opt_roi_ph.text = 'Optimized ROI: ' + str(round_off(opt_roi, 2))

    ## Removing mroi chart as per Ioannis' feedback
    # channel_mroi_df = pd.DataFrame(columns=['Channel', 'Actual mROI', 'Optimized mROI'])
    # for i, channel in enumerate(channel_roi_mroi.keys()):
    #     channel_mroi_df.at[i, 'Channel'] = channel
    #     channel_mroi_df.at[i, 'Actual mROI'] = channel_roi_mroi[channel]['actual_mroi']
    #     channel_mroi_df.at[i, 'Optimized mROI'] = channel_roi_mroi[channel]['optimized_mroi']
    # channel_mroi_df['Actual mROI']=channel_mroi_df['Actual mROI'].astype('float')
    # channel_mroi_df['Optimized mROI']=channel_mroi_df['Optimized mROI'].astype('float')
    #
    # for col in channel_mroi_df.columns:
    #     channel_mroi_df[col]=channel_mroi_df[col].apply(lambda x: round_off(x))
    #
    # # Create chart data
    # mroi_chart_data = CategoryChartData()
    # mroi_chart_data.categories = channel_mroi_df['Channel']
    # for col in ['Actual mROI', 'Optimized mROI']:
    #     mroi_chart_data.add_series(col, channel_mroi_df[col])
    #
    # mroi_chart_placeholder=slide.placeholders[ph_idx[2]]
    #
    # # Add the chart to the slide
    # bar_chart(chart_placeholder=mroi_chart_placeholder,
    #           slide=slide,
    #           chart_data=mroi_chart_data,
    #           titles={'chart_title':'Channel Wise mROI',
    #                   # 'x_axis':'Channels',
    #                   'y_axis':'mROI'},
    #           # min_y=np.floor(np.min(channel_mroi_df.select_dtypes(exclude=['object', 'datetime']))),
    #           min_y=0,
    #           max_y=np.ceil(np.max(channel_mroi_df.select_dtypes(exclude=['object', 'datetime'])))
    #           )


def effictiveness_efficiency(slide, final_data, bin_dct, scenario):
    # Add title
    placeholders = slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = slide.placeholders[ph_idx[0]]
    title_ph.text = 'Effectiveness and Efficiency'
    title_ph.text_frame.paragraphs[0].font.size = Pt(TITLE_FONT_SIZE)

    response_metrics = bin_dct['Response Metrics']

    kpi_df = final_data[response_metrics].sum(axis=0).reset_index()
    kpi_df.columns = ['Response Metric', 'Effectiveness']
    kpi_df['Efficiency'] = kpi_df['Effectiveness'] / scenario['modified_total_spends']
    kpi_df['Efficiency'] = kpi_df['Efficiency'].apply(lambda x: round_off(x, 1))
    kpi_df.sort_values(by='Effectiveness', inplace=True)
    kpi_df['Response Metric'] = kpi_df['Response Metric'].apply(lambda x: format_response_metric(x))

    # Create chart data for effectiveness
    chart_data = CategoryChartData()
    chart_data.categories = kpi_df['Response Metric']
    chart_data.add_series('Effectiveness', kpi_df['Effectiveness'])

    chart_placeholder = slide.placeholders[ph_idx[1]]

    # Add the chart to the slide
    bar_chart(chart_placeholder=chart_placeholder,
              slide=slide,
              chart_data=chart_data,
              titles={'chart_title': 'Effectiveness',
                      # 'x_axis':'Channels',
                      # 'y_axis': 'ROI'
                      },
              # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
              min_y=0,
              # max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
              type='H',
              label_type='M'
              )

    # Create chart data for efficiency
    chart_data_1 = CategoryChartData()
    chart_data_1.categories = kpi_df['Response Metric']
    chart_data_1.add_series('Efficiency', kpi_df['Efficiency'])

    chart_placeholder_1 = slide.placeholders[ph_idx[2]]

    # Add the chart to the slide
    bar_chart(chart_placeholder=chart_placeholder_1,
              slide=slide,
              chart_data=chart_data_1,
              titles={'chart_title': 'Efficiency',
                      # 'x_axis':'Channels',
                      # 'y_axis': 'ROI'
                      },
              # min_y=np.floor(np.min(channel_spends_df.select_dtypes(exclude=['object', 'datetime']))),
              min_y=0,
              # max_y=np.max(channel_roi_df.select_dtypes(exclude=['object', 'datetime'])),
              type='H'
              )

    definition_ph_1 = slide.placeholders[ph_idx[3]]
    definition_ph_1.text = 'Effectiveness is measured as the total sum of the Response Metric'
    definition_ph_2 = slide.placeholders[ph_idx[4]]
    definition_ph_2.text = 'Efficiency is measured as the ratio of sum of the Response Metric and sum of Media Spend'


def load_pickle(path):
    with open(path, "rb") as f:
        file_data = pickle.load(f)
    return file_data


def read_all_files():
    files=[]

    # Read data and bin dictionary
    if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is not None:
        final_df_loaded = st.session_state["project_dct"]["data_import"]["imputed_tool_df"].copy()
        bin_dict_loaded = st.session_state["project_dct"]["data_import"]["category_dict"].copy()

        files.append(final_df_loaded)
        files.append(bin_dict_loaded)

        if "group_dict" in st.session_state["project_dct"]["data_import"].keys():
            channels = st.session_state["project_dct"]["data_import"]["group_dict"]
            files.append(channels)


        if st.session_state["project_dct"]["transformations"]["final_df"] is not None:
            transform_dict = st.session_state["project_dct"]["transformations"]
            files.append(transform_dict)
            if retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning", "tuned_model", schema) is not None:
                tuned_model_dict = retrieve_pkl_object_without_warning(st.session_state['project_number'], "Model_Tuning",
                                                       "tuned_model", schema)  # db

                files.append(tuned_model_dict)
            else:
                files.append(None)
        else:
            files.append(None)

        if len(list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys()))>0: # check if there are model outputs for at least one metric
            metrics_list = list(st.session_state["project_dct"]["current_media_performance"]["model_outputs"].keys())
            contributions_excels_dict = {}
            for metrics in metrics_list:
                # raw_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["raw_data"]
                # spend_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["spends_data"]
                contribution_df = st.session_state["project_dct"]["current_media_performance"]["model_outputs"][metrics]["contribution_data"]
                contributions_excels_dict[metrics] = {'CONTRIBUTION MMM':contribution_df}
            files.append(contributions_excels_dict)

            # Get Saved Scenarios
            if len(list(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"].keys()))>0:
                files.append(st.session_state["project_dct"]["saved_scenarios"]["saved_scenarios_dict"])

                    # saved_scenarios_loaded = get_saved_scenarios_dict(project_path)


    return files



'''

    Template Layout

    0 : Title
    1 : Data Details Section {no changes required}
    2 : Data Import
    3 : Data Import - Channel Groups
    4 : Model Results {Duplicate for each model}  
    5 : Metrics Contribution
    6 : Media performance {Duplicate for each model} 
    7 : Media performance Tabular View {Duplicate for each model} 
    8 : Optimization Section {no changes}
    9 : Optimization Summary {Duplicate for each section}
    10 : Channel Spends {Duplicate for each model}
    11 : Channel Wise ROI {Duplicate for each model}
    12 : Efficiency & Efficacy
    13 : Appendix
    14 : Transformations
    15 : Model Summary
    16 : Thank You Slide

'''


def create_ppt(project_name, username, panel_col):
    # Read saved files
    files = read_all_files()
    transform_dict, tuned_model_dict, contributions_excels_dict, saved_scenarios_loaded = None, None, None, None

    if len(files)>0:
        # saved_data = files[0]
        data = files[0]
        bin_dict = files[1]

        channel_groups_dct = files[2]
        try:
            transform_dict = files[3]
            tuned_model_dict = files[4]
            contributions_excels_dict = files[5]
            saved_scenarios_loaded = files[6]
        except Exception as e:
            print(e)

    else:
        return False

    is_panel = True if data[panel_col].nunique()>1 else False

    template_path = 'ppt/template.pptx'
    # ppt_path = os.path.join('ProjectSummary.pptx')

    prs = Presentation(template_path)
    num_slides = len(prs.slides)
    slides = prs.slides

    # Title Slide
    title_slide_layout = slides[0].slide_layout
    title_slide = prs.slides.add_slide(title_slide_layout)

    # Add title & project name
    placeholders = title_slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = title_slide.placeholders[ph_idx[0]]
    title_ph.text = 'Media Mix Optimization Summary'
    txt_ph = title_slide.placeholders[ph_idx[1]]
    txt_ph.text = 'Project Name: ' + project_name + '\nCreated By: ' + username

    # Model Details Section
    model_section_slide_layout = slides[1].slide_layout
    model_section_slide = prs.slides.add_slide(model_section_slide_layout)

    ## Add title
    placeholders = model_section_slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = model_section_slide.placeholders[ph_idx[0]]
    title_ph.text = 'Model Details'
    section_ph = model_section_slide.placeholders[ph_idx[1]]
    section_ph.text = 'Section 1'

    # Data Import
    data_import_slide_layout = slides[2].slide_layout
    data_import_slide = prs.slides.add_slide(data_import_slide_layout)
    data_import_slide = title_and_table(slide=data_import_slide,
                                        title='Data Import',
                                        df=data_import(data, bin_dict),
                                        column_width={0: 2, 1: 7}
                                        )

    # Channel Groups
    channel_group_slide_layout = slides[3].slide_layout
    channel_group_slide = prs.slides.add_slide(channel_group_slide_layout)
    channel_group_slide = title_and_table(slide=channel_group_slide,
                                          title='Channels - Media and Spend',
                                          df=channel_groups_df(channel_groups_dct, bin_dict),
                                          column_width={0: 2, 1: 5, 2: 2}
                                          )

    if tuned_model_dict is not None:
        model_metrics_df = model_metrics(tuned_model_dict, False)

        # Model Results
        for model_key, model_dict in tuned_model_dict.items():
            model_result_slide_layout = slides[4].slide_layout
            model_result_slide = prs.slides.add_slide(model_result_slide_layout)
            model_result_slide = model_result(slide=model_result_slide,
                                              model_key=model_key,
                                              model_dict=model_dict,
                                              model_metrics_df=model_metrics_df,
                                              date_col='date')

    if contributions_excels_dict is not None:

        # Metrics Contributions
        metrics_contributions_slide_layout = slides[5].slide_layout
        metrics_contributions_slide = prs.slides.add_slide(metrics_contributions_slide_layout)
        metrics_contributions_slide = metrics_contributions(slide=metrics_contributions_slide,
                                                            contributions_excels_dict=contributions_excels_dict,
                                                            panel_col=panel_col
                                                            )

        # Media Performance
        for target in contributions_excels_dict.keys():

            # Chart
            model_media_perf_slide_layout = slides[6].slide_layout
            model_media_perf_slide = prs.slides.add_slide(model_media_perf_slide_layout)
            contribution_df, spends_df = model_media_performance(slide=model_media_perf_slide,
                                                                 target=target,
                                                                 contributions_excels_dict=contributions_excels_dict
                                                                 )

            # Tabular View
            contri_spends_df = pd.merge(spends_df, contribution_df, on='Channel', how='outer')
            contri_spends_df.fillna(0, inplace=True)

            for col in [c for c in contri_spends_df.columns if c != 'Channel']:
                contri_spends_df[col] = contri_spends_df[col].apply(lambda x: safe_num_to_per(x))

            media_performance_table_slide_layout = slides[7].slide_layout
            media_performance_table_slide = prs.slides.add_slide(media_performance_table_slide_layout)
            media_performance_table_slide = title_and_table(slide=media_performance_table_slide,
                                                            title='Media and Spends Channels Tabular View',
                                                            df=contri_spends_df,
                                                            # column_width={0:2, 1:5, 2:2}
                                                            )

    if saved_scenarios_loaded is not None:
        # Optimization Details
        opt_section_slide_layout = slides[8].slide_layout
        opt_section_slide = prs.slides.add_slide(opt_section_slide_layout)

        ## Add title
        placeholders = opt_section_slide.placeholders
        ph_idx = [ph.placeholder_format.idx for ph in placeholders]
        title_ph = opt_section_slide.placeholders[ph_idx[0]]
        title_ph.text = 'Optimizations Details'
        section_ph = opt_section_slide.placeholders[ph_idx[1]]
        section_ph.text = 'Section 2'

        # Optimization
        for scenario_name, scenario in saved_scenarios_loaded.items():
            opt_summary_slide_layout = slides[9].slide_layout
            opt_summary_slide = prs.slides.add_slide(opt_summary_slide_layout)
            optimization_summary(opt_summary_slide, scenario, scenario_name)

            channel_spends_slide_layout = slides[10].slide_layout
            channel_spends_slide = prs.slides.add_slide(channel_spends_slide_layout)
            channel_wise_spends(channel_spends_slide, scenario)

            channel_roi_slide_layout = slides[11].slide_layout
            channel_roi_slide = prs.slides.add_slide(channel_roi_slide_layout)
            channel_wise_roi(channel_roi_slide, scenario)

            effictiveness_efficiency_slide_layout = slides[12].slide_layout
            effictiveness_efficiency_slide = prs.slides.add_slide(effictiveness_efficiency_slide_layout)
            effictiveness_efficiency(effictiveness_efficiency_slide,
                                     data,
                                     bin_dict,
                                     scenario)

    # Appendix Section
    appendix_section_slide_layout = slides[13].slide_layout
    appendix_section_slide = prs.slides.add_slide(appendix_section_slide_layout)

    if tuned_model_dict is not None:

        ## Add title
        placeholders = appendix_section_slide.placeholders
        ph_idx = [ph.placeholder_format.idx for ph in placeholders]
        title_ph = appendix_section_slide.placeholders[ph_idx[0]]
        title_ph.text = 'Appendix'
        section_ph = appendix_section_slide.placeholders[ph_idx[1]]
        section_ph.text = 'Section 3'

        # Add transformations
        # if transform_dict is not None:
        #     # Transformations
        #     transformation_slide_layout = slides[14].slide_layout
        #     transformation_slide = prs.slides.add_slide(transformation_slide_layout)
        #     transformation_slide = title_and_table(slide=transformation_slide,
        #                                            title='Transformations',
        #                                            df=transformations(transform_dict),
        #                                            custom_table_height=True
        #                                            )

        # Add model summary
        # Model Summary
        model_metrics_df = model_metrics(tuned_model_dict, False)
        model_summary_slide_layout = slides[15].slide_layout
        model_summary_slide = prs.slides.add_slide(model_summary_slide_layout)
        model_summary_slide = title_and_table(slide=model_summary_slide,
                                              title='Model Summary',
                                              df=model_metrics_df,
                                              custom_table_height=True
                                              )

    # Last Slide
    last_slide_layout = slides[num_slides - 1].slide_layout
    last_slide = prs.slides.add_slide(last_slide_layout)

    # Add title
    placeholders = last_slide.placeholders
    ph_idx = [ph.placeholder_format.idx for ph in placeholders]
    title_ph = last_slide.placeholders[ph_idx[0]]
    title_ph.text = 'Thank You'

    # Remove template slides
    xml_slides = prs.slides._sldIdLst
    slides = list(xml_slides)
    for index in range(num_slides):
        xml_slides.remove(slides[index])

    # prs.save(ppt_path)

    # save the output into binary form
    binary_output = BytesIO()
    prs.save(binary_output)

    return binary_output