diff --git "a/pages/9_Scenario_Planner.py" "b/pages/9_Scenario_Planner.py" new file mode 100644--- /dev/null +++ "b/pages/9_Scenario_Planner.py" @@ -0,0 +1,2977 @@ +# Importing necessary libraries +import streamlit as st + +st.set_page_config( + page_title="Scenario Planner", + page_icon="⚖️", + layout="wide", + initial_sidebar_state="collapsed", +) + +# Disable +/- for number input +st.markdown( + """ +""", + unsafe_allow_html=True, +) + +import re +import sys +import copy +import pickle +import traceback +import numpy as np +import pandas as pd +from scenario import numerize +import plotly.graph_objects as go +from post_gres_cred import db_cred +from scipy.optimize import minimize +from log_application import log_message +from utilities import project_selection, update_db, set_header, load_local_css +from utilities import ( + get_panels_names, + get_metrics_names, + name_formating, + load_rcs_metadata_files, + load_scenario_metadata_files, + generate_rcs_data, + generate_scenario_data, +) +from constants import ( + xtol_tolerance_per, + mroi_threshold, + word_length_limit_lower, + word_length_limit_upper, +) + + +schema = db_cred["schema"] +load_local_css("styles.css") +set_header() + +# Initialize project name session state +if "project_name" not in st.session_state: + st.session_state["project_name"] = None + +# Fetch project dictionary +if "project_dct" not in st.session_state: + project_selection() + st.stop() + +# Display Username and Project Name +if "username" in st.session_state and st.session_state["username"] is not None: + + cols1 = st.columns([2, 1]) + + with cols1[0]: + st.markdown(f"**Welcome {st.session_state['username']}**") + with cols1[1]: + st.markdown(f"**Current Project: {st.session_state['project_name']}**") + +# Initialize ROI threshold +if "roi_threshold" not in st.session_state: + st.session_state.roi_threshold = 1 + +# Initialize message display holder +if "message_display" not in st.session_state: + st.session_state.message_display = {"type": "success", "message": None, "icon": ""} + + +# Function to reset modified_scenario_data +def reset_scenario(metrics_selected=None, panel_selected=None): + # Clear message_display + st.session_state.message_display = {"type": "success", "message": None, "icon": ""} + + # Use default values from session state if not provided + if metrics_selected is None: + metrics_selected = st.session_state["response_metrics_selectbox_sp"] + if panel_selected is None: + panel_selected = st.session_state["panel_selected_selectbox_sp"] + + # Load original scenario data + original_data = st.session_state["project_dct"]["scenario_planner"][ + "original_metadata_file" + ] + original_scenario_data = original_data[metrics_selected][panel_selected] + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Update the specific section with the original scenario data + data[metrics_selected][panel_selected] = copy.deepcopy(original_scenario_data) + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + +# Function to build s curve +def s_curve(x, power, K, b, a, x0): + return K / (1 + b * np.exp(-a * ((x / 10**power) - x0))) + + +# Function to retrieve S-curve parameters for a given metric, panel, and channel +def get_s_curve_params( + metrics_selected, + panel_selected, + channel_selected, + original_rcs_data, + modified_rcs_data, +): + # Retrieve 'power' parameter from the original data for the specific metric, panel, and channel + power = original_rcs_data[metrics_selected][panel_selected][channel_selected][ + "power" + ] + + # Get the S-curve parameters from the modified data for the same metric, panel, and channel + s_curve_param = modified_rcs_data[metrics_selected][panel_selected][ + channel_selected + ] + + # Load modified scenario metadata + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Update modified S-curve parameters + data[metrics_selected][panel_selected]["channels"][channel_selected][ + "response_curve_params" + ] = s_curve_param + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + # Update the 'power' parameter in the modified S-curve parameters with the original 'power' value + s_curve_param["power"] = power + + # Return the updated S-curve parameters + return s_curve_param + + +# Function to calculate total contribution +def get_total_contribution( + spends, channels, s_curve_params, channels_proportion, modified_scenario_data +): + total_contribution = 0 + for i in range(len(channels)): + channel_name = channels[i] + channel_s_curve_params = s_curve_params[channel_name] + spend_proportion = spends[i] * channels_proportion[channel_name] + total_contribution += sum( + s_curve( + spend_proportion, + channel_s_curve_params["power"], + channel_s_curve_params["K"], + channel_s_curve_params["b"], + channel_s_curve_params["a"], + channel_s_curve_params["x0"], + ) + ) + sum( + modified_scenario_data["channels"][channel_name]["correction"] + ) # correction for s-curve + return total_contribution + sum(modified_scenario_data["constant"]) + + +# Function to calculate total spends +def get_total_spends(spends, channels_conversion_ratio): + return np.sum(spends * np.array(list(channels_conversion_ratio.values()))) + + +# Function to optimizes spends for all channels given bounds and a total spend target +def optimizer( + optimization_goal, + s_curve_params, + channels_spends, + channels_proportion, + channels_conversion_ratio, + total_target, + bounds_dict, + modified_scenario_data, +): + # Extract channel names and corresponding actual spends + channels = list(channels_spends.keys()) + actual_spends = np.array(list(channels_spends.values())) + num_channels = len(actual_spends) + + # Define the objective function based on the optimization goal + def objective_fun(spends): + if optimization_goal == "Spend": + # Minimize negative total contribution to maximize the total contribution + return -get_total_contribution( + spends, + channels, + s_curve_params, + channels_proportion, + modified_scenario_data, + ) + else: + # Minimize total spends + return get_total_spends(spends, channels_conversion_ratio) + + def constraint_fun(spends): + if optimization_goal == "Spend": + # Ensure the total spends equals the total spend target + return get_total_spends(spends, channels_conversion_ratio) + else: + # Ensure the total contribution equals the total contribution target + return get_total_contribution( + spends, + channels, + s_curve_params, + channels_proportion, + modified_scenario_data, + ) + + # Equality constraint + constraints = { + "type": "eq", + "fun": lambda spends: constraint_fun(spends) - total_target, + } # Sum of all channel spends/metrics should equal the total spend/metrics target + + # Bounds for each channel's spend based + bounds = [ + ( + actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100), + actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100), + ) + for i in range(num_channels) + ] + + # Initial guess for the optimization + initial_guess = np.array(actual_spends) + + # Calculate xtol as n% of the minimum of spends + xtol = max(10, (xtol_tolerance_per / 100) * np.min(actual_spends)) + + # Perform the optimization using 'trust-constr' method + result = minimize( + objective_fun, + initial_guess, + method="trust-constr", + constraints=constraints, + bounds=bounds, + options={ + "disp": True, # Display the optimization process + "xtol": xtol, # Dynamic step size tolerance + "maxiter": 1e5, # Maximum number of iterations + }, + ) + + # Extract the optimized spends from the result + optimized_spends_array = result.x + + # Convert optimized spends back to a dictionary with channel names + optimized_spends = { + channels[i]: max(0, optimized_spends_array[i]) for i in range(num_channels) + } + + return optimized_spends, result.success + + +# Function to calculate achievable targets at lower and upper spend bounds +@st.cache_data(show_spinner=False) +def max_target_achievable( + channels_spends, + s_curve_params, + channels_proportion, + modified_scenario_data, + bounds_dict, +): + # Extract channel names and corresponding actual spends + channels = list(channels_spends.keys()) + actual_spends = np.array(list(channels_spends.values())) + num_channels = len(actual_spends) + + # Bounds for each channel's spend + lower_spends, upper_spends = [], [] + for i in range(num_channels): + lower_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100)) + upper_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100)) + + # Calculate achievable targets at lower and upper spend bounds + lower_achievable_target = get_total_contribution( + lower_spends, + channels, + s_curve_params, + channels_proportion, + modified_scenario_data, + ) + upper_achievable_target = get_total_contribution( + upper_spends, + channels, + s_curve_params, + channels_proportion, + modified_scenario_data, + ) + + # Return achievable targets with ±0.1% safety margin + return max(0, 1.001 * lower_achievable_target), 0.999 * upper_achievable_target + + +# Function to check if number is in valid format +def is_valid_number_format(number_str): + # Check for None + if number_str is None: + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + + # Define the valid suffixes + valid_suffixes = {"K", "M", "B", "T"} + + # Check for negative numbers + if number_str[0] == "-": + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + + # Check if the string ends with a digit + if number_str[-1].isdigit(): + try: + # Attempt to convert the entire string to float + number = float(number_str) + # Ensure the number is non-negative + if number >= 0: + return True + else: + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + except ValueError: + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + + # Check if the string ends with a valid suffix + suffix = number_str[-1].upper() + if suffix in valid_suffixes: + num_part = number_str[:-1] # Extract the numerical part + try: + # Attempt to convert the numerical part to float + number = float(num_part) + # Ensure the number part is non-negative + if number >= 0: + return True + else: + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + except ValueError: + # Store the message details in session state for invalid input + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + + # If neither condition is met, return False + st.session_state.message_display = { + "type": "warning", + "message": "Invalid input: Please enter a valid number.", + "icon": "⚠️", + } + return False + + +# Function to converts a string with number suffixes (K, M, B, T) to a float +def convert_to_float(number_str): + # Dictionary mapping suffixes to their multipliers + multipliers = { + "K": 1e3, # Thousand + "M": 1e6, # Million + "B": 1e9, # Billion + "T": 1e12, # Trillion + } + + # If there's no suffix, directly convert to float + if number_str[-1].isdigit(): + return float(number_str) + + # Extract the suffix (last character) and the numerical part + suffix = number_str[-1].upper() + num_part = number_str[:-1] + + # Convert the numerical part to float and multiply by the corresponding multiplier + return float(num_part) * multipliers[suffix] + + +# Function to update absolute_channel_spends change +def absolute_channel_spends_change( + channel_key, channel_spends_actual, channel, metrics_selected, panel_selected +): + # Do not update if the number is in an invalid format + if not is_valid_number_format(st.session_state[f"{channel_key}_abs_spends_key"]): + return + + # Get updated absolute spends from session state + new_absolute_spends = ( + convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) + * st.session_state["multiplier"] + ) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Total channel spends + total_channel_spends = 0 + for current_channel in list( + data[metrics_selected][panel_selected]["channels"].keys() + ): + # Channel key + channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" + + total_channel_spends += ( + convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) + * st.session_state["multiplier"] + ) + + # Check if total channel spends are within the allowed range (±50% of the original total spends) + if ( + total_channel_spends + < 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"] + and total_channel_spends + > 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"] + ): + # Update the modified_total_spends for the specified channel + data[metrics_selected][panel_selected]["channels"][channel][ + "modified_total_spends" + ] = new_absolute_spends / float( + data[metrics_selected][panel_selected]["channels"][channel][ + "conversion_rate" + ] + ) + + # Update total spends + data[metrics_selected][panel_selected][ + "modified_total_spends" + ] = total_channel_spends + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] = data + + else: + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Keep total spending within ±50% of the original value.", + "icon": "⚠️", + } + + +# Function to update percentage_channel_spends change +def percentage_channel_spends_change( + channel_key, channel_spends_actual, channel, metrics_selected, panel_selected +): + # Retrieve the percentage spend change from session state + percentage_channel_spends = round( + st.session_state[f"{channel_key}_per_spends_key"], 0 + ) + + # Calculate the new absolute spends + new_absolute_spends = channel_spends_actual * (1 + percentage_channel_spends / 100) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Total channel spends + total_channel_spends = 0 + for current_channel in list( + data[metrics_selected][panel_selected]["channels"].keys() + ): + # Channel key + channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" + + # Current channel spends actual + current_channel_spends_actual = data[metrics_selected][panel_selected][ + "channels" + ][current_channel]["actual_total_spends"] + + # Current channel conversion rate + current_channel_conversion_rate = data[metrics_selected][panel_selected][ + "channels" + ][current_channel]["conversion_rate"] + + # Calculate the current channel absolute spends + current_channel_absolute_spends = ( + current_channel_spends_actual + * current_channel_conversion_rate + * (1 + st.session_state[f"{channel_key}_per_spends_key"] / 100) + ) + + total_channel_spends += current_channel_absolute_spends + + # Check if total channel spends are within the allowed range (±50% of the original total spends) + if ( + total_channel_spends + < 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"] + and total_channel_spends + > 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"] + ): + # Update the modified_total_spends for the specified channel + data[metrics_selected][panel_selected]["channels"][channel][ + "modified_total_spends" + ] = float(new_absolute_spends) / float( + data[metrics_selected][panel_selected]["channels"][channel][ + "conversion_rate" + ] + ) + + # Update total spends + data[metrics_selected][panel_selected][ + "modified_total_spends" + ] = total_channel_spends + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] = data + + else: + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Keep total spending within ±50% of the original value.", + "icon": "⚠️", + } + + +# # Function to update total input change +# def total_input_change(per_change): +# # Load modified scenario data +# data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + +# # Get the list of all channels in the specified panel and metric +# channel_list = list(data[metrics_selected][panel_selected]["channels"].keys()) + +# # Iterate over each channel to update their modified spends +# for channel in channel_list: +# # Retrieve the actual spends for the channel +# channel_actual_spends = data[metrics_selected][panel_selected]["channels"][ +# channel +# ]["actual_total_spends"] + +# # Calculate the modified spends for the channel based on the percent change +# modified_channel_metrics = channel_actual_spends * ((100 + per_change) / 100) + +# # Update the channel's modified total spends in the data +# data[metrics_selected][panel_selected]["channels"][channel][ +# "modified_total_spends" +# ] = modified_channel_metrics + +# # Update modified scenario metadata +# st.session_state["project_dct"]["scenario_planner"][ +# "modified_metadata_file" +# ] = data + + +# Function to update total input change +def total_input_change(per_change, metrics_selected, panel_selected): + # Load modified and original scenario data + modified_data = st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ].copy() + original_data = st.session_state["project_dct"]["scenario_planner"][ + "original_metadata_file" + ].copy() + + # Get the list of all channels in the selected panel and metric + channel_list = list( + modified_data[metrics_selected][panel_selected]["channels"].keys() + ) + + # Separate channels into unfrozen and frozen based on optimization status + unfrozen_channels, frozen_channels = [], [] + for channel in channel_list: + channel_key = f"{metrics_selected}_{panel_selected}_{channel}" + if st.session_state.get(f"{channel_key}_allow_optimize_key", False): + frozen_channels.append(channel) + else: + unfrozen_channels.append(channel) + + # Calculate spends and total share from frozen channels, weighted by conversion rate + frozen_channel_share, frozen_channel_spends = 0, 0 + for channel in frozen_channels: + conversion_rate = original_data[metrics_selected][panel_selected]["channels"][ + channel + ]["conversion_rate"] + actual_spends = original_data[metrics_selected][panel_selected]["channels"][ + channel + ]["actual_total_spends"] + modified_spends = modified_data[metrics_selected][panel_selected]["channels"][ + channel + ]["modified_total_spends"] + spends_diff = max(actual_spends, 1e-3) * ((100 + per_change) / 100) - max( + modified_spends, 1e-3 + ) + frozen_channel_share += spends_diff * conversion_rate + frozen_channel_spends += max(actual_spends, 1e-3) * conversion_rate + + # Redistribute frozen share across unfrozen channels based on original spend weights + for channel in unfrozen_channels: + conversion_rate = original_data[metrics_selected][panel_selected]["channels"][ + channel + ]["conversion_rate"] + actual_spends = original_data[metrics_selected][panel_selected]["channels"][ + channel + ]["actual_total_spends"] + + # Calculate weight of the current channel's original spends + total_original_spends = original_data[metrics_selected][panel_selected][ + "actual_total_spends" + ] + channel_weight = (actual_spends * conversion_rate) / ( + total_original_spends - frozen_channel_spends + ) + + # Calculate the modified spends with redistributed frozen share + modified_spends = ( + max(actual_spends, 1e-3) * ((100 + per_change) / 100) + + (frozen_channel_share * channel_weight) / conversion_rate + ) + + # Update modified total spends in the modified data + modified_data[metrics_selected][panel_selected]["channels"][channel][ + "modified_total_spends" + ] = modified_spends + + # Save the updated modified scenario data back to the session state + st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] = modified_data + + +# Function to update total_absolute_main_key change +def total_absolute_main_key_change(metrics_selected, panel_selected, optimization_goal): + # Do not update if the number is in an invalid format + if not is_valid_number_format(st.session_state["total_absolute_main_key"]): + return + + # Get updated absolute from session state + new_absolute = ( + convert_to_float(st.session_state["total_absolute_main_key"]) + * st.session_state["multiplier"] + ) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + if optimization_goal == "Spend": + # Retrieve the old absolute spends + old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] + else: + # Retrieve the old absolute metrics + old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] + + # Calculate the allowable range for new spends + lower_bound = old_absolute * 0.5 + upper_bound = old_absolute * 1.5 + + # Ensure the new spends are within ±50% of the old value + if new_absolute < lower_bound or new_absolute > upper_bound: + new_absolute = old_absolute + + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Keep total spending within ±50% of the original value.", + "icon": "⚠️", + } + + if optimization_goal == "Spend": + # Update the modified_total_spends with the constrained value + data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute + else: + # Update the modified_total_sales with the constrained value + data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + # Update total input change + if optimization_goal == "Spend": + per_change = ((new_absolute - old_absolute) / old_absolute) * 100 + total_input_change(per_change, metrics_selected, panel_selected) + + +# Function to update total_absolute_key change +def total_absolute_key_change(metrics_selected, panel_selected, optimization_goal): + # Get updated absolute from session state + new_absolute = ( + convert_to_float(st.session_state["total_absolute_key"]) + * st.session_state["multiplier"] + ) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + if optimization_goal == "Spend": + # Update the modified_total_spends for the specified channel + data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute + old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] + else: + # Update the modified_total_sales for the specified channel + data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute + old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + # Update total input change + if optimization_goal == "Spend": + per_change = ((new_absolute - old_absolute) / old_absolute) * 100 + total_input_change(per_change, metrics_selected, panel_selected) + + +# Function to update total_absolute_key change +def total_percentage_key_change( + metrics_selected, + panel_selected, + absolute_value, + optimization_goal, +): + # Get updated absolute from session state + new_absolute = absolute_value * (1 + st.session_state["total_percentage_key"] / 100) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + if optimization_goal == "Spend": + # Update the modified_total_spends for the specified channel + data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute + old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"] + else: + # Update the modified_total_sales for the specified channel + data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute + old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"] + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + # Update total input change + if optimization_goal == "Spend": + per_change = ((new_absolute - old_absolute) / old_absolute) * 100 + total_input_change(per_change, metrics_selected, panel_selected) + + +# Function to update bound change +def bound_change(metrics_selected, panel_selected, channel_key, channel): + # Get updated bounds from session state + new_lower_bound = st.session_state[f"{channel_key}_lower_key"] + new_upper_bound = st.session_state[f"{channel_key}_upper_key"] + if new_lower_bound > new_upper_bound: + new_bounds = [-10, 10] + + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Lower bound cannot be greater than Upper bound.", + "icon": "⚠️", + } + + else: + new_bounds = [new_lower_bound, new_upper_bound] + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Update the bounds for the specified channel + data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + +# Function to update freeze change +def freeze_change(metrics_selected, panel_selected, channel_key, channel, channel_list): + # Initialize counter for channels that are not frozen + unfrozen_channel_count = 0 + + # Check the optimization status of each channel + for current_channel in channel_list: + current_channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}" + unfrozen_channel_count += ( + 1 + if not st.session_state[f"{current_channel_key}_allow_optimize_key"] + else 0 + ) + + # Ensure at least two channels are allowed for optimization + if unfrozen_channel_count < 2: + st.session_state.message_display = { + "type": "warning", + "message": "Please allow at least two channels to be optimized.", + "icon": "⚠️", + } + return + + if st.session_state[f"{channel_key}_allow_optimize_key"]: + # Updated bounds from session state + new_lower_bound, new_upper_bound = 0, 0 + new_bounds = [new_lower_bound, new_upper_bound] + new_freeze = True + else: + # Updated bounds from session state + new_lower_bound, new_upper_bound = -10, 10 + new_bounds = [new_lower_bound, new_upper_bound] + new_freeze = False + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Update the bounds for the specified channel + data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds + data[metrics_selected][panel_selected]["channels"][channel]["freeze"] = new_freeze + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + +# Function to calculate y, ROI and MROI for given point +def get_point_parms( + x_val, + current_s_curve_params, + current_channel_proportion, + current_conversion_rate, + channel_correction, +): + # Calculate y value for the given spend point + y_val = ( + sum( + s_curve( + (x_val * current_channel_proportion), + current_s_curve_params["power"], + current_s_curve_params["K"], + current_s_curve_params["b"], + current_s_curve_params["a"], + current_s_curve_params["x0"], + ) + ) + + channel_correction + ) + + # Calculate MROI using a small nudge for actual spends + nudge = 1e-3 + x1 = float(x_val * current_conversion_rate) + y1 = float(y_val) + x2 = x1 + nudge + y2 = ( + sum( + s_curve( + ((x2 / current_conversion_rate) * current_channel_proportion), + current_s_curve_params["power"], + current_s_curve_params["K"], + current_s_curve_params["b"], + current_s_curve_params["a"], + current_s_curve_params["x0"], + ) + ) + + channel_correction + ) + mroi_val = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0 + + # Calculate ROI + roi_val = y_val / (x_val * current_conversion_rate) + + return roi_val, mroi_val, y_val + + +# Function to find segment value +def find_segment_value(x, roi, mroi, roi_threshold=1, mroi_threshold=0.05): + # Initialize the start and end values of the x array + start_value = x[0] + end_value = x[-1] + + # Define the condition for the "green region" where both ROI and MROI exceed their respective thresholds + green_condition = (roi > roi_threshold) & (mroi > mroi_threshold) + + # Find indices where ROI exceeds the ROI threshold + left_indices = np.where(roi > roi_threshold)[0] + + # Find indices where both ROI and MROI exceed their thresholds (green condition) + right_indices = np.where(green_condition)[0] + + # Determine the left value based on the first index where ROI exceeds the threshold + left_value = x[left_indices[0]] if left_indices.size > 0 else x[0] + + # Determine the right value based on the last index where both ROI and MROI exceed their thresholds + right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0] + + # Ensure the left value does not exceed the right value, adjust if necessary + if left_value > right_value: + left_value = right_value + + return start_value, end_value, left_value, right_value + + +# Function to generate response curves plots +@st.cache_data(show_spinner=False) +def generate_response_curve_plots( + channel_list, + s_curve_params, + channels_proportion, + original_scenario_data, + multiplier, +): + figures, channel_roi_mroi, region_start_end = [], {}, {} + + for channel in channel_list: + spends_actual = original_scenario_data["channels"][channel][ + "actual_total_spends" + ] + conversion_rate = original_scenario_data["channels"][channel]["conversion_rate"] + channel_correction = sum( + original_scenario_data["channels"][channel]["correction"] + ) + + x_actual = np.linspace(0, 5 * spends_actual, 100) + x_plot = x_actual * conversion_rate + + # Calculate y values for the S-curve + y_plot = [ + sum( + s_curve( + (x * channels_proportion[channel]), + s_curve_params[channel]["power"], + s_curve_params[channel]["K"], + s_curve_params[channel]["b"], + s_curve_params[channel]["a"], + s_curve_params[channel]["x0"], + ) + ) + + channel_correction + for x in x_actual + ] + + # Calculate ROI and ensure they are scalar values + roi = [float(y) / float(x) if x != 0 else 0 for x, y in zip(x_plot, y_plot)] + + # Calculate MROI using a small nudge + nudge = 1e-3 + mroi = [] + for i in range(len(x_plot)): + x1 = float(x_plot[i]) + y1 = float(y_plot[i]) + x2 = x1 + nudge + y2 = ( + sum( + s_curve( + ((x2 / conversion_rate) * channels_proportion[channel]), + s_curve_params[channel]["power"], + s_curve_params[channel]["K"], + s_curve_params[channel]["b"], + s_curve_params[channel]["a"], + s_curve_params[channel]["x0"], + ) + ) + + channel_correction + ) + mroi_value = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0 + mroi.append(mroi_value) + + # Channel correction + channel_correction = sum( + original_scenario_data["channels"][channel]["correction"] + ) + + # Calculate y, ROI and MROI for the actual spend point + roi_actual, mroi_actual, y_actual = get_point_parms( + spends_actual, + s_curve_params[channel], + channels_proportion[channel], + conversion_rate, + channel_correction, + ) + + # Create the plotly figure + fig = go.Figure() + + # Add S-curve line + fig.add_trace( + go.Scatter( + x=np.array(x_plot) / multiplier, + y=np.array(y_plot) / multiplier, + mode="lines", + name="Metrics", + hoverinfo="text", + text=[ + f"Spends: {numerize(x / multiplier)}
{metrics_selected_formatted}: {numerize(y / multiplier)}
ROI: {r:.2f}
MROI: {m:.2f}" + for x, y, r, m in zip(x_plot, y_plot, roi, mroi) + ], + ) + ) + + # Add current spend point + fig.add_trace( + go.Scatter( + x=[spends_actual * conversion_rate / multiplier], + y=[y_actual / multiplier], + mode="markers", + marker=dict(color="cyan", size=10, symbol="circle"), + name="Actual Spend", + hoverinfo="text", + text=[ + f"Actual Spend: {numerize(spends_actual * conversion_rate / multiplier)}
{metrics_selected_formatted}: {numerize(y_actual / multiplier)}
ROI: {roi_actual:.2f}
MROI: {mroi_actual:.2f}" + ], + showlegend=True, + ) + ) + + # ROI Threshold + roi_threshold = st.session_state.roi_threshold + + # Scale x and y values + x, y = np.array(x_plot), np.array(y_plot) + x_scaled, y_scaled = x / max(x), y / max(y) + + # Calculate MROI scaled starting from the first point + mroi_scaled = np.zeros_like(x_scaled) + for j in range(1, len(x_scaled)): + x1, y1 = x_scaled[j - 1], y_scaled[j - 1] + x2, y2 = x_scaled[j], y_scaled[j] + mroi_scaled[j] = (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else 0 + + # Get the start_value, end_value, left_value, right_value for segments + start_value, end_value, left_value, right_value = find_segment_value( + x_plot, np.array(roi), mroi_scaled, roi_threshold, mroi_threshold + ) + + # Store region start and end points + region_start_end[channel] = { + "start_value": start_value, + "end_value": end_value, + "left_value": left_value, + "right_value": right_value, + } + + # Adding background colors + y_max = max(y_plot) * 1.3 # 30% extra space above the max + + # Yellow region + fig.add_shape( + type="rect", + x0=start_value / multiplier, + y0=0, + x1=left_value / multiplier, + y1=y_max / multiplier, + line=dict(width=0), + fillcolor="rgba(255, 255, 0, 0.3)", + layer="below", + ) + + # Green region + fig.add_shape( + type="rect", + x0=left_value / multiplier, + y0=0, + x1=right_value / multiplier, + y1=y_max / multiplier, + line=dict(width=0), + fillcolor="rgba(0, 255, 0, 0.3)", + layer="below", + ) + + # Red region + fig.add_shape( + type="rect", + x0=right_value / multiplier, + y0=0, + x1=end_value / multiplier, + y1=y_max / multiplier, + line=dict(width=0), + fillcolor="rgba(255, 0, 0, 0.3)", + layer="below", + ) + + # Layout adjustments + fig.update_layout( + title=f"{name_formating(channel)}", + showlegend=False, + xaxis=dict( + showgrid=True, + showticklabels=True, + tickformat=".2s", + gridcolor="lightgrey", + gridwidth=0.5, + griddash="dot", + ), + yaxis=dict( + showgrid=True, + showticklabels=True, + tickformat=".2s", + gridcolor="lightgrey", + gridwidth=0.5, + griddash="dot", + ), + template="plotly_white", + margin=dict(l=20, r=20, t=30, b=20), + height=100 * (len(channel_list) + 4 - 1) // 4, + ) + + figures.append(fig) + + # Store data of each channel ROI and MROI + channel_roi_mroi[channel] = { + "actual_roi": roi_actual, + "actual_mroi": mroi_actual, + } + + return figures, channel_roi_mroi, region_start_end + + +# Function to add modified spends/metrics point on plot +def modified_metrics_point( + fig, + modified_spends, + s_curve_params, + channels_proportion, + conversion_rate, + channel_correction, +): + # Calculate ROI, MROI, and y for the modified point + roi_modified, mroi_modified, y_modified = get_point_parms( + modified_spends, + s_curve_params, + channels_proportion, + conversion_rate, + channel_correction, + ) + + # Add modified spend point + fig.add_trace( + go.Scatter( + x=[modified_spends * conversion_rate / st.session_state["multiplier"]], + y=[y_modified / st.session_state["multiplier"]], + mode="markers", + marker=dict(color="blueviolet", size=10, symbol="circle"), + name="Optimized Spend", + hoverinfo="text", + text=[ + f"Modified Spend: {numerize(modified_spends * conversion_rate / st.session_state.multiplier)}
{metrics_selected_formatted}: {numerize(y_modified / st.session_state.multiplier)}
ROI: {roi_modified:.2f}
MROI: {mroi_modified:.2f}" + ], + showlegend=True, + ) + ) + + return roi_modified, mroi_modified, fig + + +# Function to update bound type change +def bound_type_change(): + # Get updated bound type from session state + new_bound_type = st.session_state["bound_type_key"] + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + + # Update the bound type + data[metrics_selected][panel_selected]["bound_type"] = new_bound_type + + # Set bounds to default value if bound type is False (Default) + channel_list = list(data[metrics_selected][panel_selected]["channels"].keys()) + if not new_bound_type: + for channel in channel_list: + data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = [ + -10, + 10, + ] + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data + + +# Function to format the numbers with decimal +def format_value(input_value): + value = abs(input_value) + return f"{input_value:.4f}" if value < 1 else f"{numerize(input_value, 1)}" + + +# Function to format the numbers with decimal +def round_value(input_value): + value = abs(input_value) + return round(input_value, 4) if value < 1 else round(input_value, 1) + + +# Function to generate ROI and MROI plots for all channels +@st.cache_data(show_spinner=False) +def roi_mori_plot(channel_roi_mroi): + # Dictionary to store plots + channel_roi_mroi_plot = {} + for channel in channel_roi_mroi: + channel_roi_mroi_data = channel_roi_mroi[channel] + # Extract the data + actual_roi = channel_roi_mroi_data["actual_roi"] + optimized_roi = channel_roi_mroi_data["optimized_roi"] + actual_mroi = channel_roi_mroi_data["actual_mroi"] + optimized_mroi = channel_roi_mroi_data["optimized_mroi"] + + # Plot ROI + fig_roi = go.Figure() + fig_roi.add_trace( + go.Bar( + x=["Actual ROI"], + y=[actual_roi], + name="Actual ROI", + marker_color="cyan", + width=1, + text=[format_value(actual_roi)], + textposition="auto", + textfont=dict(color="black", size=14), + ) + ) + fig_roi.add_trace( + go.Bar( + x=["Optimized ROI"], + y=[optimized_roi], + name="Optimized ROI", + marker_color="blueviolet", + width=1, + text=[format_value(optimized_roi)], + textposition="auto", + textfont=dict(color="black", size=14), + ) + ) + + fig_roi.update_layout( + annotations=[ + dict( + x=0.5, + y=1.3, + xref="paper", + yref="paper", + text="ROI", + showarrow=False, + font=dict(size=14), + ) + ], + barmode="group", + bargap=0, + showlegend=False, + width=110, + height=110, + xaxis=dict( + showticklabels=True, + showgrid=False, + tickangle=0, + ticktext=["Actual", "Optimized"], + tickvals=["Actual ROI", "Optimized ROI"], + ), + yaxis=dict(showticklabels=False, showgrid=False), + margin=dict(t=20, b=20, r=0, l=0), + ) + + # Plot MROI + fig_mroi = go.Figure() + fig_mroi.add_trace( + go.Bar( + x=["Actual MROI"], + y=[actual_mroi], + name="Actual MROI", + marker_color="cyan", + width=1, + text=[format_value(actual_mroi)], + textposition="auto", + textfont=dict(color="black", size=14), + ) + ) + fig_mroi.add_trace( + go.Bar( + x=["Optimized MROI"], + y=[optimized_mroi], + name="Optimized MROI", + marker_color="blueviolet", + width=1, + text=[format_value(optimized_mroi)], + textposition="auto", + textfont=dict(color="black", size=14), + ) + ) + + fig_mroi.update_layout( + annotations=[ + dict( + x=0.5, + y=1.3, + xref="paper", + yref="paper", + text="MROI", + showarrow=False, + font=dict(size=14), + ) + ], + barmode="group", + bargap=0, + showlegend=False, + width=110, + height=110, + xaxis=dict( + showticklabels=True, + showgrid=False, + tickangle=0, + ticktext=["Actual", "Optimized"], + tickvals=["Actual MROI", "Optimized MROI"], + ), + yaxis=dict(showticklabels=False, showgrid=False), + margin=dict(t=20, b=20, r=0, l=0), + ) + + # Store plots + channel_roi_mroi_plot[channel] = {"fig_roi": fig_roi, "fig_mroi": fig_mroi} + + return channel_roi_mroi_plot + + +# Function to save the current scenario with the mentioned name +def save_scenario( + scenario_dict, + metrics_selected, + panel_selected, + optimization_goal, + channel_roi_mroi, + timeframe, + multiplier, +): + # Remove extra space at start and ends + if st.session_state["scenario_name"] is not None: + st.session_state["scenario_name"] = st.session_state["scenario_name"].strip() + + if ( + st.session_state["scenario_name"] is None + or st.session_state["scenario_name"] == "" + ): + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Please provide a name to save the scenario.", + "icon": "⚠️", + } + return + + # Check the scenario name + if not ( + word_length_limit_lower + <= len(st.session_state["scenario_name"]) + <= word_length_limit_upper + and bool(re.match("^[A-Za-z0-9_]*$", st.session_state["scenario_name"])) + ): + # Store the warning message details in session state + st.session_state.message_display = { + "type": "warning", + "message": f"Please provide a valid scenario name ({word_length_limit_lower}-{word_length_limit_upper} characters, only A-Z, a-z, 0-9, and _).", + "icon": "⚠️", + } + return + + # Check if the dictionary is empty + if not scenario_dict: + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Nothing to save. The scenario data is empty.", + "icon": "⚠️", + } + return + + # Add additional scenario details + scenario_dict["panel_selected"] = panel_selected + scenario_dict["metrics_selected"] = metrics_selected + scenario_dict["optimization"] = optimization_goal + scenario_dict["channel_roi_mroi"] = channel_roi_mroi + scenario_dict["timeframe"] = timeframe + scenario_dict["multiplier"] = multiplier + + # Load existing scenarios + saved_scenarios_dict = st.session_state["project_dct"]["saved_scenarios"][ + "saved_scenarios_dict" + ] + + # Check if the name is already taken + if st.session_state["scenario_name"] in saved_scenarios_dict.keys(): + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Name already exists. Please change the name or delete the existing scenario from the Saved Scenario page.", + "icon": "⚠️", + } + return + + # Update the dictionary with the new scenario + saved_scenarios_dict[st.session_state["scenario_name"]] = scenario_dict + + # Update the updated dictionary + st.session_state["project_dct"]["saved_scenarios"][ + "saved_scenarios_dict" + ] = saved_scenarios_dict + + # Update DB + update_db( + prj_id=st.session_state["project_number"], + page_nam="Scenario Planner", + file_nam="project_dct", + pkl_obj=pickle.dumps(st.session_state["project_dct"]), + schema=schema, + ) + + # Store the message details in session state + st.session_state.message_display = { + "type": "success", + "message": f"Scenario '{st.session_state.scenario_name}' has been successfully saved!", + "icon": "💾", + } + st.toast( + f"Scenario '{st.session_state.scenario_name}' has been successfully saved!", + icon="💾", + ) + + # Clear the scenario name input + st.session_state["scenario_name"] = "" + + +# Function to calculate the RGBA color code based on the spends value and region boundaries +def calculate_rgba(spends_value, region_start_end): + # Get region start and end points + start_value = region_start_end["start_value"] + end_value = region_start_end["end_value"] + left_value = region_start_end["left_value"] + right_value = region_start_end["right_value"] + + # Calculate alpha dynamically based on the position within the range + def calculate_alpha(position, start, end, min_alpha=0.1, max_alpha=0.4): + return min_alpha + (max_alpha - min_alpha) * (position - start) / (end - start) + + if start_value <= spends_value <= left_value: + # Yellow range (0, 128, 0) - More transparent towards left, darker towards start + alpha = calculate_alpha(spends_value, left_value, start_value) + return (255, 255, 0, alpha) # RGB for yellow + elif left_value < spends_value <= right_value: + # Green range (0, 128, 0) - More transparent towards right, darker towards left + alpha = calculate_alpha(spends_value, right_value, left_value) + return (0, 128, 0, alpha) # RGB for green + elif right_value < spends_value <= end_value: + # Red range (255, 0, 0) - More transparent towards right, darker towards end + alpha = calculate_alpha(spends_value, right_value, end_value) + return (255, 0, 0, alpha) # RGB for red + + +# Function to format and display the channel name with a color and background color +def display_channel_name_with_background_color( + channel_name, background_color=(0, 128, 0, 0.1) +): + formatted_name = name_formating(channel_name) + + # Unpack the RGBA values + r, g, b, a = background_color + + # Create the HTML content with specified background color + html_content = f""" +
+ {formatted_name} +
+ """ + + return html_content + + +# Function to check optimization success +def check_optimization_success( + channel_list, + input_channels_spends, + output_channels_spends, + bounds_dict, + optimization_goal, + modified_total_metrics, + actual_total_metrics, + modified_total_spends, + actual_total_spends, + original_total_spends, + optimization_success, +): + for channel in channel_list: + input_channel_spends = input_channels_spends[channel] + output_channel_spends = abs(output_channels_spends[channel]) + + lower_percent = bounds_dict[channel][0] + upper_percent = bounds_dict[channel][1] + + lower_allowed_value = max( + (input_channel_spends * (100 + lower_percent - 1) / 100), 0 + ) # 1% Tolerance + upper_allowed_value = max( + (input_channel_spends * (100 + upper_percent + 1) / 100), 10 + ) # 1% Tolerance + + # Check if output spends are within allowed bounds + if ( + output_channel_spends > upper_allowed_value + or output_channel_spends < lower_allowed_value + ): + error_message = "Optimization failed: strict bounds. Use flexible bounds." + return False, error_message, "❌" + + # Check optimization goal and percent change + if optimization_goal == "Spend": + percent_change_happened = abs( + (modified_total_spends - actual_total_spends) / actual_total_spends + ) + if percent_change_happened > 0.01: # Greater than 1% Tolerance + error_message = "Optimization failed: input and optimized spends differ. Use flexible bounds." + return False, error_message, "❌" + else: + percent_change_happened = abs( + (modified_total_metrics - actual_total_metrics) / actual_total_metrics + ) + if percent_change_happened > 0.01: # Greater than 1% Tolerance + error_message = "Optimization failed: input and optimized metrics differ. Use flexible bounds." + return False, error_message, "❌" + + # Define the allowable range for new spends + lower_limit = original_total_spends * 0.5 + upper_limit = original_total_spends * 1.5 + + # Check if the new spends are within the allowed range + if modified_total_spends < lower_limit or modified_total_spends > upper_limit: + error_message = "New spends optimized are outside the allowed range of ±50%." + return False, error_message, "❌" + + # Check if the optimization failed to converge + if not optimization_success: + error_message = "Optimization failed to converge." + return False, error_message, "❌" + + return True, "Optimization successful.", "💸" + + +# Function to check if the optimization target is achievable within the given bounds +def check_target_achievability( + optimize_allow, + optimization_goal, + lower_achievable_target, + upper_achievable_target, + total_absolute_target, +): + # Format the messages with appropriate numerization and naming + given_input = "response metric" if optimization_goal == "Spend" else "spends" + + # Combined achievable message + achievable_message = ( + f"Achievable {optimization_goal} with the given {given_input} and bounds ranges from " + f"{numerize(lower_achievable_target / st.session_state.multiplier)} to " + f"{numerize(upper_achievable_target / st.session_state.multiplier)}" + ) + + # Check if the target is within achievable bounds + if (lower_achievable_target > total_absolute_target) or ( + upper_achievable_target < total_absolute_target + ): + # Update session state with the error message + st.session_state.message_display = { + "type": "error", + "message": achievable_message, + "icon": "⚠️", + } + optimize_allow = False + + elif (st.session_state.message_display["message"] is not None) and ( + str(st.session_state.message_display["message"]).startswith("Achievable") + ): + # Clear message_display + st.session_state.message_display = { + "type": "success", + "message": None, + "icon": "", + } + optimize_allow = True + + return optimize_allow + + +# Function to display a message with the appropriate type and icon +def display_message(): + # Retrieve the message details from the session state + message_type = st.session_state.message_display["type"] + message = st.session_state.message_display["message"] + icon = st.session_state.message_display["icon"] + + # Display the message if it exists + if message is not None: + if message_type == "success": + st.success(message, icon=icon) + # Log message + log_message("info", message, "Scenario Planner") + elif message_type == "warning": + st.warning(message, icon=icon) + # Log message + log_message("warning", message, "Scenario Planner") + elif message_type == "error": + st.error(message, icon=icon) + # Log message + log_message("error", message, "Scenario Planner") + else: + st.info(message, icon=icon) + # Log message + log_message("info", message, "Scenario Planner") + + +# Function to change bounds for all channels +def all_bound_change(channel_list, apply_all=False): + # Fetch updated upper and lower bounds for all channels + all_lower_bound = st.session_state["all_lower_key"] + all_upper_bound = st.session_state["all_upper_key"] + + # Check if lower bound is not greater than upper bound + if all_lower_bound < all_upper_bound: + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] + + # Update the bounds for the all channels + if apply_all: + for channel in channel_list: + if not data[metrics_selected][panel_selected]["channels"][channel][ + "freeze" + ]: + data[metrics_selected][panel_selected]["channels"][channel][ + "bounds" + ] = [ + all_lower_bound, + all_upper_bound, + ] + + # Update the bounds for the all channels holder + data[metrics_selected][panel_selected]["bounds"] = [ + all_lower_bound, + all_upper_bound, + ] + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] = data + + else: + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Lower bound cannot be greater than Upper bound.", + "icon": "⚠️", + } + return + + +try: + # Page Title + st.title("Scenario Planner") + + # Retrieve the list of all metric names from the specified directory + metrics_list = get_metrics_names() + + # Check if there are any metrics available in the metrics list + if not metrics_list: + # Display a warning message to the user if no metrics are found + st.warning( + "Please tune at least one model to generate response curves data.", + icon="⚠️", + ) + + # Log message + log_message( + "warning", + "Please tune at least one model to generate response curves data.", + "Scenario Planner", + ) + + st.stop() + + # Widget columns + metric_col, panel_col, timeframe_col, save_progress_col = st.columns(4) + + # Metrics Selection + metrics_selected = metric_col.selectbox( + "Response Metrics", + sorted(metrics_list), + format_func=name_formating, + key="response_metrics_selectbox_sp", + index=0, + ) + metrics_selected_formatted = name_formating(metrics_selected) + + # Retrieve the list of all panel names for specified Metrics + panel_list = get_panels_names(metrics_selected) + + # Panel Selection + panel_selected = panel_col.selectbox( + "Panel", + sorted(panel_list), + format_func=name_formating, + key="panel_selected_selectbox_sp", + index=0, + ) + panel_selected_formatted = name_formating(panel_selected) + + # Timeframe Selection + timeframe_selected = timeframe_col.selectbox( + "Timeframe", + ["Input Data Range", "Yearly", "Quarterly", "Monthly"], + key="timeframe_selected_selectbox_sp", + index=0, + ) + + # Check if the RCS metadata file does not exist + if ( + st.session_state["project_dct"]["response_curves"]["original_metadata_file"] + is None + or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] + is None + ): + # RCS metadata file does not exist. Generating new RCS data + generate_rcs_data() + + # Log message + log_message( + "info", + "RCS metadata file does not exist. Generating new RCS data.", + "Scenario Planner", + ) + + # Load rcs metadata files if they exist + original_rcs_data, modified_rcs_data = load_rcs_metadata_files() + + # Check if the scenario metadata file does not exist + if ( + st.session_state["project_dct"]["scenario_planner"]["original_metadata_file"] + is None + or st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] + is None + ): + # Scenario file does not exist. Generating new senario file data + generate_scenario_data() + + # Load scenario metadata files if they exist + original_data, modified_data = load_scenario_metadata_files() + + try: + # Data date range + date_range = pd.to_datetime( + list(original_data[metrics_selected][panel_selected]["channels"].values())[ + 0 + ]["dates"] + ) + + # Calculate the number of days between max and min dates + date_diff = pd.Series(date_range).diff() + day_data = int( + (date_range.max() - date_range.min()).days + + (6 if date_diff.value_counts().idxmax() == pd.Timedelta(weeks=1) else 0) + ) + + # Set the multiplier based on the selected timeframe + if timeframe_selected == "Input Data Range": + st.session_state["multiplier"] = 1 + elif timeframe_selected == "Yearly": + st.session_state["multiplier"] = day_data / 365 + elif timeframe_selected == "Quarterly": + st.session_state["multiplier"] = day_data / 90 + elif timeframe_selected == "Monthly": + st.session_state["multiplier"] = day_data / 30 + except: + st.session_state["multiplier"] = 1 + + # Extract original scenario data for the selected metric and panel + original_scenario_data = original_data[metrics_selected][panel_selected] + + # Extract modified scenario data for the same metric and panel + modified_scenario_data = modified_data[metrics_selected][panel_selected] + + # Display Actual Vs Optimized + st.divider() + ( + actual_spends_col, + actual_metrics_col, + actual_CPA_col, + base_col, + optimized_spends_col, + optimized_metrics_col, + optimized_CPA_col, + ) = st.columns([1, 1, 1, 1, 1.5, 1.5, 1.5]) + + # Base Contribution + base_contribution = ( + sum(original_scenario_data["constant"]) / st.session_state["multiplier"] + ) + + # Display Base Metric + base_col.metric( + f"Base {metrics_selected_formatted}", + numerize(base_contribution), + ) + + # Extracting and formatting values + actual_spends = numerize( + original_scenario_data["actual_total_spends"] / st.session_state["multiplier"] + ) + actual_metric_value = numerize( + original_scenario_data["actual_total_sales"] / st.session_state["multiplier"] + ) + optimized_spends = numerize( + modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"] + ) + optimized_metric_value = numerize( + modified_scenario_data["modified_total_sales"] / st.session_state["multiplier"] + ) + + # Calculate the deltas (differences) for spends and metrics + spends_delta_value = ( + modified_scenario_data["modified_total_spends"] + - original_scenario_data["actual_total_spends"] + ) / st.session_state["multiplier"] + + metrics_delta_value = ( + modified_scenario_data["modified_total_sales"] + - original_scenario_data["actual_total_sales"] + ) / st.session_state["multiplier"] + + # Calculate the percentage changes for spends and metrics + spends_percentage_change = ( + spends_delta_value + / ( + original_scenario_data["actual_total_spends"] + / st.session_state["multiplier"] + ) + ) * 100 + + metrics_percentage_change_media = ( + metrics_delta_value + / ( + ( + original_scenario_data["actual_total_sales"] + / st.session_state["multiplier"] + ) + - base_contribution + ) + ) * 100 + + metrics_percentage_change_all = ( + metrics_delta_value + / ( + original_scenario_data["actual_total_sales"] + / st.session_state["multiplier"] + ) + ) * 100 + + # Format the percentage change for display + spends_percentage_display = ( + f"({round(spends_percentage_change, 1)}%)" + if abs(spends_percentage_change) >= 0.1 + else "(0%)" + ) + metrics_percentage_display_media = ( + f"({round(metrics_percentage_change_media, 1)}%)" + if abs(metrics_percentage_change_media) >= 0.1 + else "(0%)" + ) + metrics_percentage_display_all = ( + f"({round(metrics_percentage_change_all, 1)}%)" + if abs(metrics_percentage_change_all) >= 0.1 + else "(0%)" + ) + + # Check if the delta for spends is less than 0.1% in absolute terms + if abs(spends_delta_value) < 0.001 * original_scenario_data["actual_total_spends"]: + spends_delta = "0" + else: + spends_delta = numerize(spends_delta_value) + + # Check if the delta for metrics is less than 0.1% in absolute terms + if abs(metrics_delta_value) < 0.001 * original_scenario_data["actual_total_sales"]: + metrics_delta = "0" + else: + metrics_delta = numerize(metrics_delta_value) + + # Display current and optimized CPA + actual_CPA = ( + original_scenario_data["actual_total_spends"] + / original_scenario_data["actual_total_sales"] + ) + optimized_CPA = ( + modified_scenario_data["modified_total_spends"] + / modified_scenario_data["modified_total_sales"] + ) + CPA_delta_value = optimized_CPA - actual_CPA + + # Calculate the percentage change for CPA + CPA_percentage_change = ( + ((CPA_delta_value / actual_CPA) * 100) if actual_CPA != 0 else 0 + ) + CPA_percentage_display = ( + f"({round(CPA_percentage_change, 1)}%)" + if abs(CPA_percentage_change) >= 0.1 + else "(0%)" + ) + + # Check if the CPA delta is less than 0.1% in absolute terms + if abs(CPA_delta_value) < 0.001 * actual_CPA: + CPA_delta = "0" + else: + CPA_delta = round_value(CPA_delta_value) + + # Display the metrics with percentage changes + actual_CPA_col.metric( + "Actual CPA", + (numerize(actual_CPA) if actual_CPA >= 1000 else round_value(actual_CPA)), + ) + optimized_spends_col.metric( + "Optimized Spend", + f"{optimized_spends} {spends_percentage_display}", + delta=spends_delta, + ) + optimized_metrics_col.metric( + f"Optimized {metrics_selected_formatted}", + f"{optimized_metric_value} {metrics_percentage_display_all}", + delta=f"{metrics_delta} {metrics_percentage_display_media}", + ) + optimized_CPA_col.metric( + "Optimized CPA", + ( + f"{numerize(optimized_CPA) if optimized_CPA >= 1000 else round_value(optimized_CPA)} {CPA_percentage_display}" + ), + delta=CPA_delta, + delta_color="inverse", + ) + + # Displaying metrics in the columns + actual_spends_col.metric("Actual Spend", actual_spends) + actual_metrics_col.metric( + f"Actual {metrics_selected_formatted}", + actual_metric_value, + ) + + # Check if the percentage display for media starts with a negative sign + if str(metrics_percentage_display_all[1:]).startswith("-"): + # If negative, set the color to red + metrics_percentage_display_media_str = f'red {metrics_percentage_display_media}' + else: + # If positive, set the color to green + metrics_percentage_display_media_str = f'green {metrics_percentage_display_media}' + + # Display percentage calculation note + st.markdown( + f"**Note:** The percentage change for the response metric in {metrics_percentage_display_media_str} reflects the change based on the media-driven portion only, excluding the fixed base contribution and the percentage in black **{metrics_percentage_display_all}** represents the change based on the total response metric, including the base contribution. For spends, the percentage change **{spends_percentage_display}** is based on the total actual spends (base spends are always zero).", + unsafe_allow_html=True, + ) + + # Divider + st.divider() + + # Calculate ROI threshold + st.session_state.roi_threshold = ( + original_scenario_data["actual_total_sales"] + - sum(original_scenario_data["constant"]) + ) / original_scenario_data["actual_total_spends"] + + # Fetch and sort channels based on actual spends + channel_list = list( + sorted( + original_scenario_data["channels"], + key=lambda channel: ( + original_scenario_data["channels"][channel]["actual_total_spends"] + * original_scenario_data["channels"][channel]["conversion_rate"] + ), + reverse=True, + ) + ) + + # Create columns for optimization goal and buttons + ( + optimization_goal_col, + message_display_col, + button_col, + bounds_col, + ) = st.columns([3, 6, 3, 3]) + + # Display spinnner or message + with message_display_col: + st.write("###") + spinner_placeholder = st.empty() + + # Save Progress + with save_progress_col: + st.write("####") # Padding + save_progress_placeholder = st.empty() + + # Save page progress + with spinner_placeholder, st.spinner("Saving Progress ..."): + if save_progress_placeholder.button("Save Progress", use_container_width=True): + # Update DB + update_db( + prj_id=st.session_state["project_number"], + page_nam="Scenario Planner", + file_nam="project_dct", + pkl_obj=pickle.dumps(st.session_state["project_dct"]), + schema=schema, + ) + + # Store the message details in session state + with message_display_col: + st.session_state.message_display = { + "type": "success", + "message": "Progress saved successfully!", + "icon": "💾", + } + st.toast("Progress saved successfully!", icon="💾") + + # Create columns for absolute text, slider, percentage number and bound type + absolute_text_col, absolute_slider_col, percentage_number_col, all_bounds_col = ( + st.columns([2, 4, 2, 2]) + ) + + # Dropdown for selecting optimization goal + optimization_goal = optimization_goal_col.selectbox( + "Fix", ["Spend", metrics_selected_formatted] + ) + + # Button columns with padding for alignment + with button_col: + st.write("##") # Padding + optimize_button_col, reset_button_col = st.columns(2) + reset_button_col.button( + "Reset", + use_container_width=True, + on_click=reset_scenario, + args=(metrics_selected, panel_selected), + ) + + # Absolute value display + if optimization_goal == "Spend": + absolute_value = modified_scenario_data["actual_total_spends"] + st.session_state.total_absolute_main_key = numerize( + modified_scenario_data["modified_total_spends"] + / st.session_state["multiplier"] + ) + else: + absolute_value = modified_scenario_data["actual_total_sales"] + st.session_state.total_absolute_main_key = numerize( + modified_scenario_data["modified_total_sales"] + / st.session_state["multiplier"] + ) + + total_absolute = absolute_text_col.text_input( + "Absolute", + key="total_absolute_main_key", + on_change=total_absolute_main_key_change, + args=( + metrics_selected, + panel_selected, + optimization_goal, + ), + ) + + # Generate and process slider options + slider_options = list( + np.linspace(int(0.5 * absolute_value), int(1.5 * absolute_value), 50) + ) # Generate range + slider_options.append( + modified_scenario_data["modified_total_spends"] + if optimization_goal == "Spend" + else modified_scenario_data["modified_total_sales"] + ) + slider_options = sorted(slider_options) # Sort the list + numerized_slider_options = [ + numerize(value / st.session_state["multiplier"]) for value in slider_options + ] # Numerize each value + + # Slider for adjusting absolute value within a range + st.session_state.total_absolute_key = numerize( + modified_scenario_data["modified_total_spends"] / st.session_state["multiplier"] + if optimization_goal == "Spend" + else modified_scenario_data["modified_total_sales"] + / st.session_state["multiplier"] + ) + + slider_value = absolute_slider_col.select_slider( + "Absolute", + numerized_slider_options, + key="total_absolute_key", + on_change=total_absolute_key_change, + args=( + metrics_selected, + panel_selected, + optimization_goal, + ), + ) + + # Number input for percentage value + if optimization_goal == "Spend": + st.session_state.total_percentage_key = int( + round( + ( + ( + modified_scenario_data["modified_total_spends"] + - modified_scenario_data["actual_total_spends"] + ) + / modified_scenario_data["actual_total_spends"] + ) + * 100, + 0, + ) + ) + else: + st.session_state.total_percentage_key = int( + round( + ( + ( + modified_scenario_data["modified_total_sales"] + - modified_scenario_data["actual_total_sales"] + ) + / modified_scenario_data["actual_total_sales"] + ) + * 100, + 0, + ) + ) + + percentage_target = percentage_number_col.number_input( + "Percentage", + min_value=-50, + max_value=50, + key="total_percentage_key", + on_change=total_percentage_key_change, + args=( + metrics_selected, + panel_selected, + absolute_value, + optimization_goal, + ), + ) + + # Toggle input for bound type + st.session_state["bound_type_key"] = modified_scenario_data["bound_type"] + with bounds_col: + st.write("##") # Padding + + # Columns for custom bounds toggle and apply all bounds button + allow_custom_bounds_col, apply_all_bounds_col = st.columns(2) + + # Toggle for enabling/disabling custom bounds + bound_type = allow_custom_bounds_col.toggle( + "Bounds", + on_change=bound_type_change, + key="bound_type_key", + ) + + # Button to apply all bounds + apply_all_bounds = apply_all_bounds_col.button( + "Apply All", + use_container_width=True, + on_click=all_bound_change, + args=(channel_list, True), + disabled=not bound_type, + ) + + # Section for setting all lower and upper bounds + with all_bounds_col: + lower_bound_all, upper_bound_all = st.columns([1, 1]) + + # Initialize session state keys for lower and upper bounds + st.session_state["all_lower_key"] = (modified_scenario_data["bounds"])[0] + st.session_state["all_upper_key"] = (modified_scenario_data["bounds"])[1] + + # Input for all lower bounds + all_lower_bound = lower_bound_all.number_input( + "All Lower Bounds", + min_value=-100, + max_value=100, + key="all_lower_key", + on_change=all_bound_change, + args=(channel_list, False), + disabled=not bound_type, + ) + + # Input for all upper bounds + all_upper_bound = upper_bound_all.number_input( + "All Upper Bounds", + min_value=-100, + max_value=100, + key="all_upper_key", + on_change=all_bound_change, + args=(channel_list, False), + disabled=not bound_type, + ) + + # Collect inputs from the user interface + total_channel_spends, optimize_allow = 0, True + bounds_dict = {} + s_curve_params = {} + channels_spends = {} + channels_proportion = {} + channels_conversion_ratio = {} + channels_name_plot_placeholder = {} + + # Optimization Inputs UI + with st.expander("Optimization Inputs", expanded=True): + # Initialize total contributions for actual and optimized spends and metrics + ( + total_actual_spend_contribution, + total_actual_metric_contribution, + total_optimized_spend_contribution, + total_optimized_metric_contribution, + ) = ( + 0, + sum(modified_scenario_data["constant"]), + 0, + sum(modified_scenario_data["constant"]), + ) + + # Iterate over each channel in the channel list + for channel in channel_list: + # Accumulate actual total spends + total_actual_spend_contribution += ( + modified_scenario_data["channels"][channel]["actual_total_spends"] + * modified_scenario_data["channels"][channel]["conversion_rate"] + ) + + # Accumulate actual total sales (metrics) + total_actual_metric_contribution += modified_scenario_data["channels"][ + channel + ]["actual_total_sales"] + + # Accumulate optimized total spends + total_optimized_spend_contribution += ( + modified_scenario_data["channels"][channel]["modified_total_spends"] + * modified_scenario_data["channels"][channel]["conversion_rate"] + ) + + # Accumulate optimized total sales (metrics) + total_optimized_metric_contribution += modified_scenario_data["channels"][ + channel + ]["modified_total_sales"] + + for channel in channel_list: + + st.divider() + + # Channel key + channel_key = f"{metrics_selected}_{panel_selected}_{channel}" + + # Create columns + if st.session_state["bound_type_key"]: + ( + name_plot_col, + input_col, + spends_col, + metrics_col, + bounds_input_col, + bounds_display_col, + allow_col, + ) = st.columns([3, 2, 2, 2, 2, 2, 1]) + else: + ( + name_plot_col, + input_col, + spends_col, + metrics_col, + bounds_display_col, + allow_col, + ) = st.columns([1.5, 1, 1.5, 1.5, 1, 0.5]) + bounds_input_col = st.empty() + + # Display channel name and ROI/MROI plot + with name_plot_col: + # Placeholder for channel name + channel_name_placeholder = st.empty() + channel_name_placeholder.markdown( + display_channel_name_with_background_color(channel), + unsafe_allow_html=True, + ) + + # Placeholder for ROI and MROI plot + channel_plot_placeholder = st.container() + + # Store placeholder for channel name and ROI/MROI plots + channels_name_plot_placeholder[channel] = { + "channel_name_placeholder": channel_name_placeholder, + "channel_plot_placeholder": channel_plot_placeholder, + } + + # Channel spends and sales + channel_spends_actual = ( + original_scenario_data["channels"][channel]["actual_total_spends"] + * original_scenario_data["channels"][channel]["conversion_rate"] + ) + channel_metrics_actual = original_scenario_data["channels"][channel][ + "actual_total_sales" + ] + + channel_spends_modified = ( + modified_scenario_data["channels"][channel]["modified_total_spends"] + * original_scenario_data["channels"][channel]["conversion_rate"] + ) + channel_metrics_modified = modified_scenario_data["channels"][channel][ + "modified_total_sales" + ] + + # Channel spends input + with input_col: + # Absolute Spends Input + st.session_state[f"{channel_key}_abs_spends_key"] = numerize( + modified_scenario_data["channels"][channel]["modified_total_spends"] + * original_scenario_data["channels"][channel]["conversion_rate"] + / st.session_state["multiplier"] + ) + absolute_channel_spends = st.text_input( + "Absolute Spends", + key=f"{channel_key}_abs_spends_key", + on_change=absolute_channel_spends_change, + args=( + channel_key, + channel_spends_actual, + channel, + metrics_selected, + panel_selected, + ), + ) + + # Update Percentage Spends Input + st.session_state[f"{channel_key}_per_spends_key"] = int( + round( + ( + ( + convert_to_float( + st.session_state[f"{channel_key}_abs_spends_key"] + ) + * st.session_state["multiplier"] + - float(channel_spends_actual) + ) + / channel_spends_actual + ) + * 100, + 0, + ) + ) + + # Percentage Spends Input + percentage_channel_spends = st.number_input( + "Percentage Spends", + min_value=-1000, + max_value=1000, + key=f"{channel_key}_per_spends_key", + on_change=percentage_channel_spends_change, + args=( + channel_key, + channel_spends_actual, + channel, + metrics_selected, + panel_selected, + ), + ) + + # Store channel spends, conversion ratio and proportion list + channels_spends[channel] = original_scenario_data["channels"][channel][ + "actual_total_spends" + ] * (1 + percentage_channel_spends / 100) + + channels_conversion_ratio[channel] = original_scenario_data["channels"][ + channel + ]["conversion_rate"] + + channels_proportion[channel] = original_scenario_data["channels"][ + channel + ]["spends"] / sum(original_scenario_data["channels"][channel]["spends"]) + + # Calculate the percent contribution of actual spends for the channel + channel_actual_spend_contribution = round( + ( + modified_scenario_data["channels"][channel][ + "actual_total_spends" + ] + * channels_conversion_ratio[channel] + / total_actual_spend_contribution + ) + * 100, + 1, + ) + + # Calculate the percent contribution of actual metrics (sales) for the channel + channel_actual_metric_contribution = round( + ( + modified_scenario_data["channels"][channel][ + "actual_total_sales" + ] + / total_actual_metric_contribution + ) + * 100, + 1, + ) + + # Calculate the percent contribution of optimized spends for the channel + channel_optimized_spend_contribution = round( + ( + modified_scenario_data["channels"][channel][ + "modified_total_spends" + ] + * channels_conversion_ratio[channel] + / total_optimized_spend_contribution + ) + * 100, + 1, + ) + + # Calculate the percent contribution of optimized metrics (sales) for the channel + channel_optimized_metric_contribution = round( + ( + modified_scenario_data["channels"][channel][ + "modified_total_sales" + ] + / total_optimized_metric_contribution + ) + * 100, + 1, + ) + + # Channel metrics display + with metrics_col: + # Absolute Metrics + st.metric( + f"Actual {name_formating(metrics_selected)}", + value=str( + numerize( + channel_metrics_actual / st.session_state["multiplier"] + ) + ) + + f"({channel_actual_metric_contribution}%)", + ) + + # Optimized Metrics + optimized_metric = ( + channel_metrics_modified / st.session_state["multiplier"] + ) + actual_metric = channel_metrics_actual / st.session_state["multiplier"] + delta_value = ( + channel_metrics_modified - channel_metrics_actual + ) / st.session_state["multiplier"] + + # Check if the delta is less than 0.1% in absolute terms + if ( + abs(delta_value) < 0.001 * actual_metric + ): # 0.1% of the actual metric + delta_display = "0" + else: + delta_display = numerize(delta_value) + + st.metric( + f"Optimized {name_formating(metrics_selected)}", + value=str(numerize(optimized_metric)) + + f"({channel_optimized_metric_contribution}%)", + delta=delta_display, + ) + + # Channel spends display + with spends_col: + # Absolute Spends + st.metric( + "Actual Spend", + value=str( + numerize(channel_spends_actual / st.session_state["multiplier"]) + ) + + f"({channel_actual_spend_contribution}%)", + ) + + # Optimized Spends + optimized_spends = ( + channel_spends_modified / st.session_state["multiplier"] + ) + actual_spends = channel_spends_actual / st.session_state["multiplier"] + delta_spends_value = ( + channel_spends_modified - channel_spends_actual + ) / st.session_state["multiplier"] + + # Check if the delta is less than 0.1% in absolute terms + if ( + abs(delta_spends_value) < 0.001 * actual_spends + ): # 0.1% of the actual spend + delta_spends_display = "0" + else: + delta_spends_display = numerize(delta_spends_value) + + st.metric( + "Optimized Spend", + value=str(numerize(optimized_spends)) + + f"({channel_optimized_spend_contribution}%)", + delta=delta_spends_display, + ) + + # Channel allows optimize + with allow_col: + # Allow Optimize (Freeze) + st.write("#") # Padding + st.session_state[f"{channel_key}_allow_optimize_key"] = ( + modified_scenario_data["channels"][channel]["freeze"] + ) + freeze = st.checkbox( + "Freeze", + key=f"{channel_key}_allow_optimize_key", + on_change=freeze_change, + args=( + metrics_selected, + panel_selected, + channel_key, + channel, + channel_list, + ), + ) + + # If channel is frozen, set bounds to keep the spend unchanged + if freeze: + lower_bound, upper_bound = 0, 0 # Freeze the spend at current level + + # Channel bounds input + if st.session_state["bound_type_key"]: + with bounds_input_col: + # Channel upper bound + st.session_state[f"{channel_key}_upper_key"] = ( + modified_scenario_data["channels"][channel]["bounds"] + )[1] + upper_bound = st.number_input( + "Upper bound (%)", + min_value=-100, + max_value=100, + key=f"{channel_key}_upper_key", + disabled=st.session_state[f"{channel_key}_allow_optimize_key"], + on_change=bound_change, + args=( + metrics_selected, + panel_selected, + channel_key, + channel, + ), + ) + + # Channel lower bound + st.session_state[f"{channel_key}_lower_key"] = ( + modified_scenario_data["channels"][channel]["bounds"] + )[0] + lower_bound = st.number_input( + "Lower bound (%)", + min_value=-100, + max_value=100, + key=f"{channel_key}_lower_key", + disabled=st.session_state[f"{channel_key}_allow_optimize_key"], + on_change=bound_change, + args=( + metrics_selected, + panel_selected, + channel_key, + channel, + ), + ) + + # Check if lower bound is greater than upper bound + if lower_bound > upper_bound: + lower_bound = -10 # Default lower bound + upper_bound = 10 # Default upper bound + + # Store bounds + bounds_dict[channel] = [lower_bound, upper_bound] + + else: + # If channel is frozen, set bounds to keep the spend unchanged + if freeze: + lower_bound, upper_bound = 0, 0 # Freeze the spend at current level + else: + lower_bound = -10 # Default lower bound + upper_bound = 10 # Default upper bound + + # Store bounds + bounds_dict[channel] = modified_scenario_data["channels"][channel][ + "bounds" + ] + + # Display the bounds for each channel's spend in the bounds_display_col + with bounds_display_col: + # Retrieve the actual spends for the channel from the original scenario data + actual_spends = ( + modified_scenario_data["channels"][channel]["modified_total_spends"] + * modified_scenario_data["channels"][channel]["conversion_rate"] + ) + + # Calculate the limit for spends + upper_limit_spends = actual_spends * (1 + upper_bound / 100) + lower_limit_spends = actual_spends * (1 + lower_bound / 100) + + # Display the upper limit spends + st.metric( + "Upper Bound", + numerize(upper_limit_spends / st.session_state["multiplier"]), + ) + st.metric( + "Lower Bound", + numerize(lower_limit_spends / st.session_state["multiplier"]), + ) + + # Store S-curve parameters + s_curve_params[channel] = get_s_curve_params( + metrics_selected, + panel_selected, + channel, + original_rcs_data, + modified_rcs_data, + ) + + # Total channel spends + total_channel_spends += ( + convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"]) + * st.session_state["multiplier"] + ) + + # Check if total channel spends are within the allowed range (±50% of the original total spends) + if ( + total_channel_spends > 1.5 * original_scenario_data["actual_total_spends"] + or total_channel_spends + < 0.5 * original_scenario_data["actual_total_spends"] + ): + # Store the message details in session state + st.session_state.message_display = { + "type": "warning", + "message": "Keep total spending within ±50% of the original value.", + "icon": "⚠️", + } + + if optimization_goal == "Spend": + # Get maximum achievable spends + lower_achievable_target, upper_achievable_target = 0, 0 + for channel in channel_list: + channel_spends_actual = ( + channels_spends[channel] * channels_conversion_ratio[channel] + ) + lower_achievable_target += channel_spends_actual * ( + 1 + bounds_dict[channel][0] / 100 + ) + upper_achievable_target += channel_spends_actual * ( + 1 + bounds_dict[channel][1] / 100 + ) + else: + # Get maximum achievable target metric + lower_achievable_target, upper_achievable_target = max_target_achievable( + channels_spends, + s_curve_params, + channels_proportion, + modified_scenario_data, + bounds_dict, + ) + + # Total target of selected metric + if optimization_goal == "Spend": + total_absolute_target = modified_scenario_data["modified_total_spends"] + else: + total_absolute_target = modified_scenario_data["modified_total_sales"] + + # Check if the target is achievable within the specified bounds + if optimize_allow: + optimize_allow = check_target_achievability( + optimize_allow, + name_formating(optimization_goal), + lower_achievable_target, + upper_achievable_target, + total_absolute_target, + ) + + # Perform the optimization + if optimize_button_col.button( + "Optimize", + use_container_width=True, + disabled=not optimize_allow, + key="run_optimizer", + ): + with message_display_col: + with spinner_placeholder, st.spinner("Optimizing ..."): + # Call the optimizer function to get optimized spends + optimized_spends, optimization_success = optimizer( + optimization_goal, + s_curve_params, + channels_spends, + channels_proportion, + channels_conversion_ratio, + total_absolute_target, + bounds_dict, + modified_scenario_data, + ) + + # Initialize dictionaries to store input and output channel spends + input_channels_spends, output_channels_spends = {}, {} + for channel in channel_list: + # Calculate input channel spends by converting spends using conversion ratio + input_channels_spends[channel] = ( + channels_spends[channel] * channels_conversion_ratio[channel] + ) + # Calculate output channel spends by converting optimized spends using conversion ratio + output_channels_spends[channel] = ( + optimized_spends[channel] * channels_conversion_ratio[channel] + ) + + # Calculate total actual and modified spends + actual_total_spends = sum(list(input_channels_spends.values())) + modified_total_spends = sum(list(output_channels_spends.values())) + + # Retrieve the actual total metrics from modified scenario data + actual_total_metrics = modified_scenario_data["modified_total_sales"] + modified_total_metrics = 0 # Initialize modified total metrics + modified_channels_metrics = {} + + # Calculate modified metrics for each channel + for channel in optimized_spends.keys(): + channel_s_curve_params = s_curve_params[channel] + spend_proportion = ( + optimized_spends[channel] * channels_proportion[channel] + ) + # Calculate the metrics using the S-curve function + modified_channels_metrics[channel] = sum( + s_curve( + spend_proportion, + channel_s_curve_params["power"], + channel_s_curve_params["K"], + channel_s_curve_params["b"], + channel_s_curve_params["a"], + channel_s_curve_params["x0"], + ) + ) + sum( + modified_scenario_data["channels"][channel]["correction"] + ) # correction for s-curve + + modified_total_metrics += modified_channels_metrics[ + channel + ] # Add channel metrics to total metrics + + # Add the constant and correction term to the modified total metrics + modified_total_metrics += sum(modified_scenario_data["constant"]) + + # Retrieve the original total spends from modified scenario data + original_total_spends = modified_scenario_data["actual_total_spends"] + + # Check the success of the optimization process + success, message, icon = check_optimization_success( + channel_list, + input_channels_spends, + output_channels_spends, + bounds_dict, + optimization_goal, + modified_total_metrics, + actual_total_metrics, + modified_total_spends, + actual_total_spends, + original_total_spends, + optimization_success, + ) + + # Store the message details in session state + st.session_state.message_display = { + "type": "success" if success else "error", + "message": message, + "icon": icon, + } + + # Update data only if the optimization is successful + if success: + # Update the modified spend and metrics for each channel in the scenario data + for channel in channel_list: + modified_scenario_data["channels"][channel][ + "modified_total_spends" + ] = optimized_spends[channel] + + # Update the modified metrics for each channel in the scenario data + modified_scenario_data["channels"][channel][ + "modified_total_sales" + ] = modified_channels_metrics[channel] + + # Update the total modified spends in the scenario data + modified_scenario_data["modified_total_spends"] = ( + modified_total_spends + ) + + # Update the total modified metrics in the scenario data + modified_scenario_data["modified_total_sales"] = ( + modified_total_metrics + ) + + # Load modified scenario data + data = st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] + + # Update the specific section with the modified scenario data + data[metrics_selected][panel_selected] = modified_scenario_data + + # Update modified scenario metadata + st.session_state["project_dct"]["scenario_planner"][ + "modified_metadata_file" + ] = data + + # Reset optimizer button + del st.session_state["run_optimizer"] + + # Rerun to update values + st.rerun() + + ########################################## Response Curves ########################################## + + # Generate plots + figures, channel_roi_mroi, region_start_end = generate_response_curve_plots( + channel_list, + s_curve_params, + channels_proportion, + original_scenario_data, + st.session_state["multiplier"], + ) + + # Display Response Curves + st.subheader(f"Response Curves (X: Spends Vs Y: {metrics_selected_formatted})") + with st.expander("Response Curves", expanded=True): + cols = st.columns(4) # Create 4 columns for the first row + for i, fig in enumerate(figures): + col = cols[i % 4] # Rotate through the columns + with col: + # Get channel parameters + channel = channel_list[i] + modified_total_spends = modified_scenario_data["channels"][channel][ + "modified_total_spends" + ] + conversion_rate = modified_scenario_data["channels"][channel][ + "conversion_rate" + ] + channel_correction = sum( + modified_scenario_data["channels"][channel]["correction"] + ) + + # Updated figure with modified metrics point + roi_optimized, mroi_optimized, fig_updated = modified_metrics_point( + fig, + modified_total_spends, + s_curve_params[channel], + channels_proportion[channel], + conversion_rate, + channel_correction, + ) + + # Store data of each channel ROI and MROI + channel_roi_mroi[channel]["optimized_roi"] = roi_optimized + channel_roi_mroi[channel]["optimized_mroi"] = mroi_optimized + + st.plotly_chart(fig_updated, use_container_width=True) + + # Start a new row after every 4 plots + if (i + 1) % 4 == 0 and i + 1 < len(figures): + cols = st.columns(4) # Create new row with 4 columns + + # Generate the plots + channel_roi_mroi_plot = roi_mori_plot(channel_roi_mroi) + + # Display the plots and name with background color + for channel in channel_list: + with channels_name_plot_placeholder[channel]["channel_plot_placeholder"]: + # Create subplots with 2 columns for ROI and MROI + roi_plot_col, mroi_plot_col = st.columns(2) + + # Display ROI and MROI plots + roi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_roi"]) + mroi_plot_col.plotly_chart(channel_roi_mroi_plot[channel]["fig_mroi"]) + + # Placeholder for the channel name + channel_name_placeholder = channels_name_plot_placeholder[channel][ + "channel_name_placeholder" + ] + + # Retrieve modified total spends and conversion rate for the channel + modified_total_spends = modified_scenario_data["channels"][channel][ + "modified_total_spends" + ] + conversion_rate = modified_scenario_data["channels"][channel]["conversion_rate"] + + # Calculate the actual spend value for the channel + channel_spends_value = modified_total_spends * conversion_rate + + # Calculate the RGBA color value for the channel based on its spend + channel_rgba_value = calculate_rgba( + channel_spends_value, region_start_end[channel] + ) + + # Display the channel name with the calculated background color + channel_name_placeholder.markdown( + display_channel_name_with_background_color(channel, channel_rgba_value), + unsafe_allow_html=True, + ) + + # Input field for the scenario name + st.text_input("Scenario Name", key="scenario_name") + + # Disable the "Save Scenario" button until a name is provided + if ( + st.session_state["scenario_name"] is None + or st.session_state["scenario_name"] == "" + ): + save_scenario_button_disabled = True + else: + save_scenario_button_disabled = False + + # Button to save the scenario + save_button_placeholder = st.empty() + with st.spinner("Saving ..."): + save_button_placeholder.button( + "Save Scenario", + on_click=save_scenario, + args=( + modified_scenario_data, + metrics_selected, + panel_selected, + optimization_goal, + channel_roi_mroi, + st.session_state["timeframe_selected_selectbox_sp"], + st.session_state["multiplier"], + ), + disabled=save_scenario_button_disabled, + ) + + ########################################## Display Message ########################################## + + # Display all message + with message_display_col: + display_message() + +except Exception as e: + # Capture the error details + exc_type, exc_value, exc_traceback = sys.exc_info() + error_message = "".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + + # Log message + log_message("error", f"An error occurred: {error_message}.", "Scenario Planner") + + # Display a warning message + st.warning( + "Oops! Something went wrong. Please try refreshing the tool or creating a new project.", + icon="⚠️", + )