abdull4h commited on
Commit
d04e4d9
·
verified ·
1 Parent(s): 9df1e5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -31
app.py CHANGED
@@ -13,6 +13,7 @@ import torch
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import PyPDF2
15
  import io
 
16
 
17
  # Set up logging
18
  logging.basicConfig(
@@ -42,13 +43,11 @@ class Vision2030Assistant:
42
  logger.info("Assistant initialized successfully")
43
 
44
  def load_embedding_models(self):
45
- """Load Arabic and English embedding models with fallback mechanism."""
46
  try:
47
  self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
48
  self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
49
- if has_gpu:
50
- self.arabic_embedder = self.arabic_embedder.to('cuda')
51
- self.english_embedder = self.english_embedder.to('cuda')
52
  logger.info("Embedding models loaded successfully")
53
  except Exception as e:
54
  logger.error(f"Failed to load embedding models: {e}")
@@ -58,7 +57,7 @@ class Vision2030Assistant:
58
  """Fallback method for embedding models using a simple random vector approach."""
59
  logger.warning("Using fallback embedding method")
60
  class SimpleEmbedder:
61
- def encode(self, text):
62
  import hashlib
63
  hash_obj = hashlib.md5(text.encode())
64
  np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
@@ -67,17 +66,15 @@ class Vision2030Assistant:
67
  self.english_embedder = SimpleEmbedder()
68
 
69
  def load_language_model(self):
70
- """Load the DistilGPT-2 language model for response generation."""
71
  try:
72
  self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
73
  self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
74
- if has_gpu:
75
- self.model = self.model.to('cuda')
76
  self.generator = pipeline(
77
- 'text-generation',
78
- model=self.model,
79
- tokenizer=self.tokenizer,
80
- device=0 if has_gpu else -1
81
  )
82
  logger.info("Language model loaded successfully")
83
  except Exception as e:
@@ -100,7 +97,7 @@ class Vision2030Assistant:
100
  self.pdf_arabic_texts = []
101
 
102
  def _create_indices(self):
103
- """Create FAISS indices for the initial knowledge base."""
104
  try:
105
  # English index
106
  english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
@@ -123,21 +120,21 @@ class Vision2030Assistant:
123
  def _create_sample_eval_data(self):
124
  """Create sample evaluation data for testing factual accuracy."""
125
  self.eval_data = [
126
- {"question": "What are the key pillars of Vision 2030?",
127
- "lang": "en",
128
  "reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
129
- {"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
130
- "lang": "ar",
131
  "reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
132
  ]
133
 
134
- def retrieve_context(self, query, lang, session_id):
135
- """Retrieve relevant context based on the query and session history."""
136
  try:
137
  history = self.session_history.get(session_id, [])
138
  history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]])
139
  embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
140
- query_vec = embedder.encode(query)
141
 
142
  if lang == "ar":
143
  if self.has_pdf_content and self.pdf_arabic_texts:
@@ -161,25 +158,30 @@ class Vision2030Assistant:
161
  logger.error(f"Retrieval error: {e}")
162
  return "Error retrieving context."
163
 
 
164
  def generate_response(self, query, session_id):
165
- """Generate a response to the user's query using context and session history."""
166
  if not query.strip():
167
  return "Please enter a valid question."
168
-
169
  start_time = time.time()
170
  try:
171
  lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en"
172
- context = self.retrieve_context(query, lang, session_id)
173
-
174
  if "Error" in context or "No relevant" in context:
175
  reply = context
176
  elif self.generator:
 
 
177
  prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
178
  response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
179
  reply = response[0]['generated_text'].split("Answer:")[-1].strip()
 
 
180
  else:
181
  reply = context
182
-
183
  self.session_history.setdefault(session_id, []).append((query, reply))
184
  self.metrics["response_times"].append(time.time() - start_time)
185
  return reply
@@ -199,25 +201,26 @@ class Vision2030Assistant:
199
  logger.error(f"Evaluation error: {e}")
200
  return 0.0
201
 
 
202
  def process_pdf(self, file):
203
- """Process an uploaded PDF file and update the knowledge base."""
204
  if not file:
205
  return "Please upload a PDF file."
206
-
207
  try:
208
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(file))
209
  text = "".join([page.extract_text() or "" for page in pdf_reader.pages])
210
  if not text.strip():
211
  return "No extractable text found in PDF."
212
-
213
  # Split text into chunks
214
  chunks = [text[i:i+300] for i in range(0, len(text), 300)]
215
  self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
216
  self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
217
 
218
- # Create indices for PDF content
219
  if self.pdf_english_texts:
220
- english_vectors = [self.english_embedder.encode(text) for text in self.pdf_english_texts]
221
  dim = len(english_vectors[0])
