import streamlit as st
import pandas as pd
import json
from scenario import Channel, Scenario
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scenario import class_to_dict
from collections import OrderedDict
import io
import plotly
from pathlib import Path
import pickle
import yaml
from yaml import SafeLoader
from streamlit.components.v1 import html
import smtplib
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scenario import class_from_dict, class_convert_to_dict
import os
import base64
import sqlite3
import datetime
from scenario import numerize
import psycopg2

#
import re
import bcrypt
import os
import json
import glob
import pickle
import streamlit as st
import streamlit as st
import pandas as pd
import json
from scenario import Channel, Scenario
import numpy as np
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from scenario import class_to_dict
from collections import OrderedDict
import io
import plotly
from pathlib import Path
import pickle
import yaml
from yaml import SafeLoader
from streamlit.components.v1 import html
import smtplib
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from scenario import class_from_dict, class_convert_to_dict
import os
import base64
import sqlite3
import datetime
from scenario import numerize
import sqlite3

# # schema = db_cred["schema"]

color_palette = [
    "#F3F3F0",
    "#5E7D7E",
    "#2FA1FF",
    "#00EDED",
    "#00EAE4",
    "#304550",
    "#EDEBEB",
    "#7FBEFD",
    "#003059",
    "#A2F3F3",
    "#E1D6E2",
    "#B6B6B6",
]


CURRENCY_INDICATOR = "$"
db_cred = None
# database_file = r"DB/User.db"

# conn = sqlite3.connect(database_file, check_same_thread=False)  # connection with sql db
# c = conn.cursor()


# def query_excecuter_postgres(
#     query,
#     db_cred,
#     params=None,
#     insert=True,
#     insert_retrieve=False,
# ):
#     """
#     Executes a SQL query on a PostgreSQL database, handling both insert and select operations.

#     Parameters:
#     query (str): The SQL query to be executed.
#     params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
#     insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
#     insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.

#     """
#     # Database connection parameters
#     dbname = db_cred["dbname"]
#     user = db_cred["user"]
#     password = db_cred["password"]
#     host = db_cred["host"]
#     port = db_cred["port"]

#     try:
#         # Establish connection to the PostgreSQL database
#         conn = psycopg2.connect(
#             dbname=dbname, user=user, password=password, host=host, port=port
#         )
#     except psycopg2.Error as e:
#         st.warning(f"Unable to connect to the database: {e}")
#         st.stop()

#     # Create a cursor object to interact with the database
#     c = conn.cursor()

#     try:
#         # Execute the query with or without parameters
#         if params:
#             c.execute(query, params)
#         else:
#             c.execute(query)

#         if not insert:
#             # If not an insert operation, fetch and return the results
#             results = c.fetchall()
#             return results
#         elif insert_retrieve:
#             # If insert and retrieve operation, fetch and return the results
#             conn.commit()
#             return c.fetchall()
#         else:
#             conn.commit()

#     except Exception as e:
#         st.write(f"Error executing query: {e}")
#     finally:
#         conn.close()


db_path = os.path.join("imp_db.db")


def query_excecuter_postgres(
    query, db_path=None, params=None, insert=True, insert_retrieve=False, db_cred=None
):
    """
    Executes a SQL query on a SQLite database, handling both insert and select operations.

    Parameters:
    query (str): The SQL query to be executed.
    db_path (str): Path to the SQLite database file.
    params (tuple, optional): Parameters to pass into the SQL query for parameterized execution.
    insert (bool, default=True): Flag to determine if the query is an insert operation (default) or a select operation.
    insert_retrieve (bool, default=False): Flag to determine if the query should insert and then return the inserted ID.

    """
    try:
        # Construct a cross-platform path to the database
        db_dir = os.path.join("db")
        os.makedirs(db_dir, exist_ok=True)  # Make sure the directory exists
        db_path = os.path.join(db_dir, "imp_db.db")

        # Establish connection to the SQLite database
        conn = sqlite3.connect(db_path)
    except sqlite3.Error as e:
        st.warning(f"Unable to connect to the SQLite database: {e}")
        st.stop()

    # Create a cursor object to interact with the database
    c = conn.cursor()

    # Prepare the query with proper placeholders
    if params:
        # Handle the `IN (?)` clause dynamically
        query = query.replace("IN (?)", f"IN ({','.join(['?' for _ in params])})")
        c.execute(query, params)
    else:
        c.execute(query)

    try:
        if not insert:
            # If not an insert operation, fetch and return the results
            results = c.fetchall()
            return results
        elif insert_retrieve:
            # If insert and retrieve operation, commit and return the last inserted row ID
            conn.commit()
            return c.lastrowid
        else:
            # For standard insert operations, commit the transaction
            conn.commit()

    except Exception as e:
        st.write(f"Error executing query: {e}")
    finally:
        conn.close()


def update_summary_df():
    """
    Updates the 'project_summary_df' in the session state with the latest project
    summary information based on the most recent updates.

    This function executes a SQL query to retrieve project metadata from a database
    and stores the result in the session state.

    Uses:
    - query_excecuter_postgres(query, params=params, insert=False): A function that
      executes the provided SQL query on a PostgreSQL database.

    Modifies:
    - st.session_state['project_summary_df']: Updates the dataframe with columns:
      'Project Number', 'Project Name', 'Last Modified Page', 'Last Modified Time'.
    """

    query = f"""
            WITH LatestUpdates AS (
                SELECT
                    prj_id,
                    page_nam,
                    updt_dt_tm,
                    ROW_NUMBER() OVER (PARTITION BY prj_id ORDER BY updt_dt_tm DESC) AS rn
                FROM
                    mmo_project_meta_data
            )
            SELECT
                p.prj_id,
                p.prj_nam AS prj_nam,
                lu.page_nam,
                lu.updt_dt_tm
            FROM
                LatestUpdates lu
            RIGHT JOIN
                mmo_projects p ON lu.prj_id = p.prj_id
            WHERE
                p.prj_ownr_id = ? AND lu.rn = 1
            """

    params = (st.session_state["emp_id"],)  # Parameters for the SQL query

    # Execute the query and retrieve project summary data
    project_summary = query_excecuter_postgres(
        query, db_cred, params=params, insert=False
    )

    # Update the session state with the project summary dataframe
    st.session_state["project_summary_df"] = pd.DataFrame(
        project_summary,
        columns=[
            "Project Number",
            "Project Name",
            "Last Modified Page",
            "Last Modified Time",
        ],
    )

    st.session_state["project_summary_df"] = st.session_state[
        "project_summary_df"
    ].sort_values(by=["Last Modified Time"], ascending=False)

    return st.session_state["project_summary_df"]


from constants import default_dct


def ensure_project_dct_structure(session_state, default_dct):
    for key, value in default_dct.items():
        if key not in session_state:
            session_state[key] = value
        elif isinstance(value, dict):
            ensure_project_dct_structure(session_state[key], value)


