kovacsvi commited on
Commit
9211a01
·
1 Parent(s): 1986c88

label names, num dicts

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor.py +8 -15
interfaces/cap_minor.py CHANGED
@@ -129,7 +129,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
129
  filtered_probs = normalize_probs(filtered_probs)
130
 
131
  output_pred = {
132
- f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
133
  for k, v in sorted(
134
  filtered_probs.items(), key=lambda item: item[1], reverse=True
135
  )
@@ -161,7 +161,7 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
161
  print(filtered_probs) # debug
162
 
163
  output_pred = {
164
- f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
165
  for k, v in sorted(
166
  filtered_probs.items(), key=lambda item: item[1], reverse=True
167
  )
@@ -205,23 +205,16 @@ def predict_flat(text, model_id, tokenizer_id, HF_TOKEN=None):
205
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
206
  top_indices = np.argsort(probs)[::-1][:10]
207
 
208
- CAP_MIN_MEDIA_LABEL_NAMES = CAP_MEDIA_LABEL_NAMES | CAP_MIN_LABEL_NAMES
209
-
210
  output_pred = {}
211
  for i in top_indices:
212
- code = CAP_MIN_MEDIA_NUM_DICT[i]
213
  prob = probs[i]
214
 
215
- if code in CAP_MEDIA_LABEL_NAMES:
216
- # Media (major) topic
217
- label = CAP_MEDIA_LABEL_NAMES[code]
218
- display = f"[{code}] {label}"
219
- else:
220
- # Minor topic
221
- major_code = code // 100
222
- major_label = CAP_MEDIA_LABEL_NAMES[major_code]
223
- minor_label = CAP_MIN_LABEL_NAMES[code]
224
- display = f"[{major_code}] {major_label} [{code}] {minor_label}"
225
 
226
  output_pred[display] = prob
227
 
 
129
  filtered_probs = normalize_probs(filtered_probs)
130
 
131
  output_pred = {
132
+ f"[{major_index_to_id[k]}] {CAP_LABEL_NAMES[major_index_to_id[k]]}": v
133
  for k, v in sorted(
134
  filtered_probs.items(), key=lambda item: item[1], reverse=True
135
  )
 
161
  print(filtered_probs) # debug
162
 
163
  output_pred = {
164
+ f"[{top_major_id}] {CAP_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
165
  for k, v in sorted(
166
  filtered_probs.items(), key=lambda item: item[1], reverse=True
167
  )
 
205
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
206
  top_indices = np.argsort(probs)[::-1][:10]
207
 
 
 
208
  output_pred = {}
209
  for i in top_indices:
210
+ code = CAP_MIN_NUM_DICT[i]
211
  prob = probs[i]
212
 
213
+ # Minor topic
214
+ major_code = code // 100
215
+ major_label = CAP_LABEL_NAMES[major_code]
216
+ minor_label = CAP_MIN_LABEL_NAMES[code]
217
+ display = f"[{major_code}] {major_label} [{code}] {minor_label}"
 
 
 
 
 
218
 
219
  output_pred[display] = prob
220