import os import torch import torch.backends.cudnn as cudnn import streamlit as st from langchain_community.document_loaders import PyPDFLoader from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS from langchain.prompts import PromptTemplate from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # Enable CUDA optimizations if available if torch.cuda.is_available(): cudnn.benchmark = True # Step 1: Load the PDF and create a vector store @st.cache_resource def load_pdf_to_vectorstore(pdf_path): # Load and split PDF loader = PyPDFLoader(pdf_path) documents = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=20, separators=["\n\n", "\n", ".", " ", ""] ) chunks = text_splitter.split_documents(documents) # Create embeddings and vector store embeddings = HuggingFaceEmbeddings( model_name="sentence-transformers/all-MiniLM-L6-v2" ) vectorstore = FAISS.from_documents(chunks, embeddings) return vectorstore # Step 2: Initialize the LaMini model @st.cache_resource def setup_model(): model_id = "MBZUAI/LaMini-Flan-T5-248M" # Using smaller model for faster inference tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForSeq2SeqLM.from_pretrained( model_id, # Removed low_cpu_mem_usage parameter torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) if torch.cuda.is_available(): model = model.cuda() pipe = pipeline( "text2text-generation", model=model, tokenizer=tokenizer, max_length=256, do_sample=False, temperature=0.3, top_p=0.95, device=0 if torch.cuda.is_available() else -1, batch_size=1 ) return pipe # Step 3: Generate a response using the model and vector store def generate_response(pipe, vectorstore, user_input): # Get relevant context docs = vectorstore.similarity_search(user_input, k=2) context = "\n".join([ f"Page {doc.metadata.get('page', 'unknown')}: {doc.page_content}" for doc in docs ]) # Create prompt prompt = PromptTemplate( input_variables=["context", "question"], template=""" Using the following medical text excerpts, answer the question. If the information isn't clearly provided in the context, or if you're unsure, please say so and recommend consulting a healthcare professional. Context: {context} Question: {question} Answer (citing relevant page numbers when possible):""" ) # Generate response using the new method prompt_text = prompt.format(context=context, question=user_input) response = pipe(prompt_text)[0]['generated_text'] return response # Cache responses for repeated questions @st.cache_data def cached_generate_response(user_input, _pipe, _vectorstore): return generate_response(_pipe, _vectorstore, user_input) # Batch processing for multiple questions def batch_generate_responses(pipe, vectorstore, questions, batch_size=4): responses = [] for i in range(0, len(questions), batch_size): batch = questions[i:i + batch_size] batch_responses = [generate_response(pipe, vectorstore, q) for q in batch] responses.extend(batch_responses) return responses # Streamlit UI def main(): st.title("Medical Chatbot Assistant 🏥") # Use the PDF file from the root directory pdf_path = "Medical_book.pdf" if os.path.exists(pdf_path): # Initialize progress progress_text = "Operation in progress. Please wait." # Load vector store and model with progress indication with st.spinner("Loading PDF and initializing model..."): vectorstore = load_pdf_to_vectorstore(pdf_path) pipe = setup_model() st.success("Ready to answer questions!") # Create a chat-like interface if "messages" not in st.session_state: st.session_state.messages = [] # Display chat history for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) # User input if prompt := st.chat_input("Ask your medical question:"): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) # Generate and display response with st.chat_message("assistant"): with st.spinner("Generating response..."): response = cached_generate_response(prompt, pipe, vectorstore) st.markdown(response) # Add assistant message to chat history st.session_state.messages.append({"role": "assistant", "content": response}) else: st.error("The file 'Medical_book.pdf' was not found in the root directory.") main()