atharvasc27112001 commited on
Commit
e1e0f62
Β·
verified Β·
1 Parent(s): fdd7edb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -52
app.py CHANGED
@@ -1,70 +1,89 @@
1
  import os
 
2
  import pandas as pd
 
3
  import torch
4
  import transformers
5
- from torch.nn.functional import cosine_similarity
6
  import gradio as gr
 
 
7
 
8
- # ── 1) Constants & Device ────────────────────────────────────────────────
9
- MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
10
- MIN_FREQ = 4
11
- MAX_LEN = 256
12
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
13
 
14
- # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
15
  df = pd.read_csv("medquad.csv")
16
- df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- vc = df["focus_area"].value_counts()
 
19
  keep = vc[vc >= MIN_FREQ].index
20
- df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  labels = sorted(df["focus_area"].unique())
23
  label2id = {lbl:i for i,lbl in enumerate(labels)}
24
- id2label = {i:l for l,i in label2id.items()}
25
-
26
- # ── 3) Load Tokenizer & Frozen BERT ─────────────────────────────────────
27
- tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
28
- bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()
29
 
30
  @torch.no_grad()
31
- def encode_text(s: str, max_length=MAX_LEN):
32
- toks = tokenizer(
33
- s,
34
- return_tensors="pt",
35
- truncation=True,
36
- max_length=max_length,
37
- padding=False,
38
- ).to(DEVICE)
39
- hidden = bert_model(**toks).last_hidden_state
40
- return hidden[:,0].squeeze().cpu()
41
 
42
- # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
43
- label_embs = torch.stack([encode_text(lbl, max_length=16) for lbl in labels])
 
 
 
44
 
45
- # ── 5) Classification Function ────────────────────────────────────────────
46
- def predict_disease(symptoms: str):
47
- if not symptoms.strip():
 
48
  return "❗️ Please enter your symptoms."
49
  try:
50
- # 1) embed user text β†’ [hidden_size]
51
- q_emb = encode_text(symptoms)
52
-
53
- # 2) compute cosine similarities β†’ [num_labels]
54
- sims = cosine_similarity(
55
- label_embs, # [num_labels, hidden_size]
56
- q_emb.unsqueeze(0), # [1, hidden_size]
57
- dim=1
58
- )
59
-
60
- # 3) pick the best label index
61
- best = sims.argmax().item()
62
- return id2label[best]
63
-
64
  except Exception as e:
65
  return f"Error: {e}"
66
 
67
- # ── 6) Gradio Interface ───────────────────────────────────────────────────
68
  app = gr.Interface(
69
  fn=predict_disease,
70
  inputs=gr.Textbox(
@@ -73,12 +92,10 @@ app = gr.Interface(
73
  ),
74
  outputs="text",
75
  title="πŸ”¬ Symptomβ†’Disease Chatbot",
76
- description="PubMed-BERT + cosine similarity"
77
  )
78
 
79
- # ── 7) Launch ─────────────────────────────────────────────────────────────
80
- app.launch(
81
- server_name="0.0.0.0",
82
- server_port=int(os.environ.get("PORT", 7860)),
83
- share=False
84
- )
 
1
  import os
2
+ import re, random, hashlib
3
  import pandas as pd
4
+ import numpy as np
5
  import torch
6
  import transformers
 
7
  import gradio as gr
8
+ from torch import nn
9
+ from torch.nn.functional import cosine_similarity
10
 
11
+ # ── Configuration ────────────────────────────────────────────────────────────
12
+ MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ MIN_FREQ = 4
15
+ MAX_LEN = 256
16
+ VERBALIZE_LABEL = True
17
 
18
+ # ── 1) Load & Clean Data ─────────────────────────────────────────────────────
19
  df = pd.read_csv("medquad.csv")
20
+ # build text field
21
+ df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip()
22
+ df = df.dropna(subset=["text"]).reset_index(drop=True)
23
+
24
+ # normalize hyphens/spaces in both text and labels
25
+ dash_pat = r"[-‐-–—]"
26
+ df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)
27
+ df["focus_area"] = (
28
+ df["focus_area"]
29
+ .fillna("")
30
+ .astype(str)
31
+ .str.replace(dash_pat, " ", regex=True)
32
+ .str.lower()
33
+ .str.replace(r"\s+", " ", regex=True)
34
+ .str.strip()
35
+ )
36
 
37
+ # prune rare labels
38
+ vc = df["focus_area"].value_counts()
39
  keep = vc[vc >= MIN_FREQ].index
40
+ df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
41
+
42
+ # ── 2) Tokenizer & Frozen BERT ───────────────────────────────────────────────
43
+ tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
44
+ bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \
45
+ .to(DEVICE).eval()
46
+
47
+ # ── 3) Label ↔ ID maps & label embeddings ────────────────────────────────────
48
+ def verbalise(lbl: str) -> str:
49
+ if VERBALIZE_LABEL:
50
+ return f"This question is about the medical focus area of {lbl}."
51
+ return lbl
52
 
53
  labels = sorted(df["focus_area"].unique())
54
  label2id = {lbl:i for i,lbl in enumerate(labels)}
55
+ id2label = {i:lbl for lbl,i in label2id.items()}
 
 
 
 
56
 
57
  @torch.no_grad()
58
+ def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor:
59
+ toks = tokenizer(s, return_tensors="pt",
60
+ truncation=True, max_length=max_length,
61
+ padding=False).to(DEVICE)
62
+ out = bert_model(**toks).last_hidden_state[:,0] # CLS
63
+ return out.squeeze().cpu()
 
 
 
 
64
 
65
+ # precompute one vector per label
66
+ label_embs = torch.stack([
67
+ encode_text(verbalise(lbl), max_length=32)
68
+ for lbl in labels
69
+ ])
70
 
71
+ # ── 4) Prediction function ──────────────────────────────────────────────────
72
+ def predict_disease(symptoms: str) -> str:
73
+ symptoms = symptoms.strip()
74
+ if not symptoms:
75
  return "❗️ Please enter your symptoms."
76
  try:
77
+ # embed user input
78
+ q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden]
79
+ # cosine with each label embedding
80
+ sims = cosine_similarity(label_embs, q_emb, dim=1) # [num_labels]
81
+ idx = sims.argmax().item()
82
+ return labels[idx]
 
 
 
 
 
 
 
 
83
  except Exception as e:
84
  return f"Error: {e}"
85
 
86
+ # ── 5) Gradio App ───────────────────────────────────────────────────────────
87
  app = gr.Interface(
88
  fn=predict_disease,
89
  inputs=gr.Textbox(
 
92
  ),
93
  outputs="text",
94
  title="πŸ”¬ Symptomβ†’Disease Chatbot",
95
+ description="PubMed-BERT frozen embeddings + cosine similarity"
96
  )
97
 
98
+ if __name__ == "__main__":
99
+ app.launch(server_name="0.0.0.0",
100
+ server_port=int(os.environ.get("PORT", 7860)),
101
+ share=False)