Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -13,6 +13,7 @@ import torch
|
|
13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
14 |
import PyPDF2
|
15 |
import io
|
|
|
16 |
|
17 |
# Set up logging
|
18 |
logging.basicConfig(
|
@@ -42,13 +43,11 @@ class Vision2030Assistant:
|
|
42 |
logger.info("Assistant initialized successfully")
|
43 |
|
44 |
def load_embedding_models(self):
|
45 |
-
"""Load Arabic and English embedding models
|
46 |
try:
|
47 |
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
|
48 |
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
49 |
-
|
50 |
-
self.arabic_embedder = self.arabic_embedder.to('cuda')
|
51 |
-
self.english_embedder = self.english_embedder.to('cuda')
|
52 |
logger.info("Embedding models loaded successfully")
|
53 |
except Exception as e:
|
54 |
logger.error(f"Failed to load embedding models: {e}")
|
@@ -58,7 +57,7 @@ class Vision2030Assistant:
|
|
58 |
"""Fallback method for embedding models using a simple random vector approach."""
|
59 |
logger.warning("Using fallback embedding method")
|
60 |
class SimpleEmbedder:
|
61 |
-
def encode(self, text):
|
62 |
import hashlib
|
63 |
hash_obj = hashlib.md5(text.encode())
|
64 |
np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
|
@@ -67,17 +66,15 @@ class Vision2030Assistant:
|
|
67 |
self.english_embedder = SimpleEmbedder()
|
68 |
|
69 |
def load_language_model(self):
|
70 |
-
"""Load the DistilGPT-2 language model
|
71 |
try:
|
72 |
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
73 |
self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
74 |
-
if has_gpu:
|
75 |
-
self.model = self.model.to('cuda')
|
76 |
self.generator = pipeline(
|
77 |
-
'text-generation',
|
78 |
-
model=self.model,
|
79 |
-
tokenizer=self.tokenizer,
|
80 |
-
device
|
81 |
)
|
82 |
logger.info("Language model loaded successfully")
|
83 |
except Exception as e:
|
@@ -100,7 +97,7 @@ class Vision2030Assistant:
|
|
100 |
self.pdf_arabic_texts = []
|
101 |
|
102 |
def _create_indices(self):
|
103 |
-
"""Create FAISS indices for the initial knowledge base."""
|
104 |
try:
|
105 |
# English index
|
106 |
english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
|
@@ -123,21 +120,21 @@ class Vision2030Assistant:
|
|
123 |
def _create_sample_eval_data(self):
|
124 |
"""Create sample evaluation data for testing factual accuracy."""
|
125 |
self.eval_data = [
|
126 |
-
{"question": "What are the key pillars of Vision 2030?",
|
127 |
-
"lang": "en",
|
128 |
"reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
|
129 |
-
{"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
|
130 |
-
"lang": "ar",
|
131 |
"reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
|
132 |
]
|
133 |
|
134 |
-
def retrieve_context(self, query, lang, session_id):
|
135 |
-
"""Retrieve relevant context
|
136 |
try:
|
137 |
history = self.session_history.get(session_id, [])
|
138 |
history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]])
|
139 |
embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
|
140 |
-
query_vec = embedder.encode(query)
|
141 |
|
142 |
if lang == "ar":
|
143 |
if self.has_pdf_content and self.pdf_arabic_texts:
|
@@ -161,25 +158,30 @@ class Vision2030Assistant:
|
|
161 |
logger.error(f"Retrieval error: {e}")
|
162 |
return "Error retrieving context."
|
163 |
|
|
|
164 |
def generate_response(self, query, session_id):
|
165 |
-
"""Generate a response
|
166 |
if not query.strip():
|
167 |
return "Please enter a valid question."
|
168 |
-
|
169 |
start_time = time.time()
|
170 |
try:
|
171 |
lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en"
|
172 |
-
context = self.retrieve_context(query, lang, session_id)
|
173 |
-
|
174 |
if "Error" in context or "No relevant" in context:
|
175 |
reply = context
|
176 |
elif self.generator:
|
|
|
|
|
177 |
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
|
178 |
response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
|
179 |
reply = response[0]['generated_text'].split("Answer:")[-1].strip()
|
|
|
|
|
180 |
else:
|
181 |
reply = context
|
182 |
-
|
183 |
self.session_history.setdefault(session_id, []).append((query, reply))
|
184 |
self.metrics["response_times"].append(time.time() - start_time)
|
185 |
return reply
|
@@ -199,25 +201,26 @@ class Vision2030Assistant:
|
|
199 |
logger.error(f"Evaluation error: {e}")
|
200 |
return 0.0
|
201 |
|
|
|
202 |
def process_pdf(self, file):
|
203 |
-
"""Process
|
204 |
if not file:
|
205 |
return "Please upload a PDF file."
|
206 |
-
|
207 |
try:
|
208 |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file))
|
209 |
text = "".join([page.extract_text() or "" for page in pdf_reader.pages])
|
210 |
if not text.strip():
|
211 |
return "No extractable text found in PDF."
|
212 |
-
|
213 |
# Split text into chunks
|
214 |
chunks = [text[i:i+300] for i in range(0, len(text), 300)]
|
215 |
self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
|
216 |
self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
|
217 |
|
218 |
-
# Create indices for PDF content
|
219 |
if self.pdf_english_texts:
|
220 |
-
english_vectors = [self.english_embedder.encode(text) for text in self.pdf_english_texts]
|
221 |
dim = len(english_vectors[0])
|
222 |
nlist = max(1, len(english_vectors) // 10)
|
223 |
quantizer = faiss.IndexFlatL2(dim)
|
@@ -226,7 +229,7 @@ class Vision2030Assistant:
|
|
226 |
self.pdf_english_index.add(np.array(english_vectors))
|
227 |
|
228 |
if self.pdf_arabic_texts:
|
229 |
-
arabic_vectors = [self.arabic_embedder.encode(text) for text in self.pdf_arabic_texts]
|
230 |
dim = len(arabic_vectors[0])
|
231 |
nlist = max(1, len(arabic_vectors) // 10)
|
232 |
quantizer = faiss.IndexFlatL2(dim)
|
|
|
13 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
14 |
import PyPDF2
|
15 |
import io
|
16 |
+
import spaces # Added for @spaces.GPU decorator
|
17 |
|
18 |
# Set up logging
|
19 |
logging.basicConfig(
|
|
|
43 |
logger.info("Assistant initialized successfully")
|
44 |
|
45 |
def load_embedding_models(self):
|
46 |
+
"""Load Arabic and English embedding models on CPU."""
|
47 |
try:
|
48 |
self.arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
|
49 |
self.english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
50 |
+
# Models remain on CPU; GPU usage handled in decorated functions
|
|
|
|
|
51 |
logger.info("Embedding models loaded successfully")
|
52 |
except Exception as e:
|
53 |
logger.error(f"Failed to load embedding models: {e}")
|
|
|
57 |
"""Fallback method for embedding models using a simple random vector approach."""
|
58 |
logger.warning("Using fallback embedding method")
|
59 |
class SimpleEmbedder:
|
60 |
+
def encode(self, text, device=None): # Added device parameter for compatibility
|
61 |
import hashlib
|
62 |
hash_obj = hashlib.md5(text.encode())
|
63 |
np.random.seed(int(hash_obj.hexdigest(), 16) % 2**32)
|
|
|
66 |
self.english_embedder = SimpleEmbedder()
|
67 |
|
68 |
def load_language_model(self):
|
69 |
+
"""Load the DistilGPT-2 language model on CPU."""
|
70 |
try:
|
71 |
self.tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
|
72 |
self.model = AutoModelForCausalLM.from_pretrained("distilgpt2")
|
|
|
|
|
73 |
self.generator = pipeline(
|
74 |
+
'text-generation',
|
75 |
+
model=self.model,
|
76 |
+
tokenizer=self.tokenizer,
|
77 |
+
device=-1 # CPU
|
78 |
)
|
79 |
logger.info("Language model loaded successfully")
|
80 |
except Exception as e:
|
|
|
97 |
self.pdf_arabic_texts = []
|
98 |
|
99 |
def _create_indices(self):
|
100 |
+
"""Create FAISS indices for the initial knowledge base on CPU."""
|
101 |
try:
|
102 |
# English index
|
103 |
english_vectors = [self.english_embedder.encode(text) for text in self.english_texts]
|
|
|
120 |
def _create_sample_eval_data(self):
|
121 |
"""Create sample evaluation data for testing factual accuracy."""
|
122 |
self.eval_data = [
|
123 |
+
{"question": "What are the key pillars of Vision 2030?",
|
124 |
+
"lang": "en",
|
125 |
"reference": "The key pillars of Vision 2030 are a vibrant society, a thriving economy, and an ambitious nation."},
|
126 |
+
{"question": "ما هي الركائز الرئيسية لرؤية 2030؟",
|
127 |
+
"lang": "ar",
|
128 |
"reference": "الركائز الرئيسية لرؤية 2030 هي مجتمع حيوي، واقتصاد مزدهر، ووطن طموح."}
|
129 |
]
|
130 |
|
131 |
+
def retrieve_context(self, query, lang, session_id, device='cpu'):
|
132 |
+
"""Retrieve relevant context using the specified device for encoding."""
|
133 |
try:
|
134 |
history = self.session_history.get(session_id, [])
|
135 |
history_context = " ".join([f"Q: {q} A: {a}" for q, a in history[-2:]])
|
136 |
embedder = self.arabic_embedder if lang == "ar" else self.english_embedder
|
137 |
+
query_vec = embedder.encode(query, device=device)
|
138 |
|
139 |
if lang == "ar":
|
140 |
if self.has_pdf_content and self.pdf_arabic_texts:
|
|
|
158 |
logger.error(f"Retrieval error: {e}")
|
159 |
return "Error retrieving context."
|
160 |
|
161 |
+
@spaces.GPU
|
162 |
def generate_response(self, query, session_id):
|
163 |
+
"""Generate a response using GPU resources when available."""
|
164 |
if not query.strip():
|
165 |
return "Please enter a valid question."
|
166 |
+
|
167 |
start_time = time.time()
|
168 |
try:
|
169 |
lang = "ar" if any('\u0600' <= c <= '\u06FF' for c in query) else "en"
|
170 |
+
context = self.retrieve_context(query, lang, session_id, device='cuda')
|
171 |
+
|
172 |
if "Error" in context or "No relevant" in context:
|
173 |
reply = context
|
174 |
elif self.generator:
|
175 |
+
# Move the language model to GPU
|
176 |
+
self.generator.model.to('cuda')
|
177 |
prompt = f"Context: {context}\nQuestion: {query}\nAnswer:"
|
178 |
response = self.generator(prompt, max_length=150, num_return_sequences=1, do_sample=True, temperature=0.7)
|
179 |
reply = response[0]['generated_text'].split("Answer:")[-1].strip()
|
180 |
+
# Move the language model back to CPU
|
181 |
+
self.generator.model.to('cpu')
|
182 |
else:
|
183 |
reply = context
|
184 |
+
|
185 |
self.session_history.setdefault(session_id, []).append((query, reply))
|
186 |
self.metrics["response_times"].append(time.time() - start_time)
|
187 |
return reply
|
|
|
201 |
logger.error(f"Evaluation error: {e}")
|
202 |
return 0.0
|
203 |
|
204 |
+
@spaces.GPU
|
205 |
def process_pdf(self, file):
|
206 |
+
"""Process a PDF file and update the knowledge base using GPU for encoding."""
|
207 |
if not file:
|
208 |
return "Please upload a PDF file."
|
209 |
+
|
210 |
try:
|
211 |
pdf_reader = PyPDF2.PdfReader(io.BytesIO(file))
|
212 |
text = "".join([page.extract_text() or "" for page in pdf_reader.pages])
|
213 |
if not text.strip():
|
214 |
return "No extractable text found in PDF."
|
215 |
+
|
216 |
# Split text into chunks
|
217 |
chunks = [text[i:i+300] for i in range(0, len(text), 300)]
|
218 |
self.pdf_english_texts = [c for c in chunks if not any('\u0600' <= char <= '\u06FF' for char in c)]
|
219 |
self.pdf_arabic_texts = [c for c in chunks if any('\u0600' <= char <= '\u06FF' for char in c)]
|
220 |
|
221 |
+
# Create indices for PDF content using GPU
|
222 |
if self.pdf_english_texts:
|
223 |
+
english_vectors = [self.english_embedder.encode(text, device='cuda') for text in self.pdf_english_texts]
|
224 |
dim = len(english_vectors[0])
|
225 |
nlist = max(1, len(english_vectors) // 10)
|
226 |
quantizer = faiss.IndexFlatL2(dim)
|
|
|
229 |
self.pdf_english_index.add(np.array(english_vectors))
|
230 |
|
231 |
if self.pdf_arabic_texts:
|
232 |
+
arabic_vectors = [self.arabic_embedder.encode(text, device='cuda') for text in self.pdf_arabic_texts]
|
233 |
dim = len(arabic_vectors[0])
|
234 |
nlist = max(1, len(arabic_vectors) // 10)
|
235 |
quantizer = faiss.IndexFlatL2(dim)
|