Update app.py
Browse files
app.py
CHANGED
@@ -1,357 +1,12 @@
|
|
1 |
-
#
|
2 |
import os
|
3 |
-
import tempfile
|
4 |
import shutil
|
5 |
-
import PyPDF2
|
6 |
import streamlit as st
|
7 |
import torch
|
8 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
9 |
-
from langchain_community.llms import HuggingFaceHub
|
10 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
11 |
-
from langchain_community.vectorstores import FAISS
|
12 |
-
from langchain.chains import RetrievalQA
|
13 |
-
from langchain.docstore.document import Document
|
14 |
-
from langchain.prompts import PromptTemplate
|
15 |
-
import time
|
16 |
-
import psutil
|
17 |
-
import uuid
|
18 |
import atexit
|
19 |
-
from
|
20 |
from metamask_component import metamask_connector
|
21 |
-
|
22 |
-
|
23 |
-
class BlockchainEnabledRAG:
|
24 |
-
def __init__(self,
|
25 |
-
llm_model_name="deepseek-ai/DeepSeek-V3-0324",
|
26 |
-
embedding_model_name="sentence-transformers/all-MiniLM-L6-v2",
|
27 |
-
chunk_size=1000,
|
28 |
-
chunk_overlap=200,
|
29 |
-
use_gpu=True,
|
30 |
-
use_blockchain=False,
|
31 |
-
contract_address=None):
|
32 |
-
"""
|
33 |
-
Initialize the GPU-efficient RAG system with MetaMask blockchain integration.
|
34 |
-
|
35 |
-
Args:
|
36 |
-
llm_model_name: The HuggingFace model for text generation
|
37 |
-
embedding_model_name: The HuggingFace model for embeddings
|
38 |
-
chunk_size: Size of document chunks
|
39 |
-
chunk_overlap: Overlap between chunks
|
40 |
-
use_gpu: Whether to use GPU acceleration
|
41 |
-
use_blockchain: Whether to enable blockchain verification
|
42 |
-
contract_address: Address of the deployed RAG Document Verifier contract
|
43 |
-
"""
|
44 |
-
self.llm_model_name = llm_model_name
|
45 |
-
self.embedding_model_name = embedding_model_name
|
46 |
-
self.use_gpu = use_gpu and torch.cuda.is_available()
|
47 |
-
self.use_blockchain = use_blockchain
|
48 |
-
|
49 |
-
# Device selection for embeddings
|
50 |
-
self.device = "cuda" if self.use_gpu else "cpu"
|
51 |
-
st.sidebar.info(f"Using device: {self.device}")
|
52 |
-
|
53 |
-
# Initialize text splitter
|
54 |
-
self.text_splitter = RecursiveCharacterTextSplitter(
|
55 |
-
chunk_size=chunk_size,
|
56 |
-
chunk_overlap=chunk_overlap,
|
57 |
-
length_function=len,
|
58 |
-
)
|
59 |
-
|
60 |
-
# Initialize embeddings model
|
61 |
-
self.embeddings = HuggingFaceEmbeddings(
|
62 |
-
model_name=embedding_model_name,
|
63 |
-
model_kwargs={"device": self.device}
|
64 |
-
)
|
65 |
-
|
66 |
-
# Initialize LLM using HuggingFaceHub instead of Ollama
|
67 |
-
try:
|
68 |
-
# Use HF_TOKEN from environment variables
|
69 |
-
hf_token = os.environ.get("HF_TOKEN")
|
70 |
-
if not hf_token:
|
71 |
-
st.warning("No HuggingFace token found. Using model without authentication.")
|
72 |
-
|
73 |
-
self.llm = HuggingFaceHub(
|
74 |
-
repo_id=llm_model_name,
|
75 |
-
huggingfacehub_api_token=hf_token,
|
76 |
-
model_kwargs={"temperature": 0.7, "max_length": 1024}
|
77 |
-
)
|
78 |
-
except Exception as e:
|
79 |
-
st.error(f"Error initializing LLM: {str(e)}")
|
80 |
-
st.info("Trying to initialize with default model...")
|
81 |
-
# Fallback to a smaller model
|
82 |
-
self.llm = HuggingFaceHub(
|
83 |
-
repo_id="google/flan-t5-small",
|
84 |
-
model_kwargs={"temperature": 0.7, "max_length": 512}
|
85 |
-
)
|
86 |
-
|
87 |
-
# Initialize vector store
|
88 |
-
self.vector_store = None
|
89 |
-
self.documents_processed = 0
|
90 |
-
|
91 |
-
# Monitoring stats
|
92 |
-
self.processing_times = {}
|
93 |
-
|
94 |
-
# Initialize blockchain manager if enabled
|
95 |
-
self.blockchain = None
|
96 |
-
if use_blockchain:
|
97 |
-
try:
|
98 |
-
self.blockchain = BlockchainManagerMetaMask(
|
99 |
-
contract_address=contract_address
|
100 |
-
)
|
101 |
-
st.sidebar.success("Blockchain manager initialized. Please connect MetaMask to continue.")
|
102 |
-
except Exception as e:
|
103 |
-
st.sidebar.error(f"Failed to initialize blockchain manager: {str(e)}")
|
104 |
-
self.use_blockchain = False
|
105 |
-
|
106 |
-
def update_blockchain_connection(self, metamask_info):
|
107 |
-
"""Update blockchain connection with MetaMask info."""
|
108 |
-
if self.blockchain and metamask_info:
|
109 |
-
self.blockchain.update_connection(
|
110 |
-
is_connected=metamask_info.get("connected", False),
|
111 |
-
user_address=metamask_info.get("address"),
|
112 |
-
network_id=metamask_info.get("network_id")
|
113 |
-
)
|
114 |
-
return self.blockchain.is_connected
|
115 |
-
return False
|
116 |
-
|
117 |
-
def process_pdfs(self, pdf_files):
|
118 |
-
"""Process PDF files, create a vector store, and verify documents on blockchain."""
|
119 |
-
all_docs = []
|
120 |
-
|
121 |
-
with st.status("Processing PDF files...") as status:
|
122 |
-
# Create temporary directory for file storage
|
123 |
-
temp_dir = tempfile.mkdtemp()
|
124 |
-
st.session_state['temp_dir'] = temp_dir
|
125 |
-
|
126 |
-
# Monitor processing time and memory usage
|
127 |
-
start_time = time.time()
|
128 |
-
|
129 |
-
# Track memory before processing
|
130 |
-
mem_before = psutil.virtual_memory().used / (1024 * 1024 * 1024) # GB
|
131 |
-
|
132 |
-
# Process each PDF file
|
133 |
-
for i, pdf_file in enumerate(pdf_files):
|
134 |
-
try:
|
135 |
-
file_start_time = time.time()
|
136 |
-
|
137 |
-
# Save uploaded file to temp directory
|
138 |
-
pdf_path = os.path.join(temp_dir, pdf_file.name)
|
139 |
-
with open(pdf_path, "wb") as f:
|
140 |
-
f.write(pdf_file.getbuffer())
|
141 |
-
|
142 |
-
status.update(label=f"Processing {pdf_file.name} ({i+1}/{len(pdf_files)})...")
|
143 |
-
|
144 |
-
# Extract text from PDF
|
145 |
-
text = ""
|
146 |
-
with open(pdf_path, "rb") as f:
|
147 |
-
pdf = PyPDF2.PdfReader(f)
|
148 |
-
for page_num in range(len(pdf.pages)):
|
149 |
-
page = pdf.pages[page_num]
|
150 |
-
page_text = page.extract_text()
|
151 |
-
if page_text:
|
152 |
-
text += page_text + "\n\n"
|
153 |
-
|
154 |
-
# Create documents
|
155 |
-
docs = [Document(page_content=text, metadata={"source": pdf_file.name})]
|
156 |
-
|
157 |
-
# Split documents into chunks
|
158 |
-
split_docs = self.text_splitter.split_documents(docs)
|
159 |
-
|
160 |
-
all_docs.extend(split_docs)
|
161 |
-
|
162 |
-
# Verify document on blockchain if enabled and connected
|
163 |
-
if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
|
164 |
-
try:
|
165 |
-
# Create a unique document ID
|
166 |
-
document_id = f"{pdf_file.name}_{uuid.uuid4().hex[:8]}"
|
167 |
-
|
168 |
-
# Verify document on blockchain
|
169 |
-
status.update(label=f"Verifying {pdf_file.name} on blockchain...")
|
170 |
-
verification = self.blockchain.verify_document(document_id, pdf_path)
|
171 |
-
|
172 |
-
if verification.get('status'): # Success
|
173 |
-
st.sidebar.success(f"β
{pdf_file.name} verified on blockchain")
|
174 |
-
if 'tx_hash' in verification:
|
175 |
-
st.sidebar.info(f"Transaction: {verification['tx_hash'][:10]}...")
|
176 |
-
|
177 |
-
# Add blockchain metadata to documents
|
178 |
-
for doc in split_docs:
|
179 |
-
doc.metadata["blockchain"] = {
|
180 |
-
"verified": True,
|
181 |
-
"document_id": document_id,
|
182 |
-
"document_hash": verification.get("document_hash", ""),
|
183 |
-
"tx_hash": verification.get("tx_hash", ""),
|
184 |
-
"block_number": verification.get("block_number", 0)
|
185 |
-
}
|
186 |
-
else:
|
187 |
-
st.sidebar.warning(f"β Failed to verify {pdf_file.name} on blockchain")
|
188 |
-
if 'error' in verification:
|
189 |
-
st.sidebar.error(f"Error: {verification['error']}")
|
190 |
-
except Exception as e:
|
191 |
-
st.sidebar.error(f"Blockchain verification error: {str(e)}")
|
192 |
-
elif self.use_blockchain:
|
193 |
-
st.sidebar.warning("MetaMask not connected. Document not verified on blockchain.")
|
194 |
-
|
195 |
-
file_end_time = time.time()
|
196 |
-
processing_time = file_end_time - file_start_time
|
197 |
-
|
198 |
-
st.sidebar.success(f"Processed {pdf_file.name}: {len(split_docs)} chunks in {processing_time:.2f}s")
|
199 |
-
self.processing_times[pdf_file.name] = {
|
200 |
-
"chunks": len(split_docs),
|
201 |
-
"time": processing_time
|
202 |
-
}
|
203 |
-
|
204 |
-
except Exception as e:
|
205 |
-
st.sidebar.error(f"Error processing {pdf_file.name}: {str(e)}")
|
206 |
-
|
207 |
-
# Create vector store if we have documents
|
208 |
-
if all_docs:
|
209 |
-
status.update(label="Building vector index...")
|
210 |
-
try:
|
211 |
-
# Record the time taken to build the index
|
212 |
-
index_start_time = time.time()
|
213 |
-
|
214 |
-
# Create the vector store using FAISS
|
215 |
-
self.vector_store = FAISS.from_documents(all_docs, self.embeddings)
|
216 |
-
|
217 |
-
index_end_time = time.time()
|
218 |
-
index_time = index_end_time - index_start_time
|
219 |
-
|
220 |
-
# Track memory after processing
|
221 |
-
mem_after = psutil.virtual_memory().used / (1024 * 1024 * 1024) # GB
|
222 |
-
mem_used = mem_after - mem_before
|
223 |
-
|
224 |
-
total_time = time.time() - start_time
|
225 |
-
|
226 |
-
status.update(label=f"Completed processing {len(all_docs)} chunks in {total_time:.2f}s", state="complete")
|
227 |
-
|
228 |
-
# Save performance metrics
|
229 |
-
self.processing_times["index_building"] = index_time
|
230 |
-
self.processing_times["total_time"] = total_time
|
231 |
-
self.processing_times["memory_used_gb"] = mem_used
|
232 |
-
self.documents_processed = len(all_docs)
|
233 |
-
|
234 |
-
return True
|
235 |
-
except Exception as e:
|
236 |
-
st.error(f"Error creating vector store: {str(e)}")
|
237 |
-
status.update(label="Error creating vector store", state="error")
|
238 |
-
return False
|
239 |
-
else:
|
240 |
-
status.update(label="No content extracted from PDFs", state="error")
|
241 |
-
return False
|
242 |
-
|
243 |
-
def ask(self, query):
|
244 |
-
"""Ask a question and get an answer based on the PDFs with blockchain logging."""
|
245 |
-
if not self.vector_store:
|
246 |
-
return "Please upload and process PDF files first."
|
247 |
-
|
248 |
-
try:
|
249 |
-
# Custom prompt
|
250 |
-
prompt_template = """
|
251 |
-
You are an AI assistant that provides accurate information based on PDF documents.
|
252 |
-
|
253 |
-
Use the following context to answer the question. Be detailed and precise in your answer.
|
254 |
-
If the answer is not in the context, say "I don't have enough information to answer this question."
|
255 |
-
|
256 |
-
Context:
|
257 |
-
{context}
|
258 |
-
|
259 |
-
Question: {question}
|
260 |
-
|
261 |
-
Answer:
|
262 |
-
"""
|
263 |
-
PROMPT = PromptTemplate(
|
264 |
-
template=prompt_template,
|
265 |
-
input_variables=["context", "question"]
|
266 |
-
)
|
267 |
-
|
268 |
-
# Start timing the query
|
269 |
-
query_start_time = time.time()
|
270 |
-
|
271 |
-
# Create QA chain
|
272 |
-
chain_type_kwargs = {"prompt": PROMPT}
|
273 |
-
qa = RetrievalQA.from_chain_type(
|
274 |
-
llm=self.llm,
|
275 |
-
chain_type="stuff",
|
276 |
-
retriever=self.vector_store.as_retriever(search_kwargs={"k": 4}),
|
277 |
-
chain_type_kwargs=chain_type_kwargs,
|
278 |
-
return_source_documents=True
|
279 |
-
)
|
280 |
-
|
281 |
-
# Get answer
|
282 |
-
with st.status("Searching documents and generating answer..."):
|
283 |
-
response = qa({"query": query})
|
284 |
-
|
285 |
-
answer = response["result"]
|
286 |
-
source_docs = response["source_documents"]
|
287 |
-
|
288 |
-
# Calculate query time
|
289 |
-
query_time = time.time() - query_start_time
|
290 |
-
|
291 |
-
# Format sources
|
292 |
-
sources = []
|
293 |
-
for i, doc in enumerate(source_docs):
|
294 |
-
# Extract blockchain verification info if available
|
295 |
-
blockchain_info = None
|
296 |
-
if "blockchain" in doc.metadata:
|
297 |
-
blockchain_info = {
|
298 |
-
"verified": doc.metadata["blockchain"]["verified"],
|
299 |
-
"document_id": doc.metadata["blockchain"]["document_id"],
|
300 |
-
"tx_hash": doc.metadata["blockchain"]["tx_hash"]
|
301 |
-
}
|
302 |
-
|
303 |
-
sources.append({
|
304 |
-
"content": doc.page_content[:300] + "..." if len(doc.page_content) > 300 else doc.page_content,
|
305 |
-
"source": doc.metadata.get("source", "Unknown"),
|
306 |
-
"blockchain": blockchain_info
|
307 |
-
})
|
308 |
-
|
309 |
-
# Log query to blockchain if enabled and connected
|
310 |
-
blockchain_log = None
|
311 |
-
if self.use_blockchain and self.blockchain and self.blockchain.is_connected:
|
312 |
-
try:
|
313 |
-
with st.status("Logging query to blockchain..."):
|
314 |
-
log_result = self.blockchain.log_query(query, answer)
|
315 |
-
|
316 |
-
if log_result.get("status"): # Success
|
317 |
-
blockchain_log = {
|
318 |
-
"logged": True,
|
319 |
-
"query_id": log_result.get("query_id", ""),
|
320 |
-
"tx_hash": log_result.get("tx_hash", ""),
|
321 |
-
"block_number": log_result.get("block_number", 0)
|
322 |
-
}
|
323 |
-
else:
|
324 |
-
st.error(f"Error logging to blockchain: {log_result.get('error', 'Unknown error')}")
|
325 |
-
except Exception as e:
|
326 |
-
st.error(f"Error logging to blockchain: {str(e)}")
|
327 |
-
|
328 |
-
return {
|
329 |
-
"answer": answer,
|
330 |
-
"sources": sources,
|
331 |
-
"query_time": query_time,
|
332 |
-
"blockchain_log": blockchain_log
|
333 |
-
}
|
334 |
-
|
335 |
-
except Exception as e:
|
336 |
-
st.error(f"Error generating answer: {str(e)}")
|
337 |
-
return f"Error: {str(e)}"
|
338 |
-
|
339 |
-
def get_performance_metrics(self):
|
340 |
-
"""Return performance metrics for the RAG system."""
|
341 |
-
if not self.processing_times:
|
342 |
-
return None
|
343 |
-
|
344 |
-
return {
|
345 |
-
"documents_processed": self.documents_processed,
|
346 |
-
"index_building_time": self.processing_times.get("index_building", 0),
|
347 |
-
"total_processing_time": self.processing_times.get("total_time", 0),
|
348 |
-
"memory_used_gb": self.processing_times.get("memory_used_gb", 0),
|
349 |
-
"device": self.device,
|
350 |
-
"embedding_model": self.embedding_model_name,
|
351 |
-
"blockchain_enabled": self.use_blockchain,
|
352 |
-
"blockchain_connected": self.blockchain.is_connected if self.blockchain else False
|
353 |
-
}
|
354 |
-
|
355 |
|
356 |
# Helper function to initialize session state
|
357 |
def initialize_session_state():
|
@@ -364,6 +19,10 @@ def initialize_session_state():
|
|
364 |
st.session_state.temp_dir = None
|
365 |
if "metamask_connected" not in st.session_state:
|
366 |
st.session_state.metamask_connected = False
|
|
|
|
|
|
|
|
|
367 |
|
368 |
# Helper function to clean up temporary files
|
369 |
def cleanup_temp_files():
|
@@ -375,13 +34,22 @@ def cleanup_temp_files():
|
|
375 |
except Exception as e:
|
376 |
print(f"Error cleaning up temporary directory: {e}")
|
377 |
|
378 |
-
|
379 |
# Streamlit UI
|
380 |
def main():
|
381 |
-
st.set_page_config(
|
|
|
|
|
|
|
|
|
382 |
|
383 |
-
st.title("π
|
384 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
|
|
|
385 |
|
386 |
# Initialize session state
|
387 |
initialize_session_state()
|
@@ -425,12 +93,14 @@ def main():
|
|
425 |
st.warning("No GPU detected. Running in CPU mode.")
|
426 |
|
427 |
# Model selection
|
|
|
428 |
llm_model = st.selectbox(
|
429 |
"LLM Model",
|
430 |
options=[
|
431 |
-
"deepseek-ai/DeepSeek-V3-0324",
|
432 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
433 |
-
"google/
|
|
|
|
|
434 |
"tiiuae/falcon-7b-instruct"
|
435 |
],
|
436 |
index=0
|
@@ -449,21 +119,20 @@ def main():
|
|
449 |
use_gpu = st.checkbox("Use GPU Acceleration", value=gpu_available)
|
450 |
|
451 |
# Blockchain configuration
|
452 |
-
st.
|
453 |
use_blockchain = st.checkbox("Enable Blockchain Verification", value=True)
|
454 |
|
455 |
if use_blockchain:
|
456 |
-
|
457 |
-
|
|
|
|
|
458 |
|
459 |
# Display MetaMask connection status in sidebar
|
460 |
if metamask_info and metamask_info.get("connected"):
|
461 |
st.success(f"β
MetaMask Connected: {metamask_info.get('address')[:10]}...")
|
462 |
else:
|
463 |
st.warning("β οΈ MetaMask not connected. Please connect your wallet above.")
|
464 |
-
|
465 |
-
if not contract_address or contract_address == "0x0000000000000000000000000000000000000000":
|
466 |
-
st.error("Please deploy the contract and enter its address")
|
467 |
|
468 |
# Advanced options
|
469 |
with st.expander("Advanced Options"):
|
@@ -476,7 +145,7 @@ def main():
|
|
476 |
if use_blockchain and not contract_address:
|
477 |
st.error("Contract address is required for blockchain integration")
|
478 |
else:
|
479 |
-
st.session_state.rag =
|
480 |
llm_model_name=llm_model,
|
481 |
embedding_model_name=embedding_model,
|
482 |
chunk_size=chunk_size,
|
@@ -503,7 +172,7 @@ def main():
|
|
503 |
if uploaded_files and st.button("Process PDFs"):
|
504 |
if not st.session_state.rag:
|
505 |
with st.spinner("Initializing RAG system..."):
|
506 |
-
st.session_state.rag =
|
507 |
llm_model_name=llm_model,
|
508 |
embedding_model_name=embedding_model,
|
509 |
chunk_size=chunk_size,
|
@@ -531,6 +200,20 @@ def main():
|
|
531 |
st.markdown(f"**Blockchain verification:** {'Enabled' if metrics['blockchain_enabled'] else 'Disabled'}")
|
532 |
st.markdown(f"**Blockchain connected:** {'Yes' if metrics.get('blockchain_connected') else 'No'}")
|
533 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
534 |
# Blockchain verification info
|
535 |
if st.session_state.rag and st.session_state.rag.use_blockchain:
|
536 |
if st.session_state.metamask_connected:
|
@@ -539,89 +222,141 @@ def main():
|
|
539 |
st.warning("π Blockchain verification is enabled but MetaMask is not connected. Please connect your MetaMask wallet to use blockchain features.")
|
540 |
|
541 |
# Display chat messages
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
-
|
547 |
-
|
548 |
-
|
549 |
-
|
550 |
-
|
551 |
-
st.caption(f"Response time: {message['content']['query_time']:.2f} seconds")
|
552 |
-
|
553 |
-
# Display blockchain log if available
|
554 |
-
if "blockchain_log" in message["content"] and message["content"]["blockchain_log"]:
|
555 |
-
blockchain_log = message["content"]["blockchain_log"]
|
556 |
-
st.success(f"β
Query logged on blockchain | Transaction: {blockchain_log['tx_hash'][:10]}...")
|
557 |
-
|
558 |
-
# Display sources in expander
|
559 |
-
if "sources" in message["content"] and message["content"]["sources"]:
|
560 |
-
with st.expander("π View Sources"):
|
561 |
-
for i, source in enumerate(message["content"]["sources"]):
|
562 |
-
st.markdown(f"**Source {i+1}: {source['source']}**")
|
563 |
-
|
564 |
-
# Show blockchain verification if available
|
565 |
-
if source.get("blockchain"):
|
566 |
-
st.success(f"β
Verified on blockchain | TX: {source['blockchain']['tx_hash'][:10]}...")
|
567 |
-
|
568 |
-
st.text(source["content"])
|
569 |
-
st.divider()
|
570 |
-
else:
|
571 |
st.markdown(message["content"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
572 |
|
573 |
# Chat input
|
574 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
# Add user message to chat
|
576 |
-
st.session_state.messages.append({"role": "user", "content":
|
577 |
|
578 |
# Display user message
|
579 |
-
with
|
580 |
-
st.
|
|
|
581 |
|
582 |
# Check if system is initialized
|
583 |
if not st.session_state.rag:
|
584 |
-
with
|
585 |
-
|
586 |
-
|
587 |
-
|
|
|
588 |
|
589 |
# Get response if vector store is ready
|
590 |
elif st.session_state.rag.vector_store:
|
591 |
-
with
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
if isinstance(response, dict):
|
596 |
-
st.markdown(response["answer"])
|
597 |
-
|
598 |
-
if "query_time" in response:
|
599 |
-
st.caption(f"Response time: {response['query_time']:.2f} seconds")
|
600 |
|
601 |
-
#
|
602 |
-
|
603 |
-
|
604 |
-
st.success(f"β
Query logged on blockchain | Transaction: {blockchain_log['tx_hash'][:10]}...")
|
605 |
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
615 |
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
620 |
else:
|
621 |
-
with
|
622 |
-
|
623 |
-
|
624 |
-
|
|
|
625 |
|
626 |
|
627 |
# Main entry point
|
@@ -629,5 +364,4 @@ if __name__ == "__main__":
|
|
629 |
# Register cleanup function
|
630 |
atexit.register(cleanup_temp_files)
|
631 |
|
632 |
-
main()
|
633 |
-
|
|
|
1 |
+
# app.py
|
2 |
import os
|
|
|
3 |
import shutil
|
|
|
4 |
import streamlit as st
|
5 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
import atexit
|
7 |
+
from advanced_rag import AdvancedRAG
|
8 |
from metamask_component import metamask_connector
|
9 |
+
from voice_component import voice_input_component
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Helper function to initialize session state
|
12 |
def initialize_session_state():
|
|
|
19 |
st.session_state.temp_dir = None
|
20 |
if "metamask_connected" not in st.session_state:
|
21 |
st.session_state.metamask_connected = False
|
22 |
+
if "retrieval_method" not in st.session_state:
|
23 |
+
st.session_state.retrieval_method = "enhanced"
|
24 |
+
if "voice_transcript" not in st.session_state:
|
25 |
+
st.session_state.voice_transcript = ""
|
26 |
|
27 |
# Helper function to clean up temporary files
|
28 |
def cleanup_temp_files():
|
|
|
34 |
except Exception as e:
|
35 |
print(f"Error cleaning up temporary directory: {e}")
|
36 |
|
|
|
37 |
# Streamlit UI
|
38 |
def main():
|
39 |
+
st.set_page_config(
|
40 |
+
page_title="Advanced RAG with MetaMask and Voice",
|
41 |
+
layout="wide",
|
42 |
+
initial_sidebar_state="expanded"
|
43 |
+
)
|
44 |
|
45 |
+
st.title("π Advanced RAG System with Blockchain Verification and Voice Input")
|
46 |
+
st.markdown("""
|
47 |
+
This application allows you to:
|
48 |
+
- Upload and process PDF documents
|
49 |
+
- Verify document authenticity on blockchain using MetaMask
|
50 |
+
- Ask questions using voice or text input
|
51 |
+
- Choose between direct retrieval or enhanced LLM-powered answers
|
52 |
+
""")
|
53 |
|
54 |
# Initialize session state
|
55 |
initialize_session_state()
|
|
|
93 |
st.warning("No GPU detected. Running in CPU mode.")
|
94 |
|
95 |
# Model selection
|
96 |
+
st.subheader("Model Selection")
|
97 |
llm_model = st.selectbox(
|
98 |
"LLM Model",
|
99 |
options=[
|
|
|
100 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
101 |
+
"google/gemma-7b-it",
|
102 |
+
"google/flan-t5-xl",
|
103 |
+
"Salesforce/xgen-7b-8k-inst",
|
104 |
"tiiuae/falcon-7b-instruct"
|
105 |
],
|
106 |
index=0
|
|
|
119 |
use_gpu = st.checkbox("Use GPU Acceleration", value=gpu_available)
|
120 |
|
121 |
# Blockchain configuration
|
122 |
+
st.subheader("π Blockchain Configuration")
|
123 |
use_blockchain = st.checkbox("Enable Blockchain Verification", value=True)
|
124 |
|
125 |
if use_blockchain:
|
126 |
+
# Hardcoded contract address - replace with your deployed contract
|
127 |
+
contract_address = os.environ.get("CONTRACT_ADDRESS", "0x123abc...") # Your pre-deployed contract
|
128 |
+
|
129 |
+
st.info(f"Using pre-deployed contract: {contract_address[:10]}...")
|
130 |
|
131 |
# Display MetaMask connection status in sidebar
|
132 |
if metamask_info and metamask_info.get("connected"):
|
133 |
st.success(f"β
MetaMask Connected: {metamask_info.get('address')[:10]}...")
|
134 |
else:
|
135 |
st.warning("β οΈ MetaMask not connected. Please connect your wallet above.")
|
|
|
|
|
|
|
136 |
|
137 |
# Advanced options
|
138 |
with st.expander("Advanced Options"):
|
|
|
145 |
if use_blockchain and not contract_address:
|
146 |
st.error("Contract address is required for blockchain integration")
|
147 |
else:
|
148 |
+
st.session_state.rag = AdvancedRAG(
|
149 |
llm_model_name=llm_model,
|
150 |
embedding_model_name=embedding_model,
|
151 |
chunk_size=chunk_size,
|
|
|
172 |
if uploaded_files and st.button("Process PDFs"):
|
173 |
if not st.session_state.rag:
|
174 |
with st.spinner("Initializing RAG system..."):
|
175 |
+
st.session_state.rag = AdvancedRAG(
|
176 |
llm_model_name=llm_model,
|
177 |
embedding_model_name=embedding_model,
|
178 |
chunk_size=chunk_size,
|
|
|
200 |
st.markdown(f"**Blockchain verification:** {'Enabled' if metrics['blockchain_enabled'] else 'Disabled'}")
|
201 |
st.markdown(f"**Blockchain connected:** {'Yes' if metrics.get('blockchain_connected') else 'No'}")
|
202 |
|
203 |
+
# Retrieval Method Selection
|
204 |
+
st.header("π Retrieval Method")
|
205 |
+
retrieval_cols = st.columns(2)
|
206 |
+
|
207 |
+
with retrieval_cols[0]:
|
208 |
+
if st.button("π Direct Retrieval", help="Get raw document chunks without LLM processing", use_container_width=True):
|
209 |
+
st.session_state.retrieval_method = "direct"
|
210 |
+
st.info("Using Direct Retrieval: Raw document passages will be returned without LLM processing")
|
211 |
+
|
212 |
+
with retrieval_cols[1]:
|
213 |
+
if st.button("π§ Enhanced Retrieval", help="Process results through LLM for comprehensive answers", use_container_width=True):
|
214 |
+
st.session_state.retrieval_method = "enhanced"
|
215 |
+
st.info("Using Enhanced Retrieval: Documents will be processed by LLM to generate comprehensive answers")
|
216 |
+
|
217 |
# Blockchain verification info
|
218 |
if st.session_state.rag and st.session_state.rag.use_blockchain:
|
219 |
if st.session_state.metamask_connected:
|
|
|
222 |
st.warning("π Blockchain verification is enabled but MetaMask is not connected. Please connect your MetaMask wallet to use blockchain features.")
|
223 |
|
224 |
# Display chat messages
|
225 |
+
st.header("π¬ Chat")
|
226 |
+
|
227 |
+
# Chat container
|
228 |
+
chat_container = st.container(height=400, border=True)
|
229 |
+
|
230 |
+
with chat_container:
|
231 |
+
for message in st.session_state.messages:
|
232 |
+
with st.chat_message(message["role"]):
|
233 |
+
if message["role"] == "user":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
st.markdown(message["content"])
|
235 |
+
else:
|
236 |
+
if isinstance(message["content"], dict):
|
237 |
+
st.markdown(message["content"]["answer"])
|
238 |
+
|
239 |
+
if "query_time" in message["content"]:
|
240 |
+
st.caption(f"Response time: {message['content']['query_time']:.2f} seconds")
|
241 |
+
|
242 |
+
if "method" in message["content"]:
|
243 |
+
method_name = "Direct Retrieval" if message["content"]["method"] == "direct" else "Enhanced Retrieval"
|
244 |
+
st.caption(f"Method: {method_name}")
|
245 |
+
|
246 |
+
# Display blockchain log if available
|
247 |
+
if "blockchain_log" in message["content"] and message["content"]["blockchain_log"]:
|
248 |
+
blockchain_log = message["content"]["blockchain_log"]
|
249 |
+
st.success(f"β
Query logged on blockchain | Transaction: {blockchain_log['tx_hash'][:10]}...")
|
250 |
+
|
251 |
+
# Display sources in expander
|
252 |
+
if "sources" in message["content"] and message["content"]["sources"]:
|
253 |
+
with st.expander("π View Sources"):
|
254 |
+
for i, source in enumerate(message["content"]["sources"]):
|
255 |
+
st.markdown(f"**Source {i+1}: {source['source']}**")
|
256 |
+
|
257 |
+
# Show blockchain verification if available
|
258 |
+
if source.get("blockchain"):
|
259 |
+
st.success(f"β
Verified on blockchain | TX: {source['blockchain']['tx_hash'][:10]}...")
|
260 |
+
|
261 |
+
st.text(source["content"])
|
262 |
+
st.divider()
|
263 |
+
else:
|
264 |
+
st.markdown(message["content"])
|
265 |
+
|
266 |
+
# Voice Input Section
|
267 |
+
st.header("π€ Voice Input")
|
268 |
+
st.markdown("You can ask questions using your voice or type them below.")
|
269 |
+
|
270 |
+
# Voice input component
|
271 |
+
voice_transcript = voice_input_component()
|
272 |
+
|
273 |
+
# Update session state with voice transcript if not empty
|
274 |
+
if voice_transcript and voice_transcript.strip():
|
275 |
+
st.session_state.voice_transcript = voice_transcript.strip()
|
276 |
+
st.success(f"Voice input received: {voice_transcript}")
|
277 |
+
|
278 |
+
# Chat input - show the voice transcript in the text input
|
279 |
+
prompt_placeholder = "Ask a question about your PDFs..."
|
280 |
+
if st.session_state.voice_transcript:
|
281 |
+
prompt_placeholder = st.session_state.voice_transcript
|
282 |
|
283 |
# Chat input
|
284 |
+
prompt = st.chat_input(prompt_placeholder)
|
285 |
+
|
286 |
+
# Process either voice input or text input
|
287 |
+
if prompt or st.session_state.voice_transcript:
|
288 |
+
# Prioritize text input over voice input
|
289 |
+
if prompt:
|
290 |
+
user_input = prompt
|
291 |
+
else:
|
292 |
+
user_input = st.session_state.voice_transcript
|
293 |
+
# Clear voice transcript after using it
|
294 |
+
st.session_state.voice_transcript = ""
|
295 |
+
# Rerun to clear the voice input display
|
296 |
+
st.rerun()
|
297 |
+
|
298 |
# Add user message to chat
|
299 |
+
st.session_state.messages.append({"role": "user", "content": user_input})
|
300 |
|
301 |
# Display user message
|
302 |
+
with chat_container:
|
303 |
+
with st.chat_message("user"):
|
304 |
+
st.markdown(user_input)
|
305 |
|
306 |
# Check if system is initialized
|
307 |
if not st.session_state.rag:
|
308 |
+
with chat_container:
|
309 |
+
with st.chat_message("assistant"):
|
310 |
+
message = "Please initialize the system and process PDFs first."
|
311 |
+
st.markdown(message)
|
312 |
+
st.session_state.messages.append({"role": "assistant", "content": message})
|
313 |
|
314 |
# Get response if vector store is ready
|
315 |
elif st.session_state.rag.vector_store:
|
316 |
+
with chat_container:
|
317 |
+
with st.chat_message("assistant"):
|
318 |
+
# Get retrieval method
|
319 |
+
method = st.session_state.retrieval_method
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
+
# Get response using specified method
|
322 |
+
response = st.session_state.rag.ask(user_input, method=method)
|
323 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
|
|
324 |
|
325 |
+
if isinstance(response, dict):
|
326 |
+
st.markdown(response["answer"])
|
327 |
+
|
328 |
+
if "query_time" in response:
|
329 |
+
st.caption(f"Response time: {response['query_time']:.2f} seconds")
|
330 |
+
|
331 |
+
if "method" in response:
|
332 |
+
method_name = "Direct Retrieval" if response["method"] == "direct" else "Enhanced Retrieval"
|
333 |
+
st.caption(f"Method: {method_name}")
|
334 |
+
|
335 |
+
# Display blockchain log if available
|
336 |
+
if "blockchain_log" in response and response["blockchain_log"]:
|
337 |
+
blockchain_log = response["blockchain_log"]
|
338 |
+
st.success(f"β
Query logged on blockchain | Transaction: {blockchain_log['tx_hash'][:10]}...")
|
339 |
+
|
340 |
+
# Display sources in expander
|
341 |
+
if "sources" in response and response["sources"]:
|
342 |
+
with st.expander("π View Sources"):
|
343 |
+
for i, source in enumerate(response["sources"]):
|
344 |
+
st.markdown(f"**Source {i+1}: {source['source']}**")
|
345 |
|
346 |
+
# Show blockchain verification if available
|
347 |
+
if source.get("blockchain"):
|
348 |
+
st.success(f"β
Verified on blockchain | TX: {source['blockchain']['tx_hash'][:10]}...")
|
349 |
+
|
350 |
+
st.text(source["content"])
|
351 |
+
st.divider()
|
352 |
+
else:
|
353 |
+
st.markdown(response)
|
354 |
else:
|
355 |
+
with chat_container:
|
356 |
+
with st.chat_message("assistant"):
|
357 |
+
message = "Please upload and process PDF files first."
|
358 |
+
st.markdown(message)
|
359 |
+
st.session_state.messages.append({"role": "assistant", "content": message})
|
360 |
|
361 |
|
362 |
# Main entry point
|
|
|
364 |
# Register cleanup function
|
365 |
atexit.register(cleanup_temp_files)
|
366 |
|
367 |
+
main()
|
|