def project_selection():

    emp_id = st.text_input("employee id", key="emp1111").lower()
    password = st.text_input("Password", max_chars=15, type="password")

    if st.button("Login"):

        if "unique_ids" not in st.session_state:
            unique_users_query = f"""
                    SELECT DISTINCT emp_id, emp_nam, emp_typ from mmo_users;
                    """
            unique_users_result = query_excecuter_postgres(
                unique_users_query, db_cred, insert=False
            )  # retrieves all the users who has access to MMO TOOL
            st.session_state["unique_ids"] = {
                emp_id: (emp_nam, emp_type)
                for emp_id, emp_nam, emp_type in unique_users_result
            }

        if emp_id not in st.session_state["unique_ids"].keys() or len(password) == 0:
            st.warning("invalid id or password!")
            st.stop()

        if not is_pswrd_flag_set(emp_id):
            st.warning("Reset password in home page to continue")
            st.stop()

        elif not verify_password(emp_id, password):
            st.warning("Invalid user name or password")
            st.stop()

        else:
            st.session_state["emp_id"] = emp_id
            st.session_state["username"] = st.session_state["unique_ids"][
                st.session_state["emp_id"]
            ][0]

        with st.spinner("Loading Saved Projects"):
            st.session_state["project_summary_df"] = update_summary_df()

            # st.write(st.session_state["project_name"][0])
        if len(st.session_state["project_summary_df"]) == 0:
            st.warning("No projects found please create a project in Home page")
            st.stop()

        else:

            try:
                st.session_state["project_name"] = (
                    st.session_state["project_summary_df"]
                    .loc[
                        st.session_state["project_summary_df"]["Project Number"]
                        == st.session_state["project_summary_df"].iloc[0, 0],
                        "Project Name",
                    ]
                    .values[0]
                )  # fetching project name from project number stored in summary df

                poroject_dct_query = f""" 
                
                SELECT pkl_obj FROM  mmo_project_meta_data WHERE prj_id = ? AND file_nam=?;
                
                """
                # Execute the query and retrieve the result

                project_number = int(st.session_state["project_summary_df"].iloc[0, 0])

                st.session_state["project_number"] = project_number

                project_dct_retrieved = query_excecuter_postgres(
                    poroject_dct_query,
                    db_cred,
                    params=(project_number, "project_dct"),
                    insert=False,
                )
                # retrieves project dict (meta data)  stored in db

                st.session_state["project_dct"] = pickle.loads(
                    project_dct_retrieved[0][0]
                )  # converting bytes data to original objet using pickle
                ensure_project_dct_structure(
                    st.session_state["project_dct"], default_dct
                )

                st.success("Project Loded")
                st.rerun()

            except Exception as e:

                st.write(
                    "Failed to load project meta data from db please create new project!"
                )
                st.stop()


def update_db(prj_id, page_nam, file_nam, pkl_obj, resp_mtrc="", schema=""):

    # Check if an entry already exists

    check_query = f"""
    SELECT 1 FROM mmo_project_meta_data
    WHERE prj_id = ? AND file_nam =?;
    """

    check_params = (prj_id, file_nam)
    result = query_excecuter_postgres(
        check_query, db_cred, params=check_params, insert=False
    )

    # If entry exists, perform an update
    if result is not None and result:

        update_query = f"""
        UPDATE mmo_project_meta_data
        SET file_nam = ?, pkl_obj = ?, page_nam=? ,updt_dt_tm = datetime('now')

        WHERE prj_id = ? AND file_nam = ?;
        """

        update_params = (file_nam, pkl_obj, page_nam, prj_id, file_nam)

        query_excecuter_postgres(
            update_query, db_cred, params=update_params, insert=True
        )

    # If entry does not exist, perform an insert
    else:

        insert_query = f"""
        INSERT INTO mmo_project_meta_data
        (prj_id, page_nam, file_nam, pkl_obj,crte_by_uid, crte_dt_tm, updt_dt_tm)
        VALUES (?, ?, ?, ?, ?, datetime('now'), datetime('now'));
        """

        insert_params = (
            prj_id,
            page_nam,
            file_nam,
            pkl_obj,
            st.session_state["emp_id"],
        )

        query_excecuter_postgres(
            insert_query, db_cred, params=insert_params, insert=True
        )

        # st.success(f"Inserted project meta data for project {prj_id}, page {page_nam}")


def retrieve_pkl_object(prj_id, page_nam, file_nam, schema=""):

    query = f"""
    SELECT pkl_obj FROM mmo_project_meta_data
    WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
    """

    params = (prj_id, page_nam, file_nam)
    result = query_excecuter_postgres(
        query, db_cred=db_cred, params=params, insert=False
    )

    if result and result[0] and result[0][0]:
        pkl_obj = result[0][0]
        # Deserialize the pickle object
        return pickle.loads(pkl_obj)
    else:
        return None


def validate_text(input_text):

    # Check the length of the text
    if len(input_text) < 2:
        return False, "Input should be at least 2 characters long."
    if len(input_text) > 30:
        return False, "Input should not exceed 30 characters."

    # Check if the text contains only allowed characters
    if not re.match(r"^[A-Za-z0-9_]+$", input_text):
        return (
            False,
            "Input contains invalid characters. Only letters, numbers and underscores are allowed.",
        )

    return True, "Input is valid."


def delete_entries(prj_id, page_names, db_cred=None, schema=None):
    """
    Deletes all entries from the project_meta_data table based on prj_id and a list of page names.

    Parameters:
    prj_id (int): The project ID.
    page_names (list): A list of page names.
    db_cred (dict): Database credentials with keys 'dbname', 'user', 'password', 'host', 'port'.
    schema (str): The schema name.
    """
    # Create placeholders for each page name in the list
    placeholders = ", ".join(["?"] * len(page_names))
    query = f"""
    DELETE FROM mmo_project_meta_data
    WHERE prj_id = ? AND page_nam IN ({placeholders});
    """

    # Combine prj_id and page_names into one list of parameters
    params = (prj_id, *page_names)

    query_excecuter_postgres(query, db_cred, params=params, insert=True)


# st.success(f"Deleted entries for project {prj_id}, page {page_name}")
def store_hashed_password(
    user_id,
    plain_text_password,
):
    """
    Hashes a plain text password using bcrypt, converts it to a UTF-8 string, and stores it as text.

    Parameters:
    plain_text_password (str): The plain text password to be hashed.
    db_cred (dict): The database credentials including dbname, user, password, host, and port.
    """
    # Hash the plain text password
    hashed_password = bcrypt.hashpw(
        plain_text_password.encode("utf-8"), bcrypt.gensalt()
    )

    # Convert the byte string to a regular string for storage
    hashed_password_str = hashed_password.decode("utf-8")

    # SQL query to update the pswrd_key for the specified user_id
    query = f"""
    UPDATE mmo_users
    SET pswrd_key = ?
    WHERE emp_id = ?;
    """

    # Execute the query using the existing query_excecuter_postgres function
    query_excecuter_postgres(
        query=query, db_cred=db_cred, params=(hashed_password_str, user_id), insert=True
    )


def verify_password(user_id, plain_text_password):
    """
    Verifies the plain text password against the stored hashed password for the specified user_id.

    Parameters:
    user_id (int): The ID of the user whose password is being verified.
    plain_text_password (str): The plain text password to verify.
    db_cred (dict): The database credentials including dbname, user, password, host, and port.
    """
    # SQL query to retrieve the hashed password for the user_id
    query = f"""
    SELECT pswrd_key FROM mmo_users WHERE emp_id = ?;
    """

    # Execute the query using the existing query_excecuter_postgres function
    result = query_excecuter_postgres(
        query=query, db_cred=db_cred, params=(user_id,), insert=False
    )

    if result:

        stored_hashed_password_str = result[0][0]
        # Convert the stored string back to bytes
        stored_hashed_password = stored_hashed_password_str.encode("utf-8")

        if bcrypt.checkpw(plain_text_password.encode("utf-8"), stored_hashed_password):

            return True
        else:

            return False
    else:

        return False


def update_password_in_db(user_id, plain_text_password):
    """
    Hashes the plain text password and updates the `pswrd_key`
    column for the given `emp_id` in the `mmo_users` table.

    Parameters:
    emp_id (var): The ID of the user whose password needs to be updated.
    plain_text_password (str): The plain text password to be hashed and stored.
    db_cred (dict): Database credentials required to connect to the database.
    """
    # Hash the plain text password using bcrypt
    hashed_password = bcrypt.hashpw(
        plain_text_password.encode("utf-8"), bcrypt.gensalt()
    )

    # Convert the hashed password from bytes to a string for storage
    hashed_password_str = hashed_password.decode("utf-8")

    # SQL query to update the password in the database
    query = f"""
    UPDATE mmo_users
    SET pswrd_key = ?
    WHERE emp_id = ?
    """

    # Parameters for the query
    params = (hashed_password_str, user_id)

    # Execute the query using the query_excecuter_postgres function
    query_excecuter_postgres(query, db_cred, params=params, insert=True)


def is_pswrd_flag_set(user_id):
    query = f"""
    SELECT pswrd_flag 
    FROM mmo_users 
    WHERE emp_id = ?;
    """

    # Execute the query
    result = query_excecuter_postgres(query, db_cred, params=(user_id,), insert=False)

    # Return True if the flag is 1, otherwise return False
    if result and result[0][0] == 1:
        return True
    else:
        return False


