abdull4h commited on
Commit
6750126
·
verified ·
1 Parent(s): bc85188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +268 -780
app.py CHANGED
@@ -1,813 +1,301 @@
1
- # Minimal working Vision 2030 Virtual Assistant
2
- import gradio as gr
3
- import time
4
- import logging
5
  import os
6
  import re
7
- from datetime import datetime
8
  import numpy as np
9
- import pandas as pd
10
- import matplotlib.pyplot as plt
11
- from sklearn.metrics import precision_recall_fscore_support, accuracy_score
12
  import PyPDF2
13
- import io
14
- import json
15
- from langdetect import detect
16
  from sentence_transformers import SentenceTransformer
17
- import faiss
18
- import torch
19
- import spaces
 
20
 
21
- # Configure logging
22
- logging.basicConfig(
23
- level=logging.INFO,
24
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
25
- handlers=[logging.StreamHandler()]
26
  )
27
- logger = logging.getLogger('vision2030_assistant')
28
-
29
- # Check for GPU availability
30
- has_gpu = torch.cuda.is_available()
31
- logger.info(f"GPU available: {has_gpu}")
32
 
33
- class Vision2030Assistant:
34
- def __init__(self):
35
- """Initialize the Vision 2030 Assistant with basic knowledge"""
36
- logger.info("Initializing Vision 2030 Assistant...")
37
 
