import streamlit as st st.set_page_config( page_title="Response Curves", page_icon="⚖️", layout="wide", initial_sidebar_state="collapsed", ) # Disable +/- for number input st.markdown( """ """, unsafe_allow_html=True, ) import sys import json import pickle import traceback import numpy as np import pandas as pd import plotly.express as px import plotly.graph_objects as go from post_gres_cred import db_cred from sklearn.metrics import r2_score from log_application import log_message from utilities import project_selection, update_db, set_header, load_local_css from utilities import ( get_panels_names, get_metrics_names, name_formating, generate_rcs_data, load_rcs_metadata_files, ) schema = db_cred["schema"] load_local_css("styles.css") set_header() # Initialize project name session state if "project_name" not in st.session_state: st.session_state["project_name"] = None # Fetch project dictionary if "project_dct" not in st.session_state: project_selection() st.stop() # Display Username and Project Name if "username" in st.session_state and st.session_state["username"] is not None: cols1 = st.columns([2, 1]) with cols1[0]: st.markdown(f"**Welcome {st.session_state['username']}**") with cols1[1]: st.markdown(f"**Current Project: {st.session_state['project_name']}**") # Function to build s curve def s_curve(x, K, b, a, x0): return K / (1 + b * np.exp(-a * (x - x0))) # Function to update the RCS parameters in the modified RCS metadata data def modify_rcs_parameters(metrics_selected, panel_selected, channel_selected): # Define unique keys for each parameter based on the selection K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" # Retrieve the updated parameters from session state K_updated, b_updated, a_updated, x0_updated = ( st.session_state[K_key], st.session_state[b_key], st.session_state[a_key], st.session_state[x0_key], ) # Load the existing modified RCS data rcs_data_modified = st.session_state["project_dct"]["response_curves"][ "modified_metadata_file" ] # Update the RCS parameters for the selected metric and panel rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { "K": K_updated, "b": b_updated, "a": a_updated, "x0": x0_updated, } # Function to reset the parameters to their default values def reset_parameters( metrics_selected, panel_selected, channel_selected, original_channel_data ): # Define unique keys for each parameter based on the selection K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" # Reset session state values to original data del st.session_state[K_key] del st.session_state[b_key] del st.session_state[a_key] del st.session_state[x0_key] # Reset the modified metadata file with original parameters rcs_data_modified = st.session_state["project_dct"]["response_curves"][ "modified_metadata_file" ] # Update the parameters in the modified data to the original values rcs_data_modified[metrics_selected][panel_selected][channel_selected] = { "K": original_channel_data["K"], "b": original_channel_data["b"], "a": original_channel_data["a"], "x0": original_channel_data["x0"], } # Update the modified metadata st.session_state["project_dct"]["response_curves"][ "modified_metadata_file" ] = rcs_data_modified # Function to generate updated RCS parameter DataFrame @st.cache_data(show_spinner=False) def updated_parm_gen(original_data, modified_data, metrics_selected, panel_selected): # Retrieve the data for the selected metric and panel original_data_selection = original_data[metrics_selected][panel_selected] modified_data_selection = modified_data[metrics_selected][panel_selected] # Initialize an empty list to hold the data for the DataFrame data = [] # Iterate through each channel in the selected metric and panel for channel in original_data_selection: # Extract original parameters K_o, b_o, a_o, x0_o = ( original_data_selection[channel]["K"], original_data_selection[channel]["b"], original_data_selection[channel]["a"], original_data_selection[channel]["x0"], ) # Extract modified parameters K_m, b_m, a_m, x0_m = ( modified_data_selection[channel]["K"], modified_data_selection[channel]["b"], modified_data_selection[channel]["a"], modified_data_selection[channel]["x0"], ) # Check if any parameters differ if (K_o != K_m) or (b_o != b_m) or (a_o != a_m) or (x0_o != x0_m): # Append the data to the list only if there is a difference data.append( { "Metric": name_formating(metrics_selected), "Panel": name_formating(panel_selected), "Channel": name_formating(channel), "K (Original)": K_o, "b (Original)": b_o, "a (Original)": a_o, "x0 (Original)": x0_o, "K (Modified)": K_m, "b (Modified)": b_m, "a (Modified)": a_m, "x0 (Modified)": x0_m, } ) # Create a DataFrame from the collected data df = pd.DataFrame(data) return df # Function to create JSON file for RCS data @st.cache_data(show_spinner=False) def create_json_file(): return json.dumps( st.session_state["project_dct"]["response_curves"]["modified_metadata_file"], indent=4, ) try: # Page Title st.title("Response Curves") # Retrieve the list of all metric names from the specified directory metrics_list = get_metrics_names() # Check if there are any metrics available in the metrics list if not metrics_list: # Display a warning message to the user if no metrics are found st.warning( "Please tune at least one model to generate response curves data.", icon="⚠️", ) # Log message log_message( "warning", "Please tune at least one model to generate response curves data.", "Response Curves", ) # Stop further execution as there is no data to process st.stop() # Widget columns metric_col, channel_col, panel_col, save_progress_col = st.columns(4) # Metrics Selection metrics_selected = metric_col.selectbox( "Response Metrics", sorted(metrics_list), format_func=name_formating, key="response_metrics_selectbox", index=0, ) # Retrieve the list of all panel names for specified Metrics panel_list = get_panels_names(metrics_selected) # Panel Selection panel_selected = panel_col.selectbox( "Panel", sorted(panel_list), format_func=name_formating, key="panel_selected_selectbox", index=0, ) # Save Progress with save_progress_col: st.write("####") # Padding save_progress_placeholder = st.empty() # Placeholder to display message and spinner message_spinner_placeholder = st.container() # Save page progress with message_spinner_placeholder, st.spinner("Saving Progress ..."): if save_progress_placeholder.button("Save Progress", use_container_width=True): # Update DB update_db( prj_id=st.session_state["project_number"], page_nam="Response Curves", file_nam="project_dct", pkl_obj=pickle.dumps(st.session_state["project_dct"]), schema=schema, ) # Store the message details in session state message_spinner_placeholder.success( "Progress saved successfully!", icon="💾" ) st.toast("Progress saved successfully!", icon="💾") # Log message log_message("info", "Progress saved successfully!", "Response Curves") # Check if the RCS metadata file does not exist if ( st.session_state["project_dct"]["response_curves"]["original_metadata_file"] is None or st.session_state["project_dct"]["response_curves"]["modified_metadata_file"] is None ): # RCS metadata file does not exist. Generating new RCS data generate_rcs_data() # Log message log_message( "info", "RCS metadata file does not exist. Generating new RCS data.", "Response Curves", ) # Load metadata files if they exist original_data, modified_data = load_rcs_metadata_files() # Retrieve the list of all channels names for specified Metrics and Panel chanel_list_final = list(original_data[metrics_selected][panel_selected].keys()) # Channel Selection channel_selected = channel_col.selectbox( "Channel", sorted(chanel_list_final), format_func=name_formating, key="selected_channel_name_selectbox", ) # Extract original channel data for the selected metric, panel, and channel original_channel_data = original_data[metrics_selected][panel_selected][ channel_selected ] # Extract modified channel data for the same metric, panel, and channel modified_channel_data = modified_data[metrics_selected][panel_selected][ channel_selected ] # X and Y values for plotting x = original_channel_data["x"] y = original_channel_data["y"] # Scaling factor for X values and range for S-curve plotting power = original_channel_data["power"] x_plot = original_channel_data["x_plot"] # Original S-curve parameters K_orig = original_channel_data["K"] b_orig = original_channel_data["b"] a_orig = original_channel_data["a"] x0_orig = original_channel_data["x0"] # Modified S-curve parameters (user-adjusted) K_mod = modified_channel_data["K"] b_mod = modified_channel_data["b"] a_mod = modified_channel_data["a"] x0_mod = modified_channel_data["x0"] # Create a scatter plot for the original data points fig = px.scatter( x=x, y=y, title="Original and Modified S-Curve Plot", labels={"x": "Spends", "y": name_formating(metrics_selected)}, ) # Add the modified S-curve trace fig.add_trace( go.Scatter( x=x_plot, y=s_curve( np.array(x_plot) / 10**power, K_mod, b_mod, a_mod, x0_mod, ), line=dict(color="red"), name="Modified", ), ) # Add the original S-curve trace fig.add_trace( go.Scatter( x=x_plot, y=s_curve( np.array(x_plot) / 10**power, K_orig, b_orig, a_orig, x0_orig, ), line=dict(color="rgba(0, 255, 0, 0.6)"), # Semi-transparent green name="Original", ), ) # Customize the layout of the plot fig.update_layout( title="Comparison of Original and Modified Response-Curves", xaxis_title="Input (Clicks, Impressions, etc..)", yaxis_title=name_formating(metrics_selected), legend_title="Curve Type", ) # Display s-curve st.plotly_chart(fig, use_container_width=True) # Calculate R-squared for the original curve y_orig_pred = s_curve(np.array(x) / 10**power, K_orig, b_orig, a_orig, x0_orig) r2_orig = r2_score(y, y_orig_pred) # Calculate R-squared for the modified curve y_mod_pred = s_curve(np.array(x) / 10**power, K_mod, b_mod, a_mod, x0_mod) r2_mod = r2_score(y, y_mod_pred) # Calculate the difference in R-squared r2_diff = r2_mod - r2_orig # Display R-squared metrics st.write("## R-squared Comparison") r2_col = st.columns(3) r2_col[0].metric("R-squared (Original)", f"{r2_orig:.2f}") r2_col[1].metric("R-squared (Modified)", f"{r2_mod:.2f}") r2_col[2].metric("Difference in R-squared", f"{r2_diff:.2f}") # Define unique keys for each parameter based on the selection K_key = f"K_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" b_key = f"b_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" a_key = f"a_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" x0_key = f"x0_updated_key_{metrics_selected}_{panel_selected}_{channel_selected}" # Initialize session state keys if they do not exist if K_key not in st.session_state: st.session_state[K_key] = K_mod if b_key not in st.session_state: st.session_state[b_key] = b_mod if a_key not in st.session_state: st.session_state[a_key] = a_mod if x0_key not in st.session_state: st.session_state[x0_key] = x0_mod # RCS parameters input rsc_ip_col = st.columns(4) with rsc_ip_col[0]: K_updated = st.number_input( "K", step=0.001, min_value=0.000000, format="%.6f", on_change=modify_rcs_parameters, args=(metrics_selected, panel_selected, channel_selected), key=K_key, ) with rsc_ip_col[1]: b_updated = st.number_input( "b", step=0.001, min_value=0.000000, format="%.6f", on_change=modify_rcs_parameters, args=(metrics_selected, panel_selected, channel_selected), key=b_key, ) with rsc_ip_col[2]: a_updated = st.number_input( "a", step=0.001, min_value=0.000000, format="%.6f", on_change=modify_rcs_parameters, args=(metrics_selected, panel_selected, channel_selected), key=a_key, ) with rsc_ip_col[3]: x0_updated = st.number_input( "x0", step=0.001, min_value=0.000000, format="%.6f", on_change=modify_rcs_parameters, args=(metrics_selected, panel_selected, channel_selected), key=x0_key, ) # Create columns for Reset and Download buttons reset_download_col = st.columns(2) with reset_download_col[0]: if st.button( "Reset", use_container_width=True, ): reset_parameters( metrics_selected, panel_selected, channel_selected, original_channel_data, ) # Log message log_message( "info", f"METRIC: {name_formating(metrics_selected)} ; PANEL: {name_formating(panel_selected)}, CHANNEL: {name_formating(channel_selected)} has been reset to its original value.", "Response Curves", ) st.rerun() with reset_download_col[1]: # Provide a download button for the modified RCS data try: # Create JSON file for RCS data json_data = create_json_file() st.download_button( label="Download", data=json_data, file_name=f"{name_formating(metrics_selected)}_{name_formating(panel_selected)}_rcs_data.json", mime="application/json", use_container_width=True, ) except: # Download failed pass # Generate the DataFrame showing only non-matching parameters updated_parm_df = updated_parm_gen( original_data, modified_data, metrics_selected, panel_selected ) # Display the DataFrame or show an informational message if no updates if not updated_parm_df.empty: st.write("## Parameter Comparison for Selected Metric and Panel") st.dataframe(updated_parm_df, hide_index=True) else: st.info("No parameters are updated for the selected Metric and Panel") except Exception as e: # Capture the error details exc_type, exc_value, exc_traceback = sys.exc_info() error_message = "".join( traceback.format_exception(exc_type, exc_value, exc_traceback) ) # Log message log_message("error", f"An error occurred: {error_message}.", "Response Curves") # Display a warning message st.warning( "Oops! Something went wrong. Please try refreshing the tool or creating a new project.", icon="⚠️", )