import streamlit as st
import plotly.express as px
import numpy as np
import plotly.graph_objects as go
from sklearn.metrics import r2_score
from collections import OrderedDict
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import streamlit as st
import re
from matplotlib.colors import ListedColormap
# from st_aggrid import AgGrid, GridOptionsBuilder
# from src.agstyler import PINLEFT, PRECISION_TWO, draw_grid


def format_numbers(x):
    if abs(x) >= 1e6:
        # Format as millions with one decimal place and commas
        return f'{x/1e6:,.1f}M'
    elif abs(x) >= 1e3:
        # Format as thousands with one decimal place and commas
        return f'{x/1e3:,.1f}K'
    else:
        # Format with one decimal place and commas for values less than 1000
        return f'{x:,.1f}'

    

def line_plot(data, x_col, y1_cols, y2_cols, title):
    """
    Create a line plot with two sets of y-axis data.

    Parameters:
    data (DataFrame): The data containing the columns to be plotted.
    x_col (str): The column name for the x-axis.
    y1_cols (list): List of column names for the primary y-axis.
    y2_cols (list): List of column names for the secondary y-axis.
    title (str): The title of the plot.

    Returns:
    fig (Figure): The Plotly figure object with the line plot.
    """
    fig = go.Figure()

    # Add traces for the primary y-axis
    for y1_col in y1_cols:
        fig.add_trace(go.Scatter(x=data[x_col], y=data[y1_col], mode='lines', name=y1_col, line=dict(color='#11B6BD')))

    # Add traces for the secondary y-axis
    for y2_col in y2_cols:
        fig.add_trace(go.Scatter(x=data[x_col], y=data[y2_col], mode='lines', name=y2_col, yaxis='y2', line=dict(color='#739FAE')))

    # Configure the layout for the secondary y-axis if needed
    if len(y2_cols) != 0:
        fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))
    else:
        fig.update_layout(yaxis=dict(), yaxis2=dict(overlaying='y', side='right'))

    # Add title if provided
    if title:
        fig.update_layout(title=title)

    # Customize axes and legend
    fig.update_xaxes(showgrid=False)
    fig.update_yaxes(showgrid=False)
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="top",
        y=1.1,
        xanchor="center",
        x=0.5
    ))

    return fig



def line_plot_target(df, target, title):
    """
    Create a line plot with a trendline for a target column.

    Parameters:
    df (DataFrame): The data containing the columns to be plotted.
    target (str): The column name for the y-axis.
    title (str): The title of the plot.

    Returns:
    fig (Figure): The Plotly figure object with the line plot and trendline.
    """
    # Calculate the trendline coefficients
    coefficients = np.polyfit(df['date'].view('int64'), df[target], 1)
    trendline = np.poly1d(coefficients)
    fig = go.Figure()

    # Add the target line plot
    fig.add_trace(go.Scatter(x=df['date'], y=df[target], mode='lines', name=target, line=dict(color='#11B6BD')))
    
    # Calculate and add the trendline plot
    trendline_x = df['date']
    trendline_y = trendline(df['date'].view('int64'))
    fig.add_trace(go.Scatter(x=trendline_x, y=trendline_y, mode='lines', name='Trendline', line=dict(color='#739FAE')))

    # Update layout with title and x-axis type
    fig.update_layout(
        title=title,
        xaxis=dict(type='date')
    )

    # Add vertical lines at the start of each year
    for year in df['date'].dt.year.unique()[1:]:
        january_1 = pd.Timestamp(year=year, month=1, day=1)
        fig.add_shape(
            go.layout.Shape(
                type="line",
                x0=january_1,
                x1=january_1,
                y0=0,
                y1=1,
                xref="x",
                yref="paper",
                line=dict(color="grey", width=1.5, dash="dash"),
            )
        )
    
    # Customize the legend
    fig.update_layout(legend=dict(
        orientation="h",
        yanchor="top",
        y=1.1,
        xanchor="center",
        x=0.5
    ))
    
    return fig


