# Importing necessary libraries
import streamlit as st

st.set_page_config(
    page_title="Data Import",
    page_icon="⚖️",
    layout="wide",
    initial_sidebar_state="collapsed",
)


import re
import sys
import pickle
import numbers
import traceback
import pandas as pd
from scenario import numerize
from post_gres_cred import db_cred
from collections import OrderedDict
from log_application import log_message
from utilities import set_header, load_local_css, update_db, project_selection
from constants import (
    upload_rows_limit,
    upload_column_limit,
    word_length_limit_lower,
    word_length_limit_upper,
    minimum_percent_overlap,
    minimum_row_req,
    percent_drop_col_threshold,
)

schema = db_cred["schema"]
load_local_css("styles.css")
set_header()


# Initialize project name session state
if "project_name" not in st.session_state:
    st.session_state["project_name"] = None

# Fetch project dictionary
if "project_dct" not in st.session_state:
    project_selection()
    st.stop()

# Display Username and Project Name
if "username" in st.session_state and st.session_state["username"] is not None:

    cols1 = st.columns([2, 1])

    with cols1[0]:
        st.markdown(f"**Welcome {st.session_state['username']}**")
    with cols1[1]:
        st.markdown(f"**Current Project: {st.session_state['project_name']}**")


# Initialize session state keys
if "granularity_selection_key" not in st.session_state:
    st.session_state["granularity_selection_key"] = st.session_state["project_dct"][
        "data_import"
    ]["granularity_selection"]


# Function to format name
def name_format_func(name):
    return str(name).strip().title()


# Function to get columns with specified prefix and remove prefix
def get_columns_with_prefix(df, prefix):
    return [
        col.replace(prefix, "")
        for col in df.columns
        if col.startswith(prefix) and str(col) != str(prefix)
    ]


# Function to fetch columns info
@st.cache_data(show_spinner=False)
def fetch_columns(gold_layer_df, data_upload_df):
    # Get lists of columns starting with 'spends_' and 'response_metric_' from gold_layer_df
    spends_columns_gold_layer = get_columns_with_prefix(gold_layer_df, "spends_")
    response_metric_columns_gold_layer = get_columns_with_prefix(
        gold_layer_df, "response_metric_"
    )

    # Get lists of columns starting with 'spends_' and 'response_metric_' from data_upload_df
    spends_columns_upload = get_columns_with_prefix(data_upload_df, "spends_")
    response_metric_columns_upload = get_columns_with_prefix(
        data_upload_df, "response_metric_"
    )

    # Combine lists from both DataFrames
    spends_columns = spends_columns_gold_layer + spends_columns_upload
    # Remove 'total' from the spends_columns list if it exists
    spends_columns = list(
        set([col for col in spends_columns if not col.endswith("_total")])
    )

    response_metric_columns = (
        response_metric_columns_gold_layer + response_metric_columns_upload
    )
    # Filter columns ending with '_total' and remove the '_total' suffix
    response_metric_columns = list(
        set(
            [
                col[:-6]
                for col in response_metric_columns
                if col.endswith("_total") and len(col[:-6]) != 0
            ]
        )
    )

    # Get list of all columns from both DataFrames
    gold_layer_columns = list(gold_layer_df.columns)
    data_upload_columns = list(data_upload_df.columns)

    # Combine all columns and get unique columns
    all_columns = list(set(gold_layer_columns + data_upload_columns))

    return (
        spends_columns,
        response_metric_columns,
        all_columns,
        gold_layer_columns,
        data_upload_columns,
    )


# Function to format values for display
@st.cache_data(show_spinner=False)
def format_values_for_display(values_list):
    # Format value
    formatted_list = [value.lower().strip() for value in values_list]
    # Join values with commas and 'and' before the last value
    if len(formatted_list) > 1:
        return ", ".join(formatted_list[:-1]) + ", and " + formatted_list[-1]
    elif formatted_list:
        return formatted_list[0]
    return "No values available"


