Spaces:
Running
on
Zero
Running
on
Zero
# Minimal working Vision 2030 Virtual Assistant | |
import gradio as gr | |
import time | |
import logging | |
import os | |
import re | |
from datetime import datetime | |
import numpy as np | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from sklearn.metrics import precision_recall_fscore_support, accuracy_score | |
import PyPDF2 | |
import io | |
import json | |
from langdetect import detect | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import torch | |
import spaces | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler()] | |
) | |
logger = logging.getLogger('vision2030_assistant') | |
# Check for GPU availability | |
has_gpu = torch.cuda.is_available() | |
logger.info(f"GPU available: {has_gpu}") | |
class Vision2030Assistant: | |
def __init__(self): | |
"""Initialize the Vision 2030 Assistant with basic knowledge""" | |
logger.info("Initializing Vision 2030 Assistant...") | |
# Initialize embedding models | |
self.load_embedding_models() | |
# Create data | |
self._create_knowledge_base() | |
self._create_indices() | |
# Create sample evaluation data | |
self._create_sample_eval_data() | |
# Initialize metrics | |
self.metrics = { | |
"response_times": [], | |
"user_ratings": [], | |
"factual_accuracy": [] | |
} | |
self.response_history = [] | |
# Flag for PDF content | |
self.has_pdf_content = False | |
logger.info("Vision 2030 Assistant initialized successfully") | |
def load_embedding_models(self): | |
"""Load embedding models for retrieval""" | |
logger.info("Loading embedding models...") | |
try: | |
# Load embedding models | |
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca') | |
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') | |
# Move to GPU if available | |
if has_gpu: | |
self.arabic_embedder = self.arabic_embedder.to('cuda') | |
self.english_embedder = self.english_embedder.to('cuda') | |
logger.info("Models moved to GPU") | |
logger.info("Embedding models loaded successfully") | |
except Exception as e: | |
logger.error(f"Error loading embedding models: {str(e)}") | |
self._create_fallback_embedders() | |
def _create_fallback_embedders(self): | |
"""Create fallback embedding methods if model loading fails""" | |
logger.warning("Using fallback embedding methods") | |
# Simple fallback using character-level encoding | |
def simple_encode(text, dim=384): | |
import hashlib | |
# Create a hash of the text | |
hash_object = hashlib.md5(text.encode()) | |
# Use the hash to seed a random number generator | |
np.random.seed(int(hash_object.hexdigest(), 16) % 2**32) | |
# Generate a random vector | |
return np.random.randn(dim).astype(np.float32) | |
# Create embedding function objects | |
class SimpleEmbedder: | |
def __init__(self, dim=384): | |
self.dim = dim | |
def encode(self, text): | |
return simple_encode(text, self.dim) | |
self.arabic_embedder = SimpleEmbedder() | |
self.english_embedder = SimpleEmbedder() | |
def _create_knowledge_base(self): | |
"""Create knowledge base with Vision 2030 information""" | |
logger.info("Creating Vision 2030 knowledge base") | |
# English texts | |
self.english_texts = [ | |
"Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.", | |
"The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.", | |
"Vision 2030 targets increasing the private sector's contribution to GDP from 40% to 65%.", | |
"NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030.", | |
"Vision 2030 aims to increase women's participation in the workforce from 22% to 30%.", | |
"The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast.", | |
"Qiddiya is an entertainment mega-project being built in Riyadh as part of Vision 2030.", | |
"The real wealth of Saudi Arabia, as emphasized in Vision 2030, is its people, particularly the youth.", | |
"Saudi Arabia aims to strengthen its position as a global gateway by leveraging its strategic location between Asia, Europe, and Africa.", | |
"Vision 2030 aims to have at least five Saudi universities among the top 200 universities in international rankings.", | |
"Vision 2030 sets a target of having at least 10 Saudi sites registered on the UNESCO World Heritage List.", | |
"Vision 2030 aims to increase the capacity to welcome Umrah visitors from 8 million to 30 million annually.", | |
"Vision 2030 includes multiple initiatives to strengthen Saudi national identity including cultural programs and heritage preservation.", | |
"Vision 2030 aims to increase non-oil government revenue from SAR 163 billion to SAR 1 trillion." | |
] | |
# Arabic texts | |
self.arabic_texts = [ | |
"رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة.", | |
"الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.", | |
"تستهدف رؤية 2030 زيادة مساهمة القطاع الخاص في الناتج المحلي الإجمالي من 40٪ إلى 65٪.", | |
"نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030.", | |
"تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪.", | |
"مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي.", | |
"القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030.", | |
"الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب.", | |
"تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا.", | |
"تهدف رؤية 2030 إلى أن تكون خمس جامعات سعودية على الأقل ضمن أفضل 200 جامعة في التصنيفات الدولية.", | |
"تضع رؤية 2030 هدفًا بتسجيل ما لا يقل عن 10 مواقع سعودية في قائمة التراث العالمي لليونسكو.", | |
"تهدف رؤية 2030 إلى زيادة القدرة على استقبال المعتمرين من 8 ملايين إلى 30 مليون معتمر سنويًا.", | |
"تتضمن رؤية 2030 مبادرات متعددة لتعزيز الهوية الوطنية السعودية بما في ذلك البرامج الثقافية والحفاظ على التراث.", | |
"تهدف رؤية 2030 إلى زيادة الإيرادات الحكومية غير النفطية من 163 مليار ريال سعودي إلى 1 تريليون ريال سعودي." | |
] | |
# Initialize PDF content containers | |
self.pdf_english_texts = [] | |
self.pdf_arabic_texts = [] | |
logger.info(f"Created knowledge base: {len(self.english_texts)} English, {len(self.arabic_texts)} Arabic texts") | |
def _create_indices(self): | |
"""Create FAISS indices for text retrieval""" | |
logger.info("Creating FAISS indices for text retrieval") | |
try: | |
# Process and embed English texts | |
self.english_vectors = [] | |
for text in self.english_texts: | |
try: | |
if has_gpu and hasattr(self.english_embedder, 'to'): | |
with torch.no_grad(): | |
vec = self.english_embedder.encode(text) | |
else: | |
vec = self.english_embedder.encode(text) | |
self.english_vectors.append(vec) | |
except Exception as e: | |
logger.error(f"Error encoding English text: {str(e)}") | |
# Use a random vector as fallback | |
self.english_vectors.append(np.random.randn(384).astype(np.float32)) | |
# Create English index | |
if self.english_vectors: | |
self.english_index = faiss.IndexFlatL2(len(self.english_vectors[0])) | |
self.english_index.add(np.array(self.english_vectors)) | |
logger.info(f"Created English index with {len(self.english_vectors)} vectors") | |
else: | |
logger.warning("No English texts to index") | |
# Process and embed Arabic texts | |
self.arabic_vectors = [] | |
for text in self.arabic_texts: | |
try: | |
if has_gpu and hasattr(self.arabic_embedder, 'to'): | |
with torch.no_grad(): | |
vec = self.arabic_embedder.encode(text) | |
else: | |
vec = self.arabic_embedder.encode(text) | |
self.arabic_vectors.append(vec) | |
except Exception as e: | |
logger.error(f"Error encoding Arabic text: {str(e)}") | |
# Use a random vector as fallback | |
self.arabic_vectors.append(np.random.randn(384).astype(np.float32)) | |
# Create Arabic index | |
if self.arabic_vectors: | |
self.arabic_index = faiss.IndexFlatL2(len(self.arabic_vectors[0])) | |
self.arabic_index.add(np.array(self.arabic_vectors)) | |
logger.info(f"Created Arabic index with {len(self.arabic_vectors)} vectors") | |
else: | |
logger.warning("No Arabic texts to index") | |
# Create PDF indices if PDF content exists | |
if hasattr(self, 'pdf_english_texts') and self.pdf_english_texts: | |
self._create_pdf_indices() | |
except Exception as e: | |
logger.error(f"Error creating FAISS indices: {str(e)}") | |
def _create_pdf_indices(self): | |
"""Create indices for PDF content""" | |
if not self.pdf_english_texts and not self.pdf_arabic_texts: | |
return | |
logger.info("Creating indices for PDF content") | |
try: | |
# Process and embed English PDF texts | |
if self.pdf_english_texts: | |
self.pdf_english_vectors = [] | |
for text in self.pdf_english_texts: | |
try: | |
if has_gpu and hasattr(self.english_embedder, 'to'): | |
with torch.no_grad(): | |
vec = self.english_embedder.encode(text) | |
else: | |
vec = self.english_embedder.encode(text) | |
self.pdf_english_vectors.append(vec) | |
except Exception as e: | |
logger.error(f"Error encoding English PDF text: {str(e)}") | |
continue | |
if self.pdf_english_vectors: | |
self.pdf_english_index = faiss.IndexFlatL2(len(self.pdf_english_vectors[0])) | |
self.pdf_english_index.add(np.array(self.pdf_english_vectors)) | |
logger.info(f"Created English PDF index with {len(self.pdf_english_vectors)} vectors") | |
# Process and embed Arabic PDF texts | |
if self.pdf_arabic_texts: | |
self.pdf_arabic_vectors = [] | |
for text in self.pdf_arabic_texts: | |
try: | |
if has_gpu and hasattr(self.arabic_embedder, 'to'): | |
with torch.no_grad(): | |
vec = self.arabic_embedder.encode(text) | |
else: | |
vec = self.arabic_embedder.encode(text) | |
self.pdf_arabic_vectors.append(vec) | |
except Exception as e: | |
logger.error(f"Error encoding Arabic PDF text: {str(e)}") | |
continue | |
if self.pdf_arabic_vectors: | |
self.pdf_arabic_index = faiss.IndexFlatL2(len(self.pdf_arabic_vectors[0])) | |
self.pdf_arabic_index.add(np.array(self.pdf_arabic_vectors)) | |
logger.info(f"Created Arabic PDF index with {len(self.pdf_arabic_vectors)} vectors") | |
# Set flag to indicate PDF content is available | |
self.has_pdf_content = True | |
except Exception as e: | |
logger.error(f"Error creating PDF indices: {str(e)}") | |
def _create_sample_eval_data(self): | |
"""Create sample evaluation data with ground truth""" | |
self.eval_data = [ | |
{ | |
"question": "What are the key pillars of Vision 2030?", | |
"lang": "en", | |
"reference_answer": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation." | |
}, | |
{ | |
"question": "ما هي الركائز الرئيسية لرؤية 2030؟", | |
"lang": "ar", | |
"reference_answer": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." | |
}, | |
{ | |
"question": "What is NEOM?", | |
"lang": "en", | |
"reference_answer": "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030." | |
}, | |
{ | |
"question": "ما هو مشروع البحر الأحمر؟", | |
"lang": "ar", | |
"reference_answer": "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي." | |
}, | |
{ | |
"question": "ما هي الثروة الحقيقية التي تعتز بها المملكة كما وردت في الرؤية؟", | |
"lang": "ar", | |
"reference_answer": "الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب." | |
}, | |
{ | |
"question": "كيف تسعى المملكة إلى تعزيز مكانتها كبوابة للعالم؟", | |
"lang": "ar", | |
"reference_answer": "تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا." | |
} | |
] | |
logger.info(f"Created {len(self.eval_data)} sample evaluation examples") | |
def retrieve_context(self, query, lang): | |
"""Retrieve relevant context with priority to PDF content""" | |
start_time = time.time() | |
try: | |
# First check if we have PDF content | |
if self.has_pdf_content: | |
# Try to retrieve from PDF content first | |
if lang == "ar" and hasattr(self, 'pdf_arabic_index') and hasattr(self, 'pdf_arabic_vectors') and len(self.pdf_arabic_vectors) > 0: | |
if has_gpu and hasattr(self.arabic_embedder, 'to'): | |
with torch.no_grad(): | |
query_vec = self.arabic_embedder.encode(query) | |
else: | |
query_vec = self.arabic_embedder.encode(query) | |
D, I = self.pdf_arabic_index.search(np.array([query_vec]), k=2) | |
# If we found good matches in the PDF | |
if D[0][0] < 1.5: # Threshold for relevance | |
context = "\n".join([self.pdf_arabic_texts[i] for i in I[0] if i < len(self.pdf_arabic_texts) and i >= 0]) | |
if context.strip(): | |
logger.info("Retrieved context from PDF (Arabic)") | |
return context | |
elif lang == "en" and hasattr(self, 'pdf_english_index') and hasattr(self, 'pdf_english_vectors') and len(self.pdf_english_vectors) > 0: | |
if has_gpu and hasattr(self.english_embedder, 'to'): | |
with torch.no_grad(): | |
query_vec = self.english_embedder.encode(query) | |
else: | |
query_vec = self.english_embedder.encode(query) | |
D, I = self.pdf_english_index.search(np.array([query_vec]), k=2) | |
# If we found good matches in the PDF | |
if D[0][0] < 1.5: # Threshold for relevance | |
context = "\n".join([self.pdf_english_texts[i] for i in I[0] if i < len(self.pdf_english_texts) and i >= 0]) | |
if context.strip(): | |
logger.info("Retrieved context from PDF (English)") | |
return context | |
# Fall back to the pre-built knowledge base | |
if lang == "ar": | |
if has_gpu and hasattr(self.arabic_embedder, 'to'): | |
with torch.no_grad(): | |
query_vec = self.arabic_embedder.encode(query) | |
else: | |
query_vec = self.arabic_embedder.encode(query) | |
D, I = self.arabic_index.search(np.array([query_vec]), k=2) | |
context = "\n".join([self.arabic_texts[i] for i in I[0] if i < len(self.arabic_texts) and i >= 0]) | |
else: | |
if has_gpu and hasattr(self.english_embedder, 'to'): | |
with torch.no_grad(): | |
query_vec = self.english_embedder.encode(query) | |
else: | |
query_vec = self.english_embedder.encode(query) | |
D, I = self.english_index.search(np.array([query_vec]), k=2) | |
context = "\n".join([self.english_texts[i] for i in I[0] if i < len(self.english_texts) and i >= 0]) | |
retrieval_time = time.time() - start_time | |
logger.info(f"Retrieved context in {retrieval_time:.2f}s") | |
return context | |
except Exception as e: | |
logger.error(f"Error retrieving context: {str(e)}") | |
return "" | |
def generate_response(self, user_input): | |
"""Generate response based on user input""" | |
if not user_input or user_input.strip() == "": | |
return "" | |
start_time = time.time() | |
# Default response in case of failure | |
default_response = { | |
"en": "I apologize, but I couldn't process your request properly. Please try again.", | |
"ar": "أعتذر، لم أتمكن من معالجة طلبك بشكل صحيح. الرجاء المحاولة مرة أخرى." | |
} | |
try: | |
# Detect language | |
try: | |
lang = detect(user_input) | |
if lang != "ar": # Simplify to just Arabic vs non-Arabic | |
lang = "en" | |
except: | |
lang = "en" # Default fallback | |
logger.info(f"Detected language: {lang}") | |
# Check for specific question patterns | |
if lang == "ar": | |
# National identity | |
if "الهوية الوطنية" in user_input or "تعزيز الهوية" in user_input: | |
reply = "تتضمن رؤية 2030 مبادرات متعددة لتعزيز الهوية الوطنية السعودية بما في ذلك البرامج الثقافية والحفاظ على التراث وتعزيز القيم السعودية." | |
# Hajj and Umrah | |
elif "المعتمرين" in user_input or "الحجاج" in user_input or "العمرة" in user_input or "الحج" in user_input: | |
reply = "تهدف رؤية 2030 إلى زيادة القدرة على استقبال المعتمرين من 8 ملايين إلى 30 مليون معتمر سنويًا." | |
# Economic diversification | |
elif "تنويع مصادر الدخل" in user_input or "الاقتصاد المزدهر" in user_input or "تنمية الاقتصاد" in user_input: | |
reply = "تهدف رؤية 2030 إلى زيادة الإيرادات الحكومية غير النفطية من 163 مليار ريال سعودي إلى 1 تريليون ريال سعودي من خلال تطوير قطاعات متنوعة مثل السياحة والتصنيع والطاقة المتجددة." | |
# UNESCO sites | |
elif "المواقع الأثرية" in user_input or "اليونسكو" in user_input or "التراث العالمي" in user_input: | |
reply = "تضع رؤية 2030 هدفًا بتسجيل ما لا يقل عن 10 مواقع سعودية في قائمة التراث العالمي لليونسكو." | |
# Real wealth | |
elif "الثروة الحقيقية" in user_input or "أثمن" in user_input or "ثروة" in user_input: | |
reply = "الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب." | |
# Global gateway | |
elif "بوابة للعالم" in user_input or "مكانتها" in user_input or "موقعها الاستراتيجي" in user_input: | |
reply = "تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا." | |
# Key pillars | |
elif "ركائز" in user_input or "اركان" in user_input: | |
reply = "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." | |
# General Vision 2030 | |
elif "ما هي" in user_input or "ماهي" in user_input: | |
reply = "رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة. الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح." | |
else: | |
# Use retrieved context | |
context = self.retrieve_context(user_input, lang) | |
reply = context if context else "لم أتمكن من العثور على معلومات كافية حول هذا السؤال." | |
else: # English | |
# Use retrieved context | |
context = self.retrieve_context(user_input, lang) | |
reply = context if context else "I couldn't find enough information about this question." | |
# Record response time | |
response_time = time.time() - start_time | |
self.metrics["response_times"].append(response_time) | |
logger.info(f"Generated response in {response_time:.2f}s") | |
# Store the interaction for later evaluation | |
interaction = { | |
"timestamp": datetime.now().isoformat(), | |
"user_input": user_input, | |
"response": reply, | |
"language": lang, | |
"response_time": response_time | |
} | |
self.response_history.append(interaction) | |
return reply | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
return default_response.get(lang, default_response["en"]) | |
def evaluate_factual_accuracy(self, response, reference): | |
"""Simple evaluation of factual accuracy by keyword matching""" | |
# This is a simplified approach - in production, use more sophisticated methods | |
keywords_reference = set(re.findall(r'\b\w+\b', reference.lower())) | |
keywords_response = set(re.findall(r'\b\w+\b', response.lower())) | |
# Remove common stopwords (simplified approach) | |
english_stopwords = {"the", "is", "a", "an", "and", "or", "of", "to", "in", "for", "with", "by", "on", "at"} | |
arabic_stopwords = {"في", "من", "إلى", "على", "و", "هي", "هو", "عن", "مع"} | |
keywords_reference = {w for w in keywords_reference if w not in english_stopwords and w not in arabic_stopwords} | |
keywords_response = {w for w in keywords_response if w not in english_stopwords and w not in arabic_stopwords} | |
common_keywords = keywords_reference.intersection(keywords_response) | |
if len(keywords_reference) > 0: | |
accuracy = len(common_keywords) / len(keywords_reference) | |
else: | |
accuracy = 0 | |
return accuracy | |
def evaluate_on_test_set(self): | |
"""Evaluate the assistant on the test set""" | |
logger.info("Running evaluation on test set") | |
eval_results = [] | |
for example in self.eval_data: | |
# Generate response | |
response = self.generate_response(example["question"]) | |
# Calculate factual accuracy | |
accuracy = self.evaluate_factual_accuracy(response, example["reference_answer"]) | |
eval_results.append({ | |
"question": example["question"], | |
"reference": example["reference_answer"], | |
"response": response, | |
"factual_accuracy": accuracy | |
}) | |
self.metrics["factual_accuracy"].append(accuracy) | |
# Calculate average factual accuracy | |
avg_accuracy = sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0 | |
avg_response_time = sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0 | |
results = { | |
"average_factual_accuracy": avg_accuracy, | |
"average_response_time": avg_response_time, | |
"detailed_results": eval_results | |
} | |
logger.info(f"Evaluation results: Factual accuracy = {avg_accuracy:.2f}, Avg response time = {avg_response_time:.2f}s") | |
return results | |
def visualize_evaluation_results(self, results): | |
"""Generate visualization of evaluation results""" | |
# Create a DataFrame from the detailed results | |
df = pd.DataFrame(results["detailed_results"]) | |
# Create the figure for visualizations | |
fig = plt.figure(figsize=(12, 8)) | |
# Bar chart of factual accuracy by question | |
plt.subplot(2, 1, 1) | |
bars = plt.bar(range(len(df)), df["factual_accuracy"], color="skyblue") | |
plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-', | |
label=f"Avg: {results['average_factual_accuracy']:.2f}") | |
plt.xlabel("Question Index") | |
plt.ylabel("Factual Accuracy") | |
plt.title("Factual Accuracy by Question") | |
plt.ylim(0, 1.1) | |
plt.legend() | |
# Add language information | |
df["language"] = df["question"].apply(lambda x: "Arabic" if detect(x) == "ar" else "English") | |
# Group by language | |
lang_accuracy = df.groupby("language")["factual_accuracy"].mean() | |
# Bar chart of accuracy by language | |
plt.subplot(2, 1, 2) | |
lang_bars = plt.bar(lang_accuracy.index, lang_accuracy.values, color=["lightblue", "lightgreen"]) | |
plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-', | |
label=f"Overall: {results['average_factual_accuracy']:.2f}") | |
plt.xlabel("Language") | |
plt.ylabel("Average Factual Accuracy") | |
plt.title("Factual Accuracy by Language") | |
plt.ylim(0, 1.1) | |
# Add value labels | |
for i, v in enumerate(lang_accuracy): | |
plt.text(i, v + 0.05, f"{v:.2f}", ha='center') | |
plt.tight_layout() | |
return fig | |
def record_user_feedback(self, user_input, response, rating, feedback_text=""): | |
"""Record user feedback for a response""" | |
feedback = { | |
"timestamp": datetime.now().isoformat(), | |
"user_input": user_input, | |
"response": response, | |
"rating": rating, | |
"feedback_text": feedback_text | |
} | |
self.metrics["user_ratings"].append(rating) | |
# In a production system, store this in a database | |
logger.info(f"Recorded user feedback: rating={rating}") | |
return True | |
def process_pdf(self, file): | |
"""Process uploaded PDF file""" | |
if file is None: | |
return "No file uploaded. Please select a PDF file." | |
try: | |
logger.info(f"Processing uploaded file") | |
# Convert bytes to file-like object | |
file_stream = io.BytesIO(file) | |
# Use PyPDF2 to read the file content | |
reader = PyPDF2.PdfReader(file_stream) | |
# Extract text from the PDF | |
full_text = "" | |
for page_num in range(len(reader.pages)): | |
page = reader.pages[page_num] | |
extracted_text = page.extract_text() | |
if extracted_text: | |
full_text += extracted_text + "\n" | |
if not full_text.strip(): | |
return "The uploaded PDF doesn't contain extractable text. Please try another file." | |
# Process the extracted text with better chunking | |
chunks = [] | |
paragraphs = re.split(r'\n\s*\n', full_text) | |
for paragraph in paragraphs: | |
# Skip very short paragraphs | |
if len(paragraph.strip()) < 20: | |
continue | |
if len(paragraph) > 500: # For very long paragraphs | |
# Split into smaller chunks | |
sentences = re.split(r'(?<=[.!?])\s+', paragraph) | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) > 300: | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
current_chunk = sentence | |
else: | |
current_chunk += " " + sentence if current_chunk else sentence | |
if current_chunk: | |
chunks.append(current_chunk.strip()) | |
else: | |
chunks.append(paragraph.strip()) | |
# Categorize text by language | |
english_chunks = [] | |
arabic_chunks = [] | |
for chunk in chunks: | |
try: | |
lang = detect(chunk) | |
if lang == "ar": | |
arabic_chunks.append(chunk) | |
else: | |
english_chunks.append(chunk) | |
except: | |
# If language detection fails, check for Arabic characters | |
if any('\u0600' <= c <= '\u06FF' for c in chunk): | |
arabic_chunks.append(chunk) | |
else: | |
english_chunks.append(chunk) | |
# Store PDF content | |
self.pdf_english_texts = english_chunks | |
self.pdf_arabic_texts = arabic_chunks | |
# Create indices for PDF content | |
self._create_pdf_indices() | |
logger.info(f"Successfully processed PDF: {len(arabic_chunks)} Arabic chunks, {len(english_chunks)} English chunks") | |
return f"✅ Successfully processed the PDF! Found {len(arabic_chunks)} Arabic and {len(english_chunks)} English text segments. PDF content will now be prioritized when answering questions." | |
except Exception as e: | |
logger.error(f"Error processing PDF: {str(e)}") | |
return f"❌ Error processing the PDF: {str(e)}. Please try another file." | |
# Create the Gradio interface | |
def create_interface(): | |
# Initialize the assistant | |
assistant = Vision2030Assistant() | |
def chat(message, history): | |
if not message or message.strip() == "": | |
return history, "" | |
# Generate response | |
reply = assistant.generate_response(message) | |
# Update history | |
history.append((message, reply)) | |
return history, "" | |
def provide_feedback(history, rating, feedback_text): | |
# Record feedback for the last conversation | |
if history and len(history) > 0: | |
last_interaction = history[-1] | |
assistant.record_user_feedback(last_interaction[0], last_interaction[1], rating, feedback_text) | |
return f"Thank you for your feedback! (Rating: {rating}/5)" | |
return "No conversation found to rate." | |
def run_evaluation(): | |
results = assistant.evaluate_on_test_set() | |
# Create summary text | |
summary = f""" | |
Evaluation Results: | |
------------------ | |
Total questions evaluated: {len(results['detailed_results'])} | |
Overall factual accuracy: {results['average_factual_accuracy']:.2f} | |
Average response time: {results['average_response_time']:.4f} seconds | |
Detailed Results: | |
""" | |
for i, result in enumerate(results['detailed_results']): | |
summary += f"\nQ{i+1}: {result['question']}\n" | |
summary += f"Reference: {result['reference']}\n" | |
summary += f"Response: {result['response']}\n" | |
summary += f"Accuracy: {result['factual_accuracy']:.2f}\n" | |
summary += "-" * 40 + "\n" | |
# Return both the results summary and visualization | |
fig = assistant.visualize_evaluation_results(results) | |
return summary, fig | |
def process_uploaded_file(file): | |
"""Process the uploaded PDF file""" | |
return assistant.process_pdf(file) | |
# Create the Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Vision 2030 Virtual Assistant 🌟") | |
gr.Markdown("Ask questions about Saudi Arabia's Vision 2030 in both Arabic and English") | |
with gr.Tab("Chat"): | |
chatbot = gr.Chatbot(height=400) | |
msg = gr.Textbox(label="Your Question", placeholder="Ask about Vision 2030...") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit") | |
clear_btn = gr.Button("Clear Chat") | |
gr.Markdown("### Provide Feedback") | |
with gr.Row(): | |
rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Rate the Response (1-5)") | |
feedback_text = gr.Textbox(label="Additional Comments (Optional)") | |
feedback_btn = gr.Button("Submit Feedback") | |
feedback_result = gr.Textbox(label="Feedback Status") | |
with gr.Tab("Evaluation"): | |
evaluate_btn = gr.Button("Run Evaluation on Test Set") | |
eval_output = gr.Textbox(label="Evaluation Results", lines=20) | |
eval_chart = gr.Plot(label="Evaluation Metrics") | |
with gr.Tab("Upload PDF"): | |
gr.Markdown(""" | |
### Upload a Vision 2030 PDF Document | |
Upload a PDF document to enhance the assistant's knowledge base. | |
""") | |
with gr.Row(): | |
file_input = gr.File( | |
label="Select PDF File", | |
file_types=[".pdf"], | |
type="binary" # This is critical - use binary mode | |
) | |
with gr.Row(): | |
upload_btn = gr.Button("Process PDF", variant="primary") | |
with gr.Row(): | |
upload_status = gr.Textbox( | |
label="Upload Status", | |
placeholder="Upload status will appear here...", | |
interactive=False | |
) | |
gr.Markdown(""" | |
### Notes: | |
- The PDF should contain text that can be extracted (not scanned images) | |
- After uploading, return to the Chat tab to ask questions about the uploaded content | |
""") | |
# Set up event handlers | |
msg.submit(chat, [msg, chatbot], [chatbot, msg]) | |
submit_btn.click(chat, [msg, chatbot], [chatbot, msg]) | |
clear_btn.click(lambda: [], None, chatbot) | |
feedback_btn.click(provide_feedback, [chatbot, rating, feedback_text], feedback_result) | |
evaluate_btn.click(run_evaluation, None, [eval_output, eval_chart]) | |
upload_btn.click(process_uploaded_file, [file_input], [upload_status]) | |
return demo | |
# Launch the app | |
demo = create_interface() | |
demo.launch() |