rahideer's picture
Update app.py
3e7f650 verified
raw
history blame
2.26 kB
import streamlit as st
from datasets import load_dataset
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
# Load a multilingual dataset (xnli or tydi_qa)
def load_data():
try:
# Use a specific version of the dataset
dataset = load_dataset("xnli", "all_languages", split="validation") # Using a direct name instead of a wildcard pattern
st.write(f"Loaded {len(dataset)} examples from the 'validation' split.")
return dataset
except Exception as e:
st.write(f"Error loading 'xnli' dataset: {e}")
return None
# Initialize RAG model components
def initialize_rag():
try:
# Initialize tokenizer and retriever
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", passages_path="./path_to_data")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
return tokenizer, retriever, model
except Exception as e:
st.write(f"Error initializing RAG components: {e}")
return None, None, None
# Main function to run the app
def main():
st.title("Multilingual RAG Translator/Answer Bot")
# Load the dataset
dataset = load_data()
if dataset is None:
st.write("Dataset could not be loaded.")
return
# Initialize RAG model components
tokenizer, retriever, model = initialize_rag()
if tokenizer is None or retriever is None or model is None:
st.write("RAG components could not be initialized.")
return
# UI to input a query
query = st.text_input("Enter your question in Urdu, Hindi, or French:")
if query:
# Tokenize the input query
inputs = tokenizer(query, return_tensors="pt")
# Retrieve relevant documents
retrieved_docs = retriever.retrieve(query)
# Generate an answer using the model
generated = model.generate(input_ids=inputs['input_ids'], context_input_ids=retrieved_docs['input_ids'])
answer = tokenizer.decode(generated[0], skip_special_tokens=True)
st.write("Answer:", answer)
# Run the Streamlit app
if __name__ == "__main__":
main()