# Function to validate input DataFrame
@st.cache_data(show_spinner=False)
def valid_input_df(
    df,
    spends_columns,
    response_metric_columns,
    total_columns,
    gold_layer_columns,
    data_upload_columns,
):
    # Check if DataFrame is empty
    if df.empty or len(df) < 1:
        return (True, None)

    # Check for invalid column names
    invalid_columns = [
        col
        for col in df.columns
        if not re.match(r"^[A-Za-z0-9_]+$", col)
        or not (word_length_limit_lower <= len(col) <= word_length_limit_upper)
    ]
    if invalid_columns:
        return (
            False,
            f"Invalid column names: {format_values_for_display(invalid_columns)}. Use only letters, numbers, and underscores. Column name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
        )

    # Ensure 'panel' column values are strings and conform to specified pattern and length
    if "panel" in df.columns:
        df["panel"] = df["panel"].astype(str).str.strip()
        invalid_panel_values = [
            val
            for val in df["panel"].unique()
            if not re.match(r"^[A-Za-z0-9_]+$", val)
            or not (word_length_limit_lower <= len(val) <= word_length_limit_upper)
        ]
        if invalid_panel_values:
            return (
                False,
                f"Invalid panel values: {format_values_for_display(invalid_panel_values)}. Use only letters, numbers, and underscores. Panel name length should be {word_length_limit_lower} to {word_length_limit_upper} characters long.",
            )

    # Check for missing required columns
    required_columns = ["date", "panel"]
    missing_columns = [col for col in required_columns if col not in df.columns]
    if missing_columns:
        return (
            False,
            f"Missing compulsory columns: {format_values_for_display(missing_columns)}.",
        )

    # Check if all other columns are numeric
    non_numeric_columns = [
        col
        for col in df.columns
        if col not in required_columns and not pd.api.types.is_numeric_dtype(df[col])
    ]
    if non_numeric_columns:
        return (
            False,
            f"Non-numeric columns: {format_values_for_display(non_numeric_columns)}. All columns except {format_values_for_display(required_columns)} should be numeric.",
        )

    # Ensure all columns in data_upload_columns are unique
    duplicate_columns_in_upload = [
        col for col in data_upload_columns if data_upload_columns.count(col) > 1
    ]
    if duplicate_columns_in_upload:
        return (
            False,
            f"Duplicate columns found in the uploaded data: {format_values_for_display(set(duplicate_columns_in_upload))}.",
        )

    # Convert 'date' column to datetime format
    try:
        df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
    except:
        return False, "The 'date' column is not in the correct format 'YYYY-MM-DD'."

    # Check date frequency
    unique_panels = df["panel"].unique()
    for panel in unique_panels:
        date_diff = df[df["panel"] == panel]["date"].diff().dropna()
        if not (
            (date_diff == pd.Timedelta(days=1)).all()
            or (date_diff == pd.Timedelta(weeks=1)).all()
        ):
            return False, "The 'date' column does not have a daily or weekly frequency."

    # Check for null values in 'date' or 'panel' columns
    if df[required_columns].isnull().any().any():
        return (
            False,
            f"The {format_values_for_display(required_columns)} should not contain null values.",
        )

    # Check for panels with less than 1% date overlap
    if not gold_layer_df.empty:
        panels_with_low_overlap = []
        unique_panels = list(
            set(df["panel"].unique()).union(set(gold_layer_df["panel"].unique()))
        )
        for panel in unique_panels:
            gold_layer_dates = set(
                gold_layer_df[gold_layer_df["panel"] == panel]["date"]
            )
            data_upload_dates = set(df[df["panel"] == panel]["date"])
            if gold_layer_dates and data_upload_dates:
                overlap = len(gold_layer_dates & data_upload_dates) / len(
                    gold_layer_dates | data_upload_dates
                )
            else:
                overlap = 0
            if overlap < (minimum_percent_overlap / 100):
                panels_with_low_overlap.append(panel)

        if panels_with_low_overlap:
            return (
                False,
                f"Date columns in the gold layer and uploaded data do not have at least {minimum_percent_overlap}% overlap for panels: {format_values_for_display(panels_with_low_overlap)}.",
            )

    # Check if spends_columns is less than two
    if len(spends_columns) < 2:
        return False, "Please add at least two spends columns."

    # Check if response_metric_columns is empty
    if len(response_metric_columns) < 1:
        return False, "Please add response metric columns."

    # Check if all numeric columns are positive except those starting with 'exogenous_' or 'internal_'
    valid_prefixes = ["exogenous_", "internal_"]
    negative_values_columns = [
        col
        for col in df.select_dtypes(include=[float, int]).columns
        if not any(col.startswith(prefix) for prefix in valid_prefixes)
        and (df[col] < 0).any()
    ]
    if negative_values_columns:
        return (
            False,
            f"Negative values detected in columns: {format_values_for_display(negative_values_columns)}. Ensure all media and response metric columns are positive.",
        )

    # Check for unassociated columns
    detected_channels = spends_columns + ["total"]
    unassociated_columns = []
    for col in df.columns:
        if (col.startswith("_") or col.endswith("_")) or not (
            col.startswith("exogenous_")  # Column starts with "exogenous_"
            or col.startswith("internal_")  # Column starts with "internal_"
            or any(
                col == f"spends_{channel}" for channel in spends_columns
            )  # Column is not in the format "spends_<channel>"
            or any(
                col == f"response_metric_{metric}_{channel}"
                for metric in response_metric_columns
                for channel in detected_channels
            )  # Column is not in the format "response_metric_<metric>_<channel>"
            or any(
                col.startswith("media_")
                and col.endswith(f"_{channel}")
                and len(col) > len(f"media__{channel}")
                for channel in spends_columns
            )  # Column is not in the format "media_<media_variable_name>_<channel>"
            or col in ["date", "panel"]
        ):
            unassociated_columns.append(col)

    if unassociated_columns:
        return (
            False,
            f"Columns with incorrect format detected: {format_values_for_display(unassociated_columns)}.",
        )

    return True, "The data is valid and meets all requirements."


