# advanced_rag.py import os import tempfile import shutil import PyPDF2 import streamlit as st import torch from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.llms import HuggingFaceHub from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.chains import RetrievalQA, LLMChain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain.docstore.document import Document from langchain.prompts import PromptTemplate import time import psutil import uuid import atexit from blockchain_utils_metamask import BlockchainManagerMetaMask class AdvancedRAG: def __init__(self, llm_model_name="mistralai/Mistral-7B-Instruct-v0.2", embedding_model_name="sentence-transformers/all-MiniLM-L6-v2", chunk_size=1000, chunk_overlap=200, use_gpu=True, use_blockchain=False, contract_address=None): """ Initialize the advanced RAG system with multiple retrieval methods. Args: llm_model_name: The HuggingFace model for text generation embedding_model_name: The HuggingFace model for embeddings chunk_size: Size of document chunks chunk_overlap: Overlap between chunks use_gpu: Whether to use GPU acceleration use_blockchain: Whether to enable blockchain verification contract_address: Address of the deployed RAG Document Verifier contract """ self.llm_model_name = llm_model_name self.embedding_model_name = embedding_model_name self.use_gpu = use_gpu and torch.cuda.is_available() self.use_blockchain = use_blockchain # Device selection for embeddings self.device = "cuda" if self.use_gpu else "cpu" st.sidebar.info(f"Using device: {self.device}") # Initialize text splitter self.text_splitter = RecursiveCharacterTextSplitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len, ) # Initialize embeddings model self.embeddings = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs={"device": self.device} ) # Initialize LLM using HuggingFaceHub try: # Use HF_TOKEN from environment variables hf_token = os.environ.get("HF_TOKEN") if not hf_token: st.warning("No HuggingFace token found. Using model without authentication.") self.llm = HuggingFaceHub( repo_id=llm_model_name, huggingfacehub_api_token=hf_token, model_kwargs={"temperature": 0.7, "max_length": 1024} ) except Exception as e: st.error(f"Error initializing LLM: {str(e)}") st.info("Trying to initialize with default model...") # Fallback to a smaller model self.llm = HuggingFaceHub( repo_id="google/flan-t5-small", model_kwargs={"temperature": 0.7, "max_length": 512} ) # Initialize vector store self.vector_store = None self.documents_processed = 0 # Monitoring stats self.processing_times = {} # Initialize blockchain manager if enabled self.blockchain = None if use_blockchain: try: self.blockchain = BlockchainManagerMetaMask( contract_address=contract_address ) st.sidebar.success("Blockchain manager initialized. Please connect MetaMask to continue.") except Exception as e: st.sidebar.error(f"Failed to initialize blockchain manager: {str(e)}") self.use_blockchain = False def update_blockchain_connection(self, metamask_info): """Update blockchain connection with MetaMask info.""" if self.blockchain and metamask_info: self.blockchain.update_connection( is_connected=metamask_info.get("connected", False), user_address=metamask_info.get("address"), network_id=metamask_info.get("network_id") ) return self.blockchain.is_connected return False def process_pdfs(self, pdf_files): """Process PDF files, create a vector store, and verify documents on blockchain.""" all_docs = [] with st.status("Processing PDF files...") as status: # Create temporary directory for file storage temp_dir = tempfile.mkdtemp() st.session_state['temp_dir'] = temp_dir # Monitor processing time and memory usage start_time = time.time() # Track memory before processing mem_before = psutil.virtual_memory().used / (1024 * 1024 * 1024) # GB # Process each PDF file for i, pdf_file in enumerate(pdf_files): try: file_start_time = time.time() # Save uploaded file to temp directory pdf_path = os.path.join(temp_dir, pdf_file.name) with open(pdf_path, "wb") as f: f.write(pdf_file.getbuffer()) status.update(label=f"Processing {pdf_file.name} ({i+1}/{len(pdf_files)})...") # Extract text from PDF text = "" with open(pdf_path, "rb") as f: pdf = PyPDF2.PdfReader(f) for page_num in range(len(pdf.pages)): page = pdf.pages[page_num] page_text = page.extract_text() if page_text: text += page_text + "\n\n" # Create documents docs = [Document(page_content=text, metadata={"source": pdf_file.name})] # Split documents into chunks split_docs = self.text_splitter.split_documents(docs) all_docs.extend(split_docs) # Verify document on blockchain if enabled and connected if self.use_blockchain and self.blockchain and self.blockchain.is_connected: try: # Create a unique document ID document_id = f"{pdf_file.name}_{uuid.uuid4().hex[:8]}" # Verify document on blockchain status.update(label=f"Verifying {pdf_file.name} on blockchain...") verification = self.blockchain.verify_document(document_id, pdf_path) if verification.get('status'): # Success st.sidebar.success(f"✅ {pdf_file.name} verified on blockchain") if 'tx_hash' in verification: st.sidebar.info(f"Transaction: {verification['tx_hash'][:10]}...") # Add blockchain metadata to documents for doc in split_docs: doc.metadata["blockchain"] = { "verified": True, "document_id": document_id, "document_hash": verification.get("document_hash", ""), "tx_hash": verification.get("tx_hash", ""), "block_number": verification.get("block_number", 0) } else: st.sidebar.warning(f"❌ Failed to verify {pdf_file.name} on blockchain") if 'error' in verification: st.sidebar.error(f"Error: {verification['error']}") except Exception as e: st.sidebar.error(f"Blockchain verification error: {str(e)}") elif self.use_blockchain: st.sidebar.warning("MetaMask not connected. Document not verified on blockchain.") file_end_time = time.time() processing_time = file_end_time - file_start_time st.sidebar.success(f"Processed {pdf_file.name}: {len(split_docs)} chunks in {processing_time:.2f}s") self.processing_times[pdf_file.name] = { "chunks": len(split_docs), "time": processing_time } except Exception as e: st.sidebar.error(f"Error processing {pdf_file.name}: {str(e)}") # Create vector store if we have documents if all_docs: status.update(label="Building vector index...") try: # Record the time taken to build the index index_start_time = time.time() # Create the vector store using FAISS self.vector_store = FAISS.from_documents(all_docs, self.embeddings) index_end_time = time.time() index_time = index_end_time - index_start_time # Track memory after processing mem_after = psutil.virtual_memory().used / (1024 * 1024 * 1024) # GB mem_used = mem_after - mem_before total_time = time.time() - start_time status.update(label=f"Completed processing {len(all_docs)} chunks in {total_time:.2f}s", state="complete") # Save performance metrics self.processing_times["index_building"] = index_time self.processing_times["total_time"] = total_time self.processing_times["memory_used_gb"] = mem_used self.documents_processed = len(all_docs) return True except Exception as e: st.error(f"Error creating vector store: {str(e)}") status.update(label="Error creating vector store", state="error") return False else: status.update(label="No content extracted from PDFs", state="error") return False def direct_retrieval(self, query): """ Direct retrieval method: simply returns the most relevant document chunks without LLM processing. Args: query: User's question Returns: dict: Results with raw document chunks """ if not self.vector_store: return "Please upload and process PDF files first." try: # Start timing the query query_start_time = time.time() # Retrieve the most relevant documents with st.status("Searching documents..."): retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) docs = retriever.get_relevant_documents(query) # Calculate query time query_time = time.time() - query_start_time # Format sources and create answer from sources directly sources = [] answer = f"Here are the most relevant passages for your query:\n\n" for i, doc in enumerate(docs): # Extract blockchain verification info if available blockchain_info = None if "blockchain" in doc.metadata: blockchain_info = { "verified": doc.metadata["blockchain"]["verified"], "document_id": doc.metadata["blockchain"]["document_id"], "tx_hash": doc.metadata["blockchain"]["tx_hash"] } source_text = doc.page_content answer += f"**Passage {i+1}** (from {doc.metadata.get('source', 'Unknown')}):\n{source_text}\n\n" sources.append({ "content": source_text, "source": doc.metadata.get("source", "Unknown"), "blockchain": blockchain_info }) # Log query to blockchain if enabled and connected blockchain_log = None if self.use_blockchain and self.blockchain and self.blockchain.is_connected: try: with st.status("Logging query to blockchain..."): log_result = self.blockchain.log_query(query, answer) if log_result.get("status"): # Success blockchain_log = { "logged": True, "query_id": log_result.get("query_id", ""), "tx_hash": log_result.get("tx_hash", ""), "block_number": log_result.get("block_number", 0) } else: st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}") except Exception as e: st.error(f"Error logging to blockchain: {str(e)}") return { "answer": answer, "sources": sources, "query_time": query_time, "blockchain_log": blockchain_log, "method": "direct" } except Exception as e: st.error(f"Error in direct retrieval: {str(e)}") return f"Error: {str(e)}" def enhanced_retrieval(self, query): """ Enhanced retrieval method: uses an LLM to process the retrieved documents and generate a comprehensive answer. Args: query: User's question Returns: dict: Results with LLM-enhanced answer """ if not self.vector_store: return "Please upload and process PDF files first." try: # Custom prompt for advanced processing prompt_template = """ You are an AI research assistant with expertise in analyzing and synthesizing information from documents. Below are relevant passages from documents that might answer the user's question. USER QUESTION: {question} RELEVANT PASSAGES: {context} Based on ONLY these passages, provide a comprehensive, accurate and well-structured answer to the question. Your answer should: 1. Directly address the user's question 2. Synthesize information from multiple passages when applicable 3. Be detailed, precise and factual 4. Include specific examples or evidence from the passages 5. Acknowledge any limitations or gaps in the provided information If the information to answer the question is not present in the passages, clearly state: "I don't have enough information to answer this question based on the available documents." ANSWER: """ PROMPT = PromptTemplate( template=prompt_template, input_variables=["context", "question"] ) # Start timing the query query_start_time = time.time() # Create QA chain retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) # Get documents first to track sources with st.status("Retrieving relevant documents..."): docs = retriever.get_relevant_documents(query) # Format sources sources = [] for i, doc in enumerate(docs): # Extract blockchain verification info if available blockchain_info = None if "blockchain" in doc.metadata: blockchain_info = { "verified": doc.metadata["blockchain"]["verified"], "document_id": doc.metadata["blockchain"]["document_id"], "tx_hash": doc.metadata["blockchain"]["tx_hash"] } sources.append({ "content": doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content, "source": doc.metadata.get("source", "Unknown"), "blockchain": blockchain_info }) # Create document chain document_chain = create_stuff_documents_chain(self.llm, PROMPT) # Generate answer with st.status("Generating enhanced answer..."): answer = document_chain.invoke({ "question": query, "context": docs }) # Calculate query time query_time = time.time() - query_start_time # Log query to blockchain if enabled and connected blockchain_log = None if self.use_blockchain and self.blockchain and self.blockchain.is_connected: try: with st.status("Logging query to blockchain..."): log_result = self.blockchain.log_query(query, answer) if log_result.get("status"): # Success blockchain_log = { "logged": True, "query_id": log_result.get("query_id", ""), "tx_hash": log_result.get("tx_hash", ""), "block_number": log_result.get("block_number", 0) } else: st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}") except Exception as e: st.error(f"Error logging to blockchain: {str(e)}") return { "answer": answer, "sources": sources, "query_time": query_time, "blockchain_log": blockchain_log, "method": "enhanced" } except Exception as e: st.error(f"Error in enhanced retrieval: {str(e)}") return f"Error: {str(e)}" def ask(self, query, method="enhanced"): """ Ask a question using the specified retrieval method. Args: query: User's question method: Retrieval method ("direct" or "enhanced") Returns: dict: Results from the specified retrieval method """ if method == "direct": return self.direct_retrieval(query) else: return self.enhanced_retrieval(query) def get_performance_metrics(self): """Return performance metrics for the RAG system.""" if not self.processing_times: return None return { "documents_processed": self.documents_processed, "index_building_time": self.processing_times.get("index_building", 0), "total_processing_time": self.processing_times.get("total_time", 0), "memory_used_gb": self.processing_times.get("memory_used_gb", 0), "device": self.device, "embedding_model": self.embedding_model_name, "blockchain_enabled": self.use_blockchain, "blockchain_connected": self.blockchain.is_connected if self.blockchain else False }