Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import json | |
from tqdm import tqdm | |
from pathlib import Path | |
import spaces | |
import gradio as gr | |
# WARNING: Don't import torch, cuda, or GPU-related modules at the top level | |
# They must ONLY be imported inside functions decorated with @spaces.GPU | |
# Helper functions that don't use GPU | |
def safe_tokenize(text): | |
"""Pure regex tokenizer with no NLTK dependency""" | |
if not text: | |
return [] | |
# Replace punctuation with spaces around them | |
text = re.sub(r'([.,!?;:()\[\]{}"\'/\\])', r' \1 ', text) | |
# Split on whitespace and filter empty strings | |
return [token for token in re.split(r'\s+', text.lower()) if token] | |
def detect_language(text): | |
"""Detect if text is primarily Arabic or English""" | |
# Simple heuristic: count Arabic characters | |
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) | |
is_arabic = len(arabic_chars) > len(text) * 0.5 | |
return "arabic" if is_arabic else "english" | |
# Comprehensive evaluation dataset | |
comprehensive_evaluation_data = [ | |
# === Overview === | |
{ | |
"query": "ما هي رؤية السعودية 2030؟", | |
"reference": "رؤية السعودية 2030 هي خطة استراتيجية تهدف إلى تنويع الاقتصاد السعودي وتقليل الاعتماد على النفط مع تطوير قطاعات مختلفة مثل الصحة والتعليم والسياحة.", | |
"category": "overview", | |
"language": "arabic" | |
}, | |
{ | |
"query": "What is Saudi Vision 2030?", | |
"reference": "Saudi Vision 2030 is a strategic framework aiming to diversify Saudi Arabia's economy and reduce dependence on oil, while developing sectors like health, education, and tourism.", | |
"category": "overview", | |
"language": "english" | |
}, | |
# === Economic Goals === | |
{ | |
"query": "ما هي الأهداف الاقتصادية لرؤية 2030؟", | |
"reference": "تشمل الأهداف الاقتصادية زيادة مساهمة القطاع الخاص إلى 65%، وزيادة الصادرات غير النفطية إلى 50% من الناتج المحلي غير النفطي، وخفض البطالة إلى 7%.", | |
"category": "economic", | |
"language": "arabic" | |
}, | |
{ | |
"query": "What are the economic goals of Vision 2030?", | |
"reference": "The economic goals of Vision 2030 include increasing private sector contribution from 40% to 65% of GDP, raising non-oil exports from 16% to 50%, reducing unemployment from 11.6% to 7%.", | |
"category": "economic", | |
"language": "english" | |
}, | |
# === Social Goals === | |
{ | |
"query": "كيف تعزز رؤية 2030 الإرث الثقافي السعودي؟", | |
"reference": "تتضمن رؤية 2030 الحفاظ على الهوية الوطنية، تسجيل مواقع أثرية في اليونسكو، وتعزيز الفعاليات الثقافية.", | |
"category": "social", | |
"language": "arabic" | |
}, | |
{ | |
"query": "How does Vision 2030 aim to improve quality of life?", | |
"reference": "Vision 2030 plans to enhance quality of life by expanding sports facilities, promoting cultural activities, and boosting tourism and entertainment sectors.", | |
"category": "social", | |
"language": "english" | |
} | |
] | |
# RAG Service class | |
class Vision2030Service: | |
def __init__(self): | |
self.initialized = False | |
self.model = None | |
self.tokenizer = None | |
self.vector_store = None | |
self.conversation_history = [] | |
def initialize(self): | |
"""Initialize the system - ALL GPU operations must happen here""" | |
if self.initialized: | |
return True | |
try: | |
# Import all GPU-dependent libraries only inside this function | |
import torch | |
import PyPDF2 | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# Define paths for PDF files | |
pdf_files = ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"] | |
# Process PDFs and create vector store | |
vector_store_dir = "vector_stores" | |
os.makedirs(vector_store_dir, exist_ok=True) | |
if os.path.exists(os.path.join(vector_store_dir, "index.faiss")): | |
print("Loading existing vector store...") | |
embedding_function = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
self.vector_store = FAISS.load_local(vector_store_dir, embedding_function) | |
else: | |
print("Creating new vector store...") | |
# Process PDFs | |
documents = [] | |
for pdf_path in pdf_files: | |
if not os.path.exists(pdf_path): | |
print(f"Warning: {pdf_path} does not exist") | |
continue | |
print(f"Processing {pdf_path}...") | |
text = "" | |
with open(pdf_path, 'rb') as file: | |
reader = PyPDF2.PdfReader(file) | |
for page in reader.pages: | |
page_text = page.extract_text() | |
if page_text: | |
text += page_text + "\n\n" | |
if text.strip(): | |
doc = Document( | |
page_content=text, | |
metadata={"source": pdf_path, "filename": os.path.basename(pdf_path)} | |
) | |
documents.append(doc) | |
if not documents: | |
raise ValueError("No documents were processed successfully.") | |
# Split into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] | |
) | |
chunks = [] | |
for doc in documents: | |
doc_chunks = text_splitter.split_text(doc.page_content) | |
chunks.extend([ | |
Document(page_content=chunk, metadata=doc.metadata) | |
for chunk in doc_chunks | |
]) | |
# Create vector store | |
embedding_function = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
self.vector_store = FAISS.from_documents(chunks, embedding_function) | |
self.vector_store.save_local(vector_store_dir) | |
# Load model | |
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_fast=False | |
) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
self.initialized = True | |
return True | |
except Exception as e: | |
import traceback | |
print(f"Initialization error: {e}") | |
print(traceback.format_exc()) | |
return False | |
def retrieve_context(self, query, top_k=5): | |
"""Retrieve contexts from vector store""" | |
# Import must be inside the function to avoid CUDA init in main process | |
if not self.initialized: | |
return [] | |
try: | |
results = self.vector_store.similarity_search_with_score(query, k=top_k) | |
contexts = [] | |
for doc, score in results: | |
contexts.append({ | |
"content": doc.page_content, | |
"source": doc.metadata.get("source", "Unknown"), | |
"relevance_score": score | |
}) | |
return contexts | |
except Exception as e: | |
print(f"Error retrieving context: {e}") | |
return [] | |
def generate_response(self, query, contexts, language="auto"): | |
"""Generate response using the model""" | |
# Import must be inside the function to avoid CUDA init in main process | |
import torch | |
if not self.initialized or self.model is None or self.tokenizer is None: | |
return "I'm still initializing. Please try again in a moment." | |
try: | |
# Auto-detect language if not specified | |
if language == "auto": | |
language = detect_language(query) | |
# Format the prompt based on language | |
if language == "arabic": | |
instruction = ( | |
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " | |
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." | |
) | |
else: # english | |
instruction = ( | |
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " | |
"If you don't know the answer, honestly say you don't know." | |
) | |
# Combine retrieved contexts | |
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) | |
# Format the prompt for ALLaM instruction format | |
prompt = f"""<s>[INST] {instruction} | |
Context: | |
{context_text} | |
Question: {query} [/INST]</s>""" | |
# Generate response | |
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) | |
outputs = self.model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.1 | |
) | |
# Decode the response | |
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the answer part (after the instruction) | |
response = full_output.split("[/INST]")[-1].strip() | |
# If response is empty for some reason, return the full output | |
if not response: | |
response = full_output | |
return response | |
except Exception as e: | |
import traceback | |
print(f"Error generating response: {e}") | |
print(traceback.format_exc()) | |
return f"Sorry, I encountered an error while generating a response." | |
def answer_question(self, query): | |
"""Process a user query and return a response with sources""" | |
if not self.initialized: | |
if not self.initialize(): | |
return "System initialization failed. Please check the logs.", [] | |
try: | |
# Add user query to conversation history | |
self.conversation_history.append({"role": "user", "content": query}) | |
# Get the full conversation context | |
conversation_context = "\n".join([ | |
f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" | |
for msg in self.conversation_history[-6:] # Keep last 3 turns | |
]) | |
# Enhance query with conversation context | |
enhanced_query = f"{conversation_context}\n{query}" | |
# Retrieve relevant contexts | |
contexts = self.retrieve_context(enhanced_query, top_k=5) | |
# Generate response | |
response = self.generate_response(query, contexts) | |
# Add response to conversation history | |
self.conversation_history.append({"role": "assistant", "content": response}) | |
# Get sources | |
sources = [ctx.get("source", "Unknown") for ctx in contexts] | |
unique_sources = list(set(sources)) | |
return response, unique_sources | |
except Exception as e: | |
import traceback | |
print(f"Error answering question: {e}") | |
print(traceback.format_exc()) | |
return f"Sorry, I encountered an error: {str(e)}", [] | |
def reset_conversation(self): | |
"""Reset the conversation history""" | |
self.conversation_history = [] | |
return "Conversation has been reset." | |
# Main function with Gradio UI | |
def main(): | |
# Create the Vision 2030 service | |
service = Vision2030Service() | |
# Build the Gradio interface | |
with gr.Blocks(title="Vision 2030 Assistant") as demo: | |
gr.Markdown("# Vision 2030 Assistant") | |
gr.Markdown("Ask questions about Saudi Vision 2030 in English or Arabic") | |
with gr.Tab("Chat"): | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Your question", placeholder="Ask about Vision 2030...") | |
clear = gr.Button("Clear History") | |
def respond(message, history): | |
if not message: | |
return history, "" | |
response, sources = service.answer_question(message) | |
sources_text = ", ".join(sources) if sources else "No specific sources" | |
# Format the response to include sources | |
full_response = f"{response}\n\nSources: {sources_text}" | |
return history + [[message, full_response]], "" | |
def reset_chat(): | |
service.reset_conversation() | |
return [], "Conversation history has been reset." | |
msg.submit(respond, [msg, chatbot], [chatbot, msg]) | |
clear.click(reset_chat, None, [chatbot, msg]) | |
with gr.Tab("System Status"): | |
init_btn = gr.Button("Initialize System") | |
status_box = gr.Textbox(label="Status", value="System not initialized") | |
def initialize_system(): | |
success = service.initialize() | |
if success: | |
return "System initialized successfully!" | |
else: | |
return "System initialization failed. Check logs for details." | |
init_btn.click(initialize_system, None, status_box) | |
# PDF Check section | |
gr.Markdown("### PDF Status") | |
pdf_btn = gr.Button("Check PDF Files") | |
pdf_status = gr.Textbox(label="PDF Files") | |
def check_pdfs(): | |
result = [] | |
for pdf_file in ["saudi_vision203.pdf", "saudi_vision2030_ar.pdf"]: | |
if os.path.exists(pdf_file): | |
size = os.path.getsize(pdf_file) / (1024 * 1024) # Size in MB | |
result.append(f"{pdf_file}: Found ({size:.2f} MB)") | |
else: | |
result.append(f"{pdf_file}: Not found") | |
return "\n".join(result) | |
pdf_btn.click(check_pdfs, None, pdf_status) | |
# System check section | |
gr.Markdown("### Dependencies") | |
sys_btn = gr.Button("Check Dependencies") | |
sys_status = gr.Textbox(label="Dependencies Status") | |
def check_dependencies(): | |
result = [] | |
# Safe imports inside GPU-decorated function | |
try: | |
import torch | |
result.append(f"✓ PyTorch: {torch.__version__}") | |
except ImportError: | |
result.append("✗ PyTorch: Not installed") | |
try: | |
import transformers | |
result.append(f"✓ Transformers: {transformers.__version__}") | |
except ImportError: | |
result.append("✗ Transformers: Not installed") | |
try: | |
import sentencepiece | |
result.append("✓ SentencePiece: Installed") | |
except ImportError: | |
result.append("✗ SentencePiece: Not installed") | |
try: | |
import accelerate | |
result.append(f"✓ Accelerate: {accelerate.__version__}") | |
except ImportError: | |
result.append("✗ Accelerate: Not installed") | |
try: | |
import langchain | |
result.append(f"✓ LangChain: {langchain.__version__}") | |
except ImportError: | |
result.append("✗ LangChain: Not installed") | |
try: | |
import langchain_community | |
result.append(f"✓ LangChain Community: {langchain_community.__version__}") | |
except ImportError: | |
result.append("✗ LangChain Community: Not installed") | |
return "\n".join(result) | |
sys_btn.click(check_dependencies, None, sys_status) | |
with gr.Tab("Sample Questions"): | |
gr.Markdown("### Sample Questions to Try") | |
sample_questions = [] | |
for item in comprehensive_evaluation_data: | |
sample_questions.append(item["query"]) | |
questions_md = "\n".join([f"- {q}" for q in sample_questions]) | |
gr.Markdown(questions_md) | |
return demo | |
if __name__ == "__main__": | |
demo = main() | |
demo.queue() | |
demo.launch() |