# Function to load the uploaded Excel file into a DataFrame
@st.cache_data(show_spinner=False)
def load_and_transform_data(uploaded_file):
    # Load the uploaded file into a DataFrame if a file is uploaded
    if uploaded_file is not None:
        df = pd.read_excel(uploaded_file)
    else:
        df = pd.DataFrame()
        return df

    # Check if DataFrame exceeds row and column limits
    if len(df) > upload_rows_limit or len(df.columns) > upload_column_limit:
        st.warning(
            f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
            icon="⚠️",
        )

        # Log message
        log_message(
            "warning",
            f"Data exceeds the row limit of {numerize(upload_rows_limit)} or column limit of {numerize(upload_column_limit)}. Please upload a smaller file.",
            "Data Import",
        )

        return pd.DataFrame()

    # If the DataFrame contains only 'panel' and 'date' columns, return empty DataFrame
    if set(df.columns) == {"date", "panel"}:
        return pd.DataFrame()

    # Transform column names: lower, strip start and end, replace spaces with _
    df.columns = [str(col).strip().lower().replace(" ", "_") for col in df.columns]

    # If 'panel' column exists, clean its values
    try:
        if "panel" in df.columns:
            df["panel"] = (
                df["panel"].astype(str).str.lower().str.strip().str.replace(" ", "_")
            )
    except:
        return df

    try:
        df["date"] = pd.to_datetime(df["date"], format="%Y-%m-%d")
    except:
        # The 'date' column is not in the correct format 'YYYY-MM-DD'
        return df

    # Check date frequency and convert to daily if needed
    date_diff = df["date"].diff().dropna()
    if (date_diff == pd.Timedelta(days=1)).all():
        # Data is already at daily level
        return df
    elif (date_diff == pd.Timedelta(weeks=1)).all():
        # Data is at weekly level, convert to daily
        weekly_data = df.copy()
        daily_data = []

        for index, row in weekly_data.iterrows():
            week_start = row["date"] - pd.to_timedelta(row["date"].weekday(), unit="D")
            for i in range(7):
                daily_date = week_start + pd.DateOffset(days=i)
                new_row = row.copy()
                new_row["date"] = daily_date
                for col in df.columns:
                    if isinstance(new_row[col], numbers.Number):
                        new_row[col] = new_row[col] / 7
                daily_data.append(new_row)

        daily_data_df = pd.DataFrame(daily_data)
        return daily_data_df
    else:
        # The 'date' column does not have a daily or weekly frequency
        return df


# Function to merge DataFrames if present
@st.cache_data(show_spinner=False)
def merge_dataframes(gold_layer_df, data_upload_df):
    if gold_layer_df.empty and data_upload_df.empty:
        return pd.DataFrame()

    if not gold_layer_df.empty and not data_upload_df.empty:
        # Merge gold_layer_df and data_upload_df on 'panel', and 'date'
        merged_df = pd.merge(
            gold_layer_df,
            data_upload_df,
            on=["panel", "date"],
            how="outer",
            suffixes=("_gold", "_upload"),
        )

        # Handle duplicate columns
        for col in merged_df.columns:
            if col.endswith("_gold"):
                base_col = col[:-5]  # Remove '_gold' suffix
                upload_col = base_col + "_upload"  # Column name in data_upload_df
                if upload_col in merged_df.columns:
                    # Prefer values from data_upload_df
                    merged_df[base_col] = merged_df[upload_col].combine_first(
                        merged_df[col]
                    )
                    merged_df.drop(columns=[col, upload_col], inplace=True)
                else:
                    # Rename column to remove the suffix
                    merged_df.rename(columns={col: base_col}, inplace=True)

    elif data_upload_df.empty:
        merged_df = gold_layer_df.copy()

    elif gold_layer_df.empty:
        merged_df = data_upload_df.copy()

    return merged_df


