import os import shutil import streamlit as st from dotenv import load_dotenv from llama_index.core import ( VectorStoreIndex, Settings, StorageContext, load_index_from_storage, ) from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.llms.groq import Groq import pandas as pd from llama_index.core import Document PERSIST_DIR = "./storage" EMBED_MODEL = "./all-MiniLM-L6-v2" LLM_MODEL = "llama3-8b-8192" CSV_FILE_PATH = "shl_assessments.csv" GROQ_API_KEY = st.secrets["GROQ_API_KEY"] or os.getenv("GROQ_API_KEY") def load_data_from_csv(csv_path): """Loads assessment data from a CSV file.""" try: df = pd.read_csv(csv_path) required_columns = ["Assessment Name", "URL", "Remote Testing Support", "Adaptive/IRT Support", "Duration (min)", "Test Type"] if not all(col in df.columns for col in required_columns): raise ValueError(f"CSV file must contain columns: {', '.join(required_columns)}") return df.to_dict(orient="records") except FileNotFoundError: raise FileNotFoundError(f"Error: CSV file not found at {csv_path}") except ValueError as e: raise ValueError(f"Error reading CSV: {e}") except Exception as e: raise Exception(f"An unexpected error occurred while loading CSV data: {e}") def load_groq_llm(): try: api_key = st.secrets.get("GROQ_API_KEY") or os.getenv("GROQ_API_KEY") except KeyError: raise ValueError("GROQ_API_KEY not found in Streamlit secrets.") return Groq(model=LLM_MODEL, api_key=api_key, temperature=0.1) def load_embeddings(): return HuggingFaceEmbedding(model_name="all-MiniLM-L6-v2") def build_index(data): """Builds the vector index from the provided assessment data.""" return HuggingFaceEmbedding(model_name=EMBED_MODEL) Settings.llm = load_groq_llm() documents = [Document(text=f"Name: {item['Assessment Name']}, URL: {item['URL']}, Remote Testing: {item['Remote Testing Support']}, Adaptive/IRT: {item['Adaptive/IRT Support']}, Duration: {item['Duration (min)']}, Type: {item['Test Type']}") for item in data] index = VectorStoreIndex.from_documents(documents) index.storage_context.persist(persist_dir=PERSIST_DIR) return index def load_chat_engine(): """Loads the chat engine from the persisted index.""" if not os.path.exists(PERSIST_DIR): return None Settings.embed_model = load_embeddings() Settings.llm = load_groq_llm() storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR) index = load_index_from_storage(storage_context) return index.as_chat_engine(chat_mode="context", verbose=True) def reset_index(): """Resets the persisted index and chat history.""" try: shutil.rmtree(PERSIST_DIR, ignore_errors=True) st.success("Knowledge index reset successfully!") st.session_state.messages = [{"role": "assistant", "content": "Hello! I'm your SHL assessment assistant. How can I help you?"}] st.session_state["index_built"] = False if 'chat_engine' in st.session_state: del st.session_state['chat_engine'] return None except Exception as e: st.error(f"Error resetting index: {str(e)}") return None def main(): st.set_page_config( page_title="SHL Assessment Chatbot", layout="wide", initial_sidebar_state="collapsed" ) st.markdown(""" """, unsafe_allow_html=True) load_dotenv() os.environ["STREAMLIT_SERVER_ENABLE_FILE_WATCHER"] = "false" os.environ["TORCH_DISABLE_STREAMLIT_WATCHER"] = "1" os.environ["LLAMA_INDEX_DISABLE_OPENAI"] = "1" if "messages" not in st.session_state: st.session_state.messages = [{ "role": "assistant", "content": "Hello! I'm your SHL assessment assistant. How can I help you?" }] if "index_built" not in st.session_state: st.session_state["index_built"] = False if not st.session_state["index_built"]: try: with st.spinner("Loading data and building index..."): assessment_data = load_data_from_csv(CSV_FILE_PATH) if assessment_data: build_index(assessment_data) st.session_state['chat_engine'] = load_chat_engine() st.session_state["index_built"] = True else: st.error("Failed to load assessment data. Please check the CSV file.") except Exception as e: st.error(f"Error initializing application: {e}") # --- Chat Interface --- chat_engine = st.session_state.get('chat_engine') if chat_engine: for msg in st.session_state.messages: icon = "🤖" if msg["role"] == "assistant" else "👤" with st.chat_message(msg["role"]): st.markdown(f"{icon} {msg['content']}", unsafe_allow_html=True) if prompt := st.chat_input("Ask me about SHL assessments..."): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(f"👤 {prompt}", unsafe_allow_html=True) with st.chat_message("assistant"): try: # Add formatting instructions to the prompt formatted_prompt = f""" {prompt} Please provide a list of all matching SHL assessments (minimum 1, maximum 10). For each matching assessment, follow this exact format: • Assessment Name: [Name] URL: [URL] Remote Testing Support: [Yes/No] Adaptive/IRT Support: [Yes/No] Duration: [Duration in minutes] Test Type: [Test Type] If there are no matches, clearly state that. Respond in a clean, readable bullet-point format.Do not use any "+" signs. Do not return JSON or markdown tables. Do not bold anything. """ response = chat_engine.chat(formatted_prompt) st.markdown(f"🤖 {response.response}", unsafe_allow_html=True) st.session_state.messages.append({"role": "assistant", "content": response.response}) except Exception as e: st.error(f"An error occurred during chat: {e}") else: st.info("💬 Chat is ready! Ask me anything about SHL assessments.") if __name__ == "__main__": main()