import streamlit as st import os import re import torch import numpy as np from pathlib import Path import PyPDF2 from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.schema import Document from langchain.embeddings import HuggingFaceEmbeddings # Set page configuration st.set_page_config( page_title="Vision 2030 Virtual Assistant", page_icon="🇸🇦", layout="wide" ) # App title and description st.title("Vision 2030 Virtual Assistant") st.markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.") # Function definitions @st.cache_resource def load_model_and_tokenizer(): """Load the ALLaM-7B model and tokenizer with error handling""" model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" st.info(f"Loading model: {model_name} (this may take a few minutes)") try: # First attempt with AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, use_fast=False ) # Load model with appropriate settings for ALLaM model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) st.success("Model loaded successfully!") except Exception as e: st.error(f"First loading attempt failed: {e}") st.info("Trying alternative loading approach...") # Try with specific tokenizer class if the first attempt fails from transformers import LlamaTokenizer tokenizer = LlamaTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto", ) st.success("Model loaded successfully with LlamaTokenizer!") return model, tokenizer def detect_language(text): """Detect if text is primarily Arabic or English""" arabic_chars = re.findall(r'[\u0600-\u06FF]', text) is_arabic = len(arabic_chars) > len(text) * 0.5 return "arabic" if is_arabic else "english" def process_pdfs(): """Process uploaded PDF documents""" documents = [] if 'uploaded_pdfs' in st.session_state and st.session_state.uploaded_pdfs: for pdf_file in st.session_state.uploaded_pdfs: try: # Save the uploaded file temporarily pdf_path = f"temp_{pdf_file.name}" with open(pdf_path, "wb") as f: f.write(pdf_file.getbuffer()) # Extract text text = "" with open(pdf_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() + "\n\n" # Remove temporary file os.remove(pdf_path) if text.strip(): # If we got some text doc = Document( page_content=text, metadata={"source": pdf_file.name, "filename": pdf_file.name} ) documents.append(doc) st.info(f"Successfully processed: {pdf_file.name}") else: st.warning(f"No text extracted from {pdf_file.name}") except Exception as e: st.error(f"Error processing {pdf_file.name}: {e}") st.success(f"Processed {len(documents)} PDF documents") return documents def create_vector_store(documents): """Split documents into chunks and create a FAISS vector store""" # Text splitter for breaking documents into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] ) # Split documents into chunks chunks = [] for doc in documents: doc_chunks = text_splitter.split_text(doc.page_content) # Preserve metadata for each chunk chunks.extend([ Document(page_content=chunk, metadata=doc.metadata) for chunk in doc_chunks ]) st.info(f"Created {len(chunks)} chunks from {len(documents)} documents") # Create a proper embedding function for LangChain embedding_function = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ) # Create FAISS index vector_store = FAISS.from_documents( chunks, embedding_function ) return vector_store def retrieve_context(query, vector_store, top_k=5): """Retrieve most relevant document chunks for a given query""" # Search the vector store using similarity search results = vector_store.similarity_search_with_score(query, k=top_k) # Format the retrieved contexts contexts = [] for doc, score in results: contexts.append({ "content": doc.page_content, "source": doc.metadata.get("source", "Unknown"), "relevance_score": score }) return contexts def generate_response(query, contexts, model, tokenizer): """Generate a response using retrieved contexts with ALLaM-specific formatting""" # Auto-detect language language = detect_language(query) # Format the prompt based on language if language == "arabic": instruction = ( "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." ) else: # english instruction = ( "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " "If you don't know the answer, honestly say you don't know." ) # Combine retrieved contexts context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) # Format the prompt for ALLaM instruction format prompt = f"""[INST] {instruction} Context: {context_text} Question: {query} [/INST]""" try: with st.spinner("Generating response..."): # Generate response with appropriate parameters for ALLaM inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Generate with appropriate parameters outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.1 ) # Decode the response full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the answer part (after the instruction) response = full_output.split("[/INST]")[-1].strip() # If response is empty for some reason, return the full output if not response: response = full_output return response, [ctx.get("source", "Unknown") for ctx in contexts] except Exception as e: st.error(f"Error during generation: {e}") # Fallback response return "I apologize, but I encountered an error while generating a response.", [] # Initialize the app state if 'conversation_history' not in st.session_state: st.session_state.conversation_history = [] if 'vector_store' not in st.session_state: st.session_state.vector_store = None if 'uploaded_pdfs' not in st.session_state: st.session_state.uploaded_pdfs = None # PDF upload section st.header("1. Upload Vision 2030 Documents") uploaded_files = st.file_uploader("Upload PDF documents about Vision 2030", type=["pdf"], accept_multiple_files=True, help="Upload one or more PDF documents containing information about Vision 2030") if uploaded_files: st.session_state.uploaded_pdfs = uploaded_files if st.button("Process PDFs"): documents = process_pdfs() if documents: with st.spinner("Creating vector database..."): st.session_state.vector_store = create_vector_store(documents) st.success("Vector database created successfully!") # Load the model (cached) model, tokenizer = load_model_and_tokenizer() # Chat interface st.header("2. Chat with the Vision 2030 Assistant") # Display conversation history for message in st.session_state.conversation_history: if message["role"] == "user": st.markdown(f"**You:** {message['content']}") else: st.markdown(f"**Assistant:** {message['content']}") if 'sources' in message and message['sources']: st.markdown(f"*Sources: {', '.join([os.path.basename(src) for src in message['sources']])}*") st.divider() # Input for new question user_input = st.text_input("Ask a question about Vision 2030 (in Arabic or English):", key="user_query") # Examples st.markdown("**Example questions:**") examples_col1, examples_col2 = st.columns(2) with examples_col1: st.markdown("- What is Saudi Vision 2030?") st.markdown("- What are the economic goals of Vision 2030?") st.markdown("- How does Vision 2030 support women's empowerment?") with examples_col2: st.markdown("- ما هي رؤية السعودية 2030؟") st.markdown("- ما هي الأهداف الاقتصادية لرؤية 2030؟") st.markdown("- كيف تدعم رؤية 2030 تمكين المرأة السعودية؟") # Process the user input if user_input and st.session_state.vector_store: # Add user message to history st.session_state.conversation_history.append({"role": "user", "content": user_input}) # Get response response, sources = generate_response(user_input, retrieve_context(user_input, st.session_state.vector_store), model, tokenizer) # Add assistant message to history st.session_state.conversation_history.append({"role": "assistant", "content": response, "sources": sources}) # Rerun to update the UI st.experimental_rerun() elif user_input and not st.session_state.vector_store: st.warning("Please upload and process Vision 2030 PDF documents first") # Reset conversation button if st.button("Reset Conversation") and len(st.session_state.conversation_history) > 0: st.session_state.conversation_history = [] st.experimental_rerun()