222
  nlist = max(1, len(english_vectors) // 10)
223
  quantizer = faiss.IndexFlatL2(dim)
@@ -226,7 +229,7 @@ class Vision2030Assistant:
226
  self.pdf_english_index.add(np.array(english_vectors))
227
 
228
  if self.pdf_arabic_texts:
229
- arabic_vectors = [self.arabic_embedder.encode(text) for text in self.pdf_arabic_texts]
230
  dim = len(arabic_vectors[0])
231
  nlist = max(1, len(arabic_vectors) // 10)
232
  quantizer = faiss.IndexFlatL2(dim)
 
13
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
14
  import PyPDF2
15
  import io
16
+ import spaces # Added for @spaces.GPU decorator
17
 
18
  # Set up logging
19
  logging.basicConfig(
 
43
  logger.info("Assistant initialized successfully")
44
 
45
  def load_embedding_models(self):
46
+ """Load Arabic and English embedding models on CPU."""
47
  try:
48
  self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
49
  self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
50
+ # Models remain on CPU; GPU usage handled in decorated functions
 
 
51
  logger.info("Embedding models loaded successfully")
52
  except Exception as e:
53
  logger.error(f"Failed to load embedding models: {e}")
 
57
  """Fallback method for embedding models using a simple random vector approach."""
58
  logger.warning("Using fallback embedding method")
59
  class SimpleEmbedder:
60
+ def encode(self, text, device=None): # Added device parameter for compatibility
61
  import hashlib
62
  hash_obj = hashlib.md5(text.encode())
63
  np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
 
66
  self.english_embedder = SimpleEmbedder()
67
 
68
  def load_language_model(self):
69
+ """Load the DistilGPT-2 language model on CPU."""
70
  try:
71
  self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
72
  self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
 
 
73
  self.generator = pipeline(
74
+ 'text-generation',
75
+ model=self.model,
76
+ tokenizer=self.tokenizer,
77
+ device=-1 # CPU
78
  )
79
  logger.info("Language model loaded successfully")
80
  except Exception as e:
 
97
  self.pdf_arabic_texts = []
98
 
99
  def _create_indices(self):
100
+ """Create FAISS indices for the initial knowledge base on CPU."""
101
  try:
102
  # English index
103
  english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
 
120
  def _create_sample_eval_data(self):
121
  """Create sample evaluation data for testing factual accuracy."""
122
  self.eval_data = [
123
+ {"question": "What are the key pillars of Vision 2030?",
124
+ "lang": "en",
125
  "reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
126
+ {"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
127
+ "lang": "ar",
128
  "reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
129
  ]
130
 
131
+ def retrieve_context(self, query, lang, session_id, device='cpu'):
132
+ """Retrieve relevant context using the specified device for encoding."""
133
  try:
134
  history = self.session_history.get(session_id, [])
135
  history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]])
136
  embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
137
+ query_vec = embedder.encode(query, device=device)
138
 
139
  if lang == "ar":
140
  if self.has_pdf_content and self.pdf_arabic_texts:
 
158
  logger.error(f"Retrieval error: {e}")
159
  return "Error retrieving context."
160
 
161
+ @spaces.GPU
162
  def generate_response(self, query, session_id):
163
+ """Generate a response using GPU resources when available."""
164
  if not query.strip():
165
  return "Please enter a valid question."
166
+
167
  start_time = time.time()
168
  try:
169
  lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en"
170
+ context = self.retrieve_context(query, lang, session_id, device='cuda')
171
+
172
  if "Error" in context or "No relevant" in context:
173
  reply = context
174
  elif self.generator:
175
+ # Move the language model to GPU
176
+ self.generator.model.to('cuda')
177
  prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
178
  response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
179
  reply = response[0]['generated_text'].split("Answer:")[-1].strip()
180
+ # Move the language model back to CPU
181
+ self.generator.model.to('cpu')
182
  else:
183
  reply = context
184
+
185
  self.session_history.setdefault(session_id, []).append((query, reply))
186
  self.metrics["response_times"].append(time.time() - start_time)
187
  return reply
 
201
  logger.error(f"Evaluation error: {e}")
202
  return 0.0
203
 
204
+ @spaces.GPU
205
  def process_pdf(self, file):
206
+ """Process a PDF file and update the knowledge base using GPU for encoding."""
207
  if not file:
208
  return "Please upload a PDF file."
209
+
210
  try:
211
  pdf_reader = PyPDF2.PdfReader(io.BytesIO(file))
212
  text = "".join([page.extract_text() or "" for page in pdf_reader.pages])
213
  if not text.strip():
214
  return "No extractable text found in PDF."
215
+
216
  # Split text into chunks
217
  chunks = [text[i:i+300] for i in range(0, len(text), 300)]
218
  self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
219
  self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
220
 
221
+ # Create indices for PDF content using GPU
222
  if self.pdf_english_texts:
223
+ english_vectors = [self.english_embedder.encode(text, device='cuda') for text in self.pdf_english_texts]
224
  dim = len(english_vectors[0])
225
  nlist = max(1, len(english_vectors) // 10)
226
  quantizer = faiss.IndexFlatL2(dim)
 
229
  self.pdf_english_index.add(np.array(english_vectors))
230
 
231
  if self.pdf_arabic_texts:
232
+ arabic_vectors = [self.arabic_embedder.encode(text, device='cuda') for text in self.pdf_arabic_texts]
233
  dim = len(arabic_vectors[0])
234
  nlist = max(1, len(arabic_vectors) // 10)
235
  quantizer = faiss.IndexFlatL2(dim)