def set_pswrd_flag(user_id):
    query = f"""
    UPDATE mmo_users
    SET pswrd_flag = 1
    WHERE emp_id = ?;
    """

    # Execute the update query
    query_excecuter_postgres(query, db_cred, params=(user_id,), insert=True)


def retrieve_pkl_object_without_warning(prj_id, page_nam, file_nam, schema):

    query = f"""
    SELECT pkl_obj FROM mmo_project_meta_data
    WHERE prj_id = ? AND page_nam = ? AND file_nam = ?;
    """

    params = (prj_id, page_nam, file_nam)
    result = query_excecuter_postgres(
        query, db_cred=db_cred, params=params, insert=False
    )

    if result and result[0] and result[0][0]:
        pkl_obj = result[0][0]
        # Deserialize the pickle object
        return pickle.loads(pkl_obj)
    else:
        # st.warning(
        #     "Pickle object not found for the given project ID, page name, and file name."
        # )
        return None


color_palette = [
    "#F3F3F0",
    "#5E7D7E",
    "#2FA1FF",
    "#00EDED",
    "#00EAE4",
    "#304550",
    "#EDEBEB",
    "#7FBEFD",
    "#003059",
    "#A2F3F3",
    "#E1D6E2",
    "#B6B6B6",
]


CURRENCY_INDICATOR = "$"


# database_file = r"DB/User.db"

# conn = sqlite3.connect(database_file, check_same_thread=False)  # connection with sql db
# c = conn.cursor()


# def load_authenticator():
#     with open("config.yaml") as file:
#         config = yaml.load(file, Loader=SafeLoader)
#         st.session_state["config"] = config
#     authenticator = stauth.Authenticate(
#         credentials=config["credentials"],
#         cookie_name=config["cookie"]["name"],
#         key=config["cookie"]["key"],
#         cookie_expiry_days=config["cookie"]["expiry_days"],
#         preauthorized=config["preauthorized"],
#     )
#     st.session_state["authenticator"] = authenticator
#     return authenticator


# Authentication
# def authenticator():
#     for k, v in st.session_state.items():
#         if k not in ["logout", "login", "config"] and not k.startswith(
#             "FormSubmitter"
#         ):
#             st.session_state[k] = v
#     with open("config.yaml") as file:
#         config = yaml.load(file, Loader=SafeLoader)
#         st.session_state["config"] = config
#     authenticator = stauth.Authenticate(
#         config["credentials"],
#         config["cookie"]["name"],
#         config["cookie"]["key"],
#         config["cookie"]["expiry_days"],
#         config["preauthorized"],
#     )
#     st.session_state["authenticator"] = authenticator
#     name, authentication_status, username = authenticator.login(
#         "Login", "main"
#     )
#     auth_status = st.session_state.get("authentication_status")

#     if auth_status == True:
#         authenticator.logout("Logout", "main")
#         is_state_initiaized = st.session_state.get("initialized", False)

#         if not is_state_initiaized:

#             if "session_name" not in st.session_state:
#                 st.session_state["session_name"] = None

#     return name


# def authentication():
#     with open("config.yaml") as file:
#         config = yaml.load(file, Loader=SafeLoader)

#         authenticator = stauth.Authenticate(
#             config["credentials"],
#             config["cookie"]["name"],
#             config["cookie"]["key"],
#             config["cookie"]["expiry_days"],
#             config["preauthorized"],
#         )

#     name, authentication_status, username = authenticator.login(
#         "Login", "main"
#     )
#     return authenticator, name, authentication_status, username


def nav_page(page_name, timeout_secs=3):
    nav_script = """
        <script type="text/javascript">
            function attempt_nav_page(page_name, start_time, timeout_secs) {
                var links = window.parent.document.getElementsByTagName("a");
                for (var i = 0; i < links.length; i++) {
                    if (links[i].href.toLowerCase().endsWith("/" + page_name.toLowerCase())) {
                        links[i].click();
                        return;
                    }
                }
                var elasped = new Date() - start_time;
                if (elasped < timeout_secs * 1000) {
                    setTimeout(attempt_nav_page, 100, page_name, start_time, timeout_secs);
                } else {
                    alert("Unable to navigate to page '" + page_name + "' after " + timeout_secs + " second(s).");
                }
            }
            window.addEventListener("load", function() {
                attempt_nav_page("%s", new Date(), %d);
            });
        </script>
    """ % (
        page_name,
        timeout_secs,
    )
    html(nav_script)


# def load_local_css(file_name):
#     with open(file_name) as f:
#         st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)


# def set_header():
#     return st.markdown(f"""<div class='main-header'>
#                     <h1>MMM LiME</h1>
#                     <img src="https://assets-global.website-files.com/64c8fffb0e95cbc525815b79/64df84637f83a891c1473c51_Vector%20(Stroke).svg   ">
#             </div>""", unsafe_allow_html=True)

path = os.path.dirname(__file__)

file_ = open(f"{path}/logo.png", "rb")

contents = file_.read()

data_url = base64.b64encode(contents).decode("utf-8")

file_.close()


DATA_PATH = "./data"

IMAGES_PATH = "./data/images_224_224"


def load_local_css(file_name):

    with open(file_name) as f:

        st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)


# def set_header():

#     return st.markdown(f"""<div class='main-header'>

#                     <h1>H & M Recommendations</h1>

#                     <img src="data:image;base64,{data_url}", alt="Logo">

#             </div>""", unsafe_allow_html=True)
path1 = os.path.dirname(__file__)

# file_1 = open(f"{path}/willbank.png", "rb")

# contents1 = file_1.read()

# data_url1 = base64.b64encode(contents1).decode("utf-8")

# file_1.close()


DATA_PATH1 = "./data"

IMAGES_PATH1 = "./data/images_224_224"


def set_header():
    return st.markdown(
        f"""<div class='main-header'>
                    <!-- <h1></h1> -->
                       <div >
                    <img class='blend-logo' src="data:image;base64,{data_url}", alt="Logo">
            </div>""",
        unsafe_allow_html=True,
    )


# def set_header():
#     logo_path = "./path/to/your/local/LIME_logo.png"  # Replace with the actual file path
#     text = "LiME"
#     return st.markdown(f"""<div class='main-header'>
#                     <img src="data:image/png;base64,{data_url}" alt="Logo" style="float: left; margin-right: 10px; width: 100px; height: auto;">
#                     <h1>{text}</h1>
#             </div>""", unsafe_allow_html=True)


def s_curve(x, K, b, a, x0):
    return K / (1 + b * np.exp(-a * (x - x0)))


def panel_level(input_df, date_column="Date"):
    # Ensure 'Date' is set as the index
    if date_column not in input_df.index.names:
        input_df = input_df.set_index(date_column)

    # Select numeric columns only (excluding 'Date' since it's now the index)
    numeric_columns_df = input_df.select_dtypes(include="number")

    # Group by 'Date' (which is the index) and sum the numeric columns
    aggregated_df = numeric_columns_df.groupby(input_df.index).sum()

    # Reset the index to bring the 'Date' column
    aggregated_df = aggregated_df.reset_index()

    return aggregated_df