# Function to check if all required columns are present in the Uploaded DataFrame
@st.cache_data(show_spinner=False)
def check_required_columns(df, detected_channels, detected_response_metric):
    required_columns = []

    # Add all channels with 'spends_' + detected channel name
    for channel in detected_channels:
        required_columns.append(f"spends_{channel}")

    # Add all channels with 'response_metric_' + detected channel name
    for response_metric in detected_response_metric:
        for channel in detected_channels + ["total"]:
            required_columns.append(f"response_metric_{response_metric}_{channel}")

    # Check for missing columns
    missing_columns = [col for col in required_columns if col not in df.columns]

    # Channel groupings
    no_media_data = []
    channel_columns_dict = {}
    for channel in detected_channels:
        channel_columns = [
            col
            for col in merged_df.columns
            if channel in col
            and not (
                col.startswith("response_metric_")
                or col.startswith("exogenous_")
                or col.startswith("internal_")
            )
            and col.endswith(channel)
        ]
        channel_columns_dict[channel] = channel_columns

        if len(channel_columns) <= 1:
            no_media_data.append(channel)

    return missing_columns, no_media_data, channel_columns_dict


# Function to prepare tool DataFrame
def prepare_tool_df(merged_df, granularity_selection):
    # Drop all response metric columns that do not end with '_total'
    cols_to_drop = [
        col
        for col in merged_df.columns
        if col.startswith("response_metric_") and not col.endswith("_total")
    ]

    # Create a DataFrame to be used for the tool
    tool_df = merged_df.drop(columns=cols_to_drop)

    # Convert to weekly granularity by aggregating all data for given panel and week
    if granularity_selection.lower() == "weekly":
        tool_df.set_index("date", inplace=True)
        tool_df = (
            tool_df.groupby(
                [pd.Grouper(freq="W-MON", closed="left", label="left"), "panel"]
            )
            .sum()
            .reset_index()
        )

    return tool_df


# Function to generate imputation DataFrame
def generate_imputation_df(tool_df):
    # Initialize lists to store the column details
    column_names = []
    categories = []
    missing_values_info = []
    zero_values_info = []
    imputation_methods = []

    # Define the function to calculate the percentage of missing values
    def calculate_missing_percentage(series):
        return series.isnull().sum(), (series.isnull().mean() * 100)

    # Define the function to calculate the percentage of zero values
    def calculate_zero_percentage(series):
        return (series == 0).sum(), ((series == 0).mean() * 100)

    # Iterate over each column to categorize and calculate missing and zero values
    for col in tool_df.columns:
        # Determine category based on column name prefix
        if col == "date" or col == "panel":
            continue
        elif col.startswith("response_metric_"):
            categories.append("Response Metrics")
        elif col.startswith("spends_"):
            categories.append("Spends")
        elif col.startswith("exogenous_"):
            categories.append("Exogenous")
        elif col.startswith("internal_"):
            categories.append("Internal")
        else:
            categories.append("Media")

        # Calculate missing values and percentage
        missing_count, missing_percentage = calculate_missing_percentage(tool_df[col])
        missing_values_info.append(f"{missing_count} ({missing_percentage:.1f}%)")

        # Calculate zero values and percentage
        zero_count, zero_percentage = calculate_zero_percentage(tool_df[col])
        zero_values_info.append(f"{zero_count} ({zero_percentage:.1f}%)")

        # Determine default imputation method based on conditions
        if col.startswith("spends_"):
            imputation_methods.append("Fill with 0")
        elif col.startswith("response_metric_"):
            imputation_methods.append("Fill with Mean")
        elif zero_percentage + missing_percentage > percent_drop_col_threshold:
            imputation_methods.append("Drop Column")
        else:
            imputation_methods.append("Fill with Mean")

        column_names.append(col)

    # Create the DataFrame
    imputation_df = pd.DataFrame(
        {
            "Column Name": column_names,
            "Category": categories,
            "Missing Values": missing_values_info,
            "Zero Values": zero_values_info,
            "Imputation Method": imputation_methods,
        }
    )

    # Define the category order for sorting
    category_order = {
        "Response Metrics": 1,
        "Spends": 2,
        "Media": 3,
        "Exogenous": 4,
        "Internal": 5,
    }

    # Add a temporary column for sorting based on the category order
    imputation_df["Category Order"] = imputation_df["Category"].map(category_order)

    # Sort the DataFrame based on the category order and then drop the temporary column
    imputation_df = imputation_df.sort_values(
        by=["Category Order", "Column Name"]
    ).drop(columns=["Category Order"])

    return imputation_df


