Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -14,6 +14,7 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
14 |
# ββ 2) Load & Filter Dataset βββββββββββββββββββββββββββββββββββββββββββββ
|
15 |
df = pd.read_csv("medquad.csv")
|
16 |
df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
|
|
|
17 |
vc = df["focus_area"].value_counts()
|
18 |
keep = vc[vc >= MIN_FREQ].index
|
19 |
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
|
@@ -28,31 +29,42 @@ bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval(
|
|
28 |
|
29 |
@torch.no_grad()
|
30 |
def encode_text(s: str, max_length=MAX_LEN):
|
31 |
-
toks
|
32 |
s,
|
33 |
return_tensors="pt",
|
34 |
truncation=True,
|
35 |
max_length=max_length,
|
36 |
-
|
37 |
).to(DEVICE)
|
38 |
hidden = bert_model(**toks).last_hidden_state
|
39 |
return hidden[:,0].squeeze().cpu()
|
40 |
|
41 |
# ββ 4) Precompute Static Label Embeddings βββββββββββββββββββββββββββββββββ
|
42 |
-
label_embs = torch.stack([
|
43 |
-
encode_text(lbl, max_length=16)
|
44 |
-
for lbl in labels
|
45 |
-
])
|
46 |
|
47 |
-
# ββ 5) Classification
|
48 |
def predict_disease(symptoms: str):
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
# ββ 6) Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
55 |
-
# Rename to "app" so Hugging Face picks it up automatically:
|
56 |
app = gr.Interface(
|
57 |
fn=predict_disease,
|
58 |
inputs=gr.Textbox(
|
@@ -64,7 +76,7 @@ app = gr.Interface(
|
|
64 |
description="PubMed-BERT + cosine similarity"
|
65 |
)
|
66 |
|
67 |
-
#
|
68 |
app.launch(
|
69 |
server_name="0.0.0.0",
|
70 |
server_port=int(os.environ.get("PORT", 7860)),
|
|
|
14 |
# ββ 2) Load & Filter Dataset βββββββββββββββββββββββββββββββββββββββββββββ
|
15 |
df = pd.read_csv("medquad.csv")
|
16 |
df["text"] = df["question"].str.strip() + " " + df["answer"].str.strip()
|
17 |
+
|
18 |
vc = df["focus_area"].value_counts()
|
19 |
keep = vc[vc >= MIN_FREQ].index
|
20 |
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
|
|
|
29 |
|
30 |
@torch.no_grad()
|
31 |
def encode_text(s: str, max_length=MAX_LEN):
|
32 |
+
toks = tokenizer(
|
33 |
s,
|
34 |
return_tensors="pt",
|
35 |
truncation=True,
|
36 |
max_length=max_length,
|
37 |
+
padding=False,
|
38 |
).to(DEVICE)
|
39 |
hidden = bert_model(**toks).last_hidden_state
|
40 |
return hidden[:,0].squeeze().cpu()
|
41 |
|
42 |
# ββ 4) Precompute Static Label Embeddings βββββββββββββββββββββββββββββββββ
|
43 |
+
label_embs = torch.stack([encode_text(lbl, max_length=16) for lbl in labels])
|
|
|
|
|
|
|
44 |
|
45 |
+
# ββ 5) Classification Function ββββββββββββββββββββββββββββββββββββββββββββ
|
46 |
def predict_disease(symptoms: str):
|
47 |
+
if not symptoms.strip():
|
48 |
+
return "βοΈ Please enter your symptoms."
|
49 |
+
try:
|
50 |
+
# 1) embed user text β [hidden_size]
|
51 |
+
q_emb = encode_text(symptoms)
|
52 |
+
|
53 |
+
# 2) compute cosine similarities β [num_labels]
|
54 |
+
sims = cosine_similarity(
|
55 |
+
label_embs, # [num_labels, hidden_size]
|
56 |
+
q_emb.unsqueeze(0), # [1, hidden_size]
|
57 |
+
dim=1
|
58 |
+
)
|
59 |
+
|
60 |
+
# 3) pick the best label index
|
61 |
+
best = sims.argmax().item()
|
62 |
+
return id2label[best]
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
return f"Error: {e}"
|
66 |
|
67 |
# ββ 6) Gradio Interface βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
68 |
app = gr.Interface(
|
69 |
fn=predict_disease,
|
70 |
inputs=gr.Textbox(
|
|
|
76 |
description="PubMed-BERT + cosine similarity"
|
77 |
)
|
78 |
|
79 |
+
# ββ 7) Launch βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
80 |
app.launch(
|
81 |
server_name="0.0.0.0",
|
82 |
server_port=int(os.environ.get("PORT", 7860)),
|