Spaces:
Running
on
Zero
Running
on
Zero
import streamlit as st | |
import os | |
import re | |
import torch | |
import numpy as np | |
from pathlib import Path | |
import PyPDF2 | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from langchain.schema import Document | |
from langchain.embeddings import HuggingFaceEmbeddings | |
# Set page configuration | |
st.set_page_config( | |
page_title="Vision 2030 Virtual Assistant", | |
page_icon="🇸🇦", | |
layout="wide" | |
) | |
# App title and description | |
st.title("Vision 2030 Virtual Assistant") | |
st.markdown("Ask questions about Saudi Vision 2030 goals, projects, and progress in Arabic or English.") | |
# Function definitions | |
def load_model_and_tokenizer(): | |
"""Load the ALLaM-7B model and tokenizer with error handling""" | |
model_name = "ALLaM-AI/ALLaM-7B-Instruct-preview" | |
st.info(f"Loading model: {model_name} (this may take a few minutes)") | |
try: | |
# First attempt with AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
trust_remote_code=True, | |
use_fast=False | |
) | |
# Load model with appropriate settings for ALLaM | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
st.success("Model loaded successfully!") | |
except Exception as e: | |
st.error(f"First loading attempt failed: {e}") | |
st.info("Trying alternative loading approach...") | |
# Try with specific tokenizer class if the first attempt fails | |
from transformers import LlamaTokenizer | |
tokenizer = LlamaTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16, | |
trust_remote_code=True, | |
device_map="auto", | |
) | |
st.success("Model loaded successfully with LlamaTokenizer!") | |
return model, tokenizer | |
def detect_language(text): | |
"""Detect if text is primarily Arabic or English""" | |
arabic_chars = re.findall(r'[\u0600-\u06FF]', text) | |
is_arabic = len(arabic_chars) > len(text) * 0.5 | |
return "arabic" if is_arabic else "english" | |
def process_pdfs(): | |
"""Process uploaded PDF documents""" | |
documents = [] | |
if 'uploaded_pdfs' in st.session_state and st.session_state.uploaded_pdfs: | |
for pdf_file in st.session_state.uploaded_pdfs: | |
try: | |
# Save the uploaded file temporarily | |
pdf_path = f"temp_{pdf_file.name}" | |
with open(pdf_path, "wb") as f: | |
f.write(pdf_file.getbuffer()) | |
# Extract text | |
text = "" | |
with open(pdf_path, 'rb') as file: | |
reader = PyPDF2.PdfReader(file) | |
for page in reader.pages: | |
text += page.extract_text() + "\n\n" | |
# Remove temporary file | |
os.remove(pdf_path) | |
if text.strip(): # If we got some text | |
doc = Document( | |
page_content=text, | |
metadata={"source": pdf_file.name, "filename": pdf_file.name} | |
) | |
documents.append(doc) | |
st.info(f"Successfully processed: {pdf_file.name}") | |
else: | |
st.warning(f"No text extracted from {pdf_file.name}") | |
except Exception as e: | |
st.error(f"Error processing {pdf_file.name}: {e}") | |
st.success(f"Processed {len(documents)} PDF documents") | |
return documents | |
def create_vector_store(documents): | |
"""Split documents into chunks and create a FAISS vector store""" | |
# Text splitter for breaking documents into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, | |
chunk_overlap=50, | |
separators=["\n\n", "\n", ".", "!", "?", ",", " ", ""] | |
) | |
# Split documents into chunks | |
chunks = [] | |
for doc in documents: | |
doc_chunks = text_splitter.split_text(doc.page_content) | |
# Preserve metadata for each chunk | |
chunks.extend([ | |
Document(page_content=chunk, metadata=doc.metadata) | |
for chunk in doc_chunks | |
]) | |
st.info(f"Created {len(chunks)} chunks from {len(documents)} documents") | |
# Create a proper embedding function for LangChain | |
embedding_function = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" | |
) | |
# Create FAISS index | |
vector_store = FAISS.from_documents( | |
chunks, | |
embedding_function | |
) | |
return vector_store | |
def retrieve_context(query, vector_store, top_k=5): | |
"""Retrieve most relevant document chunks for a given query""" | |
# Search the vector store using similarity search | |
results = vector_store.similarity_search_with_score(query, k=top_k) | |
# Format the retrieved contexts | |
contexts = [] | |
for doc, score in results: | |
contexts.append({ | |
"content": doc.page_content, | |
"source": doc.metadata.get("source", "Unknown"), | |
"relevance_score": score | |
}) | |
return contexts | |
def generate_response(query, contexts, model, tokenizer): | |
"""Generate a response using retrieved contexts with ALLaM-specific formatting""" | |
# Auto-detect language | |
language = detect_language(query) | |
# Format the prompt based on language | |
if language == "arabic": | |
instruction = ( | |
"أنت مساعد افتراضي يهتم برؤية السعودية 2030. استخدم المعلومات التالية للإجابة على السؤال. " | |
"إذا لم تعرف الإجابة، فقل بأمانة إنك لا تعرف." | |
) | |
else: # english | |
instruction = ( | |
"You are a virtual assistant for Saudi Vision 2030. Use the following information to answer the question. " | |
"If you don't know the answer, honestly say you don't know." | |
) | |
# Combine retrieved contexts | |
context_text = "\n\n".join([f"Document: {ctx['content']}" for ctx in contexts]) | |
# Format the prompt for ALLaM instruction format | |
prompt = f"""<s>[INST] {instruction} | |
Context: | |
{context_text} | |
Question: {query} [/INST]</s>""" | |
try: | |
with st.spinner("Generating response..."): | |
# Generate response with appropriate parameters for ALLaM | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate with appropriate parameters | |
outputs = model.generate( | |
inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
repetition_penalty=1.1 | |
) | |
# Decode the response | |
full_output = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract just the answer part (after the instruction) | |
response = full_output.split("[/INST]")[-1].strip() | |
# If response is empty for some reason, return the full output | |
if not response: | |
response = full_output | |
return response, [ctx.get("source", "Unknown") for ctx in contexts] | |
except Exception as e: | |
st.error(f"Error during generation: {e}") | |
# Fallback response | |
return "I apologize, but I encountered an error while generating a response.", [] | |
# Initialize the app state | |
if 'conversation_history' not in st.session_state: | |
st.session_state.conversation_history = [] | |
if 'vector_store' not in st.session_state: | |
st.session_state.vector_store = None | |
if 'uploaded_pdfs' not in st.session_state: | |
st.session_state.uploaded_pdfs = None | |
# PDF upload section | |
st.header("1. Upload Vision 2030 Documents") | |
uploaded_files = st.file_uploader("Upload PDF documents about Vision 2030", | |
type=["pdf"], | |
accept_multiple_files=True, | |
help="Upload one or more PDF documents containing information about Vision 2030") | |
if uploaded_files: | |
st.session_state.uploaded_pdfs = uploaded_files | |
if st.button("Process PDFs"): | |
documents = process_pdfs() | |
if documents: | |
with st.spinner("Creating vector database..."): | |
st.session_state.vector_store = create_vector_store(documents) | |
st.success("Vector database created successfully!") | |
# Load the model (cached) | |
model, tokenizer = load_model_and_tokenizer() | |
# Chat interface | |
st.header("2. Chat with the Vision 2030 Assistant") | |
# Display conversation history | |
for message in st.session_state.conversation_history: | |
if message["role"] == "user": | |
st.markdown(f"**You:** {message['content']}") | |
else: | |
st.markdown(f"**Assistant:** {message['content']}") | |
if 'sources' in message and message['sources']: | |
st.markdown(f"*Sources: {', '.join([os.path.basename(src) for src in message['sources']])}*") | |
st.divider() | |
# Input for new question | |
user_input = st.text_input("Ask a question about Vision 2030 (in Arabic or English):", key="user_query") | |
# Examples | |
st.markdown("**Example questions:**") | |
examples_col1, examples_col2 = st.columns(2) | |
with examples_col1: | |
st.markdown("- What is Saudi Vision 2030?") | |
st.markdown("- What are the economic goals of Vision 2030?") | |
st.markdown("- How does Vision 2030 support women's empowerment?") | |
with examples_col2: | |
st.markdown("- ما هي رؤية السعودية 2030؟") | |
st.markdown("- ما هي الأهداف الاقتصادية لرؤية 2030؟") | |
st.markdown("- كيف تدعم رؤية 2030 تمكين المرأة السعودية؟") | |
# Process the user input | |
if user_input and st.session_state.vector_store: | |
# Add user message to history | |
st.session_state.conversation_history.append({"role": "user", "content": user_input}) | |
# Get response | |
response, sources = generate_response(user_input, retrieve_context(user_input, st.session_state.vector_store), model, tokenizer) | |
# Add assistant message to history | |
st.session_state.conversation_history.append({"role": "assistant", "content": response, "sources": sources}) | |
# Rerun to update the UI | |
st.experimental_rerun() | |
elif user_input and not st.session_state.vector_store: | |
st.warning("Please upload and process Vision 2030 PDF documents first") | |
# Reset conversation button | |
if st.button("Reset Conversation") and len(st.session_state.conversation_history) > 0: | |
st.session_state.conversation_history = [] | |
st.experimental_rerun() |