diff --git "a/pages/9_Scenario_Planner.py" "b/pages/9_Scenario_Planner.py"
new file mode 100644--- /dev/null
+++ "b/pages/9_Scenario_Planner.py"
@@ -0,0 +1,2977 @@
+# Importing necessary libraries
+import streamlit as st
+
+st.set_page_config(
+ page_title="Scenario Planner",
+ page_icon="⚖️",
+ layout="wide",
+ initial_sidebar_state="collapsed",
+)
+
+# Disable +/- for number input
+st.markdown(
+ """
+""",
+ unsafe_allow_html=True,
+)
+
+import re
+import sys
+import copy
+import pickle
+import traceback
+import numpy as np
+import pandas as pd
+from scenario import numerize
+import plotly.graph_objects as go
+from post_gres_cred import db_cred
+from scipy.optimize import minimize
+from log_application import log_message
+from utilities import project_selection, update_db, set_header, load_local_css
+from utilities import (
+ get_panels_names,
+ get_metrics_names,
+ name_formating,
+ load_rcs_metadata_files,
+ load_scenario_metadata_files,
+ generate_rcs_data,
+ generate_scenario_data,
+)
+from constants import (
+ xtol_tolerance_per,
+ mroi_threshold,
+ word_length_limit_lower,
+ word_length_limit_upper,
+)
+
+
+schema = db_cred["schema"]
+load_local_css("styles.css")
+set_header()
+
+# Initialize project name session state
+if "project_name" not in st.session_state:
+ st.session_state["project_name"] = None
+
+# Fetch project dictionary
+if "project_dct" not in st.session_state:
+ project_selection()
+ st.stop()
+
+# Display Username and Project Name
+if "username" in st.session_state and st.session_state["username"] is not None:
+
+ cols1 = st.columns([2, 1])
+
+ with cols1[0]:
+ st.markdown(f"**Welcome {st.session_state['username']}**")
+ with cols1[1]:
+ st.markdown(f"**Current Project: {st.session_state['project_name']}**")
+
+# Initialize ROI threshold
+if "roi_threshold" not in st.session_state:
+ st.session_state.roi_threshold = 1
+
+# Initialize message display holder
+if "message_display" not in st.session_state:
+ st.session_state.message_display = {"type": "success", "message": None, "icon": ""}
+
+
+# Function to reset modified_scenario_data
+def reset_scenario(metrics_selected=None, panel_selected=None):
+ # Clear message_display
+ st.session_state.message_display = {"type": "success", "message": None, "icon": ""}
+
+ # Use default values from session state if not provided
+ if metrics_selected is None:
+ metrics_selected = st.session_state["response_metrics_selectbox_sp"]
+ if panel_selected is None:
+ panel_selected = st.session_state["panel_selected_selectbox_sp"]
+
+ # Load original scenario data
+ original_data = st.session_state["project_dct"]["scenario_planner"][
+ "original_metadata_file"
+ ]
+ original_scenario_data = original_data[metrics_selected][panel_selected]
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Update the specific section with the original scenario data
+ data[metrics_selected][panel_selected] = copy.deepcopy(original_scenario_data)
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+
+# Function to build s curve
+def s_curve(x, power, K, b, a, x0):
+ return K / (1 + b * np.exp(-a * ((x / 10**power) - x0)))
+
+
+# Function to retrieve S-curve parameters for a given metric, panel, and channel
+def get_s_curve_params(
+ metrics_selected,
+ panel_selected,
+ channel_selected,
+ original_rcs_data,
+ modified_rcs_data,
+):
+ # Retrieve 'power' parameter from the original data for the specific metric, panel, and channel
+ power = original_rcs_data[metrics_selected][panel_selected][channel_selected][
+ "power"
+ ]
+
+ # Get the S-curve parameters from the modified data for the same metric, panel, and channel
+ s_curve_param = modified_rcs_data[metrics_selected][panel_selected][
+ channel_selected
+ ]
+
+ # Load modified scenario metadata
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Update modified S-curve parameters
+ data[metrics_selected][panel_selected]["channels"][channel_selected][
+ "response_curve_params"
+ ] = s_curve_param
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+ # Update the 'power' parameter in the modified S-curve parameters with the original 'power' value
+ s_curve_param["power"] = power
+
+ # Return the updated S-curve parameters
+ return s_curve_param
+
+
+# Function to calculate total contribution
+def get_total_contribution(
+ spends, channels, s_curve_params, channels_proportion, modified_scenario_data
+):
+ total_contribution = 0
+ for i in range(len(channels)):
+ channel_name = channels[i]
+ channel_s_curve_params = s_curve_params[channel_name]
+ spend_proportion = spends[i] * channels_proportion[channel_name]
+ total_contribution += sum(
+ s_curve(
+ spend_proportion,
+ channel_s_curve_params["power"],
+ channel_s_curve_params["K"],
+ channel_s_curve_params["b"],
+ channel_s_curve_params["a"],
+ channel_s_curve_params["x0"],
+ )
+ ) + sum(
+ modified_scenario_data["channels"][channel_name]["correction"]
+ ) # correction for s-curve
+ return total_contribution + sum(modified_scenario_data["constant"])
+
+
+# Function to calculate total spends
+def get_total_spends(spends, channels_conversion_ratio):
+ return np.sum(spends * np.array(list(channels_conversion_ratio.values())))
+
+
+# Function to optimizes spends for all channels given bounds and a total spend target
+def optimizer(
+ optimization_goal,
+ s_curve_params,
+ channels_spends,
+ channels_proportion,
+ channels_conversion_ratio,
+ total_target,
+ bounds_dict,
+ modified_scenario_data,
+):
+ # Extract channel names and corresponding actual spends
+ channels = list(channels_spends.keys())
+ actual_spends = np.array(list(channels_spends.values()))
+ num_channels = len(actual_spends)
+
+ # Define the objective function based on the optimization goal
+ def objective_fun(spends):
+ if optimization_goal == "Spend":
+ # Minimize negative total contribution to maximize the total contribution
+ return -get_total_contribution(
+ spends,
+ channels,
+ s_curve_params,
+ channels_proportion,
+ modified_scenario_data,
+ )
+ else:
+ # Minimize total spends
+ return get_total_spends(spends, channels_conversion_ratio)
+
+ def constraint_fun(spends):
+ if optimization_goal == "Spend":
+ # Ensure the total spends equals the total spend target
+ return get_total_spends(spends, channels_conversion_ratio)
+ else:
+ # Ensure the total contribution equals the total contribution target
+ return get_total_contribution(
+ spends,
+ channels,
+ s_curve_params,
+ channels_proportion,
+ modified_scenario_data,
+ )
+
+ # Equality constraint
+ constraints = {
+ "type": "eq",
+ "fun": lambda spends: constraint_fun(spends) - total_target,
+ } # Sum of all channel spends/metrics should equal the total spend/metrics target
+
+ # Bounds for each channel's spend based
+ bounds = [
+ (
+ actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100),
+ actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100),
+ )
+ for i in range(num_channels)
+ ]
+
+ # Initial guess for the optimization
+ initial_guess = np.array(actual_spends)
+
+ # Calculate xtol as n% of the minimum of spends
+ xtol = max(10, (xtol_tolerance_per / 100) * np.min(actual_spends))
+
+ # Perform the optimization using 'trust-constr' method
+ result = minimize(
+ objective_fun,
+ initial_guess,
+ method="trust-constr",
+ constraints=constraints,
+ bounds=bounds,
+ options={
+ "disp": True, # Display the optimization process
+ "xtol": xtol, # Dynamic step size tolerance
+ "maxiter": 1e5, # Maximum number of iterations
+ },
+ )
+
+ # Extract the optimized spends from the result
+ optimized_spends_array = result.x
+
+ # Convert optimized spends back to a dictionary with channel names
+ optimized_spends = {
+ channels[i]: max(0, optimized_spends_array[i]) for i in range(num_channels)
+ }
+
+ return optimized_spends, result.success
+
+
+# Function to calculate achievable targets at lower and upper spend bounds
+@st.cache_data(show_spinner=False)
+def max_target_achievable(
+ channels_spends,
+ s_curve_params,
+ channels_proportion,
+ modified_scenario_data,
+ bounds_dict,
+):
+ # Extract channel names and corresponding actual spends
+ channels = list(channels_spends.keys())
+ actual_spends = np.array(list(channels_spends.values()))
+ num_channels = len(actual_spends)
+
+ # Bounds for each channel's spend
+ lower_spends, upper_spends = [], []
+ for i in range(num_channels):
+ lower_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][0] / 100))
+ upper_spends.append(actual_spends[i] * (1 + bounds_dict[channels[i]][1] / 100))
+
+ # Calculate achievable targets at lower and upper spend bounds
+ lower_achievable_target = get_total_contribution(
+ lower_spends,
+ channels,
+ s_curve_params,
+ channels_proportion,
+ modified_scenario_data,
+ )
+ upper_achievable_target = get_total_contribution(
+ upper_spends,
+ channels,
+ s_curve_params,
+ channels_proportion,
+ modified_scenario_data,
+ )
+
+ # Return achievable targets with ±0.1% safety margin
+ return max(0, 1.001 * lower_achievable_target), 0.999 * upper_achievable_target
+
+
+# Function to check if number is in valid format
+def is_valid_number_format(number_str):
+ # Check for None
+ if number_str is None:
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+
+ # Define the valid suffixes
+ valid_suffixes = {"K", "M", "B", "T"}
+
+ # Check for negative numbers
+ if number_str[0] == "-":
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+
+ # Check if the string ends with a digit
+ if number_str[-1].isdigit():
+ try:
+ # Attempt to convert the entire string to float
+ number = float(number_str)
+ # Ensure the number is non-negative
+ if number >= 0:
+ return True
+ else:
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+ except ValueError:
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+
+ # Check if the string ends with a valid suffix
+ suffix = number_str[-1].upper()
+ if suffix in valid_suffixes:
+ num_part = number_str[:-1] # Extract the numerical part
+ try:
+ # Attempt to convert the numerical part to float
+ number = float(num_part)
+ # Ensure the number part is non-negative
+ if number >= 0:
+ return True
+ else:
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+ except ValueError:
+ # Store the message details in session state for invalid input
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+
+ # If neither condition is met, return False
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Invalid input: Please enter a valid number.",
+ "icon": "⚠️",
+ }
+ return False
+
+
+# Function to converts a string with number suffixes (K, M, B, T) to a float
+def convert_to_float(number_str):
+ # Dictionary mapping suffixes to their multipliers
+ multipliers = {
+ "K": 1e3, # Thousand
+ "M": 1e6, # Million
+ "B": 1e9, # Billion
+ "T": 1e12, # Trillion
+ }
+
+ # If there's no suffix, directly convert to float
+ if number_str[-1].isdigit():
+ return float(number_str)
+
+ # Extract the suffix (last character) and the numerical part
+ suffix = number_str[-1].upper()
+ num_part = number_str[:-1]
+
+ # Convert the numerical part to float and multiply by the corresponding multiplier
+ return float(num_part) * multipliers[suffix]
+
+
+# Function to update absolute_channel_spends change
+def absolute_channel_spends_change(
+ channel_key, channel_spends_actual, channel, metrics_selected, panel_selected
+):
+ # Do not update if the number is in an invalid format
+ if not is_valid_number_format(st.session_state[f"{channel_key}_abs_spends_key"]):
+ return
+
+ # Get updated absolute spends from session state
+ new_absolute_spends = (
+ convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"])
+ * st.session_state["multiplier"]
+ )
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Total channel spends
+ total_channel_spends = 0
+ for current_channel in list(
+ data[metrics_selected][panel_selected]["channels"].keys()
+ ):
+ # Channel key
+ channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
+
+ total_channel_spends += (
+ convert_to_float(st.session_state[f"{channel_key}_abs_spends_key"])
+ * st.session_state["multiplier"]
+ )
+
+ # Check if total channel spends are within the allowed range (±50% of the original total spends)
+ if (
+ total_channel_spends
+ < 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
+ and total_channel_spends
+ > 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
+ ):
+ # Update the modified_total_spends for the specified channel
+ data[metrics_selected][panel_selected]["channels"][channel][
+ "modified_total_spends"
+ ] = new_absolute_spends / float(
+ data[metrics_selected][panel_selected]["channels"][channel][
+ "conversion_rate"
+ ]
+ )
+
+ # Update total spends
+ data[metrics_selected][panel_selected][
+ "modified_total_spends"
+ ] = total_channel_spends
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"][
+ "modified_metadata_file"
+ ] = data
+
+ else:
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Keep total spending within ±50% of the original value.",
+ "icon": "⚠️",
+ }
+
+
+# Function to update percentage_channel_spends change
+def percentage_channel_spends_change(
+ channel_key, channel_spends_actual, channel, metrics_selected, panel_selected
+):
+ # Retrieve the percentage spend change from session state
+ percentage_channel_spends = round(
+ st.session_state[f"{channel_key}_per_spends_key"], 0
+ )
+
+ # Calculate the new absolute spends
+ new_absolute_spends = channel_spends_actual * (1 + percentage_channel_spends / 100)
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Total channel spends
+ total_channel_spends = 0
+ for current_channel in list(
+ data[metrics_selected][panel_selected]["channels"].keys()
+ ):
+ # Channel key
+ channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
+
+ # Current channel spends actual
+ current_channel_spends_actual = data[metrics_selected][panel_selected][
+ "channels"
+ ][current_channel]["actual_total_spends"]
+
+ # Current channel conversion rate
+ current_channel_conversion_rate = data[metrics_selected][panel_selected][
+ "channels"
+ ][current_channel]["conversion_rate"]
+
+ # Calculate the current channel absolute spends
+ current_channel_absolute_spends = (
+ current_channel_spends_actual
+ * current_channel_conversion_rate
+ * (1 + st.session_state[f"{channel_key}_per_spends_key"] / 100)
+ )
+
+ total_channel_spends += current_channel_absolute_spends
+
+ # Check if total channel spends are within the allowed range (±50% of the original total spends)
+ if (
+ total_channel_spends
+ < 1.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
+ and total_channel_spends
+ > 0.5 * data[metrics_selected][panel_selected]["actual_total_spends"]
+ ):
+ # Update the modified_total_spends for the specified channel
+ data[metrics_selected][panel_selected]["channels"][channel][
+ "modified_total_spends"
+ ] = float(new_absolute_spends) / float(
+ data[metrics_selected][panel_selected]["channels"][channel][
+ "conversion_rate"
+ ]
+ )
+
+ # Update total spends
+ data[metrics_selected][panel_selected][
+ "modified_total_spends"
+ ] = total_channel_spends
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"][
+ "modified_metadata_file"
+ ] = data
+
+ else:
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Keep total spending within ±50% of the original value.",
+ "icon": "⚠️",
+ }
+
+
+# # Function to update total input change
+# def total_input_change(per_change):
+# # Load modified scenario data
+# data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+# # Get the list of all channels in the specified panel and metric
+# channel_list = list(data[metrics_selected][panel_selected]["channels"].keys())
+
+# # Iterate over each channel to update their modified spends
+# for channel in channel_list:
+# # Retrieve the actual spends for the channel
+# channel_actual_spends = data[metrics_selected][panel_selected]["channels"][
+# channel
+# ]["actual_total_spends"]
+
+# # Calculate the modified spends for the channel based on the percent change
+# modified_channel_metrics = channel_actual_spends * ((100 + per_change) / 100)
+
+# # Update the channel's modified total spends in the data
+# data[metrics_selected][panel_selected]["channels"][channel][
+# "modified_total_spends"
+# ] = modified_channel_metrics
+
+# # Update modified scenario metadata
+# st.session_state["project_dct"]["scenario_planner"][
+# "modified_metadata_file"
+# ] = data
+
+
+# Function to update total input change
+def total_input_change(per_change, metrics_selected, panel_selected):
+ # Load modified and original scenario data
+ modified_data = st.session_state["project_dct"]["scenario_planner"][
+ "modified_metadata_file"
+ ].copy()
+ original_data = st.session_state["project_dct"]["scenario_planner"][
+ "original_metadata_file"
+ ].copy()
+
+ # Get the list of all channels in the selected panel and metric
+ channel_list = list(
+ modified_data[metrics_selected][panel_selected]["channels"].keys()
+ )
+
+ # Separate channels into unfrozen and frozen based on optimization status
+ unfrozen_channels, frozen_channels = [], []
+ for channel in channel_list:
+ channel_key = f"{metrics_selected}_{panel_selected}_{channel}"
+ if st.session_state.get(f"{channel_key}_allow_optimize_key", False):
+ frozen_channels.append(channel)
+ else:
+ unfrozen_channels.append(channel)
+
+ # Calculate spends and total share from frozen channels, weighted by conversion rate
+ frozen_channel_share, frozen_channel_spends = 0, 0
+ for channel in frozen_channels:
+ conversion_rate = original_data[metrics_selected][panel_selected]["channels"][
+ channel
+ ]["conversion_rate"]
+ actual_spends = original_data[metrics_selected][panel_selected]["channels"][
+ channel
+ ]["actual_total_spends"]
+ modified_spends = modified_data[metrics_selected][panel_selected]["channels"][
+ channel
+ ]["modified_total_spends"]
+ spends_diff = max(actual_spends, 1e-3) * ((100 + per_change) / 100) - max(
+ modified_spends, 1e-3
+ )
+ frozen_channel_share += spends_diff * conversion_rate
+ frozen_channel_spends += max(actual_spends, 1e-3) * conversion_rate
+
+ # Redistribute frozen share across unfrozen channels based on original spend weights
+ for channel in unfrozen_channels:
+ conversion_rate = original_data[metrics_selected][panel_selected]["channels"][
+ channel
+ ]["conversion_rate"]
+ actual_spends = original_data[metrics_selected][panel_selected]["channels"][
+ channel
+ ]["actual_total_spends"]
+
+ # Calculate weight of the current channel's original spends
+ total_original_spends = original_data[metrics_selected][panel_selected][
+ "actual_total_spends"
+ ]
+ channel_weight = (actual_spends * conversion_rate) / (
+ total_original_spends - frozen_channel_spends
+ )
+
+ # Calculate the modified spends with redistributed frozen share
+ modified_spends = (
+ max(actual_spends, 1e-3) * ((100 + per_change) / 100)
+ + (frozen_channel_share * channel_weight) / conversion_rate
+ )
+
+ # Update modified total spends in the modified data
+ modified_data[metrics_selected][panel_selected]["channels"][channel][
+ "modified_total_spends"
+ ] = modified_spends
+
+ # Save the updated modified scenario data back to the session state
+ st.session_state["project_dct"]["scenario_planner"][
+ "modified_metadata_file"
+ ] = modified_data
+
+
+# Function to update total_absolute_main_key change
+def total_absolute_main_key_change(metrics_selected, panel_selected, optimization_goal):
+ # Do not update if the number is in an invalid format
+ if not is_valid_number_format(st.session_state["total_absolute_main_key"]):
+ return
+
+ # Get updated absolute from session state
+ new_absolute = (
+ convert_to_float(st.session_state["total_absolute_main_key"])
+ * st.session_state["multiplier"]
+ )
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ if optimization_goal == "Spend":
+ # Retrieve the old absolute spends
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
+ else:
+ # Retrieve the old absolute metrics
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
+
+ # Calculate the allowable range for new spends
+ lower_bound = old_absolute * 0.5
+ upper_bound = old_absolute * 1.5
+
+ # Ensure the new spends are within ±50% of the old value
+ if new_absolute < lower_bound or new_absolute > upper_bound:
+ new_absolute = old_absolute
+
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Keep total spending within ±50% of the original value.",
+ "icon": "⚠️",
+ }
+
+ if optimization_goal == "Spend":
+ # Update the modified_total_spends with the constrained value
+ data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
+ else:
+ # Update the modified_total_sales with the constrained value
+ data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+ # Update total input change
+ if optimization_goal == "Spend":
+ per_change = ((new_absolute - old_absolute) / old_absolute) * 100
+ total_input_change(per_change, metrics_selected, panel_selected)
+
+
+# Function to update total_absolute_key change
+def total_absolute_key_change(metrics_selected, panel_selected, optimization_goal):
+ # Get updated absolute from session state
+ new_absolute = (
+ convert_to_float(st.session_state["total_absolute_key"])
+ * st.session_state["multiplier"]
+ )
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ if optimization_goal == "Spend":
+ # Update the modified_total_spends for the specified channel
+ data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
+ else:
+ # Update the modified_total_sales for the specified channel
+ data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+ # Update total input change
+ if optimization_goal == "Spend":
+ per_change = ((new_absolute - old_absolute) / old_absolute) * 100
+ total_input_change(per_change, metrics_selected, panel_selected)
+
+
+# Function to update total_absolute_key change
+def total_percentage_key_change(
+ metrics_selected,
+ panel_selected,
+ absolute_value,
+ optimization_goal,
+):
+ # Get updated absolute from session state
+ new_absolute = absolute_value * (1 + st.session_state["total_percentage_key"] / 100)
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ if optimization_goal == "Spend":
+ # Update the modified_total_spends for the specified channel
+ data[metrics_selected][panel_selected]["modified_total_spends"] = new_absolute
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_spends"]
+ else:
+ # Update the modified_total_sales for the specified channel
+ data[metrics_selected][panel_selected]["modified_total_sales"] = new_absolute
+ old_absolute = data[metrics_selected][panel_selected]["actual_total_sales"]
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+ # Update total input change
+ if optimization_goal == "Spend":
+ per_change = ((new_absolute - old_absolute) / old_absolute) * 100
+ total_input_change(per_change, metrics_selected, panel_selected)
+
+
+# Function to update bound change
+def bound_change(metrics_selected, panel_selected, channel_key, channel):
+ # Get updated bounds from session state
+ new_lower_bound = st.session_state[f"{channel_key}_lower_key"]
+ new_upper_bound = st.session_state[f"{channel_key}_upper_key"]
+ if new_lower_bound > new_upper_bound:
+ new_bounds = [-10, 10]
+
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Lower bound cannot be greater than Upper bound.",
+ "icon": "⚠️",
+ }
+
+ else:
+ new_bounds = [new_lower_bound, new_upper_bound]
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Update the bounds for the specified channel
+ data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+
+# Function to update freeze change
+def freeze_change(metrics_selected, panel_selected, channel_key, channel, channel_list):
+ # Initialize counter for channels that are not frozen
+ unfrozen_channel_count = 0
+
+ # Check the optimization status of each channel
+ for current_channel in channel_list:
+ current_channel_key = f"{metrics_selected}_{panel_selected}_{current_channel}"
+ unfrozen_channel_count += (
+ 1
+ if not st.session_state[f"{current_channel_key}_allow_optimize_key"]
+ else 0
+ )
+
+ # Ensure at least two channels are allowed for optimization
+ if unfrozen_channel_count < 2:
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Please allow at least two channels to be optimized.",
+ "icon": "⚠️",
+ }
+ return
+
+ if st.session_state[f"{channel_key}_allow_optimize_key"]:
+ # Updated bounds from session state
+ new_lower_bound, new_upper_bound = 0, 0
+ new_bounds = [new_lower_bound, new_upper_bound]
+ new_freeze = True
+ else:
+ # Updated bounds from session state
+ new_lower_bound, new_upper_bound = -10, 10
+ new_bounds = [new_lower_bound, new_upper_bound]
+ new_freeze = False
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Update the bounds for the specified channel
+ data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = new_bounds
+ data[metrics_selected][panel_selected]["channels"][channel]["freeze"] = new_freeze
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+
+# Function to calculate y, ROI and MROI for given point
+def get_point_parms(
+ x_val,
+ current_s_curve_params,
+ current_channel_proportion,
+ current_conversion_rate,
+ channel_correction,
+):
+ # Calculate y value for the given spend point
+ y_val = (
+ sum(
+ s_curve(
+ (x_val * current_channel_proportion),
+ current_s_curve_params["power"],
+ current_s_curve_params["K"],
+ current_s_curve_params["b"],
+ current_s_curve_params["a"],
+ current_s_curve_params["x0"],
+ )
+ )
+ + channel_correction
+ )
+
+ # Calculate MROI using a small nudge for actual spends
+ nudge = 1e-3
+ x1 = float(x_val * current_conversion_rate)
+ y1 = float(y_val)
+ x2 = x1 + nudge
+ y2 = (
+ sum(
+ s_curve(
+ ((x2 / current_conversion_rate) * current_channel_proportion),
+ current_s_curve_params["power"],
+ current_s_curve_params["K"],
+ current_s_curve_params["b"],
+ current_s_curve_params["a"],
+ current_s_curve_params["x0"],
+ )
+ )
+ + channel_correction
+ )
+ mroi_val = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0
+
+ # Calculate ROI
+ roi_val = y_val / (x_val * current_conversion_rate)
+
+ return roi_val, mroi_val, y_val
+
+
+# Function to find segment value
+def find_segment_value(x, roi, mroi, roi_threshold=1, mroi_threshold=0.05):
+ # Initialize the start and end values of the x array
+ start_value = x[0]
+ end_value = x[-1]
+
+ # Define the condition for the "green region" where both ROI and MROI exceed their respective thresholds
+ green_condition = (roi > roi_threshold) & (mroi > mroi_threshold)
+
+ # Find indices where ROI exceeds the ROI threshold
+ left_indices = np.where(roi > roi_threshold)[0]
+
+ # Find indices where both ROI and MROI exceed their thresholds (green condition)
+ right_indices = np.where(green_condition)[0]
+
+ # Determine the left value based on the first index where ROI exceeds the threshold
+ left_value = x[left_indices[0]] if left_indices.size > 0 else x[0]
+
+ # Determine the right value based on the last index where both ROI and MROI exceed their thresholds
+ right_value = x[right_indices[-1]] if right_indices.size > 0 else x[0]
+
+ # Ensure the left value does not exceed the right value, adjust if necessary
+ if left_value > right_value:
+ left_value = right_value
+
+ return start_value, end_value, left_value, right_value
+
+
+# Function to generate response curves plots
+@st.cache_data(show_spinner=False)
+def generate_response_curve_plots(
+ channel_list,
+ s_curve_params,
+ channels_proportion,
+ original_scenario_data,
+ multiplier,
+):
+ figures, channel_roi_mroi, region_start_end = [], {}, {}
+
+ for channel in channel_list:
+ spends_actual = original_scenario_data["channels"][channel][
+ "actual_total_spends"
+ ]
+ conversion_rate = original_scenario_data["channels"][channel]["conversion_rate"]
+ channel_correction = sum(
+ original_scenario_data["channels"][channel]["correction"]
+ )
+
+ x_actual = np.linspace(0, 5 * spends_actual, 100)
+ x_plot = x_actual * conversion_rate
+
+ # Calculate y values for the S-curve
+ y_plot = [
+ sum(
+ s_curve(
+ (x * channels_proportion[channel]),
+ s_curve_params[channel]["power"],
+ s_curve_params[channel]["K"],
+ s_curve_params[channel]["b"],
+ s_curve_params[channel]["a"],
+ s_curve_params[channel]["x0"],
+ )
+ )
+ + channel_correction
+ for x in x_actual
+ ]
+
+ # Calculate ROI and ensure they are scalar values
+ roi = [float(y) / float(x) if x != 0 else 0 for x, y in zip(x_plot, y_plot)]
+
+ # Calculate MROI using a small nudge
+ nudge = 1e-3
+ mroi = []
+ for i in range(len(x_plot)):
+ x1 = float(x_plot[i])
+ y1 = float(y_plot[i])
+ x2 = x1 + nudge
+ y2 = (
+ sum(
+ s_curve(
+ ((x2 / conversion_rate) * channels_proportion[channel]),
+ s_curve_params[channel]["power"],
+ s_curve_params[channel]["K"],
+ s_curve_params[channel]["b"],
+ s_curve_params[channel]["a"],
+ s_curve_params[channel]["x0"],
+ )
+ )
+ + channel_correction
+ )
+ mroi_value = (float(y2) - y1) / (x2 - x1) if x2 != x1 else 0
+ mroi.append(mroi_value)
+
+ # Channel correction
+ channel_correction = sum(
+ original_scenario_data["channels"][channel]["correction"]
+ )
+
+ # Calculate y, ROI and MROI for the actual spend point
+ roi_actual, mroi_actual, y_actual = get_point_parms(
+ spends_actual,
+ s_curve_params[channel],
+ channels_proportion[channel],
+ conversion_rate,
+ channel_correction,
+ )
+
+ # Create the plotly figure
+ fig = go.Figure()
+
+ # Add S-curve line
+ fig.add_trace(
+ go.Scatter(
+ x=np.array(x_plot) / multiplier,
+ y=np.array(y_plot) / multiplier,
+ mode="lines",
+ name="Metrics",
+ hoverinfo="text",
+ text=[
+ f"Spends: {numerize(x / multiplier)}
{metrics_selected_formatted}: {numerize(y / multiplier)}
ROI: {r:.2f}
MROI: {m:.2f}"
+ for x, y, r, m in zip(x_plot, y_plot, roi, mroi)
+ ],
+ )
+ )
+
+ # Add current spend point
+ fig.add_trace(
+ go.Scatter(
+ x=[spends_actual * conversion_rate / multiplier],
+ y=[y_actual / multiplier],
+ mode="markers",
+ marker=dict(color="cyan", size=10, symbol="circle"),
+ name="Actual Spend",
+ hoverinfo="text",
+ text=[
+ f"Actual Spend: {numerize(spends_actual * conversion_rate / multiplier)}
{metrics_selected_formatted}: {numerize(y_actual / multiplier)}
ROI: {roi_actual:.2f}
MROI: {mroi_actual:.2f}"
+ ],
+ showlegend=True,
+ )
+ )
+
+ # ROI Threshold
+ roi_threshold = st.session_state.roi_threshold
+
+ # Scale x and y values
+ x, y = np.array(x_plot), np.array(y_plot)
+ x_scaled, y_scaled = x / max(x), y / max(y)
+
+ # Calculate MROI scaled starting from the first point
+ mroi_scaled = np.zeros_like(x_scaled)
+ for j in range(1, len(x_scaled)):
+ x1, y1 = x_scaled[j - 1], y_scaled[j - 1]
+ x2, y2 = x_scaled[j], y_scaled[j]
+ mroi_scaled[j] = (y2 - y1) / (x2 - x1) if (x2 - x1) != 0 else 0
+
+ # Get the start_value, end_value, left_value, right_value for segments
+ start_value, end_value, left_value, right_value = find_segment_value(
+ x_plot, np.array(roi), mroi_scaled, roi_threshold, mroi_threshold
+ )
+
+ # Store region start and end points
+ region_start_end[channel] = {
+ "start_value": start_value,
+ "end_value": end_value,
+ "left_value": left_value,
+ "right_value": right_value,
+ }
+
+ # Adding background colors
+ y_max = max(y_plot) * 1.3 # 30% extra space above the max
+
+ # Yellow region
+ fig.add_shape(
+ type="rect",
+ x0=start_value / multiplier,
+ y0=0,
+ x1=left_value / multiplier,
+ y1=y_max / multiplier,
+ line=dict(width=0),
+ fillcolor="rgba(255, 255, 0, 0.3)",
+ layer="below",
+ )
+
+ # Green region
+ fig.add_shape(
+ type="rect",
+ x0=left_value / multiplier,
+ y0=0,
+ x1=right_value / multiplier,
+ y1=y_max / multiplier,
+ line=dict(width=0),
+ fillcolor="rgba(0, 255, 0, 0.3)",
+ layer="below",
+ )
+
+ # Red region
+ fig.add_shape(
+ type="rect",
+ x0=right_value / multiplier,
+ y0=0,
+ x1=end_value / multiplier,
+ y1=y_max / multiplier,
+ line=dict(width=0),
+ fillcolor="rgba(255, 0, 0, 0.3)",
+ layer="below",
+ )
+
+ # Layout adjustments
+ fig.update_layout(
+ title=f"{name_formating(channel)}",
+ showlegend=False,
+ xaxis=dict(
+ showgrid=True,
+ showticklabels=True,
+ tickformat=".2s",
+ gridcolor="lightgrey",
+ gridwidth=0.5,
+ griddash="dot",
+ ),
+ yaxis=dict(
+ showgrid=True,
+ showticklabels=True,
+ tickformat=".2s",
+ gridcolor="lightgrey",
+ gridwidth=0.5,
+ griddash="dot",
+ ),
+ template="plotly_white",
+ margin=dict(l=20, r=20, t=30, b=20),
+ height=100 * (len(channel_list) + 4 - 1) // 4,
+ )
+
+ figures.append(fig)
+
+ # Store data of each channel ROI and MROI
+ channel_roi_mroi[channel] = {
+ "actual_roi": roi_actual,
+ "actual_mroi": mroi_actual,
+ }
+
+ return figures, channel_roi_mroi, region_start_end
+
+
+# Function to add modified spends/metrics point on plot
+def modified_metrics_point(
+ fig,
+ modified_spends,
+ s_curve_params,
+ channels_proportion,
+ conversion_rate,
+ channel_correction,
+):
+ # Calculate ROI, MROI, and y for the modified point
+ roi_modified, mroi_modified, y_modified = get_point_parms(
+ modified_spends,
+ s_curve_params,
+ channels_proportion,
+ conversion_rate,
+ channel_correction,
+ )
+
+ # Add modified spend point
+ fig.add_trace(
+ go.Scatter(
+ x=[modified_spends * conversion_rate / st.session_state["multiplier"]],
+ y=[y_modified / st.session_state["multiplier"]],
+ mode="markers",
+ marker=dict(color="blueviolet", size=10, symbol="circle"),
+ name="Optimized Spend",
+ hoverinfo="text",
+ text=[
+ f"Modified Spend: {numerize(modified_spends * conversion_rate / st.session_state.multiplier)}
{metrics_selected_formatted}: {numerize(y_modified / st.session_state.multiplier)}
ROI: {roi_modified:.2f}
MROI: {mroi_modified:.2f}"
+ ],
+ showlegend=True,
+ )
+ )
+
+ return roi_modified, mroi_modified, fig
+
+
+# Function to update bound type change
+def bound_type_change():
+ # Get updated bound type from session state
+ new_bound_type = st.session_state["bound_type_key"]
+
+ # Load modified scenario data
+ data = st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"]
+
+ # Update the bound type
+ data[metrics_selected][panel_selected]["bound_type"] = new_bound_type
+
+ # Set bounds to default value if bound type is False (Default)
+ channel_list = list(data[metrics_selected][panel_selected]["channels"].keys())
+ if not new_bound_type:
+ for channel in channel_list:
+ data[metrics_selected][panel_selected]["channels"][channel]["bounds"] = [
+ -10,
+ 10,
+ ]
+
+ # Update modified scenario metadata
+ st.session_state["project_dct"]["scenario_planner"]["modified_metadata_file"] = data
+
+
+# Function to format the numbers with decimal
+def format_value(input_value):
+ value = abs(input_value)
+ return f"{input_value:.4f}" if value < 1 else f"{numerize(input_value, 1)}"
+
+
+# Function to format the numbers with decimal
+def round_value(input_value):
+ value = abs(input_value)
+ return round(input_value, 4) if value < 1 else round(input_value, 1)
+
+
+# Function to generate ROI and MROI plots for all channels
+@st.cache_data(show_spinner=False)
+def roi_mori_plot(channel_roi_mroi):
+ # Dictionary to store plots
+ channel_roi_mroi_plot = {}
+ for channel in channel_roi_mroi:
+ channel_roi_mroi_data = channel_roi_mroi[channel]
+ # Extract the data
+ actual_roi = channel_roi_mroi_data["actual_roi"]
+ optimized_roi = channel_roi_mroi_data["optimized_roi"]
+ actual_mroi = channel_roi_mroi_data["actual_mroi"]
+ optimized_mroi = channel_roi_mroi_data["optimized_mroi"]
+
+ # Plot ROI
+ fig_roi = go.Figure()
+ fig_roi.add_trace(
+ go.Bar(
+ x=["Actual ROI"],
+ y=[actual_roi],
+ name="Actual ROI",
+ marker_color="cyan",
+ width=1,
+ text=[format_value(actual_roi)],
+ textposition="auto",
+ textfont=dict(color="black", size=14),
+ )
+ )
+ fig_roi.add_trace(
+ go.Bar(
+ x=["Optimized ROI"],
+ y=[optimized_roi],
+ name="Optimized ROI",
+ marker_color="blueviolet",
+ width=1,
+ text=[format_value(optimized_roi)],
+ textposition="auto",
+ textfont=dict(color="black", size=14),
+ )
+ )
+
+ fig_roi.update_layout(
+ annotations=[
+ dict(
+ x=0.5,
+ y=1.3,
+ xref="paper",
+ yref="paper",
+ text="ROI",
+ showarrow=False,
+ font=dict(size=14),
+ )
+ ],
+ barmode="group",
+ bargap=0,
+ showlegend=False,
+ width=110,
+ height=110,
+ xaxis=dict(
+ showticklabels=True,
+ showgrid=False,
+ tickangle=0,
+ ticktext=["Actual", "Optimized"],
+ tickvals=["Actual ROI", "Optimized ROI"],
+ ),
+ yaxis=dict(showticklabels=False, showgrid=False),
+ margin=dict(t=20, b=20, r=0, l=0),
+ )
+
+ # Plot MROI
+ fig_mroi = go.Figure()
+ fig_mroi.add_trace(
+ go.Bar(
+ x=["Actual MROI"],
+ y=[actual_mroi],
+ name="Actual MROI",
+ marker_color="cyan",
+ width=1,
+ text=[format_value(actual_mroi)],
+ textposition="auto",
+ textfont=dict(color="black", size=14),
+ )
+ )
+ fig_mroi.add_trace(
+ go.Bar(
+ x=["Optimized MROI"],
+ y=[optimized_mroi],
+ name="Optimized MROI",
+ marker_color="blueviolet",
+ width=1,
+ text=[format_value(optimized_mroi)],
+ textposition="auto",
+ textfont=dict(color="black", size=14),
+ )
+ )
+
+ fig_mroi.update_layout(
+ annotations=[
+ dict(
+ x=0.5,
+ y=1.3,
+ xref="paper",
+ yref="paper",
+ text="MROI",
+ showarrow=False,
+ font=dict(size=14),
+ )
+ ],
+ barmode="group",
+ bargap=0,
+ showlegend=False,
+ width=110,
+ height=110,
+ xaxis=dict(
+ showticklabels=True,
+ showgrid=False,
+ tickangle=0,
+ ticktext=["Actual", "Optimized"],
+ tickvals=["Actual MROI", "Optimized MROI"],
+ ),
+ yaxis=dict(showticklabels=False, showgrid=False),
+ margin=dict(t=20, b=20, r=0, l=0),
+ )
+
+ # Store plots
+ channel_roi_mroi_plot[channel] = {"fig_roi": fig_roi, "fig_mroi": fig_mroi}
+
+ return channel_roi_mroi_plot
+
+
+# Function to save the current scenario with the mentioned name
+def save_scenario(
+ scenario_dict,
+ metrics_selected,
+ panel_selected,
+ optimization_goal,
+ channel_roi_mroi,
+ timeframe,
+ multiplier,
+):
+ # Remove extra space at start and ends
+ if st.session_state["scenario_name"] is not None:
+ st.session_state["scenario_name"] = st.session_state["scenario_name"].strip()
+
+ if (
+ st.session_state["scenario_name"] is None
+ or st.session_state["scenario_name"] == ""
+ ):
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Please provide a name to save the scenario.",
+ "icon": "⚠️",
+ }
+ return
+
+ # Check the scenario name
+ if not (
+ word_length_limit_lower
+ <= len(st.session_state["scenario_name"])
+ <= word_length_limit_upper
+ and bool(re.match("^[A-Za-z0-9_]*$", st.session_state["scenario_name"]))
+ ):
+ # Store the warning message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": f"Please provide a valid scenario name ({word_length_limit_lower}-{word_length_limit_upper} characters, only A-Z, a-z, 0-9, and _).",
+ "icon": "⚠️",
+ }
+ return
+
+ # Check if the dictionary is empty
+ if not scenario_dict:
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Nothing to save. The scenario data is empty.",
+ "icon": "⚠️",
+ }
+ return
+
+ # Add additional scenario details
+ scenario_dict["panel_selected"] = panel_selected
+ scenario_dict["metrics_selected"] = metrics_selected
+ scenario_dict["optimization"] = optimization_goal
+ scenario_dict["channel_roi_mroi"] = channel_roi_mroi
+ scenario_dict["timeframe"] = timeframe
+ scenario_dict["multiplier"] = multiplier
+
+ # Load existing scenarios
+ saved_scenarios_dict = st.session_state["project_dct"]["saved_scenarios"][
+ "saved_scenarios_dict"
+ ]
+
+ # Check if the name is already taken
+ if st.session_state["scenario_name"] in saved_scenarios_dict.keys():
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "warning",
+ "message": "Name already exists. Please change the name or delete the existing scenario from the Saved Scenario page.",
+ "icon": "⚠️",
+ }
+ return
+
+ # Update the dictionary with the new scenario
+ saved_scenarios_dict[st.session_state["scenario_name"]] = scenario_dict
+
+ # Update the updated dictionary
+ st.session_state["project_dct"]["saved_scenarios"][
+ "saved_scenarios_dict"
+ ] = saved_scenarios_dict
+
+ # Update DB
+ update_db(
+ prj_id=st.session_state["project_number"],
+ page_nam="Scenario Planner",
+ file_nam="project_dct",
+ pkl_obj=pickle.dumps(st.session_state["project_dct"]),
+ schema=schema,
+ )
+
+ # Store the message details in session state
+ st.session_state.message_display = {
+ "type": "success",
+ "message": f"Scenario '{st.session_state.scenario_name}' has been successfully saved!",
+ "icon": "💾",
+ }
+ st.toast(
+ f"Scenario '{st.session_state.scenario_name}' has been successfully saved!",
+ icon="💾",
+ )
+
+ # Clear the scenario name input
+ st.session_state["scenario_name"] = ""
+
+
+# Function to calculate the RGBA color code based on the spends value and region boundaries
+def calculate_rgba(spends_value, region_start_end):
+ # Get region start and end points
+ start_value = region_start_end["start_value"]
+ end_value = region_start_end["end_value"]
+ left_value = region_start_end["left_value"]
+ right_value = region_start_end["right_value"]
+
+ # Calculate alpha dynamically based on the position within the range
+ def calculate_alpha(position, start, end, min_alpha=0.1, max_alpha=0.4):
+ return min_alpha + (max_alpha - min_alpha) * (position - start) / (end - start)
+
+ if start_value <= spends_value <= left_value:
+ # Yellow range (0, 128, 0) - More transparent towards left, darker towards start
+ alpha = calculate_alpha(spends_value, left_value, start_value)
+ return (255, 255, 0, alpha) # RGB for yellow
+ elif left_value < spends_value <= right_value:
+ # Green range (0, 128, 0) - More transparent towards right, darker towards left
+ alpha = calculate_alpha(spends_value, right_value, left_value)
+ return (0, 128, 0, alpha) # RGB for green
+ elif right_value < spends_value <= end_value:
+ # Red range (255, 0, 0) - More transparent towards right, darker towards end
+ alpha = calculate_alpha(spends_value, right_value, end_value)
+ return (255, 0, 0, alpha) # RGB for red
+
+
+# Function to format and display the channel name with a color and background color
+def display_channel_name_with_background_color(
+ channel_name, background_color=(0, 128, 0, 0.1)
+):
+ formatted_name = name_formating(channel_name)
+
+ # Unpack the RGBA values
+ r, g, b, a = background_color
+
+ # Create the HTML content with specified background color
+ html_content = f"""
+