# Function to perform imputation as per user requests
def perform_imputation(imputation_df, tool_df):
    # Detect channels associated with spends
    detected_channels = [
        col.replace("spends_", "")
        for col in tool_df.columns
        if col.startswith("spends_")
    ]

    # Create a dictionary with keys as channels and values as associated columns
    group_dict = {
        channel: [
            col
            for col in tool_df.columns
            if channel in col
            and not (
                col.startswith("response_metric_")
                or col.startswith("exogenous_")
                or col.startswith("internal_")
            )
        ]
        for channel in detected_channels
    }

    # Create a reverse dictionary with keys as columns and values as channels
    column_to_channel_dict = {
        col: channel for channel, cols in group_dict.items() for col in cols
    }

    # Perform imputation
    already_dropped = []
    for index, row in imputation_df.iterrows():
        col_name = row["Column Name"]
        impute_method = row["Imputation Method"]

        # Skip already dropped columns
        if col_name in already_dropped:
            continue

        # Skip imputation if dropping response metric column and add warning
        if impute_method == "Drop Column" and col_name.startswith("response_metric_"):
            return None, {}, f"Cannot drop response metric column: {col_name}"

        # Drop column if requested
        if impute_method == "Drop Column":
            # If spends column is dropped, drop all related columns
            if col_name.startswith("spends_"):
                tool_df.drop(
                    columns=group_dict[col_name.replace("spends_", "")],
                    inplace=True,
                )
                already_dropped += group_dict[col_name.replace("spends_", "")]
                del group_dict[col_name.replace("spends_", "")]
            else:
                tool_df.drop(columns=[col_name], inplace=True)
                if not (
                    col_name.startswith("exogenous_")
                    or col_name.startswith("internal_")
                ):
                    group_name = column_to_channel_dict[col_name]
                    group_dict[group_name].remove(col_name)

                    # Check for channels with one or fewer associated columns and add warning if needed
                    if len(group_dict[group_name]) <= 1:
                        return (
                            None,
                            {},
                            f"No media variable associated with category {col_name.replace('spends_', '')}.",
                        )
            continue

        # Check for each panel
        for panel in tool_df["panel"].unique():
            panel_df = tool_df[tool_df["panel"] == panel]

            # Check if the column is entirely null or empty for the current panel
            if panel_df[col_name].isnull().all():
                if impute_method in ["Fill with Mean", "Fill with Median"]:
                    return (
                        None,
                        {},
                        f"Cannot impute for empty column(s) with mean or median. Select 'Fill with 0'. Details: Panel: {panel}, Column: {col_name}",
                    )

        # Fill missing values as requested
        if impute_method == "Fill with Mean":
            tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
                lambda x: x.fillna(x.mean())
            )
        elif impute_method == "Fill with Median":
            tool_df[col_name] = tool_df.groupby("panel")[col_name].transform(
                lambda x: x.fillna(x.median())
            )
        elif impute_method == "Fill with 0":
            tool_df[col_name].fillna(0, inplace=True)

    # Check if final DataFrame has at least one response metric and two spends categories
    response_metrics = [
        col for col in tool_df.columns if col.startswith("response_metric_")
    ]
    spends_categories = [col for col in tool_df.columns if col.startswith("spends_")]

    if len(response_metrics) < 1:
        return (None, {}, "The final DataFrame must have at least one response metric.")
    if len(spends_categories) < 2:
        return (
            None,
            {},
            "The final DataFrame must have at least two spends categories.",
        )

    return tool_df, group_dict, "Imputed Successfully!"


# Function to display groups with custom styling
def display_groups(input_dict):
    # Define custom CSS for pastel light blue rounded rectangle
    custom_css = """
    <style>
    .group-box {
        background-color: #ffdaab;
        border-radius: 10px;
        padding: 10px;
        margin: 5px 0;
    }
    </style>
    """
    st.markdown(custom_css, unsafe_allow_html=True)

    for group_name, values in input_dict.items():
        group_html = f"<div class='group-box'><strong>{group_name}:</strong> {format_values_for_display(values)}</div>"
        st.markdown(group_html, unsafe_allow_html=True)


# Function to categorize columns and create an ordered dictionary
def create_ordered_category_dict(df):
    category_dict = {
        "Response Metrics": [],
        "Spends": [],
        "Media": [],
        "Exogenous": [],
        "Internal": [],
    }

    # Define the category order for sorting
    category_order = {
        "Response Metrics": 1,
        "Spends": 2,
        "Media": 3,
        "Exogenous": 4,
        "Internal": 5,
    }

    for column in df.columns:
        if column == "date" or column == "panel":
            continue  # Skip 'date' and 'panel' columns

        if column.startswith("response_metric_"):
            category_dict["Response Metrics"].append(column)
        elif column.startswith("spends_"):
            category_dict["Spends"].append(column)
        elif column.startswith("exogenous_"):
            category_dict["Exogenous"].append(column)
        elif column.startswith("internal_"):
            category_dict["Internal"].append(column)
        else:
            category_dict["Media"].append(column)

    # Sort the dictionary based on the defined category order
    sorted_category_dict = OrderedDict(
        sorted(category_dict.items(), key=lambda item: category_order[item[0]])
    )

    return sorted_category_dict