def fetch_actual_data(
    panel=None,
    target_file="Overview_data_test.xlsx",
    updated_rcs=None,
    metrics=None,
):
    excel = pd.read_excel(Path(target_file), sheet_name=None)

    # Extract dataframes for raw data, spend input, and contribution MMM
    raw_df = excel["RAW DATA MMM"]
    spend_df = excel["SPEND INPUT"]
    contri_df = excel["CONTRIBUTION MMM"]

    # Check if the panel is not None
    if panel is not None and panel != "Aggregated":
        raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
        spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
        contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
    elif panel == "Aggregated":
        raw_df = panel_level(raw_df, date_column="Date")
        spend_df = panel_level(spend_df, date_column="Week")
        contri_df = panel_level(contri_df, date_column="Date")

    # Revenue_df = excel['Revenue']

    ## remove sesonalities, indices etc ...
    unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
    ## remove sesonalities, indices etc ...

    exclude_columns = [
        "Date",
        "Region",
        "Controls_Grammarly_Index_SeasonalAVG",
        "Controls_Quillbot_Index",
        "Daily_Positive_Outliers",
        "External_RemoteClass_Index",
        "Intervals ON 20190520-20190805 | 20200518-20200803 | 20210517-20210802",
        "Intervals ON 20190826-20191209 | 20200824-20201207 | 20210823-20211206",
        "Intervals ON 20201005-20201019",
        "Promotion_PercentOff",
        "Promotion_TimeBased",
        "Seasonality_Indicator_Chirstmas",
        "Seasonality_Indicator_NewYears_Days",
        "Seasonality_Indicator_Thanksgiving",
        "Trend 20200302 / 20200803",
    ] + unnamed_cols

    raw_df["Date"] = pd.to_datetime(raw_df["Date"])
    contri_df["Date"] = pd.to_datetime(contri_df["Date"])
    input_df = raw_df.sort_values(by="Date")
    output_df = contri_df.sort_values(by="Date")
    spend_df["Week"] = pd.to_datetime(
        spend_df["Week"], format="%Y-%m-%d", errors="coerce"
    )
    spend_df.sort_values(by="Week", inplace=True)

    # spend_df['Week'] = pd.to_datetime(spend_df['Week'], errors='coerce')
    # spend_df = spend_df.sort_values(by='Week')

    channel_list = [col for col in input_df.columns if col not in exclude_columns]
    channel_list = list(set(channel_list) - set(["fb_level_achieved_tier_1", "ga_app"]))

    infeasible_channels = [
        c
        for c in contri_df.select_dtypes(include=["float", "int"]).columns
        if contri_df[c].sum() <= 0
    ]
    # st.write(channel_list)
    channel_list = list(set(channel_list) - set(infeasible_channels))

    upper_limits = {}
    output_cols = []
    actual_output_dic = {}
    actual_input_dic = {}

    for inp_col in channel_list:
        # st.write(inp_col)
        spends = input_df[inp_col].values
        x = spends.copy()
        # upper limit for penalty
        upper_limits[inp_col] = 2 * x.max()

        # contribution
        # out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
        out_col = inp_col
        y = output_df[out_col].values.copy()
        actual_output_dic[inp_col] = y.copy()
        actual_input_dic[inp_col] = x.copy()
        ##output cols aggregation
        output_cols.append(out_col)

    return pd.DataFrame(actual_input_dic), pd.DataFrame(actual_output_dic)


# Function to initialize model results data
def initialize_data(panel=None, metrics=None):
    # Extract dataframes for raw data, spend input, and contribution data
    raw_df = st.session_state["project_dct"]["current_media_performance"][
        "model_outputs"
    ][metrics]["raw_data"].copy()
    spend_df = st.session_state["project_dct"]["current_media_performance"][
        "model_outputs"
    ][metrics]["spends_data"].copy()
    contribution_df = st.session_state["project_dct"]["current_media_performance"][
        "model_outputs"
    ][metrics]["contribution_data"].copy()

    # Check if 'Panel' or 'panel' is in the columns
    panel_column = None
    if "Panel" in raw_df.columns:
        panel_column = "Panel"
    elif "panel" in raw_df.columns:
        panel_column = "panel"

    # Filter data by panel if provided
    if panel and panel.lower() != "aggregated":
        raw_df = raw_df[raw_df[panel_column] == panel].drop(columns=[panel_column])
        spend_df = spend_df[spend_df[panel_column] == panel].drop(
            columns=[panel_column]
        )
        contribution_df = contribution_df[contribution_df[panel_column] == panel].drop(
            columns=[panel_column]
        )
    else:
        raw_df = panel_level(raw_df, date_column="Date")
        spend_df = panel_level(spend_df, date_column="Date")
        contribution_df = panel_level(contribution_df, date_column="Date")

    # Remove unnecessary columns
    unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]
    exclude_columns = ["Date"] + unnamed_cols

    # Convert Date columns to datetime
    for df in [raw_df, spend_df, contribution_df]:
        df["Date"] = pd.to_datetime(df["Date"], format="%Y-%m-%d", errors="coerce")

    # Sort data by Date
    input_df = raw_df.sort_values(by="Date")
    contribution_df = contribution_df.sort_values(by="Date")
    spend_df.sort_values(by="Date", inplace=True)

    # Extract channels excluding unwanted columns
    channel_list = [col for col in input_df.columns if col not in exclude_columns]

    # Filter out channels with non-positive contributions
    negative_contributions = [
        col
        for col in contribution_df.select_dtypes(include=["float", "int"]).columns
        if contribution_df[col].sum() <= 0
    ]
    channel_list = list(set(channel_list) - set(negative_contributions))

    # Initialize dictionaries for metrics and response curves
    response_curves, mapes, rmses, upper_limits = {}, {}, {}, {}
    r2_scores, powers, conversion_rates, actual_output, actual_input = (
        {},
        {},
        {},
        {},
        {},
    )
    channels = {}
    sales = None
    dates = input_df["Date"].values

    # Fit s-curve for each channel
    for channel in channel_list:
        spends = input_df[channel].values
        x = spends.copy()
        upper_limits[channel] = 2 * x.max()

        # Get corresponding output column
        output_col = [
            _col for _col in contribution_df.columns if _col.startswith(channel)
        ][0]
        y = contribution_df[output_col].values.copy()
        actual_output[channel] = y.copy()
        actual_input[channel] = x.copy()

        # Scale input data
        power = np.ceil(np.log(x.max()) / np.log(10)) - 3
        if power >= 0:
            x = x / 10**power
        x, y = x.astype("float64"), y.astype("float64")

        # Set bounds for curve fitting
        if y.max() <= 0.01:
            bounds = (
                (0, 0, 0, 0),
                (3 * 0.01, 1000, 1, x.max() if x.max() > 0 else 0.01),
            )
        else:
            bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))

        # Set y to 0 where x is 0
        y[x == 0] = 0

        # Fit s-curve and calculate metrics
        # params, _ = curve_fit(
        #     s_curve,
        #     x
        #     y,
        #     p0=(2 * y.max(), 0.01, 1e-5, x.max()),
        #     bounds=bounds,
        #     maxfev=int(1e6),
        # )
        params, _ = curve_fit(
            s_curve,
            list(x) + [0] * len(x),
            list(y) + [0] * len(y),
            p0=(2 * y.max(), 0.01, 1e-5, x.max()),
            bounds=bounds,
            maxfev=int(1e6),
        )

        mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
        rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
        r2_score_ = r2_score(y, s_curve(x, *params))

        # Store metrics and parameters
        response_curves[channel] = {
            "K": params[0],
            "b": params[1],
            "a": params[2],
            "x0": params[3],
        }
        mapes[channel] = mape
        rmses[channel] = rmse
        r2_scores[channel] = r2_score_
        powers[channel] = power

        conversion_rate = spend_df[channel].sum() / max(input_df[channel].sum(), 1e-9)
        conversion_rates[channel] = conversion_rate
        correction = y - s_curve(x, *params)

        # Initialize Channel object
        channel_obj = Channel(
            name=channel,
            dates=dates,
            spends=spends,
            conversion_rate=conversion_rate,
            response_curve_type="s-curve",
            response_curve_params={
                "K": params[0],
                "b": params[1],
                "a": params[2],
                "x0": params[3],
            },
            bounds=np.array([-10, 10]),
            correction=correction,
        )
        channels[channel] = channel_obj
        if sales is None:
            sales = channel_obj.actual_sales
        else:
            sales += channel_obj.actual_sales

    # Calculate other contributions
    other_contributions = (
        contribution_df.drop(columns=[*response_curves.keys()])
        .sum(axis=1, numeric_only=True)
        .values
    )

    # Initialize Scenario object
    scenario = Scenario(
        name="default",
        channels=channels,
        constant=other_contributions,
        correction=np.array([]),
    )

    # Set session state variables
    st.session_state.update(
        {
            "initialized": True,
            "actual_df": input_df,
            "raw_df": raw_df,
            "contri_df": contribution_df,
            "default_scenario_dict": class_to_dict(scenario),
            "scenario": scenario,
            "channels_list": channel_list,
            "optimization_channels": {
                channel_name: False for channel_name in channel_list
            },
            "rcs": response_curves.copy(),
            "powers": powers,
            "actual_contribution_df": pd.DataFrame(actual_output),
            "actual_input_df": pd.DataFrame(actual_input),
            "xlsx_buffer": io.BytesIO(),
            "saved_scenarios": (
                pickle.load(open("../saved_scenarios.pkl", "rb"))
                if Path("../saved_scenarios.pkl").exists()
                else OrderedDict()
            ),
            "disable_download_button": True,
        }
    )

    for channel in channels.values():
        st.session_state[channel.name] = numerize(
            channel.actual_total_spends * channel.conversion_rate, 1
        )

    # Prepare response curve data for output
    response_curve_data = {}
    for channel, params in st.session_state["rcs"].items():
        x = st.session_state["actual_input_df"][channel].values.astype(float)
        y = st.session_state["actual_contribution_df"][channel].values.astype(float)
        power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
        x_plot = list(np.linspace(0, 5 * max(x), 100))

        response_curve_data[channel] = {
            "K": float(params["K"]),
            "b": float(params["b"]),
            "a": float(params["a"]),
            "x0": float(params["x0"]),
            "power": power,
            "x": list(x),
            "y": list(y),
            "x_plot": x_plot,
        }

    return response_curve_data, scenario