38
- # Load QA pipelines for English & Arabic
39
- self._load_qa_pipelines()
40
-
41
- # Initialize embedding models
42
- self.load_embedding_models()
43
-
44
- # Create data
45
- self._create_knowledge_base()
46
- self._create_indices()
47
-
48
- # Create sample evaluation data
49
- self._create_sample_eval_data()
50
-
51
- # Initialize metrics
52
- self.metrics = {
53
- "response_times": [],
54
- "user_ratings": [],
55
- "factual_accuracy": []
56
- }
57
- self.response_history = []
58
-
59
- # Flag for PDF content
60
- self.has_pdf_content = False
61
-
62
- logger.info("Vision 2030 Assistant initialized successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- @spaces.GPU
65
- def _load_qa_pipelines(self):
66
- """
67
- Load or initialize QA models for English and Arabic.
68
- You can choose any Hugging Face QA model; below are just examples.
69
- """
70
- logger.info("Loading QA pipelines...")
71
- try:
72
- # English QA pipeline
73
- self.qa_pipeline_en = pipeline(
74
- "question-answering",
75
- model="distilbert-base-cased-distilled-squad",
76
- tokenizer="distilbert-base-cased-distilled-squad",
77
- device=0 if has_gpu else -1 # Use GPU if available
78
- )
79
-
80
- # Arabic QA pipeline
81
- # For Arabic, you can use a model like `aubmindlab/bert-base-arabertv02-qa`:
82
- self.qa_pipeline_ar = pipeline(
83
- "question-answering",
84
- model="aubmindlab/bert-base-arabertv02-qa",
85
- tokenizer="aubmindlab/bert-base-arabertv02-qa",
86
- device=0 if has_gpu else -1
87
- )
88
-
89
- logger.info("QA pipelines loaded successfully.")
90
- except Exception as e:
91
- logger.error(f"Error loading QA pipelines: {str(e)}")
92
- self.qa_pipeline_en = None
93
- self.qa_pipeline_ar = None
94
 
95
- @spaces.GPU
96
- def load_embedding_models(self):
97
- """Load embedding models for retrieval"""
98
- logger.info("Loading embedding models...")
99
-
100
- try:
101
- # Load embedding models
102
- self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
103
- self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
104
-
105
- # Move to GPU if available
106
- if has_gpu:
107
- self.arabic_embedder = self.arabic_embedder.to('cuda')
108
- self.english_embedder = self.english_embedder.to('cuda')
109
- logger.info("Models moved to GPU")
110
-
111
- logger.info("Embedding models loaded successfully")
112
- except Exception as e:
113
- logger.error(f"Error loading embedding models: {str(e)}")
114
- self._create_fallback_embedders()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
- def _create_fallback_embedders(self):
117
- """Create fallback embedding methods if model loading fails"""
118
- logger.warning("Using fallback embedding methods")
119
-
120
- # Simple fallback using character-level encoding
121
- def simple_encode(text, dim=384):
122
- import hashlib
123
- # Create a hash of the text
124
- hash_object = hashlib.md5(text.encode())
125
- # Use the hash to seed a random number generator
126
- np.random.seed(int(hash_object.hexdigest(), 16) % 2**32)
127
- # Generate a random vector
128
- return np.random.randn(dim).astype(np.float32)
129
-
130
- # Create embedding function objects
131
- class SimpleEmbedder:
132
- def __init__(self, dim=384):
133
- self.dim = dim
134
-
135
- def encode(self, text):
136
- return simple_encode(text, self.dim)
137
-
138
- self.arabic_embedder = SimpleEmbedder()
139
- self.english_embedder = SimpleEmbedder()
 
 
 
 
 
 
 
 
 
140
 
141
- def _create_knowledge_base(self):
142
- """Create knowledge base with Vision 2030 information"""
143
- logger.info("Creating Vision 2030 knowledge base")
144
-
145
- # English texts
146
- self.english_texts = [
147
- "Vision 2030 is Saudi Arabia's strategic framework to reduce dependence on oil, diversify the economy, and develop public sectors.",
148
- "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation.",
149
- "Vision 2030 targets increasing the private sector's contribution to GDP from 40% to 65%.",
150
- "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030.",
151
- "Vision 2030 aims to increase women's participation in the workforce from 22% to 30%.",
152
- "The Red Sea Project is a Vision 2030 initiative to develop luxury tourism destinations across 50 islands off Saudi Arabia's Red Sea coast.",
153
- "Qiddiya is an entertainment mega-project being built in Riyadh as part of Vision 2030.",
154
- "The real wealth of Saudi Arabia, as emphasized in Vision 2030, is its people, particularly the youth.",
155
- "Saudi Arabia aims to strengthen its position as a global gateway by leveraging its strategic location between Asia, Europe, and Africa.",
156
- "Vision 2030 aims to have at least five Saudi universities among the top 200 universities in international rankings.",
157
- "Vision 2030 sets a target of having at least 10 Saudi sites registered on the UNESCO World Heritage List.",
158
- "Vision 2030 aims to increase the capacity to welcome Umrah visitors from 8 million to 30 million annually.",
159
- "Vision 2030 includes multiple initiatives to strengthen Saudi national identity including cultural programs and heritage preservation.",
160
- "Vision 2030 aims to increase non-oil government revenue from SAR 163 billion to SAR 1 trillion."
161
- ]
162
-
163
- # Arabic texts
164
- self.arabic_texts = [
165
- "رؤية 2030 هي الإطار الاستراتيجي للمملكة العربية السعودية للحد من الاعتماد على النفط وتنويع الاقتصاد وتطوير القطاعات العامة.",
166
- "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح.",
167
- "تستهدف رؤية 2030 زيادة مساهمة القطاع الخاص في الناتج المحلي الإجمالي من 40٪ إلى 65٪.",
168
- "نيوم هي مدينة ذكية مخططة عبر الحدود في مقاطعة تبوك شمال غرب المملكة العربية السعودية، وهي مشروع رئيسي من رؤية 2030.",
169
- "تهدف رؤية 2030 إلى زيادة مشاركة المرأة في القوى العاملة من 22٪ إلى 30٪.",
170
- "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي.",
171
- "القدية هي مشروع ترفيهي ضخم يتم بناؤه في الرياض كجزء من رؤية 2030.",
172
- "الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب.",
173
- "تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا.",
174
- "تهدف رؤية 2030 إلى أن تكون خمس جامعات سعودية على الأقل ضمن أفضل 200 جامعة في التصنيفات الدولية.",
175
- "تضع رؤية 2030 هدفًا بتسجيل ما لا يقل عن 10 مواقع سعودية في قائمة التراث العالمي لليونسكو.",
176
- "تهدف رؤية 2030 إلى زيادة القدرة على استقبال المعتمرين من 8 ملايين إلى 30 مليون معتمر سنويًا.",
177
- "تتضمن رؤية 2030 مبادرات متعددة لتعزيز الهوية الوطنية السعودية بما في ذلك البرامج الثقافية والحفاظ على التراث.",
178
- "تهدف رؤية 2030 إلى زيادة الإيرادات الحكومية غير النفطية من 163 مليار ريال سعودي إلى 1 تريليون ريال سعودي."
179
- ]
180
-
181
- # Initialize PDF content containers
182
- self.pdf_english_texts = []
183
- self.pdf_arabic_texts = []
184
-
185
- logger.info(f"Created knowledge base: {len(self.english_texts)} English, {len(self.arabic_texts)} Arabic texts")
186
 
187
- @spaces.GPU
188
- def _create_indices(self):
189
- """Create FAISS indices for text retrieval"""
190
- logger.info("Creating FAISS indices for text retrieval")
191
-
192
- try:
193
- # Process and embed English texts
194
- self.english_vectors = []
195
- for text in self.english_texts:
196
- try:
197
- if has_gpu and hasattr(self.english_embedder, 'to'):
198
- with torch.no_grad():
199
- vec = self.english_embedder.encode(text)
200
- else:
201
- vec = self.english_embedder.encode(text)
202
- self.english_vectors.append(vec)
203
- except Exception as e:
204
- logger.error(f"Error encoding English text: {str(e)}")
205
- # Use a random vector as fallback
206
- self.english_vectors.append(np.random.randn(384).astype(np.float32))
207
-
208
- # Create English index
209
- if self.english_vectors:
210
- self.english_index = faiss.IndexFlatL2(len(self.english_vectors[0]))
211
- self.english_index.add(np.array(self.english_vectors))
212
- logger.info(f"Created English index with {len(self.english_vectors)} vectors")
213
- else:
214
- logger.warning("No English texts to index")
215
-
216
- # Process and embed Arabic texts
217
- self.arabic_vectors = []
218
- for text in self.arabic_texts:
219
- try:
220
- if has_gpu and hasattr(self.arabic_embedder, 'to'):
221
- with torch.no_grad():
222
- vec = self.arabic_embedder.encode(text)
223
- else:
224
- vec = self.arabic_embedder.encode(text)
225
- self.arabic_vectors.append(vec)
226
- except Exception as e:
227
- logger.error(f"Error encoding Arabic text: {str(e)}")
228
- # Use a random vector as fallback
229
- self.arabic_vectors.append(np.random.randn(384).astype(np.float32))
230
-
231
- # Create Arabic index
232
- if self.arabic_vectors:
233
- self.arabic_index = faiss.IndexFlatL2(len(self.arabic_vectors[0]))
234
- self.arabic_index.add(np.array(self.arabic_vectors))
235
- logger.info(f"Created Arabic index with {len(self.arabic_vectors)} vectors")
236
- else:
237
- logger.warning("No Arabic texts to index")
238
-
239
- # Create PDF indices if PDF content exists
240
- if hasattr(self, 'pdf_english_texts') and self.pdf_english_texts:
241
- self._create_pdf_indices()
242
-
243
- except Exception as e:
244
- logger.error(f"Error creating FAISS indices: {str(e)}")
245
 
246
- def _create_pdf_indices(self):
247
- """Create indices for PDF content"""
248
- if not self.pdf_english_texts and not self.pdf_arabic_texts:
249
- return
250
-
251
- logger.info("Creating indices for PDF content")
252
-
253
- try:
254
- # Process and embed English PDF texts
255
- if self.pdf_english_texts:
256
- self.pdf_english_vectors = []
257
- for text in self.pdf_english_texts:
258
- try:
259
- if has_gpu and hasattr(self.english_embedder, 'to'):
260
- with torch.no_grad():
261
- vec = self.english_embedder.encode(text)
262
- else:
263
- vec = self.english_embedder.encode(text)
264
- self.pdf_english_vectors.append(vec)
265
- except Exception as e:
266
- logger.error(f"Error encoding English PDF text: {str(e)}")
267
- continue
268
-
269
- if self.pdf_english_vectors:
270
- self.pdf_english_index = faiss.IndexFlatL2(len(self.pdf_english_vectors[0]))
271
- self.pdf_english_index.add(np.array(self.pdf_english_vectors))
272
- logger.info(f"Created English PDF index with {len(self.pdf_english_vectors)} vectors")
273
-
274
- # Process and embed Arabic PDF texts
275
- if self.pdf_arabic_texts:
276
- self.pdf_arabic_vectors = []
277
- for text in self.pdf_arabic_texts:
278
- try:
279
- if has_gpu and hasattr(self.arabic_embedder, 'to'):
280
- with torch.no_grad():
281
- vec = self.arabic_embedder.encode(text)
282
- else:
283
- vec = self.arabic_embedder.encode(text)
284
- self.pdf_arabic_vectors.append(vec)
285
- except Exception as e:
286
- logger.error(f"Error encoding Arabic PDF text: {str(e)}")
287
- continue
288
-
289
- if self.pdf_arabic_vectors:
290
- self.pdf_arabic_index = faiss.IndexFlatL2(len(self.pdf_arabic_vectors[0]))
291
- self.pdf_arabic_index.add(np.array(self.pdf_arabic_vectors))
292
- logger.info(f"Created Arabic PDF index with {len(self.pdf_arabic_vectors)} vectors")
293
-
294
- # Set flag to indicate PDF content is available
295
- self.has_pdf_content = True
296
-
297
- except Exception as e:
298
- logger.error(f"Error creating PDF indices: {str(e)}")
299
 
300
- def _create_sample_eval_data(self):
301
- """Create sample evaluation data with ground truth"""
302
- self.eval_data = [
303
- {
304
- "question": "What are the key pillars of Vision 2030?",
305
- "lang": "en",
306
- "reference_answer": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."
307
- },
308
- {
309
- "question": "ما هي الركائز الرئيسية لرؤية 2030؟",
310
- "lang": "ar",
311
- "reference_answer": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."
312
- },
313
- {
314
- "question": "What is NEOM?",
315
- "lang": "en",
316
- "reference_answer": "NEOM is a planned cross-border smart city in the Tabuk Province of northwestern Saudi Arabia, a key project of Vision 2030."
317
- },
318
- {
319
- "question": "ما هو مشروع البحر الأحمر؟",
320
- "lang": "ar",
321
- "reference_answer": "مشروع البحر الأحمر هو مبادرة رؤية 2030 لتطوير وجهات سياحية فاخرة عبر 50 جزيرة قبالة ساحل البحر الأحمر السعودي."
322
- },
323
- {
324
- "question": "ما هي الثروة الحقيقية التي تعتز بها المملكة كما وردت في الرؤية؟",
325
- "lang": "ar",
326
- "reference_answer": "الثروة الحقيقية للمملكة العربية السعودية، كما أكدت رؤية 2030، هي شعبها، وخاصة الشباب."
327
- },
328
- {
329
- "question": "كيف تسعى المملكة إلى تعزيز مكانتها كبوابة للعالم؟",
330
- "lang": "ar",
331
- "reference_answer": "تهدف المملكة العربية السعودية إلى تعزيز مكانتها كبوابة عالمية من خلال الاستفادة من موقعها الاستراتيجي بين آسيا وأوروبا وأفريقيا."
332
- }
333
- ]
334
- logger.info(f"Created {len(self.eval_data)} sample evaluation examples")
335
 
336
- @spaces.GPU
337
- def retrieve_context(self, query, lang):
338
- """Retrieve relevant context with priority to PDF content"""
339
- start_time = time.time()
340
-
341
- try:
342
- # First check if we have PDF content
343
- if self.has_pdf_content:
344
- # Try to retrieve from PDF content first
345
- if lang == "ar" and hasattr(self, 'pdf_arabic_index') and hasattr(self, 'pdf_arabic_vectors') and len(self.pdf_arabic_vectors) > 0:
346
- if has_gpu and hasattr(self.arabic_embedder, 'to'):
347
- with torch.no_grad():
348
- query_vec = self.arabic_embedder.encode(query)
349
- else:
350
- query_vec = self.arabic_embedder.encode(query)
351
-
352
- D, I = self.pdf_arabic_index.search(np.array([query_vec]), k=2)
353
-
354
- # If we found good matches in the PDF
355
- if D[0][0] < 1.5: # Threshold for relevance
356
- context = "\n".join([self.pdf_arabic_texts[i] for i in I[0] if i < len(self.pdf_arabic_texts) and i >= 0])
357
- if context.strip():
358
- logger.info("Retrieved context from PDF (Arabic)")
359
- return context
360
-
361
- elif lang == "en" and hasattr(self, 'pdf_english_index') and hasattr(self, 'pdf_english_vectors') and len(self.pdf_english_vectors) > 0:
362
- if has_gpu and hasattr(self.english_embedder, 'to'):
363
- with torch.no_grad():
364
- query_vec = self.english_embedder.encode(query)
365
- else:
366
- query_vec = self.english_embedder.encode(query)
367
-
368
- D, I = self.pdf_english_index.search(np.array([query_vec]), k=2)
369
-
370
- # If we found good matches in the PDF
371
- if D[0][0] < 1.5: # Threshold for relevance
372
- context = "\n".join([self.pdf_english_texts[i] for i in I[0] if i < len(self.pdf_english_texts) and i >= 0])
373
- if context.strip():
374
- logger.info("Retrieved context from PDF (English)")
375
- return context
376
-
377
- # Fall back to the pre-built knowledge base
378
- if lang == "ar":
379
- if has_gpu and hasattr(self.arabic_embedder, 'to'):
380
- with torch.no_grad():
381
- query_vec = self.arabic_embedder.encode(query)
382
- else:
383
- query_vec = self.arabic_embedder.encode(query)
384
-
385
- D, I = self.arabic_index.search(np.array([query_vec]), k=2)
386
- context = "\n".join([self.arabic_texts[i] for i in I[0] if i < len(self.arabic_texts) and i >= 0])
387
- else:
388
- if has_gpu and hasattr(self.english_embedder, 'to'):
389
- with torch.no_grad():
390
- query_vec = self.english_embedder.encode(query)
391
- else:
392
- query_vec = self.english_embedder.encode(query)
393
-
394
- D, I = self.english_index.search(np.array([query_vec]), k=2)
395
- context = "\n".join([self.english_texts[i] for i in I[0] if i < len(self.english_texts) and i >= 0])
396
-
397
- retrieval_time = time.time() - start_time
398
- logger.info(f"Retrieved context in {retrieval_time:.2f}s")
399
-
400
- return context
401
- except Exception as e:
402
- logger.error(f"Error retrieving context: {str(e)}")
403
- return ""
404
 
405
- def generate_response(self, user_input):
406
- """Generate a more detailed answer using a QA pipeline if available."""
407
- if not user_input or user_input.strip() == "":
408
- return ""
409
-
410
- start_time = time.time()
411
-
412
- default_response = {
413
- "en": "I apologize, but I couldn't process your request properly. Please try again.",
414
- "ar": "أعتذر، لم أتمكن من معالجة طلبك بشكل صحيح. الرجاء المحاولة مرة أخرى."
415
- }
416
-
417
- try:
418
- # 1) Detect language
419
- try:
420
- lang_detected = detect(user_input)
421
- lang = "ar" if lang_detected == "ar" else "en"
422
- except:
423
- lang = "en" # fallback
424
-
425
- logger.info(f"Detected language: {lang}")
426
-
427
- # 2) Retrieve relevant context (could be from PDF or base knowledge)
428
- context = self.retrieve_context(user_input, lang)
429
 
430
- # 3) Decide whether to use QA pipeline or fallback
431
- if lang == "ar" and self.qa_pipeline_ar is not None and context:
432
- # Use Arabic QA pipeline
433
- try:
434
- answer = self.qa_pipeline_ar(question=user_input, context=context)
435
- reply = answer["answer"].strip()
436
-
437
- # If the QA model returns something too short or obviously unhelpful,
438
- # you can fallback to the original context-based approach:
439
- if len(reply) < 2:
440
- reply = context # fallback to returning the raw context
441
- except Exception as e:
442
- logger.error(f"Error in Arabic QA pipeline: {str(e)}")
443
- # fallback
444
- reply = context if context else "لم أتمكن من العثور على معلومات كافية حول هذا السؤال."
445
 
446
- elif lang == "en" and self.qa_pipeline_en is not None and context:
447
- # Use English QA pipeline
448
- try:
449
- answer = self.qa_pipeline_en(question=user_input, context=context)
450
- reply = answer["answer"].strip()
451
- if len(reply) < 2:
452
- reply = context
453
- except Exception as e:
454
- logger.error(f"Error in English QA pipeline: {str(e)}")
455
- reply = context if context else "I couldn't find enough information about this question."
456
 
457
- else:
458
- # 4) If no QA pipeline or no context, fallback to your existing approach
459
- # e.g., returning context or a short fallback message.
460
 
461
- if lang == "ar":
462
- reply = context if context else "لم أتمكن من العثور على معلومات كافية حول هذا السؤال."
463
- else:
464
- reply = context if context else "I couldn't find enough information about this question."
465
-
466
- # 5) Record metrics and return
467
- response_time = time.time() - start_time
468
- self.metrics["response_times"].append(response_time)
469
- logger.info(f"Generated response in {response_time:.2f}s")
470
-
471
- # Store the interaction
472
- interaction = {
473
- "timestamp": datetime.now().isoformat(),
474
- "user_input": user_input,
475
- "response": reply,
476
- "language": lang,
477
- "response_time": response_time
478
- }
479
- self.response_history.append(interaction)
480
-
481
- return reply
482
 
483
- except Exception as e:
484
- logger.error(f"Error generating response: {str(e)}")
485
- # fallback to default
486
- return default_response.get(lang, default_response["en"])
487
 
 
 
 
488
 
489
- def evaluate_factual_accuracy(self, response, reference):
490
- """Simple evaluation of factual accuracy by keyword matching"""
491
- # This is a simplified approach - in production, use more sophisticated methods
492
- keywords_reference = set(re.findall(r'\b\w+\b', reference.lower()))
493
- keywords_response = set(re.findall(r'\b\w+\b', response.lower()))
494
-
495
- # Remove common stopwords (simplified approach)
496
- english_stopwords = {"the", "is", "a", "an", "and", "or", "of", "to", "in", "for", "with", "by", "on", "at"}
497
- arabic_stopwords = {"في", "من", "إلى", "على", "و", "هي", "هو", "عن", "مع"}
498
-
499
- keywords_reference = {w for w in keywords_reference if w not in english_stopwords and w not in arabic_stopwords}
500
- keywords_response = {w for w in keywords_response if w not in english_stopwords and w not in arabic_stopwords}
501
-
502
- common_keywords = keywords_reference.intersection(keywords_response)
503
-
504
- if len(keywords_reference) > 0:
505
- accuracy = len(common_keywords) / len(keywords_reference)
506
- else:
507
- accuracy = 0
508
-
509
- return accuracy
510
 
511
- @spaces.GPU
512
- def evaluate_on_test_set(self):
513
- """Evaluate the assistant on the test set"""
514
- logger.info("Running evaluation on test set")
515
-
516
- eval_results = []
517
-
518
- for example in self.eval_data:
519
- # Generate response
520
- response = self.generate_response(example["question"])
521
-
522
- # Calculate factual accuracy
523
- accuracy = self.evaluate_factual_accuracy(response, example["reference_answer"])
524
-
525
- eval_results.append({
526
- "question": example["question"],
527
- "reference": example["reference_answer"],
528
- "response": response,
529
- "factual_accuracy": accuracy
530
- })
531
-
532
- self.metrics["factual_accuracy"].append(accuracy)
533
-
534
- # Calculate average factual accuracy
535
- avg_accuracy = sum(self.metrics["factual_accuracy"]) / len(self.metrics["factual_accuracy"]) if self.metrics["factual_accuracy"] else 0
536
- avg_response_time = sum(self.metrics["response_times"]) / len(self.metrics["response_times"]) if self.metrics["response_times"] else 0
537
-
538
- results = {
539
- "average_factual_accuracy": avg_accuracy,
540
- "average_response_time": avg_response_time,
541
- "detailed_results": eval_results
542
- }
543
-
544
- logger.info(f"Evaluation results: Factual accuracy = {avg_accuracy:.2f}, Avg response time = {avg_response_time:.2f}s")
545
-
546
- return results
547
-
548
- def visualize_evaluation_results(self, results):
549
- """Generate visualization of evaluation results"""
550
- # Create a DataFrame from the detailed results
551
- df = pd.DataFrame(results["detailed_results"])
552
-
553
- # Create the figure for visualizations
554
- fig = plt.figure(figsize=(12, 8))
555
-
556
- # Bar chart of factual accuracy by question
557
- plt.subplot(2, 1, 1)
558
- bars = plt.bar(range(len(df)), df["factual_accuracy"], color="skyblue")
559
- plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-',
560
- label=f"Avg: {results['average_factual_accuracy']:.2f}")
561
- plt.xlabel("Question Index")
562
- plt.ylabel("Factual Accuracy")
563
- plt.title("Factual Accuracy by Question")
564
- plt.ylim(0, 1.1)
565
- plt.legend()
566
-
567
- # Add language information
568
- df["language"] = df["question"].apply(lambda x: "Arabic" if detect(x) == "ar" else "English")
569
-
570
- # Group by language
571
- lang_accuracy = df.groupby("language")["factual_accuracy"].mean()
572
-
573
- # Bar chart of accuracy by language
574
- plt.subplot(2, 1, 2)
575
- lang_bars = plt.bar(lang_accuracy.index, lang_accuracy.values, color=["lightblue", "lightgreen"])
576
- plt.axhline(y=results["average_factual_accuracy"], color='r', linestyle='-',
577
- label=f"Overall: {results['average_factual_accuracy']:.2f}")
578
- plt.xlabel("Language")
579
- plt.ylabel("Average Factual Accuracy")
580
- plt.title("Factual Accuracy by Language")
581
- plt.ylim(0, 1.1)
582
-
583
- # Add value labels
584
- for i, v in enumerate(lang_accuracy):
585
- plt.text(i, v + 0.05, f"{v:.2f}", ha='center')
586
-
587
- plt.tight_layout()
588
- return fig
589
 
