khalifssa commited on
Commit
26cb36a
·
verified ·
1 Parent(s): c67895f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -65
app.py CHANGED
@@ -16,6 +16,7 @@ if torch.cuda.is_available():
16
  # Step 1: Load the PDF and create a vector store
17
  @st.cache_resource
18
  def load_pdf_to_vectorstore(pdf_path):
 
19
  loader = PyPDFLoader(pdf_path)
20
  documents = loader.load()
21
 
@@ -27,6 +28,7 @@ def load_pdf_to_vectorstore(pdf_path):
27
 
28
  chunks = text_splitter.split_documents(documents)
29
 
 
30
  embeddings = HuggingFaceEmbeddings(
31
  model_name="sentence-transformers/all-MiniLM-L6-v2"
32
  )
@@ -37,10 +39,11 @@ def load_pdf_to_vectorstore(pdf_path):
37
  # Step 2: Initialize the LaMini model
38
  @st.cache_resource
39
  def setup_model():
40
- model_id = "MBZUAI/LaMini-Flan-T5-248M"
41
  tokenizer = AutoTokenizer.from_pretrained(model_id)
42
  model = AutoModelForSeq2SeqLM.from_pretrained(
43
  model_id,
 
44
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
45
  )
46
 
@@ -51,16 +54,17 @@ def setup_model():
51
  "text2text-generation",
52
  model=model,
53
  tokenizer=tokenizer,
54
- max_length=512, # Increased max length for better context
55
  do_sample=False,
56
  temperature=0.3,
57
  top_p=0.95,
58
  device=0 if torch.cuda.is_available() else -1,
 
59
  )
60
  return pipe
61
 
62
- # Step 3: Generate response with conversation history
63
- def generate_response(pipe, vectorstore, user_input, chat_history):
64
  # Get relevant context
65
  docs = vectorstore.similarity_search(user_input, k=2)
66
  context = "\n".join([
@@ -68,89 +72,83 @@ def generate_response(pipe, vectorstore, user_input, chat_history):
68
  for doc in docs
69
  ])
70
 
71
- # Format conversation history
72
- history_text = "\n".join(
73
- [f"{msg['role'].capitalize()}: {msg['content']}"
74
- for msg in chat_history]
75
- ) if chat_history else "No previous conversation"
76
-
77
- # Create contextual prompt
78
- prompt_template = PromptTemplate(
79
- input_variables=["history", "context", "question"],
80
  template="""
81
- Previous Conversation:
82
- {history}
83
-
84
- Medical Context:
85
- {context}
86
-
87
- Current Question: {question}
88
-
89
- Instructions:
90
- 1. Answer based on context and conversation history
91
- 2. Cite page numbers when possible
92
- 3. If unsure, recommend consulting a professional
93
- 4. Maintain natural conversation flow
94
-
95
- Assistant Response:
96
- """
97
- )
98
-
99
- prompt = prompt_template.format(
100
- history=history_text,
101
- context=context,
102
- question=user_input
103
  )
104
 
105
- # Generate response
106
- response = pipe(prompt, max_length=512)[0]['generated_text']
 
107
 
108
  return response
109
 
110
- # Streamlit UI with conversation memory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def main():
112
  st.title("Medical Chatbot Assistant 🏥")
113
 
114
- # Initialize session state for chat history
115
- if "messages" not in st.session_state:
116
- st.session_state.messages = []
117
-
118
- # Load resources
119
  pdf_path = "Medical_book.pdf"
 
120
  if os.path.exists(pdf_path):
121
- with st.spinner("Initializing system..."):
 
 
 
 
122
  vectorstore = load_pdf_to_vectorstore(pdf_path)
123
  pipe = setup_model()
 
124
 
125
- # Display chat messages
 
 
 
 
126
  for message in st.session_state.messages:
127
  with st.chat_message(message["role"]):
128
  st.markdown(message["content"])
129
 
130
- # User input handling
131
  if prompt := st.chat_input("Ask your medical question:"):
132
- # Add user message to history
133
  st.session_state.messages.append({"role": "user", "content": prompt})
134
-
135
- # Generate response with conversation context
136
- with st.spinner("Analyzing question..."):
137
- response = generate_response(
138
- pipe,
139
- vectorstore,
140
- prompt,
141
- chat_history=st.session_state.messages[:-1] # Exclude current prompt
142
- )
143
-
144
- # Add and display assistant response
145
- st.session_state.messages.append({"role": "assistant", "content": response})
146
-
147
- # Display conversation
148
  with st.chat_message("user"):
149
  st.markdown(prompt)
 
 
150
  with st.chat_message("assistant"):
151
- st.markdown(response)
 
 
 
 
 
152
  else:
153
- st.error("Medical reference book not found!")
 
154
 