# def initialize_data(panel=None, metrics=None):
#     # Extract dataframes for raw data, spend input, and contribution data
#     raw_df = st.session_state["project_dct"]["current_media_performance"][
#         "model_outputs"
#     ][metrics]["raw_data"]
#     spend_df = st.session_state["project_dct"]["current_media_performance"][
#         "model_outputs"
#     ][metrics]["spends_data"]
#     contri_df = st.session_state["project_dct"]["current_media_performance"][
#         "model_outputs"
#     ][metrics]["contribution_data"]

#     # Check if the panel is not None
#     if panel is not None and panel.lower() != "aggregated":
#         raw_df = raw_df[raw_df["Panel"] == panel].drop(columns=["Panel"])
#         spend_df = spend_df[spend_df["Panel"] == panel].drop(columns=["Panel"])
#         contri_df = contri_df[contri_df["Panel"] == panel].drop(columns=["Panel"])
#     elif panel.lower() == "aggregated":
#         raw_df = panel_level(raw_df, date_column="Date")
#         spend_df = panel_level(spend_df, date_column="Date")
#         contri_df = panel_level(contri_df, date_column="Date")

#     ## remove sesonalities, indices etc ...
#     unnamed_cols = [col for col in raw_df.columns if col.lower().startswith("unnamed")]

#     ## remove sesonalities, indices etc ...
#     exclude_columns = ["Date"] + unnamed_cols

#     raw_df["Date"] = pd.to_datetime(raw_df["Date"], format="%Y-%m-%d", errors="coerce")
#     contri_df["Date"] = pd.to_datetime(
#         contri_df["Date"], format="%Y-%m-%d", errors="coerce"
#     )
#     spend_df["Date"] = pd.to_datetime(
#         spend_df["Date"], format="%Y-%m-%d", errors="coerce"
#     )

#     input_df = raw_df.sort_values(by="Date")
#     output_df = contri_df.sort_values(by="Date")
#     spend_df.sort_values(by="Date", inplace=True)

#     channel_list = [col for col in input_df.columns if col not in exclude_columns]

#     negative_contribution = [
#         c
#         for c in contri_df.select_dtypes(include=["float", "int"]).columns
#         if contri_df[c].sum() <= 0
#     ]
#     channel_list = list(set(channel_list) - set(negative_contribution))

#     response_curves = {}
#     mapes = {}
#     rmses = {}
#     upper_limits = {}
#     powers = {}
#     r2 = {}
#     conv_rates = {}
#     output_cols = []
#     channels = {}
#     sales = None
#     dates = input_df.Date.values
#     actual_output_dic = {}
#     actual_input_dic = {}

#     for inp_col in channel_list:
#         spends = input_df[inp_col].values
#         x = spends.copy()
#         # upper limit for penalty
#         upper_limits[inp_col] = 2 * x.max()

#         # contribution
#         out_col = [_col for _col in output_df.columns if _col.startswith(inp_col)][0]
#         y = output_df[out_col].values.copy()
#         actual_output_dic[inp_col] = y.copy()
#         actual_input_dic[inp_col] = x.copy()
#         ##output cols aggregation
#         output_cols.append(out_col)

#         ## scale the input
#         power = np.ceil(np.log(x.max()) / np.log(10)) - 3
#         if power >= 0:
#             x = x / 10**power

#         x = x.astype("float64")
#         y = y.astype("float64")

#         if y.max() <= 0.01:
#             if x.max() <= 0.0:
#                 bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, 0.01))

#             else:
#                 bounds = ((0, 0, 0, 0), (3 * 0.01, 1000, 1, x.max()))
#         else:
#             bounds = ((0, 0, 0, 0), (3 * y.max(), 1000, 1, x.max()))

#         params, _ = curve_fit(
#             s_curve,
#             x,
#             y,
#             p0=(2 * y.max(), 0.01, 1e-5, x.max()),
#             bounds=bounds,
#             maxfev=int(1e5),
#         )
#         mape = (100 * abs(1 - s_curve(x, *params) / y.clip(min=1))).mean()
#         rmse = np.sqrt(((y - s_curve(x, *params)) ** 2).mean())
#         r2_ = r2_score(y, s_curve(x, *params))

#         response_curves[inp_col] = {
#             "K": params[0],
#             "b": params[1],
#             "a": params[2],
#             "x0": params[3],
#         }

#         mapes[inp_col] = mape
#         rmses[inp_col] = rmse
#         r2[inp_col] = r2_
#         powers[inp_col] = power

#         conv = spend_df[inp_col].sum() / max(input_df[inp_col].sum(), 1e-9)
#         conv_rates[inp_col] = conv

#         correction = y - s_curve(x, *params)

#         channel = Channel(
#             name=inp_col,
#             dates=dates,
#             spends=spends,
#             conversion_rate=conv_rates[inp_col],
#             response_curve_type="s-curve",
#             response_curve_params={
#                 "K": params[0],
#                 "b": params[1],
#                 "a": params[2],
#                 "x0": params[3],
#             },
#             bounds=np.array([-10, 10]),
#             correction=correction,
#         )

#         channels[inp_col] = channel
#         if sales is None:
#             sales = channel.actual_sales
#         else:
#             sales += channel.actual_sales

#     other_contributions = (
#         output_df.drop([*output_cols], axis=1).sum(axis=1, numeric_only=True).values
#     )

#     scenario = Scenario(
#         name="default",
#         channels=channels,
#         constant=other_contributions,
#         correction=np.array([]),
#     )

#     ## setting session variables
#     st.session_state["initialized"] = True
#     st.session_state["actual_df"] = input_df
#     st.session_state["raw_df"] = raw_df
#     st.session_state["contri_df"] = output_df
#     default_scenario_dict = class_to_dict(scenario)
#     st.session_state["default_scenario_dict"] = default_scenario_dict
#     st.session_state["scenario"] = scenario
#     st.session_state["channels_list"] = channel_list
#     st.session_state["optimization_channels"] = {
#         channel_name: False for channel_name in channel_list
#     }
#     st.session_state["rcs"] = response_curves.copy()

#     st.session_state["powers"] = powers
#     st.session_state["actual_contribution_df"] = pd.DataFrame(actual_output_dic)
#     st.session_state["actual_input_df"] = pd.DataFrame(actual_input_dic)

#     for channel in channels.values():
#         st.session_state[channel.name] = numerize(
#             channel.actual_total_spends * channel.conversion_rate, 1
#         )

#     st.session_state["xlsx_buffer"] = io.BytesIO()

#     if Path("../saved_scenarios.pkl").exists():
#         with open("../saved_scenarios.pkl", "rb") as f:
#             st.session_state["saved_scenarios"] = pickle.load(f)
#     else:
#         st.session_state["saved_scenarios"] = OrderedDict()

#     # st.session_state["total_spends_change"] = 0
#     st.session_state["optimization_channels"] = {
#         channel_name: False for channel_name in channel_list
#     }
#     st.session_state["disable_download_button"] = True