590
- def record_user_feedback(self, user_input, response, rating, feedback_text=""):
591
- """Record user feedback for a response"""
592
- feedback = {
593
- "timestamp": datetime.now().isoformat(),
594
- "user_input": user_input,
595
- "response": response,
596
- "rating": rating,
597
- "feedback_text": feedback_text
598
- }
599
-
600
- self.metrics["user_ratings"].append(rating)
601
-
602
- # In a production system, store this in a database
603
- logger.info(f"Recorded user feedback: rating={rating}")
604
-
605
- return True
606
 
607
- @spaces.GPU
608
- def process_pdf(self, file):
609
- """Process uploaded PDF file"""
610
- if file is None:
611
- return "No file uploaded. Please select a PDF file."
612
-
613
- try:
614
- logger.info(f"Processing uploaded file")
615
-
616
- # Convert bytes to file-like object
617
- file_stream = io.BytesIO(file)
618
-
619
- # Use PyPDF2 to read the file content
620
- reader = PyPDF2.PdfReader(file_stream)
621
-
622
- # Extract text from the PDF
623
- full_text = ""
624
- for page_num in range(len(reader.pages)):
625
- page = reader.pages[page_num]
626
- extracted_text = page.extract_text()
627
- if extracted_text:
628
- full_text += extracted_text + "\n"
629
-
630
- if not full_text.strip():
631
- return "The uploaded PDF doesn't contain extractable text. Please try another file."
632
-
633
- # Process the extracted text with better chunking
634
- chunks = []
635
- paragraphs = re.split(r'\n\s*\n', full_text)
636
-
637
- for paragraph in paragraphs:
638
- # Skip very short paragraphs
639
- if len(paragraph.strip()) < 20:
640
- continue
641
-
642
- if len(paragraph) > 500: # For very long paragraphs
643
- # Split into smaller chunks
644
- sentences = re.split(r'(?<=[.!?])\s+', paragraph)
645
- current_chunk = ""
646
- for sentence in sentences:
647
- if len(current_chunk) + len(sentence) > 300:
648
- if current_chunk:
649
- chunks.append(current_chunk.strip())
650
- current_chunk = sentence
651
- else:
652
- current_chunk += " " + sentence if current_chunk else sentence
653
-
654
- if current_chunk:
655
- chunks.append(current_chunk.strip())
656
- else:
657
- chunks.append(paragraph.strip())
658
-
659
- # Categorize text by language
660
- english_chunks = []
661
- arabic_chunks = []
662
-
663
- for chunk in chunks:
664
- try:
665
- lang = detect(chunk)
666
- if lang == "ar":
667
- arabic_chunks.append(chunk)
668
- else:
669
- english_chunks.append(chunk)
670
- except:
671
- # If language detection fails, check for Arabic characters
672
- if any('\u0600' <= c <= '\u06FF' for c in chunk):
673
- arabic_chunks.append(chunk)
674
- else:
675
- english_chunks.append(chunk)
676
-
677
- # Store PDF content
678
- self.pdf_english_texts = english_chunks
679
- self.pdf_arabic_texts = arabic_chunks
680
-
681
- # Create indices for PDF content
682
- self._create_pdf_indices()
683
-
684
- logger.info(f"Successfully processed PDF: {len(arabic_chunks)} Arabic chunks, {len(english_chunks)} English chunks")
685
-
686
- return f"✅ Successfully processed the PDF! Found {len(arabic_chunks)} Arabic and {len(english_chunks)} English text segments. PDF content will now be prioritized when answering questions."
687
-
688
- except Exception as e:
689
- logger.error(f"Error processing PDF: {str(e)}")
690
- return f"❌ Error processing the PDF: {str(e)}. Please try another file."
691
 
