Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import pandas as pd
|
2 |
import torch
|
3 |
import transformers
|
@@ -13,7 +14,6 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
13 |
# ββ 2) Load & Filter Dataset βββββββββββββββββββββββββββββββββββββββββββββ
|
14 |
df = pd.read_csv("medquad.csv")
|
15 |
df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
|
16 |
-
|
17 |
vc = df["focus_area"].value_counts()
|
18 |
keep = vc[vc >= MIN_FREQ].index
|
19 |
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
|
@@ -36,7 +36,7 @@ def encode_text(s: str, max_length=MAX_LEN):
|
|
36 |
add_special_tokens=True
|
37 |
).to(DEVICE)
|
38 |
hidden = bert_model(**toks).last_hidden_state
|
39 |
-
return hidden[:,0].squeeze().cpu()
|
40 |
|
41 |
# ββ 4) Precompute Static Label Embeddings βββββββββββββββββββββββββββββββββ
|
42 |
label_embs = torch.stack([
|
@@ -44,28 +44,29 @@ label_embs = torch.stack([
|
|
44 |
for lbl in labels
|
45 |
])
|
46 |
|
47 |
-
# ββ 5) Classification
|
48 |
def predict_disease(symptoms: str):
|
49 |
-
|
50 |
-
|
51 |
-
to each label embedding, and return the top label.
|
52 |
-
"""
|
53 |
-
q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden_size]
|
54 |
-
sims = cosine_similarity(q_emb, label_embs) # [1, num_labels]
|
55 |
idx = sims.argmax(dim=1).item()
|
56 |
return id2label[idx]
|
57 |
|
58 |
# ββ 6) Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
59 |
-
|
|
|
60 |
fn=predict_disease,
|
61 |
inputs=gr.Textbox(
|
62 |
lines=3,
|
63 |
-
placeholder="Enter your symptoms here
|
64 |
),
|
65 |
outputs="text",
|
66 |
-
title="π¬
|
67 |
-
description="
|
68 |
)
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import pandas as pd
|
3 |
import torch
|
4 |
import transformers
|
|
|
14 |
# ββ 2) Load & Filter Dataset βββββββββββββββββββββββββββββββββββββββββββββ
|
15 |
df = pd.read_csv("medquad.csv")
|
16 |
df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
|
|
|
17 |
vc = df["focus_area"].value_counts()
|
18 |
keep = vc[vc >= MIN_FREQ].index
|
19 |
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
|
|
|
36 |
add_special_tokens=True
|
37 |
).to(DEVICE)
|
38 |
hidden = bert_model(**toks).last_hidden_state
|
39 |
+
return hidden[:,0].squeeze().cpu()
|
40 |
|
41 |
# ββ 4) Precompute Static Label Embeddings βββββββββββββββββββββββββββββββββ
|
42 |
label_embs = torch.stack([
|
|
|
44 |
for lbl in labels
|
45 |
])
|
46 |
|
47 |
+
# ββ 5) Classification fn βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
48 |
def predict_disease(symptoms: str):
|
49 |
+
q_emb = encode_text(symptoms).unsqueeze(0)
|
50 |
+
sims = cosine_similarity(q_emb, label_embs)
|
|
|
|
|
|
|
|
|
51 |
idx = sims.argmax(dim=1).item()
|
52 |
return id2label[idx]
|
53 |
|
54 |
# ββ 6) Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
55 |
+
# Rename to "app" so Hugging Face picks it up automatically:
|
56 |
+
app = gr.Interface(
|
57 |
fn=predict_disease,
|
58 |
inputs=gr.Textbox(
|
59 |
lines=3,
|
60 |
+
placeholder="Enter your symptoms hereβ¦"
|
61 |
),
|
62 |
outputs="text",
|
63 |
+
title="π¬ SymptomβDisease Chatbot",
|
64 |
+
description="PubMed-BERT + cosine similarity"
|
65 |
)
|
66 |
|
67 |
+
# Explicitly launch on HFβs host/port
|
68 |
+
app.launch(
|
69 |
+
server_name="0.0.0.0",
|
70 |
+
server_port=int(os.environ.get("PORT", 7860)),
|
71 |
+
share=False
|
72 |
+
)
|