atharvasc27112001 commited on
Commit
d1af83e
Β·
verified Β·
1 Parent(s): 29aa409

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import torch
3
+ import transformers
4
+ from torch.nn.functional import cosine_similarity
5
+ import gradio as gr
6
+
7
+ # ── 1) Constants & Device ────────────────────────────────────────────────
8
+ MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
9
+ MIN_FREQ = 4
10
+ MAX_LEN = 256
11
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
14
+ df = pd.read_csv("medquad.csv")
15
+ df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
16
+
17
+ vc = df["focus_area"].value_counts()
18
+ keep = vc[vc >= MIN_FREQ].index
19
+ df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
20
+
21
+ labels = sorted(df["focus_area"].unique())
22
+ label2id = {lbl:i for i,lbl in enumerate(labels)}
23
+ id2label = {i:l for l,i in label2id.items()}
24
+
25
+ # ── 3) Load Tokenizer & Frozen BERT ─────────────────────────────────────
26
+ tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
27
+ bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()
28
+
29
+ @torch.no_grad()
30
+ def encode_text(s: str, max_length=MAX_LEN):
31
+ toks = tokenizer(
32
+ s,
33
+ return_tensors="pt",
34
+ truncation=True,
35
+ max_length=max_length,
36
+ add_special_tokens=True
37
+ ).to(DEVICE)
38
+ hidden = bert_model(**toks).last_hidden_state
39
+ return hidden[:,0].squeeze().cpu() # CLS token
40
+
41
+ # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
42
+ label_embs = torch.stack([
43
+ encode_text(lbl, max_length=16)
44
+ for lbl in labels
45
+ ])
46
+
47
+ # ── 5) Classification Function ────────────────────────────────────────────
48
+ def predict_disease(symptoms: str):
49
+ """
50
+ Encode the user's input, compute cosine similarity
51
+ to each label embedding, and return the top label.
52
+ """
53
+ q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden_size]
54
+ sims = cosine_similarity(q_emb, label_embs) # [1, num_labels]
55
+ idx = sims.argmax(dim=1).item()
56
+ return id2label[idx]
57
+
58
+ # ── 6) Gradio Interface ───────────────────────────────────────────────────
59
+ iface = gr.Interface(
60
+ fn=predict_disease,
61
+ inputs=gr.Textbox(
62
+ lines=3,
63
+ placeholder="Enter your symptoms here, e.g.\n'I have eye pain and blurred vision...'"
64
+ ),
65
+ outputs="text",
66
+ title="πŸ”¬ Medical Symptomβ†’Disease Chatbot",
67
+ description="Type your symptoms; PubMed‐BERT + cosine similarity predicts the most likely disease category."
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ iface.launch()