692
- # Create the Gradio interface
693
- def create_interface():
694
- # Initialize the assistant
695
- assistant = Vision2030Assistant()
696
-
697
- def chat(message, history):
698
- if not message or message.strip() == "":
699
- return history, ""
700
-
701
- # Generate response
702
- reply = assistant.generate_response(message)
703
-
704
- # Update history
705
- history.append((message, reply))
706
-
707
- return history, ""
708
-
709
- def provide_feedback(history, rating, feedback_text):
710
- # Record feedback for the last conversation
711
- if history and len(history) > 0:
712
- last_interaction = history[-1]
713
- assistant.record_user_feedback(last_interaction[0], last_interaction[1], rating, feedback_text)
714
- return f"Thank you for your feedback! (Rating: {rating}/5)"
715
- return "No conversation found to rate."
716
-
717
- @spaces.GPU
718
- def run_evaluation():
719
- results = assistant.evaluate_on_test_set()
720
-
721
- # Create summary text
722
- summary = f"""
723
- Evaluation Results:
724
- ------------------
725
- Total questions evaluated: {len(results['detailed_results'])}
726
- Overall factual accuracy: {results['average_factual_accuracy']:.2f}
727
- Average response time: {results['average_response_time']:.4f} seconds
728
-
729
- Detailed Results:
730
- """
731
-
732
- for i, result in enumerate(results['detailed_results']):
733
- summary += f"\nQ{i+1}: {result['question']}\n"
734
- summary += f"Reference: {result['reference']}\n"
735
- summary += f"Response: {result['response']}\n"
736
- summary += f"Accuracy: {result['factual_accuracy']:.2f}\n"
737
- summary += "-" * 40 + "\n"
738
-
739
- # Return both the results summary and visualization
740
- fig = assistant.visualize_evaluation_results(results)
741
-
742
- return summary, fig
743
 
