kovacsvi commited on
Commit
247282f
·
1 Parent(s): 73953d0

also display second most probable label in table

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -103,9 +103,11 @@ def predict(text, model_id, tokenizer_id):
103
  return probs
104
 
105
 
106
- def get_most_probable_label(probs):
107
- label = id2label[probs.argmax()]
108
- probability = f"{round(100 * probs.max(), 2)}%"
 
 
109
  return label, probability
110
 
111
 
@@ -206,8 +208,9 @@ def predict_wrapper(text, language):
206
  results_heatmap = []
207
  for sentence in tqdm(sentences):
208
  probs = predict(sentence, model_id, tokenizer_id)
209
- label, probability = get_most_probable_label(probs)
210
- results.append([sentence, label, probability])
 
211
  results_heatmap.append({"sentence": sentence, "emotions": probs})
212
 
213
  # let's see...
@@ -256,8 +259,8 @@ with gr.Blocks(css=css) as demo:
256
  with gr.Row():
257
  with gr.Column(scale=7):
258
  result_table = gr.Dataframe(
259
- headers=["Sentence", "Prediction", "Confidence"],
260
- column_widths=["65%", "25%", "10%"],
261
  wrap=True, # important
262
  )
263
  with gr.Column(scale=3):
 
103
  return probs
104
 
105
 
106
+ def get_most_probable_label(probs, idx=1):
107
+ sorted_indices = probs.argsort()[::-1]
108
+ selected_idx = sorted_indices[idx - 1]
109
+ label = id2label[selected_idx]
110
+ probability = f"{round(100 * probs[selected_idx], 2)}%"
111
  return label, probability
112
 
113
 
 
208
  results_heatmap = []
209
  for sentence in tqdm(sentences):
210
  probs = predict(sentence, model_id, tokenizer_id)
211
+ label1, probability1 = get_most_probable_label(probs, 1)
212
+ label2, probability2 = get_most_probable_label(probs, 2)
213
+ results.append([sentence, label1, probability1, label2, probability2])
214
  results_heatmap.append({"sentence": sentence, "emotions": probs})
215
 
216
  # let's see...
 
259
  with gr.Row():
260
  with gr.Column(scale=7):
261
  result_table = gr.Dataframe(
262
+ headers=["Sentence", "Prediction (1)", "Confidence (1)", "Prediction (2)", "Confidence (2)"],
263
+ column_widths=["40%", "20%", "10%", "20%", "10%"],
264
  wrap=True, # important
265
  )
266
  with gr.Column(scale=3):