Spaces:
Running
on
Zero
Running
on
Zero
# Import necessary libraries | |
import gradio as gr | |
import time | |
import logging | |
import os | |
import re | |
from datetime import datetime | |
import numpy as np | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer, util | |
import faiss | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
import PyPDF2 | |
import io | |
import spaces # Added for @spaces.GPU decorator | |
# Set up logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger('Vision2030Assistant') | |
# Check for GPU availability | |
has_gpu = torch.cuda.is_available() | |
logger.info(f"GPU available: {has_gpu}") | |
# Define the Vision2030Assistant class | |
class Vision2030Assistant: | |
def __init__(self): | |
"""Initialize the Vision 2030 Assistant with models, knowledge base, and indices.""" | |
logger.info("Initializing Vision 2030 Assistant...") | |
self.load_embedding_models() | |
self.load_language_model() | |
self._create_knowledge_base() | |
self._create_indices() | |
self._create_sample_eval_data() | |
self.metrics = {"response_times": [], "user_ratings": [], "factual_accuracy": []} | |
self.session_history = {} # Dictionary to store session history | |
self.has_pdf_content = False # Flag to indicate if PDF content is available | |
logger.info("Assistant initialized successfully") | |
def load_embedding_models(self): | |
"""Load Arabic and English embedding models on CPU.""" | |
try: | |
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') | |
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Models remain on CPU; GPU usage handled in decorated functions | |
logger.info("Embedding models loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load embedding models: {e}") | |
self._fallback_embedding() | |
def _fallback_embedding(self): | |
"""Fallback method for embedding models using a simple random vector approach.""" | |
logger.warning("Using fallback embedding method") | |
class SimpleEmbedder: | |
def encode(self, text, device=None): # Added device parameter for compatibility | |
import hashlib | |
hash_obj = hashlib.md5(text.encode()) | |
np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32) | |
return np.random.randn(384).astype(np.float32) | |
self.arabic_embedder = SimpleEmbedder() | |
self.english_embedder = SimpleEmbedder() | |
def load_language_model(self): | |
"""Load the DistilGPT-2 language model on CPU.""" | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2") | |
self.model = AutoModelForCausalLM.from_pretrained("distilgpt2") | |
self.generator = pipeline( | |
'text-generation', | |
model=self.model, | |
tokenizer=self.tokenizer, | |
device=-1 # CPU | |
) | |
logger.info("Language model loaded successfully") | |
except Exception as e: | |
logger.error(f"Failed to load language model: {e}") | |
self.generator = None | |
def _create_knowledge_base(self): | |
"""Initialize the knowledge base with basic Vision 2030 information.""" | |
self.english_texts = [ | |
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.", | |
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.", | |
"NEOM is a planned smart city in Tabuk Province, a key Vision 2030 project." | |
] | |
self.arabic_texts = [ | |
"رؤية 2030 هي إطار استراتيجي لتقليل الاعتماد على النفط وتنويع الاقتصاد.", | |
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.", | |
"نيوم مدينة ذكية مخططة في تبوك، مشروع رئيسي لرؤية 2030." | |
] | |
self.pdf_english_texts = [] | |
self.pdf_arabic_texts = [] | |
def _create_indices(self): | |
"""Create FAISS indices for the initial knowledge base on CPU.""" | |
try: | |
# English index | |
english_vectors = [self.english_embedder.encode(text) for text in self.english_texts] | |
dim = len(english_vectors[0]) | |
nlist = max(1, len(english_vectors) // 10) | |
quantizer = faiss.IndexFlatL2(dim) | |
self.english_index = faiss.IndexIVFFlat(quantizer, dim, nlist) | |
self.english_index.train(np.array(english_vectors)) | |
self.english_index.add(np.array(english_vectors)) | |
# Arabic index | |
arabic_vectors = [self.arabic_embedder.encode(text) for text in self.arabic_texts] | |
self.arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist) | |
self.arabic_index.train(np.array(arabic_vectors)) | |
self.arabic_index.add(np.array(arabic_vectors)) | |
logger.info("FAISS indices created successfully") | |
except Exception as e: | |
logger.error(f"Error creating indices: {e}") | |
def _create_sample_eval_data(self): | |
"""Create sample evaluation data for testing factual accuracy.""" | |
self.eval_data = [ | |
{"question": "What are the key pillars of Vision 2030?", | |
"lang": "en", | |
"reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."}, | |
{"question": "ما هي الركائز الرئيسية لرؤية 2030؟", | |
"lang": "ar", | |
"reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."} | |
] | |
def retrieve_context(self, query, lang, session_id, device='cpu'): | |
"""Retrieve relevant context using the specified device for encoding.""" | |
try: | |
history = self.session_history.get(session_id, []) | |
history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]]) | |
embedder = self.arabic_embedder if lang == "ar" else self.english_embedder | |
query_vec = embedder.encode(query, device=device) | |
if lang == "ar": | |
if self.has_pdf_content and self.pdf_arabic_texts: | |
index = self.pdf_arabic_index | |
texts = self.pdf_arabic_texts | |
else: | |
index = self.arabic_index | |
texts = self.arabic_texts | |
else: | |
if self.has_pdf_content and self.pdf_english_texts: | |
index = self.pdf_english_index | |
texts = self.pdf_english_texts | |
else: | |
index = self.english_index | |
texts = self.english_texts | |
D, I = index.search(np.array([query_vec]), k=2) | |
context = "\n".join([texts[i] for i in I[0] if i >= 0]) + f"\nHistory: {history_context}" | |
return context if context.strip() else "No relevant information found." | |
except Exception as e: | |
logger.error(f"Retrieval error: {e}") | |
return "Error retrieving context." | |
def generate_response(self, query, session_id): | |
"""Generate a response using GPU resources when available.""" | |
if not query.strip(): | |
return "Please enter a valid question." | |
start_time = time.time() | |
try: | |
lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en" | |
context = self.retrieve_context(query, lang, session_id, device='cuda') | |
if "Error" in context or "No relevant" in context: | |
reply = context | |
elif self.generator: | |
# Move the language model to GPU | |
self.generator.model.to('cuda') | |
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:" | |
response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7) | |
reply = response[0]['generated_text'].split("Answer:")[-1].strip() | |
# Move the language model back to CPU | |
self.generator.model.to('cpu') | |
else: | |
reply = context | |
self.session_history.setdefault(session_id, []).append((query, reply)) | |
self.metrics["response_times"].append(time.time() - start_time) | |
return reply | |
except Exception as e: | |
logger.error(f"Response generation error: {e}") | |
return "Sorry, an error occurred. Please try again." | |
def evaluate_factual_accuracy(self, response, reference): | |
"""Evaluate the factual accuracy of a response using semantic similarity.""" | |
try: | |
embedder = self.english_embedder # Assuming reference is in English for simplicity | |
response_vec = embedder.encode(response) | |
reference_vec = embedder.encode(reference) | |
similarity = util.cos_sim(response_vec, reference_vec).item() | |
return similarity | |
except Exception as e: | |
logger.error(f"Evaluation error: {e}") | |
return 0.0 | |
def process_pdf(self, file): | |
"""Process a PDF file and update the knowledge base using GPU for encoding.""" | |
if not file: | |
return "Please upload a PDF file." | |
try: | |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file)) | |
text = "".join([page.extract_text() or "" for page in pdf_reader.pages]) | |
if not text.strip(): | |
return "No extractable text found in PDF." | |
# Split text into chunks | |
chunks = [text[i:i+300] for i in range(0, len(text), 300)] | |
self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)] | |
self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)] | |
# Create indices for PDF content using GPU | |
if self.pdf_english_texts: | |
english_vectors = [self.english_embedder.encode(text, device='cuda') for text in self.pdf_english_texts] | |
dim = len(english_vectors[0]) | |
nlist = max(1, len(english_vectors) // 10) | |
quantizer = faiss.IndexFlatL2(dim) | |
self.pdf_english_index = faiss.IndexIVFFlat(quantizer, dim, nlist) | |
self.pdf_english_index.train(np.array(english_vectors)) | |
self.pdf_english_index.add(np.array(english_vectors)) | |
if self.pdf_arabic_texts: | |
arabic_vectors = [self.arabic_embedder.encode(text, device='cuda') for text in self.pdf_arabic_texts] | |
dim = len(arabic_vectors[0]) | |
nlist = max(1, len(arabic_vectors) // 10) | |
quantizer = faiss.IndexFlatL2(dim) | |
self.pdf_arabic_index = faiss.IndexIVFFlat(quantizer, dim, nlist) | |
self.pdf_arabic_index.train(np.array(arabic_vectors)) | |
self.pdf_arabic_index.add(np.array(arabic_vectors)) | |
self.has_pdf_content = True | |
return f"PDF processed: {len(self.pdf_english_texts)} English, {len(self.pdf_arabic_texts)} Arabic chunks." | |
except Exception as e: | |
logger.error(f"PDF processing error: {e}") | |
return f"Error processing PDF: {e}" | |
# Create the Gradio interface | |
def create_interface(): | |
"""Set up the Gradio interface for chatting and PDF uploading.""" | |
assistant = Vision2030Assistant() | |
def chat(query, history, session_id): | |
reply = assistant.generate_response(query, session_id) | |
history.append((query, reply)) | |
return history, "" | |
with gr.Blocks() as demo: | |
gr.Markdown("# Vision 2030 Virtual Assistant") | |
session_id = gr.State(value="user1") # Fixed session ID for simplicity | |
chatbot = gr.Chatbot() | |
msg = gr.Textbox(label="Ask a question") | |
submit = gr.Button("Submit") | |
pdf_upload = gr.File(label="Upload PDF", type="binary") | |
upload_status = gr.Textbox(label="Upload Status") | |
submit.click(chat, [msg, chatbot, session_id], [chatbot, msg]) | |
pdf_upload.upload(assistant.process_pdf, pdf_upload, upload_status) | |
return demo | |
# Launch the interface | |
if __name__ == "__main__": | |
demo = create_interface() | |
demo.launch() |