744
- def process_uploaded_file(file):
745
- """Process the uploaded PDF file"""
746
- return assistant.process_pdf(file)
747
 
748
- # Create the Gradio interface
749
- with gr.Blocks() as demo:
750
- gr.Markdown("# Vision 2030 Virtual Assistant 🌟")
751
- gr.Markdown("Ask questions about Saudi Arabia's Vision 2030 in both Arabic and English")
752
-
753
- with gr.Tab("Chat"):
754
- chatbot = gr.Chatbot(height=400)
755
- msg = gr.Textbox(label="Your Question", placeholder="Ask about Vision 2030...")
756
- with gr.Row():
757
- submit_btn = gr.Button("Submit")
758
- clear_btn = gr.Button("Clear Chat")
759
-
760
- gr.Markdown("### Provide Feedback")
761
- with gr.Row():
762
- rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, label="Rate the Response (1-5)")
763
- feedback_text = gr.Textbox(label="Additional Comments (Optional)")
764
- feedback_btn = gr.Button("Submit Feedback")
765
- feedback_result = gr.Textbox(label="Feedback Status")
766
-
767
- with gr.Tab("Evaluation"):
768
- evaluate_btn = gr.Button("Run Evaluation on Test Set")
769
- eval_output = gr.Textbox(label="Evaluation Results", lines=20)
770
- eval_chart = gr.Plot(label="Evaluation Metrics")
771
-
772
- with gr.Tab("Upload PDF"):
773
- gr.Markdown("""
774
- ### Upload a Vision 2030 PDF Document
775
- Upload a PDF document to enhance the assistant's knowledge base.
776
- """)
777
-
778
- with gr.Row():
779
- file_input = gr.File(
780
- label="Select PDF File",
781
- file_types=[".pdf"],
782
- type="binary" # This is critical - use binary mode
783
- )
784
-
785
- with gr.Row():
786
- upload_btn = gr.Button("Process PDF", variant="primary")
787
-
788
- with gr.Row():
789
- upload_status = gr.Textbox(
790
- label="Upload Status",
791
- placeholder="Upload status will appear here...",
792
- interactive=False
793
- )
794
-
795
- gr.Markdown("""
796
- ### Notes:
797
- - The PDF should contain text that can be extracted (not scanned images)
798
- - After uploading, return to the Chat tab to ask questions about the uploaded content
799
- """)
800
-
801
- # Set up event handlers
802
- msg.submit(chat, [msg, chatbot], [chatbot, msg])
803
- submit_btn.click(chat, [msg, chatbot], [chatbot, msg])
804
- clear_btn.click(lambda: [], None, chatbot)
805
- feedback_btn.click(provide_feedback, [chatbot, rating, feedback_text], feedback_result)
806
- evaluate_btn.click(run_evaluation, None, [eval_output, eval_chart])
807
- upload_btn.click(process_uploaded_file, [file_input], [upload_status])
808
 