#     rcs_data = {}
#     for channel in st.session_state["rcs"]:
#         # Convert to native Python lists and types
#         x = list(st.session_state["actual_input_df"][channel].values.astype(float))
#         y = list(
#             st.session_state["actual_contribution_df"][channel].values.astype(float)
#         )
#         power = float(np.ceil(np.log(max(x)) / np.log(10)) - 3)
#         x_plot = list(np.linspace(0, 5 * max(x), 100))

#         rcs_data[channel] = {
#             "K": float(st.session_state["rcs"][channel]["K"]),
#             "b": float(st.session_state["rcs"][channel]["b"]),
#             "a": float(st.session_state["rcs"][channel]["a"]),
#             "x0": float(st.session_state["rcs"][channel]["x0"]),
#             "power": power,
#             "x": x,
#             "y": y,
#             "x_plot": x_plot,
#         }

#     return rcs_data, scenario


# def initialize_data():
#     # fetch data from excel
#     output = pd.read_excel('data.xlsx',sheet_name=None)
#     raw_df = output['RAW DATA MMM']
#     contribution_df = output['CONTRIBUTION MMM']
#     Revenue_df = output['Revenue']

#     ## channels to be shows
#     channel_list = []
#     for col in raw_df.columns:
#         if 'click' in col.lower() or 'spend' in col.lower() or 'imp' in col.lower():
#             channel_list.append(col)
#         else:
#             pass

#     ## NOTE : Considered only Desktop spends for all calculations
#     acutal_df = raw_df[raw_df.Region == 'Desktop'].copy()
#     ## NOTE : Considered one year of data
#     acutal_df = acutal_df[acutal_df.Date>'2020-12-31']
#     actual_df = acutal_df.drop('Region',axis=1).sort_values(by='Date')[[*channel_list,'Date']]

#     ##load response curves
#     with open('./grammarly_response_curves.json','r') as f:
#         response_curves = json.load(f)

#     ## create channel dict for scenario creation
#     dates = actual_df.Date.values
#     channels = {}
#     rcs = {}
#     constant = 0.
#     for i,info_dict in enumerate(response_curves):
#         name = info_dict.get('name')
#         response_curve_type = info_dict.get('response_curve')
#         response_curve_params = info_dict.get('params')
#         rcs[name] = response_curve_params
#         if name != 'constant':
#             spends = actual_df[name].values
#             channel = Channel(name=name,dates=dates,
#                             spends=spends,
#                             response_curve_type=response_curve_type,
#                             response_curve_params=response_curve_params,
#                             bounds=np.array([-30,30]))

#             channels[name] = channel
#         else:
#             constant = info_dict.get('value',0.) * len(dates)

#     ## create scenario
#     scenario = Scenario(name='default', channels=channels, constant=constant)
#     default_scenario_dict = class_to_dict(scenario)


#     ## setting session variables
#     st.session_state['initialized'] = True
#     st.session_state['actual_df'] = actual_df
#     st.session_state['raw_df'] = raw_df
#     st.session_state['default_scenario_dict'] = default_scenario_dict
#     st.session_state['scenario'] = scenario
#     st.session_state['channels_list'] = channel_list
#     st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}
#     st.session_state['rcs'] = rcs
#     for channel in channels.values():
#         if channel.name not in st.session_state:
#             st.session_state[channel.name] = float(channel.actual_total_spends)

#     if 'xlsx_buffer' not in st.session_state:
#         st.session_state['xlsx_buffer'] = io.BytesIO()

#     ## for saving scenarios
#     if 'saved_scenarios' not in st.session_state:
#         if Path('../saved_scenarios.pkl').exists():
#             with open('../saved_scenarios.pkl','rb') as f:
#                 st.session_state['saved_scenarios'] = pickle.load(f)

#         else:
#             st.session_state['saved_scenarios'] = OrderedDict()

#     if 'total_spends_change' not in st.session_state:
#         st.session_state['total_spends_change'] = 0

#     if 'optimization_channels' not in st.session_state:
#         st.session_state['optimization_channels'] = {channel_name : False for channel_name in channel_list}

#     if 'disable_download_button' not in st.session_state:
#         st.session_state['disable_download_button'] = True


def create_channel_summary(scenario):

    # Provided data
    data = {
        "Channel": [
            "Paid Search",
            "Ga will cid baixo risco",
            "Digital tactic others",
            "Fb la tier 1",
            "Fb la tier 2",
            "Paid social others",
            "Programmatic",
            "Kwai",
            "Indicacao",
            "Infleux",
            "Influencer",
        ],
        "Spends": [
            "$ 11.3K",
            "$ 155.2K",
            "$ 50.7K",
            "$ 125.4K",
            "$ 125.2K",
            "$ 105K",
            "$ 3.3M",
            "$ 47.5K",
            "$ 55.9K",
            "$ 632.3K",
            "$ 48.3K",
        ],
        "Revenue": [
            "558.0K",
            "3.5M",
            "5.2M",
            "3.1M",
            "3.1M",
            "2.1M",
            "20.8M",
            "1.6M",
            "728.4K",
            "22.9M",
            "4.8M",
        ],
    }

    # Create DataFrame
    df = pd.DataFrame(data)

    # Convert currency strings to numeric values
    df["Spends"] = (
        df["Spends"]
        .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
        .map(pd.eval)
        .astype(int)
    )
    df["Revenue"] = (
        df["Revenue"]
        .replace({"\$": "", "K": "*1e3", "M": "*1e6"}, regex=True)
        .map(pd.eval)
        .astype(int)
    )

    # Calculate ROI
    df["ROI"] = (df["Revenue"] - df["Spends"]) / df["Spends"]

    # Format columns
    format_currency = lambda x: f"${x:,.1f}"
    format_roi = lambda x: f"{x:.1f}"

    df["Spends"] = [
        "$ 11.3K",
        "$ 155.2K",
        "$ 50.7K",
        "$ 125.4K",
        "$ 125.2K",
        "$ 105K",
        "$ 3.3M",
        "$ 47.5K",
        "$ 55.9K",
        "$ 632.3K",
        "$ 48.3K",
    ]
    df["Revenue"] = [
        "$ 536.3K",
        "$ 3.4M",
        "$ 5M",
        "$ 3M",
        "$ 3M",
        "$ 2M",
        "$ 20M",
        "$ 1.5M",
        "$ 7.1M",
        "$ 22M",
        "$ 4.6M",
    ]
    df["ROI"] = df["ROI"].apply(format_roi)

    return df


# @st.cache(allow_output_mutation=True)
# def create_contribution_pie(scenario):
#     #c1f7dc
#     colors_map = {col:color for col,color in zip(st.session_state['channels_list'],plotly.colors.n_colors(plotly.colors.hex_to_rgb('#BE6468'), plotly.colors.hex_to_rgb('#E7B8B7'),23))}
#     total_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "pie"}, {"type": "pie"}]])
#     total_contribution_fig.add_trace(
#                 go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
#                     values= [round(scenario.channels[channel_name].actual_total_spends * scenario.channels[channel_name].conversion_rate,1) for channel_name in st.session_state['channels_list']] + [0],
#                     marker=dict(colors = [plotly.colors.label_rgb(colors_map[channel_name]) for channel_name in st.session_state['channels_list']] + ['#F0F0F0']),
#                         hole=0.3),
#                 row=1, col=1)

#     total_contribution_fig.add_trace(
#                 go.Pie(labels=[channel_name_formating(channel_name) for channel_name in st.session_state['channels_list']] + ['Non Media'],
#                     values= [scenario.channels[channel_name].actual_total_sales for channel_name in st.session_state['channels_list']] + [scenario.correction.sum() + scenario.constant.sum()],
#                         hole=0.3),
#                 row=1, col=2)

#     total_contribution_fig.update_traces(textposition='inside',texttemplate='%{percent:.1%}')
#     total_contribution_fig.update_layout(uniformtext_minsize=12,title='Channel contribution', uniformtext_mode='hide')
#     return total_contribution_fig

# @st.cache(allow_output_mutation=True)

