Update app.py
Browse files
app.py
CHANGED
@@ -1,70 +1,89 @@
|
|
1 |
import os
|
|
|
2 |
import pandas as pd
|
|
|
3 |
import torch
|
4 |
import transformers
|
5 |
-
from torch.nn.functional import cosine_similarity
|
6 |
import gradio as gr
|
|
|
|
|
7 |
|
8 |
-
# ββ
|
9 |
-
MODEL_NAME
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
-
# ββ
|
15 |
df = pd.read_csv("medquad.csv")
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
|
|
19 |
keep = vc[vc >= MIN_FREQ].index
|
20 |
-
df
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
labels = sorted(df["focus_area"].unique())
|
23 |
label2id = {lbl:i for i,lbl in enumerate(labels)}
|
24 |
-
id2label = {i:
|
25 |
-
|
26 |
-
# ββ 3) Load Tokenizer & Frozen BERT βββββββββββββββββββββββββββββββββββββ
|
27 |
-
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
28 |
-
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()
|
29 |
|
30 |
@torch.no_grad()
|
31 |
-
def encode_text(s: str, max_length=MAX_LEN):
|
32 |
-
toks
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
padding=False,
|
38 |
-
).to(DEVICE)
|
39 |
-
hidden = bert_model(**toks).last_hidden_state
|
40 |
-
return hidden[:,0].squeeze().cpu()
|
41 |
|
42 |
-
#
|
43 |
-
label_embs = torch.stack([
|
|
|
|
|
|
|
44 |
|
45 |
-
# ββ
|
46 |
-
def predict_disease(symptoms: str):
|
47 |
-
|
|
|
48 |
return "βοΈ Please enter your symptoms."
|
49 |
try:
|
50 |
-
#
|
51 |
-
q_emb = encode_text(symptoms)
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
q_emb.unsqueeze(0), # [1, hidden_size]
|
57 |
-
dim=1
|
58 |
-
)
|
59 |
-
|
60 |
-
# 3) pick the best label index
|
61 |
-
best = sims.argmax().item()
|
62 |
-
return id2label[best]
|
63 |
-
|
64 |
except Exception as e:
|
65 |
return f"Error: {e}"
|
66 |
|
67 |
-
# ββ
|
68 |
app = gr.Interface(
|
69 |
fn=predict_disease,
|
70 |
inputs=gr.Textbox(
|
@@ -73,12 +92,10 @@ app = gr.Interface(
|
|
73 |
),
|
74 |
outputs="text",
|
75 |
title="π¬ SymptomβDisease Chatbot",
|
76 |
-
description="PubMed-BERT + cosine similarity"
|
77 |
)
|
78 |
|
79 |
-
|
80 |
-
app.launch(
|
81 |
-
|
82 |
-
|
83 |
-
share=False
|
84 |
-
)
|
|
|
1 |
import os
|
2 |
+
import re, random, hashlib
|
3 |
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
import torch
|
6 |
import transformers
|
|
|
7 |
import gradio as gr
|
8 |
+
from torch import nn
|
9 |
+
from torch.nn.functional import cosine_similarity
|
10 |
|
11 |
+
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
12 |
+
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
|
13 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
MIN_FREQ = 4
|
15 |
+
MAX_LEN = 256
|
16 |
+
VERBALIZE_LABEL = True
|
17 |
|
18 |
+
# ββ 1) Load & Clean Data βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
19 |
df = pd.read_csv("medquad.csv")
|
20 |
+
# build text field
|
21 |
+
df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip()
|
22 |
+
df = df.dropna(subset=["text"]).reset_index(drop=True)
|
23 |
+
|
24 |
+
# normalize hyphens/spaces in both text and labels
|
25 |
+
dash_pat = r"[-β-ββ]"
|
26 |
+
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)
|
27 |
+
df["focus_area"] = (
|
28 |
+
df["focus_area"]
|
29 |
+
.fillna("")
|
30 |
+
.astype(str)
|
31 |
+
.str.replace(dash_pat, " ", regex=True)
|
32 |
+
.str.lower()
|
33 |
+
.str.replace(r"\s+", " ", regex=True)
|
34 |
+
.str.strip()
|
35 |
+
)
|
36 |
|
37 |
+
# prune rare labels
|
38 |
+
vc = df["focus_area"].value_counts()
|
39 |
keep = vc[vc >= MIN_FREQ].index
|
40 |
+
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
|
41 |
+
|
42 |
+
# ββ 2) Tokenizer & Frozen BERT βββββββββββββββββββββββββββββββββββββββββββββββ
|
43 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
|
44 |
+
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \
|
45 |
+
.to(DEVICE).eval()
|
46 |
+
|
47 |
+
# ββ 3) Label β ID maps & label embeddings ββββββββββββββββββββββββββββββββββββ
|
48 |
+
def verbalise(lbl: str) -> str:
|
49 |
+
if VERBALIZE_LABEL:
|
50 |
+
return f"This question is about the medical focus area of {lbl}."
|
51 |
+
return lbl
|
52 |
|
53 |
labels = sorted(df["focus_area"].unique())
|
54 |
label2id = {lbl:i for i,lbl in enumerate(labels)}
|
55 |
+
id2label = {i:lbl for lbl,i in label2id.items()}
|
|
|
|
|
|
|
|
|
56 |
|
57 |
@torch.no_grad()
|
58 |
+
def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor:
|
59 |
+
toks = tokenizer(s, return_tensors="pt",
|
60 |
+
truncation=True, max_length=max_length,
|
61 |
+
padding=False).to(DEVICE)
|
62 |
+
out = bert_model(**toks).last_hidden_state[:,0] # CLS
|
63 |
+
return out.squeeze().cpu()
|
|
|
|
|
|
|
|
|
64 |
|
65 |
+
# precompute one vector per label
|
66 |
+
label_embs = torch.stack([
|
67 |
+
encode_text(verbalise(lbl), max_length=32)
|
68 |
+
for lbl in labels
|
69 |
+
])
|
70 |
|
71 |
+
# ββ 4) Prediction function ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
72 |
+
def predict_disease(symptoms: str) -> str:
|
73 |
+
symptoms = symptoms.strip()
|
74 |
+
if not symptoms:
|
75 |
return "βοΈ Please enter your symptoms."
|
76 |
try:
|
77 |
+
# embed user input
|
78 |
+
q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden]
|
79 |
+
# cosine with each label embedding
|
80 |
+
sims = cosine_similarity(label_embs, q_emb, dim=1) # [num_labels]
|
81 |
+
idx = sims.argmax().item()
|
82 |
+
return labels[idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
except Exception as e:
|
84 |
return f"Error: {e}"
|
85 |
|
86 |
+
# ββ 5) Gradio App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
87 |
app = gr.Interface(
|
88 |
fn=predict_disease,
|
89 |
inputs=gr.Textbox(
|
|
|
92 |
),
|
93 |
outputs="text",
|
94 |
title="π¬ SymptomβDisease Chatbot",
|
95 |
+
description="PubMed-BERT frozen embeddings + cosine similarity"
|
96 |
)
|
97 |
|
98 |
+
if __name__ == "__main__":
|
99 |
+
app.launch(server_name="0.0.0.0",
|
100 |
+
server_port=int(os.environ.get("PORT", 7860)),
|
101 |
+
share=False)
|
|
|
|