809
- return demo
 
 
 
 
810
 
811
- # Launch the app
812
- demo = create_interface()
813
- demo.launch()
 
 
1
+ import streamlit as st
 
 
 
2
  import os
3
  import re
4
+ import torch
5
  import numpy as np
6
+ from pathlib import Path
 
 
7
  import PyPDF2
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
9
  from sentence_transformers import SentenceTransformer
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain.schema import Document
13
+ from langchain.embeddings import HuggingFaceEmbeddings
14
 
15
+ # Set page configuration
16
+ st.set_page_config(
17
+ page_title="Vision 2030 Virtual Assistant",
18
+ page_icon="🇸🇦",
19
+ layout="wide"
20
  )
 
 
 
 
 
21
 
22
+ # App title and description
23
+ st.title("Vision 2030 Virtual Assistant")
24
+ st.markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.")
 
25
 
26
+ # Function definitions
27
+ @st.cache_resource
28
+ def load_model_and_tokenizer():
29
+ """Load the ALLaM-7B model and tokenizer with error handling"""
30
+ model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview"
31
+ st.info(f"Loading model: {model_name} (this may take a few minutes)")
32
+
33
+ try:
34
+ # First attempt with AutoTokenizer
35
+ tokenizer = AutoTokenizer.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ use_fast=False
39
+ )
40
+
41
+ # Load model with appropriate settings for ALLaM
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ torch_dtype=torch.bfloat16,
45
+ trust_remote_code=True,
46
+ device_map="auto",
47
+ )
48
+
49
+ st.success("Model loaded successfully!")
50
+
51
+ except Exception as e:
52
+ st.error(f"First loading attempt failed: {e}")
53
+ st.info("Trying alternative loading approach...")
54
+
55
+ # Try with specific tokenizer class if the first attempt fails
56
+ from transformers import LlamaTokenizer
57
+
58
+ tokenizer = LlamaTokenizer.from_pretrained(model_name)
59
+ model = AutoModelForCausalLM.from_pretrained(
60
+ model_name,
61
+ torch_dtype=torch.float16,
62
+ trust_remote_code=True,
63
+ device_map="auto",
64
+ )
65
+
66
+ st.success("Model loaded successfully with LlamaTokenizer!")
67
+
68
+ return model, tokenizer
69
 