# def create_contribuion_stacked_plot(scenario):
#     weekly_contribution_fig = make_subplots(rows=1, cols=2,subplot_titles=['Spends','Revenue'],specs=[[{"type": "bar"}, {"type": "bar"}]])
#     raw_df = st.session_state['raw_df']
#     df = raw_df.sort_values(by='Date')
#     x = df.Date
#     weekly_spends_data = []
#     weekly_sales_data = []
#     for channel_name in st.session_state['channels_list']:
#         weekly_spends_data.append((go.Bar(x=x,
#                                           y=scenario.channels[channel_name].actual_spends * scenario.channels[channel_name].conversion_rate,
#                                           name=channel_name_formating(channel_name),
#                                           hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
#                                           legendgroup=channel_name)))
#         weekly_sales_data.append((go.Bar(x=x,
#                                          y=scenario.channels[channel_name].actual_sales,
#                                          name=channel_name_formating(channel_name),
#                                          hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
#                                          legendgroup=channel_name, showlegend=False)))
#     for _d in weekly_spends_data:
#         weekly_contribution_fig.add_trace(_d, row=1, col=1)
#     for _d in weekly_sales_data:
#         weekly_contribution_fig.add_trace(_d, row=1, col=2)
#     weekly_contribution_fig.add_trace(go.Bar(x=x,
#                                          y=scenario.constant + scenario.correction,
#                                          name='Non Media',
#                                          hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), row=1, col=2)
#     weekly_contribution_fig.update_layout(barmode='stack', title='Channel contribuion by week', xaxis_title='Date')
#     weekly_contribution_fig.update_xaxes(showgrid=False)
#     weekly_contribution_fig.update_yaxes(showgrid=False)
#     return weekly_contribution_fig

# @st.cache(allow_output_mutation=True)
# def create_channel_spends_sales_plot(channel):
#     if channel is not None:
#         x = channel.dates
#         _spends = channel.actual_spends * channel.conversion_rate
#         _sales = channel.actual_sales
#         channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
#         channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
#         channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#005b96'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
#         channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
#         channel_sales_spends_fig.update_xaxes(showgrid=False)
#         channel_sales_spends_fig.update_yaxes(showgrid=False)
#     else:
#         raw_df = st.session_state['raw_df']
#         df = raw_df.sort_values(by='Date')
#         x = df.Date
#         scenario = class_from_dict(st.session_state['default_scenario_dict'])
#         _sales = scenario.constant + scenario.correction
#         channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
#         channel_sales_spends_fig.add_trace(go.Bar(x=x, y=_sales,marker_color='#c1f7dc',name='Revenue', hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}"), secondary_y = False)
#         # channel_sales_spends_fig.add_trace(go.Scatter(x=x, y=_spends,line=dict(color='#15C39A'),name='Spends',hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}"), secondary_y = True)
#         channel_sales_spends_fig.update_layout(xaxis_title='Date',yaxis_title='Revenue',yaxis2_title='Spends ($)',title='Channel spends and Revenue week wise')
#         channel_sales_spends_fig.update_xaxes(showgrid=False)
#         channel_sales_spends_fig.update_yaxes(showgrid=False)
#     return channel_sales_spends_fig


# Define a shared color palette


def create_contribution_pie():
    color_palette = [
        "#F3F3F0",
        "#5E7D7E",
        "#2FA1FF",
        "#00EDED",
        "#00EAE4",
        "#304550",
        "#EDEBEB",
        "#7FBEFD",
        "#003059",
        "#A2F3F3",
        "#E1D6E2",
        "#B6B6B6",
    ]
    total_contribution_fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=["Spends", "Revenue"],
        specs=[[{"type": "pie"}, {"type": "pie"}]],
    )

    channels_list = [
        "Paid Search",
        "Ga will cid baixo risco",
        "Digital tactic others",
        "Fb la tier 1",
        "Fb la tier 2",
        "Paid social others",
        "Programmatic",
        "Kwai",
        "Indicacao",
        "Infleux",
        "Influencer",
        "Non Media",
    ]

    # Assign colors from the limited palette to channels
    colors_map = {
        col: color_palette[i % len(color_palette)]
        for i, col in enumerate(channels_list)
    }
    colors_map["Non Media"] = color_palette[
        5
    ]  # Assign fixed green color for 'Non Media'

    # Hardcoded values for Spends and Revenue
    spends_values = [0.5, 3.36, 1.1, 2.7, 2.7, 2.27, 70.6, 1, 1, 13.7, 1, 0]
    revenue_values = [1, 4, 5, 3, 3, 2, 50.8, 1.5, 0.7, 13, 0, 16]

    # Add trace for Spends pie chart
    total_contribution_fig.add_trace(
        go.Pie(
            labels=[channel_name for channel_name in channels_list],
            values=spends_values,
            marker=dict(
                colors=[colors_map[channel_name] for channel_name in channels_list]
            ),
            hole=0.3,
        ),
        row=1,
        col=1,
    )

    # Add trace for Revenue pie chart
    total_contribution_fig.add_trace(
        go.Pie(
            labels=[channel_name for channel_name in channels_list],
            values=revenue_values,
            marker=dict(
                colors=[colors_map[channel_name] for channel_name in channels_list]
            ),
            hole=0.3,
        ),
        row=1,
        col=2,
    )

    total_contribution_fig.update_traces(
        textposition="inside", texttemplate="%{percent:.1%}"
    )
    total_contribution_fig.update_layout(
        uniformtext_minsize=12,
        title="Channel contribution",
        uniformtext_mode="hide",
    )
    return total_contribution_fig


def create_contribuion_stacked_plot(scenario):
    weekly_contribution_fig = make_subplots(
        rows=1,
        cols=2,
        subplot_titles=["Spends", "Revenue"],
        specs=[[{"type": "bar"}, {"type": "bar"}]],
    )
    raw_df = st.session_state["raw_df"]
    df = raw_df.sort_values(by="Date")
    x = df.Date
    weekly_spends_data = []
    weekly_sales_data = []

    for i, channel_name in enumerate(st.session_state["channels_list"]):
        color = color_palette[i % len(color_palette)]

        weekly_spends_data.append(
            go.Bar(
                x=x,
                y=scenario.channels[channel_name].actual_spends
                * scenario.channels[channel_name].conversion_rate,
                name=channel_name_formating(channel_name),
                hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
                legendgroup=channel_name,
                marker_color=color,
            )
        )

        weekly_sales_data.append(
            go.Bar(
                x=x,
                y=scenario.channels[channel_name].actual_sales,
                name=channel_name_formating(channel_name),
                hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
                legendgroup=channel_name,
                showlegend=False,
                marker_color=color,
            )
        )

    for _d in weekly_spends_data:
        weekly_contribution_fig.add_trace(_d, row=1, col=1)
    for _d in weekly_sales_data:
        weekly_contribution_fig.add_trace(_d, row=1, col=2)

    weekly_contribution_fig.add_trace(
        go.Bar(
            x=x,
            y=scenario.constant + scenario.correction,
            name="Non Media",
            hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
            marker_color=color_palette[-1],
        ),
        row=1,
        col=2,
    )

    weekly_contribution_fig.update_layout(
        barmode="stack",
        title="Channel contribution by week",
        xaxis_title="Date",
    )
    weekly_contribution_fig.update_xaxes(showgrid=False)
    weekly_contribution_fig.update_yaxes(showgrid=False)
    return weekly_contribution_fig


def create_channel_spends_sales_plot(channel):
    if channel is not None:
        x = channel.dates
        _spends = channel.actual_spends * channel.conversion_rate
        _sales = channel.actual_sales
        channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
        channel_sales_spends_fig.add_trace(
            go.Bar(
                x=x,
                y=_sales,
                marker_color=color_palette[
                    3
                ],  # You can choose a color from the palette
                name="Revenue",
                hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
            ),
            secondary_y=False,
        )

        channel_sales_spends_fig.add_trace(
            go.Scatter(
                x=x,
                y=_spends,
                line=dict(
                    color=color_palette[2]
                ),  # You can choose another color from the palette
                name="Spends",
                hovertemplate="Date:%{x}<br>Spend:%{y:$.2s}",
            ),
            secondary_y=True,
        )

        channel_sales_spends_fig.update_layout(
            xaxis_title="Date",
            yaxis_title="Revenue",
            yaxis2_title="Spends ($)",
            title="Channel spends and Revenue week-wise",
        )
        channel_sales_spends_fig.update_xaxes(showgrid=False)
        channel_sales_spends_fig.update_yaxes(showgrid=False)
    else:
        raw_df = st.session_state["raw_df"]
        df = raw_df.sort_values(by="Date")
        x = df.Date
        scenario = class_from_dict(st.session_state["default_scenario_dict"])
        _sales = scenario.constant + scenario.correction
        channel_sales_spends_fig = make_subplots(specs=[[{"secondary_y": True}]])
        channel_sales_spends_fig.add_trace(
            go.Bar(
                x=x,
                y=_sales,
                marker_color=color_palette[
                    0
                ],  # You can choose a color from the palette
                name="Revenue",
                hovertemplate="Date:%{x}<br>Revenue:%{y:$.2s}",
            ),
            secondary_y=False,
        )

        channel_sales_spends_fig.update_layout(
            xaxis_title="Date",
            yaxis_title="Revenue",
            yaxis2_title="Spends ($)",
            title="Channel spends and Revenue week-wise",
        )
        channel_sales_spends_fig.update_xaxes(showgrid=False)
        channel_sales_spends_fig.update_yaxes(showgrid=False)

    return channel_sales_spends_fig


