kovacsvi commited on
Commit
8cbc9a3
·
1 Parent(s): c14e676

cap minor hierarchical

Browse files
Files changed (1) hide show
  1. interfaces/cap_minor.py +167 -15
interfaces/cap_minor.py CHANGED
@@ -6,8 +6,8 @@ import numpy as np
6
  import pandas as pd
7
  from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
 
9
  from huggingface_hub import HfApi
10
-
11
  from collections import defaultdict
12
 
13
  from label_dicts import (
@@ -41,12 +41,23 @@ domains = {
41
  }
42
 
43
 
44
- def get_label_name(idx):
45
- minor_code = CAP_MIN_NUM_DICT[idx]
46
- minor_label_name = CAP_MIN_LABEL_NAMES[minor_code]
47
- major_code = minor_code // 100 if minor_code not in [99, 999, 9999] else 999
48
- major_label_name = CAP_LABEL_NAMES[major_code]
49
- return f"[{major_code}] {major_label_name} [{minor_code}] {minor_label_name}"
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  def check_huggingface_path(checkpoint_path: str):
@@ -64,7 +75,99 @@ def build_huggingface_path(language: str, domain: str):
64
  return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"
65
 
66
 
67
- def predict(text, model_id, tokenizer_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  device = torch.device("cpu")
69
 
70
  # Load JIT-traced model
@@ -89,28 +192,77 @@ def predict(text, model_id, tokenizer_id):
89
  release_model(model, model_id)
90
 
91
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- output_pred = {get_label_name(i): probs[i] for i in np.argsort(probs)[::-1]}
94
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
95
- return output_pred, output_info
96
 
 
97
 
98
- def predict_cap(text, language, domain):
99
- domain = domains[domain]
100
- model_id = build_huggingface_path(language, domain)
101
- tokenizer_id = "xlm-roberta-large"
102
 
 
103
  if is_disk_full():
104
  os.system("rm -rf /data/models*")
105
  os.system("rm -r ~/.cache/huggingface/hub")
106
 
107
- return predict(text, model_id, tokenizer_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
 
 
 
109
 
110
  demo = gr.Interface(
111
  title="CAP Minor Topics Babel Demo",
112
  fn=predict_cap,
113
  inputs=[
 
 
 
 
 
 
114
  gr.Textbox(lines=6, label="Input"),
115
  gr.Dropdown(languages, label="Language", value=languages[0]),
116
  gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0]),
 
6
  import pandas as pd
7
  from transformers import AutoModelForSequenceClassification
8
  from transformers import AutoTokenizer
9
+ import torch.nn.functional as F
10
  from huggingface_hub import HfApi
 
11
  from collections import defaultdict
12
 
13
  from label_dicts import (
 
41
  }
42
 
43
 
44
+ CAP_MEDIA_CODES = list(CAP_MEDIA_NUM_DICT.values())
45
+ CAP_MIN_CODES = list(CAP_MIN_NUM_DICT.values())
46
+
47
+ major_index_to_id = {i: code for i, code in enumerate(CAP_MEDIA_CODES)}
48
+ minor_id_to_index = {code: i for i, code in enumerate(CAP_MIN_CODES)}
49
+ minor_index_to_id = {i: code for i, code in enumerate(CAP_MIN_CODES)}
50
+
51
+ major_to_minor_map = defaultdict(list)
52
+ for code in CAP_MIN_CODES:
53
+ major_id = int(str(code)[:-2])
54
+ major_to_minor_map[major_id].append(code)
55
+ major_to_minor_map = dict(major_to_minor_map)
56
+
57
+
58
+ def normalize_probs(probs: dict) -> dict:
59
+ total = sum(probs.values())
60
+ return {k: v / total for k, v in probs.items()}
61
 
62
 
63
  def check_huggingface_path(checkpoint_path: str):
 
75
  return "poltextlab/xlm-roberta-large-pooled-cap-minor-v3"
76
 
77
 
78
+ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
79
+ device = torch.device("cpu")
80
+
81
+ # Load major and minor models + tokenizer
82
+ major_model = AutoModelForSequenceClassification.from_pretrained(
83
+ major_model_id,
84
+ low_cpu_mem_usage=True,
85
+ device_map="auto",
86
+ offload_folder="offload",
87
+ token=HF_TOKEN,
88
+ ).to(device)
89
+
90
+ minor_model = AutoModelForSequenceClassification.from_pretrained(
91
+ minor_model_id,
92
+ low_cpu_mem_usage=True,
93
+ device_map="auto",
94
+ offload_folder="offload",
95
+ token=HF_TOKEN,
96
+ ).to(device)
97
+
98
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
99
+
100
+ # Tokenize input
101
+ inputs = tokenizer(
102
+ text, max_length=64, truncation=True, padding=True, return_tensors="pt"
103
+ ).to(device)
104
+
105
+ # Predict major topic
106
+ major_model.eval()
107
+ with torch.no_grad():
108
+ major_logits = major_model(**inputs).logits
109
+ major_probs = F.softmax(major_logits, dim=-1)
110
+ major_probs_np = major_probs.cpu().numpy().flatten()
111
+ top_major_index = int(np.argmax(major_probs_np))
112
+ top_major_id = major_index_to_id[top_major_index]
113
+
114
+ # Default: show major topic predictions
115
+ filtered_probs = {
116
+ i: float(major_probs_np[i]) for i in np.argsort(major_probs_np)[::-1]
117
+ }
118
+ filtered_probs = normalize_probs(filtered_probs)
119
+
120
+ output_pred = {
121
+ f"[{major_index_to_id[k]}] {CAP_MEDIA_LABEL_NAMES[major_index_to_id[k]]}": v
122
+ for k, v in sorted(
123
+ filtered_probs.items(), key=lambda item: item[1], reverse=True
124
+ )
125
+ }
126
+
127
+ # If eligible for minor prediction
128
+ if top_major_id in major_to_minor_map:
129
+ valid_minor_ids = major_to_minor_map[top_major_id]
130
+ minor_model.eval()
131
+ with torch.no_grad():
132
+ minor_logits = minor_model(**inputs).logits
133
+ minor_probs = F.softmax(minor_logits, dim=-1)
134
+
135
+ release_model(major_model, major_model_id)
136
+ release_model(minor_model, minor_model_id)
137
+
138
+ print(minor_probs) # debug
139
+ # Restrict to valid minor codes
140
+ valid_indices = [
141
+ minor_id_to_index[mid]
142
+ for mid in valid_minor_ids
143
+ if mid in minor_id_to_index
144
+ ]
145
+ filtered_probs = {
146
+ minor_index_to_id[i]: float(minor_probs[0][i]) for i in valid_indices
147
+ }
148
+ print(filtered_probs) # debug
149
+ filtered_probs = normalize_probs(filtered_probs)
150
+ print(filtered_probs) # debug
151
+
152
+ output_pred = {
153
+ f"[{top_major_id}] {CAP_MEDIA_LABEL_NAMES[top_major_id]} [{k}] {CAP_MIN_LABEL_NAMES[k]}": v
154
+ for k, v in sorted(
155
+ filtered_probs.items(), key=lambda item: item[1], reverse=True
156
+ )
157
+ }
158
+
159
+ output_info = f'<p style="text-align: center; display: block">Prediction used <a href="https://huggingface.co/{major_model_id}">{major_model_id}</a> and <a href="https://huggingface.co/{minor_model_id}">{minor_model_id}</a>.</p>'
160
+
161
+ interpretation_info = """
162
+ ## How to Interpret These Values (Hierarchical Classification)
163
+
164
+ The values are the confidences for minor topics **within a given major topic**, and they are **normalized to sum to 1**.
165
+ """
166
+
167
+ return interpretation_info, output_pred, output_info
168
+
169
+
170
+ def predict_flat(text, model_id, tokenizer_id, HF_TOKEN=None):
171
  device = torch.device("cpu")
172
 
173
  # Load JIT-traced model
 
192
  release_model(model, model_id)
193
 
194
  probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
195
+ top_indices = np.argsort(probs)[::-1][:10]
196
+
197
+ CAP_MIN_MEDIA_LABEL_NAMES = CAP_MEDIA_LABEL_NAMES | CAP_MIN_LABEL_NAMES
198
+
199
+ output_pred = {}
200
+ for i in top_indices:
201
+ code = CAP_MIN_MEDIA_NUM_DICT[i]
202
+ prob = probs[i]
203
+
204
+ if code in CAP_MEDIA_LABEL_NAMES:
205
+ # Media (major) topic
206
+ label = CAP_MEDIA_LABEL_NAMES[code]
207
+ display = f"[{code}] {label}"
208
+ else:
209
+ # Minor topic
210
+ major_code = code // 100
211
+ major_label = CAP_MEDIA_LABEL_NAMES[major_code]
212
+ minor_label = CAP_MIN_LABEL_NAMES[code]
213
+ display = f"[{major_code}] {major_label} [{code}] {minor_label}"
214
+
215
+ output_pred[display] = prob
216
+
217
+ interpretation_info = """
218
+ ## How to Interpret These Values (Flat Classification)
219
+
220
+ This method returns predictions made by a single model. **Only the top 10 most confident labels are displayed**.
221
+ """
222
 
 
223
  output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
 
224
 
225
+ return interpretation_info, output_pred, output_info
226
 
 
 
 
 
227
 
228
+ def predict_cap(tmp, method, text, language, domain):
229
  if is_disk_full():
230
  os.system("rm -rf /data/models*")
231
  os.system("rm -r ~/.cache/huggingface/hub")
232
 
233
+ domain = domains[domain]
234
+
235
+ if method == "Hierarchical Classification":
236
+ major_model_id, minor_model_id = build_huggingface_path(language, domain, True)
237
+ tokenizer_id = "xlm-roberta-large"
238
+ return predict(text, major_model_id, minor_model_id, tokenizer_id)
239
+
240
+ else:
241
+ model_id = build_huggingface_path(language, domain, False)
242
+ tokenizer_id = "xlm-roberta-large"
243
+ return predict_flat(text, model_id, tokenizer_id)
244
+
245
+
246
+ description = """
247
+ You can choose between two approaches for making predictions:
248
+
249
+ **1. Hierarchical Classification**
250
+ First, the model predicts a **major topic**. Then, a second model selects the most probable **subtopic** from within that major topic's category.
251
 
252
+ **2. Flat Classification (single model)**
253
+ A single model directly predicts the most relevant label from all available minor topics.
254
+ """
255
 
256
  demo = gr.Interface(
257
  title="CAP Minor Topics Babel Demo",
258
  fn=predict_cap,
259
  inputs=[
260
+ gr.Markdown(description),
261
+ gr.Radio(
262
+ choices=["Hierarchical Classification", "Flat Classification"],
263
+ label="Prediction Mode",
264
+ value="Hierarchical Classification",
265
+ ),
266
  gr.Textbox(lines=6, label="Input"),
267
  gr.Dropdown(languages, label="Language", value=languages[0]),
268
  gr.Dropdown(domains.keys(), label="Domain", value=list(domains.keys())[0]),