70
+ def detect_language(text):
71
+ """Detect if text is primarily Arabic or English"""
72
+ arabic_chars = re.findall(r'[\u0600-\u06FF]', text)
73
+ is_arabic = len(arabic_chars) > len(text) * 0.5
74
+ return "arabic" if is_arabic else "english"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def process_pdfs():
77
+ """Process uploaded PDF documents"""
78
+ documents = []
79
+
80
+ if 'uploaded_pdfs' in st.session_state and st.session_state.uploaded_pdfs:
81
+ for pdf_file in st.session_state.uploaded_pdfs:
82
+ try:
83
+ # Save the uploaded file temporarily
84
+ pdf_path = f"temp_{pdf_file.name}"
85
+ with open(pdf_path, "wb") as f:
86
+ f.write(pdf_file.getbuffer())
87
+
88
+ # Extract text
89
+ text = ""
90
+ with open(pdf_path, 'rb') as file:
91
+ reader = PyPDF2.PdfReader(file)
92
+ for page in reader.pages:
93
+ text += page.extract_text() + "\n\n"
94
+
95
+ # Remove temporary file
96
+ os.remove(pdf_path)
97
+
98
+ if text.strip(): # If we got some text
99
+ doc = Document(
100
+ page_content=text,
101
+ metadata={"source": pdf_file.name, "filename": pdf_file.name}
102
+ )
103
+ documents.append(doc)
104
+ st.info(f"Successfully processed: {pdf_file.name}")
105
+ else:
106
+ st.warning(f"No text extracted from {pdf_file.name}")
107
+ except Exception as e:
108
+ st.error(f"Error processing {pdf_file.name}: {e}")
109
+
110
+ st.success(f"Processed {len(documents)} PDF documents")
111
+ return documents
112
 
113
+ def create_vector_store(documents):
114
+ """Split documents into chunks and create a FAISS vector store"""
115
+ # Text splitter for breaking documents into chunks
116
+ text_splitter = RecursiveCharacterTextSplitter(
117
+ chunk_size=500,
118
+ chunk_overlap=50,
119
+ separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""]
120
+ )
121
+
122
+ # Split documents into chunks
123
+ chunks = []
124
+ for doc in documents:
125
+ doc_chunks = text_splitter.split_text(doc.page_content)
126
+ # Preserve metadata for each chunk
127
+ chunks.extend([
128
+ Document(page_content=chunk, metadata=doc.metadata)
129
+ for chunk in doc_chunks
130
+ ])
131
+
132
+ st.info(f"Created {len(chunks)} chunks from {len(documents)} documents")
133
+
134
+ # Create a proper embedding function for LangChain
135
+ embedding_function = HuggingFaceEmbeddings(
136
+ model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
137
+ )
138
+
139
+ # Create FAISS index
140
+ vector_store = FAISS.from_documents(
141
+ chunks,
142
+ embedding_function
143
+ )
144
+
145
+ return vector_store
146
 