155
- if __name__ == "__main__":
156
- main()
 
16
  # Step 1: Load the PDF and create a vector store
17
  @st.cache_resource
18
  def load_pdf_to_vectorstore(pdf_path):
19
+ # Load and split PDF
20
  loader = PyPDFLoader(pdf_path)
21
  documents = loader.load()
22
 
 
28
 
29
  chunks = text_splitter.split_documents(documents)
30
 
31
+ # Create embeddings and vector store
32
  embeddings = HuggingFaceEmbeddings(
33
  model_name="sentence-transformers/all-MiniLM-L6-v2"
34
  )
 
39
  # Step 2: Initialize the LaMini model
40
  @st.cache_resource
41
  def setup_model():
42
+ model_id = "MBZUAI/LaMini-Flan-T5-248M" # Using smaller model for faster inference
43
  tokenizer = AutoTokenizer.from_pretrained(model_id)
44
  model = AutoModelForSeq2SeqLM.from_pretrained(
45
  model_id,
46
+ # Removed low_cpu_mem_usage parameter
47
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
48
  )
49
 
 
54
  "text2text-generation",
55
  model=model,
56
  tokenizer=tokenizer,
57
+ max_length=256,
58
  do_sample=False,
59
  temperature=0.3,
60
  top_p=0.95,
61
  device=0 if torch.cuda.is_available() else -1,
62
+ batch_size=1
63
  )
64
  return pipe
65
 
66
+ # Step 3: Generate a response using the model and vector store
67
+ def generate_response(pipe, vectorstore, user_input):
68
  # Get relevant context
69
  docs = vectorstore.similarity_search(user_input, k=2)
70
  context = "\n".join([
 
72
  for doc in docs
73
  ])
74
 
75
+ # Create prompt
76
+ prompt = PromptTemplate(
77
+ input_variables=["context", "question"],
 
 
 
 
 
 
78
  template="""
79
+ Using the following medical text excerpts, answer the question.
80
+ If the information isn't clearly provided in the context, or if you're unsure, please say so and recommend consulting a healthcare professional.
81
+
82
+ Context: {context}
83
+
84
+ Question: {question}
85
+
86
+ Answer (citing relevant page numbers when possible):"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
 
89
+ # Generate response using the new method
90
+ prompt_text = prompt.format(context=context, question=user_input)
91
+ response = pipe(prompt_text)[0]['generated_text']
92
 
93
  return response
94
 
95
+ # Cache responses for repeated questions
96
+ @st.cache_data
97
+ def cached_generate_response(user_input, _pipe, _vectorstore):
98
+ return generate_response(_pipe, _vectorstore, user_input)
99
+
100
+ # Batch processing for multiple questions
101
+ def batch_generate_responses(pipe, vectorstore, questions, batch_size=4):
102
+ responses = []
103
+ for i in range(0, len(questions), batch_size):
104
+ batch = questions[i:i + batch_size]
105
+ batch_responses = [generate_response(pipe, vectorstore, q) for q in batch]
106
+ responses.extend(batch_responses)
107
+ return responses
108
+
109
+ # Streamlit UI
110
  def main():
111
  st.title("Medical Chatbot Assistant 🏥")
112
 
113
+ # Use the PDF file from the root directory
 
 
 
 
114
  pdf_path = "Medical_book.pdf"
115
+
116
  if os.path.exists(pdf_path):
117
+ # Initialize progress
118
+ progress_text = "Operation in progress. Please wait."
119
+
120
+ # Load vector store and model with progress indication
121
+ with st.spinner("Loading PDF and initializing model..."):
122
  vectorstore = load_pdf_to_vectorstore(pdf_path)
123
  pipe = setup_model()
124
+ st.success("Ready to answer questions!")
125
 
126
+ # Create a chat-like interface
127
+ if "messages" not in st.session_state:
128
+ st.session_state.messages = []
129
+
130
+ # Display chat history
131
  for message in st.session_state.messages:
132
  with st.chat_message(message["role"]):
133
  st.markdown(message["content"])
134
 
135
+ # User input
136
  if prompt := st.chat_input("Ask your medical question:"):
137
+ # Add user message to chat history
138
  st.session_state.messages.append({"role": "user", "content": prompt})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  with st.chat_message("user"):
140
  st.markdown(prompt)
141
+
142
+ # Generate and display response
143
  with st.chat_message("assistant"):
144
+ with st.spinner("Generating response..."):
145
+ response = cached_generate_response(prompt, pipe, vectorstore)
146
+ st.markdown(response)
147
+ # Add assistant message to chat history
148
+ st.session_state.messages.append({"role": "assistant", "content": response})
149
+
150
  else:
151
+ st.error("The file 'Medical_book.pdf' was not found in the root directory.")
152
+
153
 
154
+ main()