atharvasc27112001 commited on
Commit
c3c7cca
Β·
verified Β·
1 Parent(s): d1af83e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -15
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import pandas as pd
2
  import torch
3
  import transformers
@@ -13,7 +14,6 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
  # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
14
  df = pd.read_csv("medquad.csv")
15
  df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
16
-
17
  vc = df["focus_area"].value_counts()
18
  keep = vc[vc >= MIN_FREQ].index
19
  df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
@@ -36,7 +36,7 @@ def encode_text(s: str, max_length=MAX_LEN):
36
  add_special_tokens=True
37
  ).to(DEVICE)
38
  hidden = bert_model(**toks).last_hidden_state
39
- return hidden[:,0].squeeze().cpu() # CLS token
40
 
41
  # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
42
  label_embs = torch.stack([
@@ -44,28 +44,29 @@ label_embs = torch.stack([
44
  for lbl in labels
45
  ])
46
 
47
- # ── 5) Classification Function ────────────────────────────────────────────
48
  def predict_disease(symptoms: str):
49
- """
50
- Encode the user's input, compute cosine similarity
51
- to each label embedding, and return the top label.
52
- """
53
- q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden_size]
54
- sims = cosine_similarity(q_emb, label_embs) # [1, num_labels]
55
  idx = sims.argmax(dim=1).item()
56
  return id2label[idx]
57
 
58
  # ── 6) Gradio Interface ───────────────────────────────────────────────────
59
- iface = gr.Interface(
 
60
  fn=predict_disease,
61
  inputs=gr.Textbox(
62
  lines=3,
63
- placeholder="Enter your symptoms here, e.g.\n'I have eye pain and blurred vision...'"
64
  ),
65
  outputs="text",
66
- title="πŸ”¬ Medical Symptomβ†’Disease Chatbot",
67
- description="Type your symptoms; PubMed‐BERT + cosine similarity predicts the most likely disease category."
68
  )
69
 
70
- if __name__ == "__main__":
71
- iface.launch()
 
 
 
 
 
1
+ import os
2
  import pandas as pd
3
  import torch
4
  import transformers
 
14
  # ── 2) Load & Filter Dataset ─────────────────────────────────────────────
15
  df = pd.read_csv("medquad.csv")
16
  df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
 
17
  vc = df["focus_area"].value_counts()
18
  keep = vc[vc >= MIN_FREQ].index
19
  df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
 
36
  add_special_tokens=True
37
  ).to(DEVICE)
38
  hidden = bert_model(**toks).last_hidden_state
39
+ return hidden[:,0].squeeze().cpu()
40
 
41
  # ── 4) Precompute Static Label Embeddings ─────────────────────────────────
42
  label_embs = torch.stack([
 
44
  for lbl in labels
45
  ])
46
 
47
+ # ── 5) Classification fn ───────────────────────────────────────────────────
48
  def predict_disease(symptoms: str):
49
+ q_emb = encode_text(symptoms).unsqueeze(0)
50
+ sims = cosine_similarity(q_emb, label_embs)
 
 
 
 
51
  idx = sims.argmax(dim=1).item()
52
  return id2label[idx]
53
 
54
  # ── 6) Gradio Interface ───────────────────────────────────────────────────
55
+ # Rename to "app" so Hugging Face picks it up automatically:
56
+ app = gr.Interface(
57
  fn=predict_disease,
58
  inputs=gr.Textbox(
59
  lines=3,
60
+ placeholder="Enter your symptoms here…"
61
  ),
62
  outputs="text",
63
+ title="πŸ”¬ Symptomβ†’Disease Chatbot",
64
+ description="PubMed-BERT + cosine similarity"
65
  )
66
 
67
+ # Explicitly launch on HF’s host/port
68
+ app.launch(
69
+ server_name="0.0.0.0",
70
+ server_port=int(os.environ.get("PORT", 7860)),
71
+ share=False
72
+ )