atharvasc27112001 commited on
Commit
14b56fb
Β·
verified Β·
1 Parent(s): c3c7cca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -13
app.py CHANGED
@@ -14,6 +14,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
15
  df = pd.read_csv("medquad.csv")
16
  df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
 
17
  vc = df["focus_area"].value_counts()
18
  keep = vc[vc >= MIN_FREQ].index
19
  df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
@@ -28,31 +29,42 @@ bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval(
28
 
29
  @torch.no_grad()
30
  def encode_text(s: str, max_length=MAX_LEN):
31
- toks = tokenizer(
32
  s,
33
  return_tensors="pt",
34
  truncation=True,
35
  max_length=max_length,
36
- add_special_tokens=True
37
  ).to(DEVICE)
38
  hidden = bert_model(**toks).last_hidden_state
39
  return hidden[:,0].squeeze().cpu()
40
 
41
  # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
42
- label_embs = torch.stack([
43
- encode_text(lbl, max_length=16)
44
- for lbl in labels
45
- ])
46
 
47
- # ── 5) Classification fn ───────────────────────────────────────────────────
48
  def predict_disease(symptoms: str):
49
- q_emb = encode_text(symptoms).unsqueeze(0)
50
- sims = cosine_similarity(q_emb, label_embs)
51
- idx = sims.argmax(dim=1).item()
52
- return id2label[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # ── 6) Gradio Interface ───────────────────────────────────────────────────
55
- # Rename to "app" so Hugging Face picks it up automatically:
56
  app = gr.Interface(
57
  fn=predict_disease,
58
  inputs=gr.Textbox(
@@ -64,7 +76,7 @@ app = gr.Interface(
64
  description="PubMed-BERT + cosine similarity"
65
  )
66
 
67
- # Explicitly launch on HF’s host/port
68
  app.launch(
69
  server_name="0.0.0.0",
70
  server_port=int(os.environ.get("PORT", 7860)),
 
14
  # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
15
  df = pd.read_csv("medquad.csv")
16
  df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
17
+
18
  vc = df["focus_area"].value_counts()
19
  keep = vc[vc >= MIN_FREQ].index
20
  df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
 
29
 
30
  @torch.no_grad()
31
  def encode_text(s: str, max_length=MAX_LEN):
32
+ toks = tokenizer(
33
  s,
34
  return_tensors="pt",
35
  truncation=True,
36
  max_length=max_length,
37
+ padding=False,
38
  ).to(DEVICE)
39
  hidden = bert_model(**toks).last_hidden_state
40
  return hidden[:,0].squeeze().cpu()
41
 
42
  # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
43
+ label_embs = torch.stack([encode_text(lbl, max_length=16) for lbl in labels])
 
 
 
44
 
45
+ # ── 5) Classification Function ────────────────────────────────────────────
46
  def predict_disease(symptoms: str):
47
+ if not symptoms.strip():
48
+ return "❗️ Please enter your symptoms."
49
+ try:
50
+ # 1) embed user text β†’ [hidden_size]
51
+ q_emb = encode_text(symptoms)
52
+
53
+ # 2) compute cosine similarities β†’ [num_labels]
54
+ sims = cosine_similarity(
55
+ label_embs, # [num_labels, hidden_size]
56
+ q_emb.unsqueeze(0), # [1, hidden_size]
57
+ dim=1
58
+ )
59
+
60
+ # 3) pick the best label index
61
+ best = sims.argmax().item()
62
+ return id2label[best]
63
+
64
+ except Exception as e:
65
+ return f"Error: {e}"
66
 
67
  # ── 6) Gradio Interface ───────────────────────────────────────────────────
 
68
  app = gr.Interface(
69
  fn=predict_disease,
70
  inputs=gr.Textbox(
 
76
  description="PubMed-BERT + cosine similarity"
77
  )
78
 
79
+ # ── 7) Launch ─────────────────────────────────────────────────────────────
80
  app.launch(
81
  server_name="0.0.0.0",
82
  server_port=int(os.environ.get("PORT", 7860)),