147
+ def retrieve_context(query, vector_store, top_k=5):
148
+ """Retrieve most relevant document chunks for a given query"""
149
+ # Search the vector store using similarity search
150
+ results = vector_store.similarity_search_with_score(query, k=top_k)
151
+
152
+ # Format the retrieved contexts
153
+ contexts = []
154
+ for doc, score in results:
155
+ contexts.append({
156
+ "content": doc.page_content,
157
+ "source": doc.metadata.get("source", "Unknown"),
158
+ "relevance_score": score
159
+ })
160
+
161
+ return contexts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
+ def generate_response(query, contexts, model, tokenizer):
164
+ """Generate a response using retrieved contexts with ALLaM-specific formatting"""
165
+ # Auto-detect language
166
+ language = detect_language(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # Format the prompt based on language
169
+ if language == "arabic":
170
+ instruction = (
171
+ "أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. "
172
+ "إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف."
173
+ )
174
+ else: # english
175
+ instruction = (
176
+ "You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. "
177
+ "If you don't know the answer, honestly say you don't know."
178
+ )
179
+
180
+ # Combine retrieved contexts
181
+ context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
+ # Format the prompt for ALLaM instruction format
184
+ prompt = f"""<s>[INST] {instruction}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ Context:
187
+ {context_text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
+ Question: {query} [/INST]</s>"""
190
+
191
+ try:
192
+ with st.spinner("Generating response..."):
193
+ # Generate response with appropriate parameters for ALLaM
194
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
195
+
196
+ # Generate with appropriate parameters
197
+ outputs = model.generate(
198
+ inputs.input_ids,
199
+ attention_mask=inputs.attention_mask,
200
+ max_new_tokens=512,
201
+ temperature=0.7,
202
+ top_p=0.9,
203
+ do_sample=True,
204
+ repetition_penalty=1.1
205
+ )
 
 
 
 
 
 
 
206
 
207
+ # Decode the response
208
+ full_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ # Extract just the answer part (after the instruction)
211
+ response = full_output.split("[/INST]")[-1].strip()
 
 
 
 
 
 
 
 
212
 
213
+ # If response is empty for some reason, return the full output
214
+ if not response:
215
+ response = full_output
216
 
217
+ return response, [ctx.get("source", "Unknown") for ctx in contexts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
+ except Exception as e:
220
+ st.error(f"Error during generation: {e}")
221
+ # Fallback response
222
+ return "I apologize, but I encountered an error while generating a response.", []
223
 
224
+ # Initialize the app state
225
+ if 'conversation_history' not in st.session_state:
226
+ st.session_state.conversation_history = []
227
 
228
+ if 'vector_store' not in st.session_state:
229
+ st.session_state.vector_store = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ if 'uploaded_pdfs' not in st.session_state:
232
+ st.session_state.uploaded_pdfs = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # PDF upload section
235
+ st.header("1. Upload Vision 2030 Documents")
236
+ uploaded_files = st.file_uploader("Upload PDF documents about Vision 2030",
237
+ type=["pdf"],
238
+ accept_multiple_files=True,
239
+ help="Upload one or more PDF documents containing information about Vision 2030")
 
 
 
 
 
 
 
 
 
 
240
 
241
+ if uploaded_files:
242
+ st.session_state.uploaded_pdfs = uploaded_files
243
+ if st.button("Process PDFs"):
244
+ documents = process_pdfs()
245
+ if documents:
246
+ with st.spinner("Creating vector database..."):
247
+ st.session_state.vector_store = create_vector_store(documents)
248
+ st.success("Vector database created successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
+ # Load the model (cached)
251
+ model, tokenizer = load_model_and_tokenizer()
252
+
253
+ # Chat interface
254
+ st.header("2. Chat with the Vision 2030 Assistant")
255
+
256
+ # Display conversation history
257
+ for message in st.session_state.conversation_history:
258
+ if message["role"] == "user":
259
+ st.markdown(f"**You:** {message['content']}")
260
+ else:
261
+ st.markdown(f"**Assistant:** {message['content']}")
262
+ if 'sources' in message and message['sources']:
263
+ st.markdown(f"*Sources: {', '.join([os.path.basename(src) for src in message['sources']])}*")
264
+ st.divider()
265
+
266
+ # Input for new question
267
+ user_input = st.text_input("Ask a question about Vision 2030 (in Arabic or English):", key="user_query")
268
+
269
+ # Examples
270
+ st.markdown("**Example questions:**")
271
+ examples_col1, examples_col2 = st.columns(2)
272
+ with examples_col1:
273
+ st.markdown("- What is Saudi Vision 2030?")
274
+ st.markdown("- What are the economic goals of Vision 2030?")
275
+ st.markdown("- How does Vision 2030 support women's empowerment?")
276
+ with examples_col2:
277
+ st.markdown("- ما هي رؤية السعودية 2030؟")
278
+ st.markdown("- ما هي الأهداف الاقتصادية لرؤية 2030؟")
279
+ st.markdown("- كيف تدعم رؤية 2030 تمكين المرأة السعودية؟")
280
+
281
+ # Process the user input
282
+ if user_input and st.session_state.vector_store:
283
+ # Add user message to history
284
+ st.session_state.conversation_history.append({"role": "user", "content": user_input})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
+ # Get response
287
+ response, sources = generate_response(user_input, retrieve_context(user_input, st.session_state.vector_store), model, tokenizer)
 
288
 
289
+ # Add assistant message to history
290
+ st.session_state.conversation_history.append({"role": "assistant", "content": response, "sources": sources})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
+ # Rerun to update the UI
293
+ st.experimental_rerun()
294
+
295
+ elif user_input and not st.session_state.vector_store:
296
+ st.warning("Please upload and process Vision 2030 PDF documents first")
297
 
298
+ # Reset conversation button
299
+ if st.button("Reset Conversation") and len(st.session_state.conversation_history) > 0:
300
+ st.session_state.conversation_history = []
301
+ st.experimental_rerun()