Spaces:
Running
Running
kovacsvi
commited on
Commit
·
af68a82
1
Parent(s):
ecdbfcf
flat min media pred
Browse files- interfaces/cap_minor_media.py +53 -8
- label_dicts.py +221 -0
- utils.py +4 -0
interfaces/cap_minor_media.py
CHANGED
@@ -12,7 +12,8 @@ from huggingface_hub import HfApi
|
|
12 |
from collections import defaultdict
|
13 |
|
14 |
from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
|
15 |
-
CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES
|
|
|
16 |
|
17 |
from .utils import is_disk_full
|
18 |
|
@@ -54,8 +55,11 @@ def check_huggingface_path(checkpoint_path: str):
|
|
54 |
except:
|
55 |
return False
|
56 |
|
57 |
-
def build_huggingface_path(language: str, domain: str):
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
#@spaces.GPU(duration=30)
|
61 |
def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
@@ -141,16 +145,57 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
|
141 |
|
142 |
return interpretation_info, output_pred, output_info
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
if is_disk_full():
|
150 |
os.system('rm -rf /data/models*')
|
151 |
os.system('rm -r ~/.cache/huggingface/hub')
|
152 |
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
description = """
|
156 |
You can choose between two approaches for making predictions:
|
|
|
12 |
from collections import defaultdict
|
13 |
|
14 |
from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
|
15 |
+
CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES,
|
16 |
+
CAP_MIN_MEDIA_NUM_DICT)
|
17 |
|
18 |
from .utils import is_disk_full
|
19 |
|
|
|
55 |
except:
|
56 |
return False
|
57 |
|
58 |
+
def build_huggingface_path(language: str, domain: str, hierarchical: bool):
|
59 |
+
if hierarchical:
|
60 |
+
return ("poltextlab/xlm-roberta-large-pooled-cap-media", "poltextlab/xlm-roberta-large-pooled-cap-minor-v3")
|
61 |
+
else:
|
62 |
+
return "poltextlab/xlm-roberta-large-pooled-cap-media-minor"
|
63 |
|
64 |
#@spaces.GPU(duration=30)
|
65 |
def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
|
|
|
145 |
|
146 |
return interpretation_info, output_pred, output_info
|
147 |
|
148 |
+
|
149 |
+
def predict_flat(text, model_id, tokenizer_id, HF_TOKEN=None):
|
150 |
+
device = torch.device("cpu")
|
151 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN).to(device)
|
152 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
153 |
+
|
154 |
+
inputs = tokenizer(text,
|
155 |
+
max_length=256,
|
156 |
+
truncation=True,
|
157 |
+
padding="do_not_pad",
|
158 |
+
return_tensors="pt").to(device)
|
159 |
+
model.eval()
|
160 |
+
|
161 |
+
with torch.no_grad():
|
162 |
+
logits = model(**inputs).logits
|
163 |
+
|
164 |
+
probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
|
165 |
+
top_indices = np.argsort(probs)[::-1][:10]
|
166 |
+
|
167 |
+
CAP_MIN_MEDIA_LABEL_NAMES = CAP_MEDIA_LABEL_NAMES | CAP_MIN_LABEL_NAMES
|
168 |
+
output_pred = {
|
169 |
+
f"[{CAP_MIN_MEDIA_NUM_DICT[i]}] {CAP_MIN_MEDIA_LABEL_NAMES[CAP_MIN_MEDIA_NUM_DICT[i]]}": probs[i]
|
170 |
+
for i in top_indices
|
171 |
+
}
|
172 |
+
output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
|
173 |
|
174 |
+
interpretation_info = """
|
175 |
+
## How to Interpret These Values (Flat Classification)
|
176 |
+
|
177 |
+
This method returns predictions made by a single model. Both media codes and minor topics may appear in the output list. Only the top 10 most confident labels are displayed.
|
178 |
+
"""
|
179 |
+
|
180 |
+
return interpretation_info, output_pred, output_info
|
181 |
+
|
182 |
+
|
183 |
+
def predict_cap(tmp, method, text, language, domain):
|
184 |
if is_disk_full():
|
185 |
os.system('rm -rf /data/models*')
|
186 |
os.system('rm -r ~/.cache/huggingface/hub')
|
187 |
|
188 |
+
domain = domains[domain]
|
189 |
+
|
190 |
+
if method == "Hierarchical Classification":
|
191 |
+
major_model_id, minor_model_id = build_huggingface_path(language, domain, True)
|
192 |
+
tokenizer_id = "xlm-roberta-large"
|
193 |
+
return predict(text, major_model_id, minor_model_id, tokenizer_id)
|
194 |
+
|
195 |
+
else:
|
196 |
+
model_id = build_huggingface_path(language, domain, False)
|
197 |
+
tokenizer_id = "xlm-roberta-large"
|
198 |
+
return predict_flat(text, model_id, tokenizer_id)
|
199 |
|
200 |
description = """
|
201 |
You can choose between two approaches for making predictions:
|
label_dicts.py
CHANGED
@@ -329,6 +329,227 @@ CAP_LABEL_NAMES = {
|
|
329 |
999: "No Policy Content"
|
330 |
}
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
CAP_MIN_LABEL_NAMES = {
|
333 |
# 1. Macroeconomics
|
334 |
100: "General",
|
|
|
329 |
999: "No Policy Content"
|
330 |
}
|
331 |
|
332 |
+
CAP_MIN_MEDIA_NUM_DICT = {0: 100,
|
333 |
+
1: 101,
|
334 |
+
2: 103,
|
335 |
+
3: 104,
|
336 |
+
4: 105,
|
337 |
+
5: 107,
|
338 |
+
6: 108,
|
339 |
+
7: 110,
|
340 |
+
8: 199,
|
341 |
+
9: 200,
|
342 |
+
10: 201,
|
343 |
+
11: 202,
|
344 |
+
12: 204,
|
345 |
+
13: 205,
|
346 |
+
14: 206,
|
347 |
+
15: 207,
|
348 |
+
16: 208,
|
349 |
+
17: 209,
|
350 |
+
18: 299,
|
351 |
+
19: 300,
|
352 |
+
20: 301,
|
353 |
+
21: 302,
|
354 |
+
22: 321,
|
355 |
+
23: 322,
|
356 |
+
24: 323,
|
357 |
+
25: 324,
|
358 |
+
26: 325,
|
359 |
+
27: 331,
|
360 |
+
28: 332,
|
361 |
+
29: 333,
|
362 |
+
30: 334,
|
363 |
+
31: 335,
|
364 |
+
32: 341,
|
365 |
+
33: 342,
|
366 |
+
34: 398,
|
367 |
+
35: 399,
|
368 |
+
36: 400,
|
369 |
+
37: 401,
|
370 |
+
38: 402,
|
371 |
+
39: 403,
|
372 |
+
40: 404,
|
373 |
+
41: 405,
|
374 |
+
42: 408,
|
375 |
+
43: 498,
|
376 |
+
44: 499,
|
377 |
+
45: 500,
|
378 |
+
46: 501,
|
379 |
+
47: 502,
|
380 |
+
48: 503,
|
381 |
+
49: 504,
|
382 |
+
50: 505,
|
383 |
+
51: 506,
|
384 |
+
52: 529,
|
385 |
+
53: 599,
|
386 |
+
54: 600,
|
387 |
+
55: 601,
|
388 |
+
56: 602,
|
389 |
+
57: 603,
|
390 |
+
58: 604,
|
391 |
+
59: 606,
|
392 |
+
60: 607,
|
393 |
+
61: 698,
|
394 |
+
62: 699,
|
395 |
+
63: 700,
|
396 |
+
64: 701,
|
397 |
+
65: 703,
|
398 |
+
66: 704,
|
399 |
+
67: 705,
|
400 |
+
68: 707,
|
401 |
+
69: 708,
|
402 |
+
70: 709,
|
403 |
+
71: 711,
|
404 |
+
72: 798,
|
405 |
+
73: 799,
|
406 |
+
74: 800,
|
407 |
+
75: 801,
|
408 |
+
76: 802,
|
409 |
+
77: 803,
|
410 |
+
78: 805,
|
411 |
+
79: 806,
|
412 |
+
80: 807,
|
413 |
+
81: 898,
|
414 |
+
82: 899,
|
415 |
+
83: 900,
|
416 |
+
84: 1000,
|
417 |
+
85: 1001,
|
418 |
+
86: 1002,
|
419 |
+
87: 1003,
|
420 |
+
88: 1005,
|
421 |
+
89: 1007,
|
422 |
+
90: 1010,
|
423 |
+
91: 1098,
|
424 |
+
92: 1099,
|
425 |
+
93: 1200,
|
426 |
+
94: 1201,
|
427 |
+
95: 1202,
|
428 |
+
96: 1203,
|
429 |
+
97: 1204,
|
430 |
+
98: 1205,
|
431 |
+
99: 1206,
|
432 |
+
100: 1207,
|
433 |
+
101: 1208,
|
434 |
+
102: 1210,
|
435 |
+
103: 1211,
|
436 |
+
104: 1227,
|
437 |
+
105: 1299,
|
438 |
+
106: 1300,
|
439 |
+
107: 1302,
|
440 |
+
108: 1303,
|
441 |
+
109: 1304,
|
442 |
+
110: 1305,
|
443 |
+
111: 1308,
|
444 |
+
112: 1399,
|
445 |
+
113: 1400,
|
446 |
+
114: 1401,
|
447 |
+
115: 1403,
|
448 |
+
116: 1404,
|
449 |
+
117: 1405,
|
450 |
+
118: 1406,
|
451 |
+
119: 1407,
|
452 |
+
120: 1408,
|
453 |
+
121: 1409,
|
454 |
+
122: 1498,
|
455 |
+
123: 1499,
|
456 |
+
124: 1500,
|
457 |
+
125: 1501,
|
458 |
+
126: 1502,
|
459 |
+
127: 1504,
|
460 |
+
128: 1505,
|
461 |
+
129: 1507,
|
462 |
+
130: 1520,
|
463 |
+
131: 1521,
|
464 |
+
132: 1522,
|
465 |
+
133: 1523,
|
466 |
+
134: 1524,
|
467 |
+
135: 1525,
|
468 |
+
136: 1526,
|
469 |
+
137: 1598,
|
470 |
+
138: 1599,
|
471 |
+
139: 1600,
|
472 |
+
140: 1602,
|
473 |
+
141: 1603,
|
474 |
+
142: 1604,
|
475 |
+
143: 1605,
|
476 |
+
144: 1606,
|
477 |
+
145: 1608,
|
478 |
+
146: 1610,
|
479 |
+
147: 1611,
|
480 |
+
148: 1612,
|
481 |
+
149: 1614,
|
482 |
+
150: 1615,
|
483 |
+
151: 1616,
|
484 |
+
152: 1617,
|
485 |
+
153: 1619,
|
486 |
+
154: 1620,
|
487 |
+
155: 1698,
|
488 |
+
156: 1699,
|
489 |
+
157: 1700,
|
490 |
+
158: 1701,
|
491 |
+
159: 1704,
|
492 |
+
160: 1705,
|
493 |
+
161: 1706,
|
494 |
+
162: 1707,
|
495 |
+
163: 1708,
|
496 |
+
164: 1709,
|
497 |
+
165: 1798,
|
498 |
+
166: 1799,
|
499 |
+
167: 1800,
|
500 |
+
168: 1802,
|
501 |
+
169: 1803,
|
502 |
+
170: 1804,
|
503 |
+
171: 1806,
|
504 |
+
172: 1807,
|
505 |
+
173: 1808,
|
506 |
+
174: 1899,
|
507 |
+
175: 1900,
|
508 |
+
176: 1901,
|
509 |
+
177: 1902,
|
510 |
+
178: 1905,
|
511 |
+
179: 1906,
|
512 |
+
180: 1910,
|
513 |
+
181: 1921,
|
514 |
+
182: 1925,
|
515 |
+
183: 1926,
|
516 |
+
184: 1927,
|
517 |
+
185: 1929,
|
518 |
+
186: 1999,
|
519 |
+
187: 2000,
|
520 |
+
188: 2001,
|
521 |
+
189: 2002,
|
522 |
+
190: 2003,
|
523 |
+
191: 2004,
|
524 |
+
192: 2005,
|
525 |
+
193: 2006,
|
526 |
+
194: 2007,
|
527 |
+
195: 2008,
|
528 |
+
196: 2009,
|
529 |
+
197: 2010,
|
530 |
+
198: 2011,
|
531 |
+
199: 2012,
|
532 |
+
200: 2013,
|
533 |
+
201: 2014,
|
534 |
+
202: 2015,
|
535 |
+
203: 2030,
|
536 |
+
204: 2099,
|
537 |
+
205: 2100,
|
538 |
+
206: 2101,
|
539 |
+
207: 2102,
|
540 |
+
208: 2103,
|
541 |
+
209: 2104,
|
542 |
+
210: 2105,
|
543 |
+
211: 2300,
|
544 |
+
212: 9999,
|
545 |
+
213: 24,
|
546 |
+
214: 26,
|
547 |
+
215: 27,
|
548 |
+
216: 29,
|
549 |
+
217: 30,
|
550 |
+
218: 31,
|
551 |
+
219: 99}
|
552 |
+
|
553 |
CAP_MIN_LABEL_NAMES = {
|
554 |
# 1. Macroeconomics
|
555 |
100: "General",
|
utils.py
CHANGED
@@ -13,6 +13,7 @@ from interfaces.illframes import domains as domains_illframes
|
|
13 |
|
14 |
from interfaces.cap import build_huggingface_path as hf_cap_path
|
15 |
from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
|
|
|
16 |
from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
|
17 |
from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
|
18 |
from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
|
@@ -37,6 +38,9 @@ for language in languages_cap:
|
|
37 |
|
38 |
# cap media
|
39 |
models.append(hf_cap_media_path("", ""))
|
|
|
|
|
|
|
40 |
|
41 |
# emotion9
|
42 |
for language in languages_emotion9:
|
|
|
13 |
|
14 |
from interfaces.cap import build_huggingface_path as hf_cap_path
|
15 |
from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
|
16 |
+
from interfaces.cap_minor_media import build_huggingface_path as hf_cap_minor_media_path
|
17 |
from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
|
18 |
from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
|
19 |
from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
|
|
|
38 |
|
39 |
# cap media
|
40 |
models.append(hf_cap_media_path("", ""))
|
41 |
+
|
42 |
+
# cap minor media
|
43 |
+
models.append(hf_cap_minor_media_path("", "", False))
|
44 |
|
45 |
# emotion9
|
46 |
for language in languages_emotion9:
|