abdull4h's picture
Create app.py
5224f4e verified
raw
history blame
5.48 kB
# Vision 2030 Virtual Assistant with Arabic (ALLaM-7B) and English (Mistral-7B-Instruct) + RAG + Improved Prompting
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langdetect import detect
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
# ----------------------------
# Load Arabic Model (ALLaM-7B)
# ----------------------------
print("Loading ALLaM-7B-Instruct-preview for Arabic...")
arabic_model_id = "ALLaM-AI/ALLaM-7B-Instruct-preview"
arabic_tokenizer = AutoTokenizer.from_pretrained(arabic_model_id)
arabic_model = AutoModelForCausalLM.from_pretrained(arabic_model_id, device_map="auto")
arabic_pipe = pipeline("text-generation", model=arabic_model, tokenizer=arabic_tokenizer)
# ----------------------------
# Load English Model (Mistral-7B-Instruct)
# ----------------------------
print("Loading Mistral-7B-Instruct-v0.2 for English...")
english_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
english_tokenizer = AutoTokenizer.from_pretrained(english_model_id)
english_model = AutoModelForCausalLM.from_pretrained(english_model_id, device_map="auto")
english_pipe = pipeline("text-generation", model=english_model, tokenizer=english_tokenizer)
# ----------------------------
# Load Embedding Models for Retrieval
# ----------------------------
print("Loading Embedding Models for Retrieval...")
arabic_embedder = SentenceTransformer('CAMeL-Lab/bert-base-arabic-camelbert-ca')
english_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# ----------------------------
# Prepare FAISS Index (dummy example)
# ----------------------------
# In real scenario, load Vision 2030 documents, preprocess & embed
# Here we'll create dummy data for demonstration
documents = [
{"text": "Vision 2030 aims to diversify the Saudi economy.", "lang": "en"},
{"text": "رؤية 2030 تهدف إلى تنويع الاقتصاد السعودي.", "lang": "ar"}
]
# Embed documents and build index
english_vectors = []
arabic_vectors = []
english_texts = []
arabic_texts = []
for doc in documents:
if doc["lang"] == "en":
vec = english_embedder.encode(doc["text"])
english_vectors.append(vec)
english_texts.append(doc["text"])
else:
vec = arabic_embedder.encode(doc["text"])
arabic_vectors.append(vec)
arabic_texts.append(doc["text"])
# FAISS indexes
english_index = faiss.IndexFlatL2(len(english_vectors[0]))
english_index.add(np.array(english_vectors))
arabic_index = faiss.IndexFlatL2(len(arabic_vectors[0]))
arabic_index.add(np.array(arabic_vectors))
# ----------------------------
# Define the RAG response function with Improved Prompting
# ----------------------------
def retrieve_and_generate(user_input):
try:
lang = detect(user_input)
except:
lang = "en" # Default fallback
if lang == "ar":
print("Detected Arabic input")
query_vec = arabic_embedder.encode(user_input)
D, I = arabic_index.search(np.array([query_vec]), k=1)
context = arabic_texts[I[0][0]] if I[0][0] >= 0 else ""
# Improved Arabic Prompt
input_text = (
f"أنت خبير في رؤية السعودية 2030.\n"
f"إليك بعض المعلومات المهمة:\n{context}\n\n"
f"مثال:\n"
f"السؤال: ما هي ركائز رؤية 2030؟\n"
f"الإجابة: ركائز رؤية 2030 هي مجتمع حيوي، اقتصاد مزدهر، ووطن طموح.\n\n"
f"أجب عن سؤال المستخدم بشكل واضح ودقيق.\n"
f"السؤال: {user_input}\n"
f"الإجابة:"
)
response = arabic_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7)
reply = response[0]['generated_text']
else:
print("Detected English input")
query_vec = english_embedder.encode(user_input)
D, I = english_index.search(np.array([query_vec]), k=1)
context = english_texts[I[0][0]] if I[0][0] >= 0 else ""
# Improved English Prompt
input_text = (
f"You are an expert on Saudi Arabia's Vision 2030.\n"
f"Here is some relevant information:\n{context}\n\n"
f"Example:\n"
f"Question: What are the key pillars of Vision 2030?\n"
f"Answer: The key pillars are a vibrant society, a thriving economy, and an ambitious nation.\n\n"
f"Answer the user's question clearly and accurately.\n"
f"Question: {user_input}\n"
f"Answer:"
)
response = english_pipe(input_text, max_new_tokens=256, do_sample=True, temperature=0.7)
reply = response[0]['generated_text']
return reply
# ----------------------------
# Gradio UI
# ----------------------------
with gr.Blocks() as demo:
gr.Markdown("# Vision 2030 Virtual Assistant 🌍\n\nSupports Arabic & English queries about Vision 2030 (with RAG retrieval and improved prompting).")
chatbot = gr.Chatbot()
msg = gr.Textbox(label="Ask me anything about Vision 2030")
clear = gr.Button("Clear")
def chat(message, history):
reply = retrieve_and_generate(message)
history.append((message, reply))
return history, ""
msg.submit(chat, [msg, chatbot], [chatbot, msg])
clear.click(lambda: None, None, chatbot, queue=False)
# Launching the space
demo.launch()