|
import streamlit as st |
|
import os |
|
import re |
|
import torch |
|
import numpy as np |
|
from pathlib import Path |
|
import PyPDF2 |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from sentence_transformers import SentenceTransformer |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
|
|
|
|
st.set_page_config( |
|
page_title="Vision 2030 Virtual Assistant", |
|
page_icon="🇸🇦", |
|
layout="wide" |
|
) |
|
|
|
|
|
st.title("Vision 2030 Virtual Assistant") |
|
st.markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.") |
|
|
|
|
|
@st.cache_resource |
|
def load_model_and_tokenizer(): |
|
"""Load the ALLaM-7B model and tokenizer with error handling""" |
|
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" |
|
st.info(f"Loading model: {model_name} (this may take a few minutes)") |
|
|
|
try: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
use_fast=False |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
st.success("Model loaded successfully!") |
|
|
|
except Exception as e: |
|
st.error(f"First loading attempt failed: {e}") |
|
st.info("Trying alternative loading approach...") |
|
|
|
|
|
from transformers import LlamaTokenizer |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16, |
|
trust_remote_code=True, |
|
device_map="auto", |
|
) |
|
|
|
st.success("Model loaded successfully with LlamaTokenizer!") |
|
|
|
return model, tokenizer |
|
|
|
def detect_language(text): |
|
"""Detect if text is primarily Arabic or English""" |
|
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) |
|
is_arabic = len(arabic_chars) > len(text) * 0.5 |
|
return "arabic" if is_arabic else "english" |
|
|
|
def process_pdfs(): |
|
"""Process uploaded PDF documents""" |
|
documents = [] |
|
|
|
if 'uploaded_pdfs' in st.session_state and st.session_state.uploaded_pdfs: |
|
for pdf_file in st.session_state.uploaded_pdfs: |
|
try: |
|
|
|
pdf_path = f"temp_{pdf_file.name}" |
|
with open(pdf_path, "wb") as f: |
|
f.write(pdf_file.getbuffer()) |
|
|
|
|
|
text = "" |
|
with open(pdf_path, 'rb') as file: |
|
reader = PyPDF2.PdfReader(file) |
|
for page in reader.pages: |
|
text += page.extract_text() + "\n\n" |
|
|
|
|
|
os.remove(pdf_path) |
|
|
|
if text.strip(): |
|
doc = Document( |
|
page_content=text, |
|
metadata={"source": pdf_file.name, "filename": pdf_file.name} |
|
) |
|
documents.append(doc) |
|
st.info(f"Successfully processed: {pdf_file.name}") |
|
else: |
|
st.warning(f"No text extracted from {pdf_file.name}") |
|
except Exception as e: |
|
st.error(f"Error processing {pdf_file.name}: {e}") |
|
|
|
st.success(f"Processed {len(documents)} PDF documents") |
|
return documents |
|
|
|
def create_vector_store(documents): |
|
"""Split documents into chunks and create a FAISS vector store""" |
|
|
|
text_splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=500, |
|
chunk_overlap=50, |
|
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] |
|
) |
|
|
|
|
|
chunks = [] |
|
for doc in documents: |
|
doc_chunks = text_splitter.split_text(doc.page_content) |
|
|
|
chunks.extend([ |
|
Document(page_content=chunk, metadata=doc.metadata) |
|
for chunk in doc_chunks |
|
]) |
|
|
|
st.info(f"Created {len(chunks)} chunks from {len(documents)} documents") |
|
|
|
|
|
embedding_function = HuggingFaceEmbeddings( |
|
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" |
|
) |
|
|
|
|
|
vector_store = FAISS.from_documents( |
|
chunks, |
|
embedding_function |
|
) |
|
|
|
return vector_store |
|
|
|
def retrieve_context(query, vector_store, top_k=5): |
|
"""Retrieve most relevant document chunks for a given query""" |
|
|
|
results = vector_store.similarity_search_with_score(query, k=top_k) |
|
|
|
|
|
contexts = [] |
|
for doc, score in results: |
|
contexts.append({ |
|
"content": doc.page_content, |
|
"source": doc.metadata.get("source", "Unknown"), |
|
"relevance_score": score |
|
}) |
|
|
|
return contexts |
|
|
|
def generate_response(query, contexts, model, tokenizer): |
|
"""Generate a response using retrieved contexts with ALLaM-specific formatting""" |
|
|
|
language = detect_language(query) |
|
|
|
|
|
if language == "arabic": |
|
instruction = ( |
|
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " |
|
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." |
|
) |
|
else: |
|
instruction = ( |
|
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " |
|
"If you don't know the answer, honestly say you don't know." |
|
) |
|
|
|
|
|
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) |
|
|
|
|
|
prompt = f"""<s>[INST] {instruction} |
|
|
|
Context: |
|
{context_text} |
|
|
|
Question: {query} [/INST]</s>""" |
|
|
|
try: |
|
with st.spinner("Generating response..."): |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=512, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
|
|
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
response = full_output.split("[/INST]")[-1].strip() |
|
|
|
|
|
if not response: |
|
response = full_output |
|
|
|
return response, [ctx.get("source", "Unknown") for ctx in contexts] |
|
|
|
except Exception as e: |
|
st.error(f"Error during generation: {e}") |
|
|
|
return "I apologize, but I encountered an error while generating a response.", [] |
|
|
|
|
|
if 'conversation_history' not in st.session_state: |
|
st.session_state.conversation_history = [] |
|
|
|
if 'vector_store' not in st.session_state: |
|
st.session_state.vector_store = None |
|
|
|
if 'uploaded_pdfs' not in st.session_state: |
|
st.session_state.uploaded_pdfs = None |
|
|
|
|
|
st.header("1. Upload Vision 2030 Documents") |
|
uploaded_files = st.file_uploader("Upload PDF documents about Vision 2030", |
|
type=["pdf"], |
|
accept_multiple_files=True, |
|
help="Upload one or more PDF documents containing information about Vision 2030") |
|
|
|
if uploaded_files: |
|
st.session_state.uploaded_pdfs = uploaded_files |
|
if st.button("Process PDFs"): |
|
documents = process_pdfs() |
|
if documents: |
|
with st.spinner("Creating vector database..."): |
|
st.session_state.vector_store = create_vector_store(documents) |
|
st.success("Vector database created successfully!") |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer() |
|
|
|
|
|
st.header("2. Chat with the Vision 2030 Assistant") |
|
|
|
|
|
for message in st.session_state.conversation_history: |
|
if message["role"] == "user": |
|
st.markdown(f"**You:** {message['content']}") |
|
else: |
|
st.markdown(f"**Assistant:** {message['content']}") |
|
if 'sources' in message and message['sources']: |
|
st.markdown(f"*Sources: {', '.join([os.path.basename(src) for src in message['sources']])}*") |
|
st.divider() |
|
|
|
|
|
user_input = st.text_input("Ask a question about Vision 2030 (in Arabic or English):", key="user_query") |
|
|
|
|
|
st.markdown("**Example questions:**") |
|
examples_col1, examples_col2 = st.columns(2) |
|
with examples_col1: |
|
st.markdown("- What is Saudi Vision 2030?") |
|
st.markdown("- What are the economic goals of Vision 2030?") |
|
st.markdown("- How does Vision 2030 support women's empowerment?") |
|
with examples_col2: |
|
st.markdown("- ما هي رؤية السعودية 2030؟") |
|
st.markdown("- ما هي الأهداف الاقتصادية لرؤية 2030؟") |
|
st.markdown("- كيف تدعم رؤية 2030 تمكين المرأة السعودية؟") |
|
|
|
|
|
if user_input and st.session_state.vector_store: |
|
|
|
st.session_state.conversation_history.append({"role": "user", "content": user_input}) |
|
|
|
|
|
response, sources = generate_response(user_input, retrieve_context(user_input, st.session_state.vector_store), model, tokenizer) |
|
|
|
|
|
st.session_state.conversation_history.append({"role": "assistant", "content": response, "sources": sources}) |
|
|
|
|
|
st.experimental_rerun() |
|
|
|
elif user_input and not st.session_state.vector_store: |
|
st.warning("Please upload and process Vision 2030 PDF documents first") |
|
|
|
|
|
if st.button("Reset Conversation") and len(st.session_state.conversation_history) > 0: |
|
st.session_state.conversation_history = [] |
|
st.experimental_rerun() |