def correlation_plot(df, selected_features, target):
    """
    Create a correlation heatmap plot for selected features and target column.

    Parameters:
    df (DataFrame): The data containing the columns to be plotted.
    selected_features (list): List of column names to be included in the correlation plot.
    target (str): The target column name to be included in the correlation plot.

    Returns:
    fig (Figure): The Matplotlib figure object with the correlation heatmap plot.
    """
    # Define custom colormap
    custom_cmap = ListedColormap(['#08083B', "#11B6BD"])  
    
    # Select the relevant columns for correlation calculation
    corr_df = df[selected_features]
    corr_df = pd.concat([corr_df, df[target]], axis=1)
    
    # Create a matplotlib figure and axis
    fig, ax = plt.subplots(figsize=(16, 12))
    
    # Generate the heatmap with correlation coefficients
    sns.heatmap(corr_df.corr(), annot=True, cmap='Blues', fmt=".2f", linewidths=0.5, mask=np.triu(corr_df.corr()))
    
    # Customize the plot
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    
    return fig


def summary(data, selected_feature, spends, Target=None):
    """
    Create a summary table of selected features and optionally a target column.

    Parameters:
    data (DataFrame): The data containing the columns to be summarized.
    selected_feature (list): List of column names to be included in the summary.
    spends (str): The column name for the spends data.
    Target (str, optional): The target column name for additional summary calculations. Default is None.

    Returns:
    sum_df (DataFrame): The summary DataFrame with formatted values.
    """
    if Target:
        # Summarize data for the target column
        sum_df = data[selected_feature]
        sum_df['Year'] = data['date'].dt.year
        sum_df = sum_df.groupby('Year')[selected_feature].sum().reset_index()
        
        # Calculate total sum and append to the DataFrame
        total_sum = sum_df.sum(numeric_only=True)
        total_sum['Year'] = 'Total'
        sum_df = pd.concat([sum_df, total_sum.to_frame().T], axis=0, ignore_index=True).copy()
        
        # Set 'Year' as index and format numbers
        sum_df.set_index(['Year'], inplace=True)
        sum_df = sum_df.applymap(format_numbers)
        
        # Format spends columns as currency
        spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
        for col in spends_col:
            sum_df[col] = sum_df[col].map(lambda x: f'${x}')
        
        return sum_df
    else:
        # Include spends in the selected features
        selected_feature.append(spends)
        
        # Ensure unique features
        selected_feature = list(set(selected_feature))
        
        if len(selected_feature) > 1:
            imp_clicks = selected_feature[1]
            spends_col = selected_feature[0]
            
            # Summarize data for the selected features
            sum_df = data[selected_feature]
            sum_df['Year'] = data['date'].dt.year
            sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
            
            # Calculate CPM/CPC
            sum_df['CPM/CPC'] = (sum_df[spends_col] / sum_df[imp_clicks]) * 1000
            
            # Calculate grand total and append to the DataFrame
            sum_df.loc['Grand Total'] = sum_df.sum()
            
            # Format numbers and replace NaNs
            sum_df = sum_df.applymap(format_numbers)
            sum_df.fillna('-', inplace=True)
            sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
            
            # Format spends columns as currency
            sum_df[spends_col] = sum_df[spends_col].map(lambda x: f'${x}')
            
            return sum_df
        else:
            # Summarize data for a single selected feature
            sum_df = data[selected_feature]
            sum_df['Year'] = data['date'].dt.year
            sum_df = sum_df.groupby('Year')[selected_feature].agg('sum')
            
            # Calculate grand total and append to the DataFrame
            sum_df.loc['Grand Total'] = sum_df.sum()
            
            # Format numbers and replace NaNs
            sum_df = sum_df.applymap(format_numbers)
            sum_df.fillna('-', inplace=True)
            sum_df = sum_df.replace({"0.0": '-', 'nan': '-'})
            
            # Format spends columns as currency
            spends_col = [col for col in sum_df.columns if any(keyword in col for keyword in ['spends', 'cost'])]
            for col in spends_col:
                sum_df[col] = sum_df[col].map(lambda x: f'${x}')
            
            return sum_df



def sanitize_key(key, prefix=""):
    # Use regular expressions to remove non-alphanumeric characters and spaces
    key = re.sub(r'[^a-zA-Z0-9]', '', key)
    return f"{prefix}{key}"