Haseeb-001 commited on
Commit
685cf87
·
verified ·
1 Parent(s): 100f218

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +137 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
+ import numpy as np
5
+ from groq import Groq
6
+ import io
7
+ import faiss
8
+
9
+ # Groq API Key (and other initializations)
10
+ groq_api_key = os.environ.get("GROQ_API_KEY")
11
+ if groq_api_key is None:
12
+ st.error("GROQ_API_KEY environment variable not set.")
13
+ st.stop()
14
+
15
+ try:
16
+ client = Groq(api_key=groq_api_key)
17
+ except Exception as e:
18
+ st.error(f"Error initializing Groq client: {e}")
19
+ st.stop()
20
+
21
+ try:
22
+ pubmedbert_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
23
+ pubmedbert_model = AutoModelForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext")
24
+ pubmedbert_pipeline = pipeline('feature-extraction', model=pubmedbert_model, tokenizer=pubmedbert_tokenizer, device=-1)
25
+ except Exception as e:
26
+ st.error(f"Error loading PubMedBERT: {e}")
27
+ st.stop()
28
+
29
+ embedding_dim = 768
30
+ index = faiss.IndexFlatL2(embedding_dim)
31
+
32
+ if "all_conversations" not in st.session_state:
33
+ st.session_state.all_conversations = {}
34
+ if "current_conversation_id" not in st.session_state:
35
+ st.session_state.current_conversation_id = 0
36
+ if "current_conversation_messages" not in st.session_state:
37
+ st.session_state.current_conversation_messages = []
38
+ if "embeddings" not in st.session_state:
39
+ st.session_state.embeddings = []
40
+
41
+ # Functions
42
+ def preprocess_query(query):
43
+ tokens = query.lower().split()
44
+ keywords = [keyword for keyword in tokens if keyword in ["seizure", "symptoms", "jerks", "confusion", "epilepsy"]]
45
+ is_medical_related = any(keyword in keywords for keyword in ["seizure", "symptoms", "jerks", "confusion", "epilepsy", "medical"])
46
+ return tokens, keywords, is_medical_related
47
+
48
+ def generate_response(user_query):
49
+ tokens, keywords, is_medical_related = preprocess_query(user_query)
50
+ enhanced_query = " ".join(tokens)
51
+ symptom_insights = ""
52
+
53
+ conversation_history = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in st.session_state.current_conversation_messages])
54
+
55
+ if is_medical_related:
56
+ try:
57
+ pubmedbert_embeddings = pubmedbert_pipeline(user_query)
58
+ embedding_mean = np.mean(pubmedbert_embeddings[0], axis=0)
59
+ st.session_state.embeddings.append(embedding_mean)
60
+ index.add(np.array([embedding_mean]))
61
+ pubmedbert_insights = "PubMedBERT analysis..."
62
+ model_name = "PubMedBERT"
63
+ model_response = pubmedbert_insights
64
+ if "seizure" in keywords or "symptoms" in keywords:
65
+ remedy_recommendations = "\n\n**General Recommendations:**\n..."
66
+ else:
67
+ remedy_recommendations = ""
68
+ except Exception as e:
69
+ model_response = f"Error during PubMedBERT: {e}"
70
+ remedy_recommendations = ""
71
+ else:
72
+ model_name = "LLaMA 2 / Mistral 7B (via Groq)"
73
+ try:
74
+ prompt = f"""
75
+ Conversation History:
76
+ {conversation_history}
77
+
78
+ User: {user_query}
79
+ Bot:
80
+ """
81
+ chat_completion = client.chat.completions.create(
82
+ messages=[{"role": "user", "content": prompt}],
83
+ model="llama-3.3-70b-versatile",
84
+ stream=False,
85
+ )
86
+ model_response = chat_completion.choices[0].message.content.strip()
87
+ except Exception as e:
88
+ model_response = f"Error from Groq: {e}"
89
+ remedy_recommendations = ""
90
+
91
+ final_response = f"**Enhanced Query:** {enhanced_query}\n\nChatbot Analysis:...\n\nModel Response/Insights:\n{model_response}\n{remedy_recommendations}"
92
+ return final_response, model_response
93
+ # Streamlit Interface (and other parts)
94
+ st.set_page_config(page_title="Epilepsy Chatbot", layout="wide")
95
+ st.markdown("<style>.chat-message.user {background-color: #e6f7ff; padding: 8px; border-radius: 8px; margin-bottom: 8px;}.chat-message.bot {background-color: #f0f0f0; padding: 8px; border-radius: 8px; margin-bottom: 8px;}.stTextArea textarea {background-color: #f8f8f8;}</style>", unsafe_allow_html=True)
96
+
97
+ with st.sidebar:
98
+ st.title("Conversations")
99
+ if st.button("New Conversation"):
100
+ st.session_state.current_conversation_id += 1
101
+ st.session_state.current_conversation_messages = []
102
+ st.session_state.embeddings = []
103
+ index.reset()
104
+ for conv_id in st.session_state.all_conversations:
105
+ if st.button(f"Conversation {conv_id}"):
106
+ st.session_state.current_conversation_id = conv_id
107
+ st.session_state.current_conversation_messages = st.session_state.all_conversations[conv_id]
108
+ st.session_state.embeddings = []
109
+ index.reset()
110
+
111
+ st.title("Epilepsy & Seizure Chatbot")
112
+ st.write("Ask questions related to epilepsy and seizures.")
113
+
114
+ for message in st.session_state.current_conversation_messages:
115
+ with st.chat_message(message["role"]):
116
+ st.markdown(message["content"])
117
+
118
+ if prompt := st.chat_input("Enter your query here:"):
119
+ st.session_state.current_conversation_messages.append({"role": "user", "content": prompt})
120
+ with st.chat_message("user"):
121
+ st.markdown(prompt)
122
+
123
+ with st.chat_message("bot"):
124
+ with st.spinner("Generating response..."):
125
+ try:
126
+ full_response, model_only_response = generate_response(prompt)
127
+ st.markdown(model_only_response)
128
+ st.session_state.current_conversation_messages.append({"role": "bot", "content": model_only_response})
129
+ except Exception as e:
130
+ st.error(f"Error processing query: {e}")
131
+
132
+ st.session_state.all_conversations[st.session_state.current_conversation_id] = st.session_state.current_conversation_messages
133
+
134
+ # Download Chat
135
+ if st.session_state.current_conversation_messages:
136
+ conversation_text = "\n".join([f"{message['role'].capitalize()}: {message['content']}" for message in st.session_state.current_conversation_messages])
137
+ st.download_button("Download Chat", data=conversation_text, file_name=f"chat_history_{st.session_state.current_conversation_id}.txt")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers
2
+ streamlit
3
+ numpy
4
+ torch
5
+ groq
6
+ faiss-cpu