import os import re import json from tqdm import tqdm from pathlib import Path import spaces import gradio as gr # WARNING: Don't import torch, cuda, or GPU-related modules at the top level # They must ONLY be imported inside functions decorated with @spaces.GPU # Helper functions that don't use GPU def safe_tokenize(text): """Pure regex tokenizer with no NLTK dependency""" if not text: return [] # Replace punctuation with spaces around them text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text) # Split on whitespace and filter empty strings return [token for token in re.split(r'\s+', text.lower()) if token] def detect_language(text): """Detect if text is primarily Arabic or English""" # Simple heuristic: count Arabic characters arabic_chars = re.findall(r'[\u0600-\u06FF]', text) is_arabic = len(arabic_chars) > len(text) * 0.5 return "arabic" if is_arabic else "english" # Comprehensive evaluation dataset comprehensive_evaluation_data = [ # === Overview === { "query": "ما هي رؤية السعودية 2030؟", "reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", "category": "overview", "language": "arabic" }, { "query": "What is Saudi Vision 2030?", "reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", "category": "overview", "language": "english" }, # === Economic Goals === { "query": "ما هي الأهداف الاقتصادية لرؤية 2030؟", "reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.", "category": "economic", "language": "arabic" }, { "query": "What are the economic goals of Vision 2030?", "reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.", "category": "economic", "language": "english" }, # === Social Goals === { "query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", "reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.", "category": "social", "language": "arabic" }, { "query": "How does Vision 2030 aim to improve quality of life?", "reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.", "category": "social", "language": "english" } ] # RAG Service class class Vision2030Service: def __init__(self): self.initialized = False self.model = None self.tokenizer = None self.vector_store = None self.conversation_history = [] @spaces.GPU def initialize(self): """Initialize the system - ALL GPU operations must happen here""" if self.initialized: return True try: # Import all GPU-dependent libraries only inside this function import torch import PyPDF2 from transformers import AutoTokenizer, AutoModelForCausalLM from sentence_transformers import SentenceTransformer from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.schema import Document from langchain.embeddings import HuggingFaceEmbeddings # Define paths for PDF files pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"] # Process PDFs and create vector store vector_store_dir = "vector_stores" os.makedirs(vector_store_dir, exist_ok=True) if os.path.exists(os.path.join(vector_store_dir, "index.faiss")): print("Loading existing vector store...") embedding_function = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ) # Important: Add allow_dangerous_deserialization=True to fix the pickle error self.vector_store = FAISS.load_local( vector_store_dir, embedding_function, allow_dangerous_deserialization=True # Add this parameter ) else: print("Creating new vector store...") # Process PDFs documents = [] for pdf_path in pdf_files: if not os.path.exists(pdf_path): print(f"Warning: {pdf_path} does not exist") continue print(f"Processing {pdf_path}...") text = "" with open(pdf_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: page_text = page.extract_text() if page_text: text += page_text + "\n\n" if text.strip(): doc = Document( page_content=text, metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)} ) documents.append(doc) if not documents: raise ValueError("No documents were processed successfully.") # Split into chunks text_splitter = RecursiveCharacterTextSplitter( chunk_size=500, chunk_overlap=50, separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] ) chunks = [] for doc in documents: doc_chunks = text_splitter.split_text(doc.page_content) chunks.extend([ Document(page_content=chunk, metadata=doc.metadata) for chunk in doc_chunks ]) # Create vector store embedding_function = HuggingFaceEmbeddings( model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" ) self.vector_store = FAISS.from_documents(chunks, embedding_function) self.vector_store.save_local(vector_store_dir) # Load model model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, use_fast=False ) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto", ) self.initialized = True return True except Exception as e: import traceback print(f"Initialization error: {e}") print(traceback.format_exc()) return False @spaces.GPU def retrieve_context(self, query, top_k=5): """Retrieve contexts from vector store""" # Import must be inside the function to avoid CUDA init in main process if not self.initialized: return [] try: results = self.vector_store.similarity_search_with_score(query, k=top_k) contexts = [] for doc, score in results: contexts.append({ "content": doc.page_content, "source": doc.metadata.get("source", "Unknown"), "relevance_score": score }) return contexts except Exception as e: print(f"Error retrieving context: {e}") return [] @spaces.GPU def generate_response(self, query, contexts, language="auto"): """Generate response using the model""" # Import must be inside the function to avoid CUDA init in main process import torch if not self.initialized or self.model is None or self.tokenizer is None: return "I'm still initializing. Please try again in a moment." try: # Auto-detect language if not specified if language == "auto": language = detect_language(query) # Format the prompt based on language if language == "arabic": instruction = ( "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." ) else: # english instruction = ( "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " "If you don't know the answer, honestly say you don't know." ) # Combine retrieved contexts context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) # Format the prompt for ALLaM instruction format prompt = f"""[INST] {instruction} Context: {context_text} Question: {query} [/INST]""" # Generate response inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) outputs = self.model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True, repetition_penalty=1.1 ) # Decode the response full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract just the answer part (after the instruction) response = full_output.split("[/INST]")[-1].strip() # If response is empty for some reason, return the full output if not response: response = full_output return response except Exception as e: import traceback print(f"Error generating response: {e}") print(traceback.format_exc()) return f"Sorry, I encountered an error while generating a response." @spaces.GPU def answer_question(self, query): """Process a user query and return a response with sources""" if not self.initialized: if not self.initialize(): return "System initialization failed. Please check the logs.", [] try: # Add user query to conversation history self.conversation_history.append({"role": "user", "content": query}) # Get the full conversation context conversation_context = "\n".join([ f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" for msg in self.conversation_history[-6:] # Keep last 3 turns ]) # Enhance query with conversation context enhanced_query = f"{conversation_context}\n{query}" # Retrieve relevant contexts contexts = self.retrieve_context(enhanced_query, top_k=5) # Generate response response = self.generate_response(query, contexts) # Add response to conversation history self.conversation_history.append({"role": "assistant", "content": response}) # Get sources sources = [ctx.get("source", "Unknown") for ctx in contexts] unique_sources = list(set(sources)) return response, unique_sources except Exception as e: import traceback print(f"Error answering question: {e}") print(traceback.format_exc()) return f"Sorry, I encountered an error: {str(e)}", [] def reset_conversation(self): """Reset the conversation history""" self.conversation_history = [] return "Conversation has been reset." # Main function with Gradio UI def main(): # Create the Vision 2030 service service = Vision2030Service() # Build the Gradio interface with gr.Blocks(title="Vision 2030 Assistant") as demo: gr.Markdown("# Vision 2030 Assistant") gr.Markdown("Ask questions about Saudi Vision 2030 in English or Arabic") with gr.Tab("Chat"): chatbot = gr.Chatbot() msg = gr.Textbox(label="Your question", placeholder="Ask about Vision 2030...") clear = gr.Button("Clear History") @spaces.GPU def respond(message, history): if not message: return history, "" response, sources = service.answer_question(message) sources_text = ", ".join(sources) if sources else "No specific sources" # Format the response to include sources full_response = f"{response}\n\nSources: {sources_text}" return history + [[message, full_response]], "" def reset_chat(): service.reset_conversation() return [], "Conversation history has been reset." msg.submit(respond, [msg, chatbot], [chatbot, msg]) clear.click(reset_chat, None, [chatbot, msg]) with gr.Tab("System Status"): init_btn = gr.Button("Initialize System") status_box = gr.Textbox(label="Status", value="System not initialized") @spaces.GPU def initialize_system(): success = service.initialize() if success: return "System initialized successfully!" else: return "System initialization failed. Check logs for details." init_btn.click(initialize_system, None, status_box) # PDF Check section gr.Markdown("### PDF Status") pdf_btn = gr.Button("Check PDF Files") pdf_status = gr.Textbox(label="PDF Files") def check_pdfs(): result = [] for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]: if os.path.exists(pdf_file): size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB result.append(f"{pdf_file}: Found ({size:.2f} MB)") else: result.append(f"{pdf_file}: Not found") return "\n".join(result) pdf_btn.click(check_pdfs, None, pdf_status) # System check section gr.Markdown("### Dependencies") sys_btn = gr.Button("Check Dependencies") sys_status = gr.Textbox(label="Dependencies Status") @spaces.GPU def check_dependencies(): result = [] # Safe imports inside GPU-decorated function try: import torch result.append(f"✓ PyTorch: {torch.__version__}") except ImportError: result.append("✗ PyTorch: Not installed") try: import transformers result.append(f"✓ Transformers: {transformers.__version__}") except ImportError: result.append("✗ Transformers: Not installed") try: import sentencepiece result.append("✓ SentencePiece: Installed") except ImportError: result.append("✗ SentencePiece: Not installed") try: import accelerate result.append(f"✓ Accelerate: {accelerate.__version__}") except ImportError: result.append("✗ Accelerate: Not installed") try: import langchain result.append(f"✓ LangChain: {langchain.__version__}") except ImportError: result.append("✗ LangChain: Not installed") try: import langchain_community result.append(f"✓ LangChain Community: {langchain_community.__version__}") except ImportError: result.append("✗ LangChain Community: Not installed") return "\n".join(result) sys_btn.click(check_dependencies, None, sys_status) with gr.Tab("Sample Questions"): gr.Markdown("### Sample Questions to Try") sample_questions = [] for item in comprehensive_evaluation_data: sample_questions.append(item["query"]) questions_md = "\n".join([f"- {q}" for q in sample_questions]) gr.Markdown(questions_md) # Add a button to auto-initialize the system when viewing sample questions auto_init_btn = gr.Button("Initialize System First") auto_init_status = gr.Textbox(label="Initialization Status") auto_init_btn.click(initialize_system, None, auto_init_status) return demo if __name__ == "__main__": demo = main() demo.queue() demo.launch()