Spaces:
Running
Running
kovacsvi
commited on
Commit
·
247282f
1
Parent(s):
73953d0
also display second most probable label in table
Browse files
app.py
CHANGED
@@ -103,9 +103,11 @@ def predict(text, model_id, tokenizer_id):
|
|
103 |
return probs
|
104 |
|
105 |
|
106 |
-
def get_most_probable_label(probs):
|
107 |
-
|
108 |
-
|
|
|
|
|
109 |
return label, probability
|
110 |
|
111 |
|
@@ -206,8 +208,9 @@ def predict_wrapper(text, language):
|
|
206 |
results_heatmap = []
|
207 |
for sentence in tqdm(sentences):
|
208 |
probs = predict(sentence, model_id, tokenizer_id)
|
209 |
-
|
210 |
-
|
|
|
211 |
results_heatmap.append({"sentence": sentence, "emotions": probs})
|
212 |
|
213 |
# let's see...
|
@@ -256,8 +259,8 @@ with gr.Blocks(css=css) as demo:
|
|
256 |
with gr.Row():
|
257 |
with gr.Column(scale=7):
|
258 |
result_table = gr.Dataframe(
|
259 |
-
headers=["Sentence", "Prediction", "Confidence"],
|
260 |
-
column_widths=["
|
261 |
wrap=True, # important
|
262 |
)
|
263 |
with gr.Column(scale=3):
|
|
|
103 |
return probs
|
104 |
|
105 |
|
106 |
+
def get_most_probable_label(probs, idx=1):
|
107 |
+
sorted_indices = probs.argsort()[::-1]
|
108 |
+
selected_idx = sorted_indices[idx - 1]
|
109 |
+
label = id2label[selected_idx]
|
110 |
+
probability = f"{round(100 * probs[selected_idx], 2)}%"
|
111 |
return label, probability
|
112 |
|
113 |
|
|
|
208 |
results_heatmap = []
|
209 |
for sentence in tqdm(sentences):
|
210 |
probs = predict(sentence, model_id, tokenizer_id)
|
211 |
+
label1, probability1 = get_most_probable_label(probs, 1)
|
212 |
+
label2, probability2 = get_most_probable_label(probs, 2)
|
213 |
+
results.append([sentence, label1, probability1, label2, probability2])
|
214 |
results_heatmap.append({"sentence": sentence, "emotions": probs})
|
215 |
|
216 |
# let's see...
|
|
|
259 |
with gr.Row():
|
260 |
with gr.Column(scale=7):
|
261 |
result_table = gr.Dataframe(
|
262 |
+
headers=["Sentence", "Prediction (1)", "Confidence (1)", "Prediction (2)", "Confidence (2)"],
|
263 |
+
column_widths=["40%", "20%", "10%", "20%", "10%"],
|
264 |
wrap=True, # important
|
265 |
)
|
266 |
with gr.Column(scale=3):
|