#!/usr/bin/env python # coding: utf-8 # # Retrieval-Augmented QA Demo # # This notebook builds a minimal RAG (Retrieval-Augmented Generation) pipeline with enhancements: # # - Slimmed & deduplicated corpora # - Chunking long passages # - Persistent FAISS index & embeddings # - Distance threshold to avoid hallucinations # - Context-length control # - Polished Gradio interface with separate contexts panel # ## 1. Configuration & Imports # # We detect device, print settings, and support loading saved index. # In[2]: import os import pickle from datasets import load_dataset from sentence_transformers import SentenceTransformer, CrossEncoder import faiss import numpy as np import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from transformers import AutoTokenizer as _AutoTokenizer import gradio as gr import evaluate # Settings data_dir = os.path.join(os.getcwd(), "data") os.makedirs(data_dir, exist_ok=True) INDEX_PATH = os.path.join(data_dir, "faiss_index.faiss") EMB_PATH = os.path.join(data_dir, "embeddings.npy") PCTX_PATH = os.path.join(data_dir, "passages.pkl") MODEL_NAME = os.getenv("MODEL_NAME", "google/flan-t5-small") EMBEDDER_MODEL = os.getenv("EMBEDDER_MODEL", "sentence-transformers/all-MiniLM-L6-v2") device = 0 if torch.cuda.is_available() else -1 print(f"Using model: {MODEL_NAME}, embedder: {EMBEDDER_MODEL}, device: {'GPU' if device==0 else 'CPU'}") # Threshold for maximum acceptable L2 distance dist_threshold = 1.0 # tune as needed # Max words per context snippet max_context_words = 200 # ## Useful functions def make_context_snippets(contexts, max_words=200): snippets = [] for c in contexts: words = c.split() if len(words) > max_words: c = " ".join(words[:max_words]) + " ... [truncated]" snippets.append(c) return snippets # ## 2. Load, Deduplicate & Chunk Corpora # # For this demo we sample small slices and remove duplicates. We also chunk any passage >512 tokens. # # tokenizer for chunking chunk_tokenizer = _AutoTokenizer.from_pretrained(MODEL_NAME) max_tokens = chunk_tokenizer.model_max_length def chunk_text(text: str, max_tokens: int, stride: int = None) -> list[str]: """ Split `text` into overlapping chunks of up to max_tokens words. By default uses 25% overlap (stride = max_tokens // 4). """ words = text.split() if stride is None: stride = max_tokens // 4 # 25% overlap chunks = [] start = 0 while start < len(words): end = start + max_tokens chunk = " ".join(words[start:end]) chunks.append(chunk) # advance by stride, not full window start += stride return chunks # Load corpora wiki_ds = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus", split="passages") wiki_passages = wiki_ds["passage"] squad_ds = load_dataset("rajpurkar/squad_v2", split="train[:100]") squad_passages = [ex["context"] for ex in squad_ds] trivia_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split="validation[:100]") trivia_passages = [] for ex in trivia_ds: for field in ("wiki_context", "search_context"): txt = ex.get(field) or "" if txt: trivia_passages.append(txt) # Combine, dedupe, chunk all_passages = wiki_passages + squad_passages + trivia_passages unique_passages = list(dict.fromkeys(all_passages)) passages = [] for p in unique_passages: # count tokens without encoding to avoid warnings tokens = chunk_tokenizer.tokenize(p) if len(tokens) > max_tokens: passages.extend(chunk_text(p, max_tokens)) else: passages.append(p) print(f"Total passages after dedupe & chunk: {len(passages)}") # Persist raw passages list with open(PCTX_PATH, "wb") as f: pickle.dump(passages, f) # ## 3. Build or Load FAISS Index & Embeddings # # We save embeddings & index to disk to skip slow re-encoding. # ── Initialize embedder and reranker ── from sentence_transformers import SentenceTransformer from torch import no_grad embedder = SentenceTransformer(EMBEDDER_MODEL) reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # ── Load or (re)build FAISS index with cosine similarity ── if os.path.exists(INDEX_PATH) and os.path.exists(EMB_PATH): print("Loading saved index and embeddings…") index = faiss.read_index(INDEX_PATH) embeddings = np.load(EMB_PATH) else: print("Encoding passages (with overlap)…") embeddings = embedder.encode( passages, show_progress_bar=True, convert_to_numpy=True, batch_size=32 ) # Normalize to unit length so that inner‐product = cosine sim embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True) # Build a FAISS index over inner‐product (cosine) space dim = embeddings.shape[1] index = faiss.IndexFlatIP(dim) index.add(embeddings) # Persist to disk for faster reload faiss.write_index(index, INDEX_PATH) np.save(EMB_PATH, embeddings) print(f"Indexed {index.ntotal} vectors.") # ## 4. Load QA Model & Pipeline tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) qa_pipeline = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, device=device, early_stopping=True ) print("QA pipeline ready.") # ## 5. Retrieval + Generation Functions # # We bail out early if top distance > threshold to avoid hallucination. def retrieve(question: str, k: int = 20, rerank_k: int = 5): # 1) dense‐search top k q_emb = embedder.encode([question], convert_to_numpy=True) distances, indices = index.search(q_emb, k) # 2) pull out those k contexts candidates = [passages[i] for i in indices[0]] # 3) score with cross‐encoder pairs = [[question, ctx] for ctx in candidates] scores = reranker.predict(pairs) # 4) pick top rerank_k top_idxs = np.argsort(scores)[-rerank_k:][::-1] final_ctxs = [candidates[i] for i in top_idxs] final_dist = [distances[0][i] for i in top_idxs] return final_ctxs, final_dist def generate(question: str, contexts: list) -> str: """ Build a RAG prompt from the retrieved contexts and generate an answer using the HF text2text pipeline. """ # 1) Turn each context into a truncated snippet snippet_lines = [ f"Context {i+1}: {s}" for i, s in enumerate(make_context_snippets(contexts, max_context_words)) ] # 2) Build the full prompt prompt = ( "You are a helpful assistant. Use ONLY the following contexts to answer. " "If the answer is not contained, say 'Sorry, I don't know.'\n\n" + "\n".join(snippet_lines) + f"\n\nQuestion: {question}\nAnswer:" ) # 3) Call the pipeline (it handles tokenization + generation + decoding) result = qa_pipeline(prompt, truncation=True, max_new_tokens=200)[0]["generated_text"] return result.strip() def retrieve_and_answer(question, k=5): contexts, distances = retrieve(question, k=20) if not contexts or distances[0] > dist_threshold: return "Sorry, I don't know.", [] ans = generate(question, contexts) return ans, contexts import random print("Some sample passages:\n") for p in random.sample(passages, 5): print(p, "\n" + "-"*80 + "\n") # ## 6. Gradio Demo Interface # # Separate panels for answer and contexts. def answer_and_contexts(question: str): """ Full end-to-end: retrieve, threshold-check, generate answer, and return both the answer and a formatted string of contexts. """ answer, contexts = retrieve_and_answer(question) # If no valid contexts, just return the apology if not contexts: return answer, "" # Otherwise format each snippet for display ctx_snippets = [ f"Context {i+1}: {s}" for i, s in enumerate(make_context_snippets(contexts, max_context_words)) ] return answer, "\n\n---\n\n".join(ctx_snippets) iface = gr.Interface( fn=answer_and_contexts, inputs=gr.Textbox(lines=1, placeholder="Enter your question here...", label="Question"), outputs=[ gr.Textbox(label="Answer"), gr.Textbox(label="Retrieved Contexts") ], title="🔍 RAG QA Demo", description="Retrieval-Augmented QA with distance threshold and context preview" ) iface.launch() # # Test the Model # load SQuAD v2 (we only need validation split) squad = load_dataset("rajpurkar/squad_v2", split="validation") # load the SQuAD metric (handles no-answer properly) squad_metric = evaluate.load("squad") def retrieval_recall(dataset, k=20, num_samples=100): hits = 0 for ex in dataset.select(range(num_samples)): question = ex["question"] gold_answers = ex["answers"]["text"] # list, empty if unanswerable # get your top-k contexts ctxs, _ = retrieve(question, k=k, rerank_k=k) # or rerank_k smaller # check if any gold answer appears in any context if any(any(ans in ctx for ctx in ctxs) for ans in gold_answers): hits += 1 recall = hits / num_samples print(f"Retrieval Recall@{k}: {recall:.3f}") return recall # ## Only answerable Questions def retrieval_recall_answerable(dataset, k=20, num_samples=100): hits = 0 total = 0 for ex in dataset.select(range(num_samples)): if not ex["answers"]["text"]: continue # skip unanswerable total += 1 ctxs, _ = retrieve(ex["question"], k=k, rerank_k=k) if any(any(ans in ctx for ctx in ctxs) for ans in ex["answers"]["text"]): hits += 1 recall = hits / total print(f"Retrieval Recall@{k} on answerable only: {recall:.3f} ({hits}/{total})") return recall def qa_eval_all(dataset, num_samples=100, k=20): preds, refs = [], [] for ex in dataset.select(range(num_samples)): qid = ex["id"] gold = ex["answers"] # ensure metric has something to iterate over if not gold["text"]: gold = {"text":[""], "answer_start":[0]} ans, _ = retrieve_and_answer(ex["question"], k=k) # for metric purposes, treat our refusal as empty string pred_text = "" if ans.strip().lower().startswith("sorry") else ans preds.append({"id": qid, "prediction_text": pred_text}) refs.append({"id": qid, "answers": gold}) results = squad_metric.compute(predictions=preds, references=refs) print(f"Full QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") return results def qa_eval_answerable(dataset, num_samples=100, k=20): preds, refs = [], [] for ex in dataset.select(range(num_samples)): if not ex["answers"]["text"]: continue # skip unanswerable qid = ex["id"] ans, _ = retrieve_and_answer(ex["question"], k=k) preds.append({"id": qid, "prediction_text": ans}) refs.append({"id": qid, "answers": ex["answers"]}) results = squad_metric.compute(predictions=preds, references=refs) print(f"Answerable-only QA EM: {results['exact_match']:.2f}, F1: {results['f1']:.2f}") return results retrieval_recall(squad, k=2, num_samples=100) retrieval_recall_answerable(squad, k=2, num_samples=100) qa_eval_all(squad, num_samples=100, k=2) qa_eval_answerable(squad, num_samples=100, k=2)