def format_numbers(value, n_decimals=1, include_indicator=True):
    if value is None:
        return None
    _value = value if value < 1 else numerize(value, n_decimals)
    if include_indicator:
        return f"{CURRENCY_INDICATOR} {_value}"
    else:
        return f"{_value}"


def decimal_formater(num_string, n_decimals=1):
    parts = num_string.split(".")
    if len(parts) == 1:
        return num_string + "." + "0" * n_decimals
    else:
        to_be_padded = n_decimals - len(parts[-1])
        if to_be_padded > 0:
            return num_string + "0" * to_be_padded
        else:
            return num_string


def channel_name_formating(channel_name):
    name_mod = channel_name.replace("_", " ")
    if name_mod.lower().endswith(" imp"):
        name_mod = name_mod.replace("Imp", "Spend")
    elif name_mod.lower().endswith(" clicks"):
        name_mod = name_mod.replace("Clicks", "Spend")
    return name_mod


def send_email(email, message):
    s = smtplib.SMTP("smtp.gmail.com", 587)
    s.starttls()
    s.login("geethu4444@gmail.com", "jgydhpfusuremcol")
    s.sendmail("geethu4444@gmail.com", email, message)
    s.quit()


# if __name__ == "__main__":
#     initialize_data()


#############################################################################################################

import os
import json
import streamlit as st


# Function to get panels names
def get_panels_names(file_selected):
    raw_data_df = st.session_state["project_dct"]["current_media_performance"][
        "model_outputs"
    ][file_selected]["raw_data"]

    if "panel" in raw_data_df.columns:
        panel = list(set(raw_data_df["panel"]))
    elif "Panel" in raw_data_df.columns:
        panel = list(set(raw_data_df["Panel"]))
    else:
        panel = []

    return panel + ["aggregated"]


# Function to get metrics names
def get_metrics_names():
    return list(
        st.session_state["project_dct"]["current_media_performance"][
            "model_outputs"
        ].keys()
    )


# Function to load the original and modified rcs metadata files into dictionaries
def load_rcs_metadata_files():
    original_data = st.session_state["project_dct"]["response_curves"][
        "original_metadata_file"
    ]
    modified_data = st.session_state["project_dct"]["response_curves"][
        "modified_metadata_file"
    ]

    return original_data, modified_data


# Function to format name
def name_formating(name):
    # Replace underscores with spaces
    name_mod = name.replace("_", " ")

    # Capitalize the first letter of each word
    name_mod = name_mod.title()

    return name_mod


# Function to load the original and modified scenario metadata files into dictionaries
def load_scenario_metadata_files():
    original_data = st.session_state["project_dct"]["scenario_planner"][
        "original_metadata_file"
    ]
    modified_data = st.session_state["project_dct"]["scenario_planner"][
        "modified_metadata_file"
    ]

    return original_data, modified_data


# Function to generate RCS data and store it as dictionary
def generate_rcs_data():
    # Retrieve the list of all metric names from the specified directory
    metrics_list = get_metrics_names()

    # Dictionary to store RCS data for all metrics and their respective panels
    all_rcs_data_original = {}
    all_rcs_data_modified = {}

    # Iterate over each metric in the metrics list
    for metric in metrics_list:
        # Retrieve the list of panel names from the current metric's Excel file
        panel_list = get_panels_names(file_selected=metric)

        # Check if rcs_data_modified exist
        if (
            st.session_state["project_dct"]["response_curves"]["modified_metadata_file"]
            is not None
        ):
            modified_data = st.session_state["project_dct"]["response_curves"][
                "modified_metadata_file"
            ]

        # Iterate over each panel in the panel list
        for panel in panel_list:
            # Initialize the original RCS data for the current panel and metric
            rcs_dict_original, scenario = initialize_data(
                panel=panel,
                metrics=metric,
            )

            # Ensure the dictionary has the metric as a key for original data
            if metric not in all_rcs_data_original:
                all_rcs_data_original[metric] = {}

            # Store the original RCS data under the corresponding panel for the current metric
            all_rcs_data_original[metric][panel] = rcs_dict_original

            # Ensure the dictionary has the metric as a key for modified data
            if metric not in all_rcs_data_modified:
                all_rcs_data_modified[metric] = {}

            # Store the modified RCS data under the corresponding panel for the current metric
            for channel in rcs_dict_original:
                all_rcs_data_modified[metric][panel] = all_rcs_data_modified[
                    metric
                ].get(panel, {})

                try:
                    updated_rcs_dict = modified_data[metric][panel][channel]
                except:
                    updated_rcs_dict = {
                        "K": rcs_dict_original[channel]["K"],
                        "b": rcs_dict_original[channel]["b"],
                        "a": rcs_dict_original[channel]["a"],
                        "x0": rcs_dict_original[channel]["x0"],
                    }

                all_rcs_data_modified[metric][panel][channel] = updated_rcs_dict

    # Write the original RCS data
    st.session_state["project_dct"]["response_curves"][
        "original_metadata_file"
    ] = all_rcs_data_original

    # Write the modified RCS data
    st.session_state["project_dct"]["response_curves"][
        "modified_metadata_file"
    ] = all_rcs_data_modified


# Function to generate scenario data and store it as dictionary
def generate_scenario_data():
    # Retrieve the list of all metric names from the specified directory
    metrics_list = get_metrics_names()

    # Dictionary to store scenario data for all metrics and their respective panels
    all_scenario_data_original = {}
    all_scenario_data_modified = {}

    # Iterate over each metric in the metrics list
    for metric in metrics_list:
        # Retrieve the list of panel names from the current metric's Excel file
        panel_list = get_panels_names(metric)

        # Check if scenario_data_modified exist
        if (
            st.session_state["project_dct"]["scenario_planner"][
                "modified_metadata_file"
            ]
            is not None
        ):
            modified_data = st.session_state["project_dct"]["scenario_planner"][
                "modified_metadata_file"
            ]

        # Iterate over each panel in the panel list
        for panel in panel_list:
            # Initialize the original scenario data for the current panel and metric
            rcs_dict_original, scenario = initialize_data(
                panel=panel,
                metrics=metric,
            )

            # Ensure the dictionary has the metric as a key for original data
            if metric not in all_scenario_data_original:
                all_scenario_data_original[metric] = {}

            # Store the original scenario data under the corresponding panel for the current metric
            all_scenario_data_original[metric][panel] = class_convert_to_dict(scenario)

            # Ensure the dictionary has the metric as a key for modified data
            if metric not in all_scenario_data_modified:
                all_scenario_data_modified[metric] = {}

            # Store the modified scenario data under the corresponding panel for the current metric
            try:
                all_scenario_data_modified[metric][panel] = modified_data[metric][panel]
            except:
                all_scenario_data_modified[metric][panel] = class_convert_to_dict(
                    scenario
                )

    # Write the original scenario data
    st.session_state["project_dct"]["scenario_planner"][
        "original_metadata_file"
    ] = all_scenario_data_original

    # Write the modified scenario data
    st.session_state["project_dct"]["scenario_planner"][
        "modified_metadata_file"
    ] = all_scenario_data_modified


#############################################################################################################