try:
    # Page Title
    st.title("Data Import")

    # Create file uploader
    uploaded_file = st.file_uploader(
        "Upload Data", type=["xlsx"], accept_multiple_files=False
    )

    # Expander with markdown for upload rules
    with st.expander("Upload Rules and Guidelines"):
        st.markdown(
            """
        ### Upload Guidelines
        
        Please ensure your data adheres to the following rules:
        
        1. **File Format**: 
        - Upload all data in a single Excel file.
        
        2. **Compulsory Columns**: 
        - **Date**: Must be in the format `YYYY-MM-DD` only.
        - **Panel**: If no panel data exists, use `aggregated` as a single panel.
        
        3. **Column Naming Conventions**: 
        - All columns should start with the associated category prefix.
        
        **Examples**:

        - **Response Metric Column**: 
        - Format: `response_metric_<response_metric_name>_<channel_name>`
        - Example: `response_metric_revenue_facebook`
        
        - **Total Response Metric**: 
        - Format: `response_metric_<response_metric_name>_total`
        - Example: `response_metric_revenue_total`

        - **Spend Column**: 
        - Format: `spends_<channel_name>`
        - Example: `spends_facebook`

        - **Media Column**: 
        - Format: `media_<media_variable_name>_<channel_name>`
        - Example: `media_clicks_facebook`
        
        - **Exogenous Column**: 
        - Format: `exogenous_<variable_name>`
        - Example: `exogenous_unemployment_rate`
        
        - **Internal Column**: 
        - Format: `internal_<variable_name>`
        - Example: `internal_discount`
        
        **Notes**:
        
        - The `total` response metric should represent the total for a particular date and panel, including all channels and organic contributions.
        - The `date` column for weekly data should be the Monday of that week, representing the data from that Monday to the following Sunday. Example: If the week starts on Monday, August 5th, 2024, and ends on Sunday, August 11th, 2024, the date column for that week should display 2024-08-05.
        """
        )

    # Upload warning placeholder
    upload_warning_placeholder = st.container()

    # Load the uploaded file into a DataFrame if a file is uploaded
    data_upload_df = load_and_transform_data(uploaded_file)

    # Columns for user input
    granularity_col, validate_process_col = st.columns(2)

    # Dropdown for data granularity
    granularity_selection = granularity_col.selectbox(
        "Select data granularity",
        options=["daily", "weekly"],
        format_func=name_format_func,
        key="granularity_selection_key",
    )

    # Gold Layer DataFrame
    gold_layer_df = st.session_state["project_dct"]["data_import"]["gold_layer_df"]
    if not gold_layer_df.empty:
        st.subheader("Gold Layer DataFrame")
        with st.expander("Gold Layer DataFrame"):
            st.dataframe(
                gold_layer_df,
                hide_index=True,
                column_config={
                    "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
                },
            )
    else:
        st.info(
            "No gold layer data is selected for this project. Please upload data manually.",
            icon="📊",
        )

    # Check input data
    with validate_process_col:
        st.write("##")  # Padding

    if validate_process_col.button("Validate and Process", use_container_width=True):
        with st.spinner("Processing ..."):
            # Check if both DataFrames are empty
            valid_input = True
            if gold_layer_df.empty and data_upload_df.empty:
                # If both gold_layer_df and data_upload_df are empty, display a warning and stop the script
                st.warning(
                    "Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
                    icon="⚠️",
                )

                # Log message
                log_message(
                    "warning",
                    "Both the Gold Layer data and the uploaded data are empty. Please provide at least one data source.",
                    "Data Import",
                )
                valid_input = False

            # If the uploaded DataFrame is empty and the Gold Layer is not, swap them to ensure all validation conditions are checked
            elif not gold_layer_df.empty and data_upload_df.empty:
                data_upload_df, gold_layer_df = (
                    gold_layer_df.copy(),
                    data_upload_df.copy(),
                )
                valid_input = True

            if valid_input:
                # Fetch all necessary columns list
                (
                    spends_columns,
                    response_metric_columns,
                    total_columns,
                    gold_layer_columns,
                    data_upload_columns,
                ) = fetch_columns(gold_layer_df, data_upload_df)

                with upload_warning_placeholder:
                    valid_input, message = valid_input_df(
                        data_upload_df,
                        spends_columns,
                        response_metric_columns,
                        total_columns,
                        gold_layer_columns,
                        data_upload_columns,
                    )
                    if not valid_input:
                        st.warning(message, icon="⚠️")

                        # Log message
                        log_message("warning", message, "Data Import")

            # Merge gold_layer_df and data_upload_df on 'panel' and 'date'
            if valid_input:
                merged_df = merge_dataframes(gold_layer_df, data_upload_df)

                missing_columns, no_media_data, channel_columns_dict = (
                    check_required_columns(
                        merged_df, spends_columns, response_metric_columns
                    )
                )

                with upload_warning_placeholder:
                    # Warning for categories with no media data
                    if no_media_data:
                        st.warning(
                            f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
                            icon="⚠️",
                        )
                        valid_input = False

                        # Log message
                        log_message(
                            "warning",
                            f"Categories without media data: {format_values_for_display(no_media_data)}. Please upload at least one media column to proceed.",
                            "Data Import",
                        )

                    # Warning for insufficient rows
                    elif any(
                        granularity_selection == "daily"
                        and len(merged_df[merged_df["panel"] == panel])
                        < minimum_row_req
                        for panel in merged_df["panel"].unique()
                    ):
                        st.warning(
                            f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
                            icon="⚠️",
                        )
                        valid_input = False

                        # Log message
                        log_message(
                            "warning",
                            f"Insufficient data. Please provide at least {minimum_row_req} days of data for all panel.",
                            "Data Import",
                        )

                    elif any(
                        granularity_selection == "weekly"
                        and len(merged_df[merged_df["panel"] == panel])
                        < minimum_row_req * 7
                        for panel in merged_df["panel"].unique()
                    ):
                        st.warning(
                            f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
                            icon="⚠️",
                        )
                        valid_input = False

                        # Log message
                        log_message(
                            "warning",
                            f"Insufficient data. Please provide at least {minimum_row_req} weeks of data for all panel.",
                            "Data Import",
                        )

                    # Info for missing columns
                    elif missing_columns:
                        st.info(
                            f"Missing columns: {format_values_for_display(missing_columns)}. Please upload all required columns.",
                            icon="💡",
                        )

            if valid_input:
                # Create a copy of the merged DataFrame for dashboard purposes
                dashboard_df = merged_df

                # Create a DataFrame for tool purposes
                tool_df = prepare_tool_df(merged_df, granularity_selection)

                # Create Imputation DataFrame
                imputation_df = generate_imputation_df(tool_df)

                # Save data to project dictionary
                st.session_state["project_dct"]["data_import"][
                    "granularity_selection"
                ] = st.session_state["granularity_selection_key"]
                st.session_state["project_dct"]["data_import"][
                    "dashboard_df"
                ] = dashboard_df
                st.session_state["project_dct"]["data_import"]["tool_df"] = tool_df
                st.session_state["project_dct"]["data_import"]["unique_panels"] = (
                    tool_df["panel"].unique()
                )
                st.session_state["project_dct"]["data_import"][
                    "imputation_df"
                ] = imputation_df

                # Success message
                with upload_warning_placeholder:
                    st.success("Processed Successfully!", icon="🗂️")
                    st.toast("Processed Successfully!", icon="🗂️")

                # Log message
                log_message("info", "Processed Successfully!", "Data Import")

    # Load saved data from project dictionary
    if st.session_state["project_dct"]["data_import"]["tool_df"] is None:
        st.stop()
    else:
        tool_df = st.session_state["project_dct"]["data_import"]["tool_df"]
        imputation_df = st.session_state["project_dct"]["data_import"]["imputation_df"]
        unique_panels = st.session_state["project_dct"]["data_import"]["unique_panels"]

    # Unique Panel
    st.subheader("Unique Panel")

    # Get unique panels count
    total_count = len(unique_panels)

    # Define custom CSS for pastel light blue rounded rectangle
    custom_css = """
    <style>
    .panel-box {
        background-color: #ffdaab;
        border-radius: 10px;
        padding: 10px;
        margin: 0 0;
    }
    </style>
    """

    # Display unique panels with total count
    st.markdown(custom_css, unsafe_allow_html=True)
    panel_html = f"<div class='panel-box'><strong>Unique Panels:</strong> {format_values_for_display(unique_panels)}<br><strong>Total Count:</strong> {total_count}</div>"
    st.markdown(panel_html, unsafe_allow_html=True)
    st.write("##")  # Padding

    # Impute Missing Values
    st.subheader("Impute Missing Values")
    edited_imputation_df = st.data_editor(
        imputation_df,
        column_config={
            "Imputation Method": st.column_config.SelectboxColumn(
                options=[
                    "Drop Column",
                    "Fill with Mean",
                    "Fill with Median",
                    "Fill with 0",
                ],
                required=True,
                default="Fill with 0",
            ),
        },
        column_order=[
            "Column Name",
            "Category",
            "Missing Values",
            "Zero Values",
            "Imputation Method",
        ],
        disabled=["Column Name", "Category", "Missing Values", "Zero Values"],
        hide_index=True,
        use_container_width=True,
        key="imputation_df_key",
    )

    # Expander with markdown for imputation rules
    with st.expander("Impute Missing Values Guidelines"):
        st.markdown(
            f"""
        ### Imputation Guidelines
        
        Please adhere to the following rules when handling missing values:

        1. **Default Imputation Strategies**:
        - **Response Metrics**: Imputed using the **mean** value of the column.
        - **Spends**: Imputed with **zero** values.
        - **Media, Exogenous, Internal**: Imputation strategy is **dynamic** based on the data.

        2. **Drop Threshold**:
        - If the combined percentage of **zeros** and **null values** in any column exceeds `{percent_drop_col_threshold}%`, the column will be **categorized to drop** by default which user can change manually.
        - **Example**: If `spends_facebook` has more than `{percent_drop_col_threshold}%` of zeros and nulls combined, it will be marked for dropping.

        3. **Category Generation and Association**:
        - Categories are automatically generated from the **Spends** columns. 
        - **Example**: The column `spends_facebook` will generate the **facebook** category. This means columns like `spends_facebook`, `media_impression_facebook` and `media_clicks_facebook` will also be associated with this category.

        4. **Column Association and Imputation**:
        - Each category must have at least **one Media column** associated with it for imputation to proceed.
        - **Example**: If the **facebook** category does not have any media columns like `media_impression_facebook`, imputation will not be allowed for that category.
        - Solution: Either **drop the entire category** if it is empty, or **impute the columns** associated with the category instead of dropping them.

        5. **Response Metrics and Category Count**:
        - Dropping **Response Metric** columns is **not allowed** under any circumstances.
        - At least **two categories** must exist after imputation, or the Imputation will not proceed.
        - **Example**: If only **facebook** remains after selection, imputation will be halted.

        **Notes**:

        - The decision to drop a spends column will result in all associated columns being dropped. 
        - **Example**: Dropping `spends_facebook` will also drop all related columns like `media_impression_facebook` and `media_clicks_facebook`.
        """
        )

    # Imputation Warning Placeholder
    imputation_warning_placeholder = st.container()

    # Save the DataFrame and dictionary from the current session
    if st.button("Impute and Save", use_container_width=True):
        with st.spinner("Imputing ..."):
            with imputation_warning_placeholder:
                # Perform Imputation
                imputed_tool_df, group_dict, message = perform_imputation(
                    edited_imputation_df.copy(), tool_df.copy()
                )

                if imputed_tool_df is None:
                    st.warning(message, icon="⚠️")

                    # Log message
                    log_message("warning", message, "Data Import")

                else:
                    st.session_state["project_dct"]["data_import"][
                        "imputed_tool_df"
                    ] = imputed_tool_df
                    st.session_state["project_dct"]["data_import"][
                        "imputation_df"
                    ] = edited_imputation_df
                    st.session_state["project_dct"]["data_import"][
                        "group_dict"
                    ] = group_dict
                    st.session_state["project_dct"]["data_import"]["category_dict"] = (
                        create_ordered_category_dict(imputed_tool_df)
                    )

            if imputed_tool_df is not None:
                # Update DB
                update_db(
                    prj_id=st.session_state["project_number"],
                    page_nam="Data Import",
                    file_nam="project_dct",
                    pkl_obj=pickle.dumps(st.session_state["project_dct"]),
                    schema=schema,
                )

                # Success message
                st.success("Saved Successfully!", icon="💾")
                st.toast("Saved Successfully!", icon="💾")

                # Log message
                log_message("info", "Saved Successfully!", "Data Import")

    # Load saved data from project dictionary
    if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
        st.stop()
    else:
        imputed_tool_df = st.session_state["project_dct"]["data_import"][
            "imputed_tool_df"
        ]
        group_dict = st.session_state["project_dct"]["data_import"]["group_dict"]
        category_dict = st.session_state["project_dct"]["data_import"]["category_dict"]

    # Channel Groupings
    st.subheader("Channel Groupings")
    display_groups(group_dict)
    st.write("##")  # Padding

    # Variable Categorization
    st.subheader("Variable Categorization")
    display_groups(category_dict)
    st.write("##")  # Padding

    # Final DataFrame
    st.subheader("Final DataFrame")
    st.dataframe(
        imputed_tool_df,
        hide_index=True,
        column_config={
            "date": st.column_config.DateColumn("date", format="YYYY-MM-DD")
        },
    )
    st.write("##")  # Padding

except Exception as e:
    # Capture the error details
    exc_type, exc_value, exc_traceback = sys.exc_info()
    error_message = "".join(
        traceback.format_exception(exc_type, exc_value, exc_traceback)
    )

    # Log message
    log_message("error", f"An error occurred: {error_message}.", "Data Import")

    # Display a warning message
    st.warning(
        "Oops! Something went wrong. Please try refreshing the tool or creating a new project.",
        icon="⚠️",
    )