kovacsvi commited on
Commit
af68a82
·
1 Parent(s): ecdbfcf

flat min media pred

Browse files
Files changed (3) hide show
  1. interfaces/cap_minor_media.py +53 -8
  2. label_dicts.py +221 -0
  3. utils.py +4 -0
interfaces/cap_minor_media.py CHANGED
@@ -12,7 +12,8 @@ from huggingface_hub import HfApi
12
  from collections import defaultdict
13
 
14
  from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
15
- CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES)
 
16
 
17
  from .utils import is_disk_full
18
 
@@ -54,8 +55,11 @@ def check_huggingface_path(checkpoint_path: str):
54
  except:
55
  return False
56
 
57
- def build_huggingface_path(language: str, domain: str):
58
- return ("poltextlab/xlm-roberta-large-pooled-cap-media", "poltextlab/xlm-roberta-large-pooled-cap-minor-v3")
 
 
 
59
 
60
  #@spaces.GPU(duration=30)
61
  def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
@@ -141,16 +145,57 @@ def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
141
 
142
  return interpretation_info, output_pred, output_info
143
 
144
- def predict_cap(tmp, method, text, language, domain):
145
- domain = domains[domain]
146
- major_model_id, minor_model_id = build_huggingface_path(language, domain)
147
- tokenizer_id = "xlm-roberta-large"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
 
 
 
 
 
 
 
 
 
149
  if is_disk_full():
150
  os.system('rm -rf /data/models*')
151
  os.system('rm -r ~/.cache/huggingface/hub')
152
 
153
- return predict(text, major_model_id, minor_model_id, tokenizer_id)
 
 
 
 
 
 
 
 
 
 
154
 
155
  description = """
156
  You can choose between two approaches for making predictions:
 
12
  from collections import defaultdict
13
 
14
  from label_dicts import (CAP_MEDIA_NUM_DICT, CAP_MEDIA_LABEL_NAMES,
15
+ CAP_MIN_NUM_DICT, CAP_MIN_LABEL_NAMES,
16
+ CAP_MIN_MEDIA_NUM_DICT)
17
 
18
  from .utils import is_disk_full
19
 
 
55
  except:
56
  return False
57
 
58
+ def build_huggingface_path(language: str, domain: str, hierarchical: bool):
59
+ if hierarchical:
60
+ return ("poltextlab/xlm-roberta-large-pooled-cap-media", "poltextlab/xlm-roberta-large-pooled-cap-minor-v3")
61
+ else:
62
+ return "poltextlab/xlm-roberta-large-pooled-cap-media-minor"
63
 
64
  #@spaces.GPU(duration=30)
65
  def predict(text, major_model_id, minor_model_id, tokenizer_id, HF_TOKEN=None):
 
145
 
146
  return interpretation_info, output_pred, output_info
147
 
148
+
149
+ def predict_flat(text, model_id, tokenizer_id, HF_TOKEN=None):
150
+ device = torch.device("cpu")
151
+ model = AutoModelForSequenceClassification.from_pretrained(model_id, low_cpu_mem_usage=True, device_map="auto", offload_folder="offload", token=HF_TOKEN).to(device)
152
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
153
+
154
+ inputs = tokenizer(text,
155
+ max_length=256,
156
+ truncation=True,
157
+ padding="do_not_pad",
158
+ return_tensors="pt").to(device)
159
+ model.eval()
160
+
161
+ with torch.no_grad():
162
+ logits = model(**inputs).logits
163
+
164
+ probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy().flatten()
165
+ top_indices = np.argsort(probs)[::-1][:10]
166
+
167
+ CAP_MIN_MEDIA_LABEL_NAMES = CAP_MEDIA_LABEL_NAMES | CAP_MIN_LABEL_NAMES
168
+ output_pred = {
169
+ f"[{CAP_MIN_MEDIA_NUM_DICT[i]}] {CAP_MIN_MEDIA_LABEL_NAMES[CAP_MIN_MEDIA_NUM_DICT[i]]}": probs[i]
170
+ for i in top_indices
171
+ }
172
+ output_info = f'<p style="text-align: center; display: block">Prediction was made using the <a href="https://huggingface.co/{model_id}">{model_id}</a> model.</p>'
173
 
174
+ interpretation_info = """
175
+ ## How to Interpret These Values (Flat Classification)
176
+
177
+ This method returns predictions made by a single model. Both media codes and minor topics may appear in the output list. Only the top 10 most confident labels are displayed.
178
+ """
179
+
180
+ return interpretation_info, output_pred, output_info
181
+
182
+
183
+ def predict_cap(tmp, method, text, language, domain):
184
  if is_disk_full():
185
  os.system('rm -rf /data/models*')
186
  os.system('rm -r ~/.cache/huggingface/hub')
187
 
188
+ domain = domains[domain]
189
+
190
+ if method == "Hierarchical Classification":
191
+ major_model_id, minor_model_id = build_huggingface_path(language, domain, True)
192
+ tokenizer_id = "xlm-roberta-large"
193
+ return predict(text, major_model_id, minor_model_id, tokenizer_id)
194
+
195
+ else:
196
+ model_id = build_huggingface_path(language, domain, False)
197
+ tokenizer_id = "xlm-roberta-large"
198
+ return predict_flat(text, model_id, tokenizer_id)
199
 
200
  description = """
201
  You can choose between two approaches for making predictions:
label_dicts.py CHANGED
@@ -329,6 +329,227 @@ CAP_LABEL_NAMES = {
329
  999: "No Policy Content"
330
  }
331
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  CAP_MIN_LABEL_NAMES = {
333
  # 1. Macroeconomics
334
  100: "General",
 
329
  999: "No Policy Content"
330
  }
331
 
332
+ CAP_MIN_MEDIA_NUM_DICT = {0: 100,
333
+ 1: 101,
334
+ 2: 103,
335
+ 3: 104,
336
+ 4: 105,
337
+ 5: 107,
338
+ 6: 108,
339
+ 7: 110,
340
+ 8: 199,
341
+ 9: 200,
342
+ 10: 201,
343
+ 11: 202,
344
+ 12: 204,
345
+ 13: 205,
346
+ 14: 206,
347
+ 15: 207,
348
+ 16: 208,
349
+ 17: 209,
350
+ 18: 299,
351
+ 19: 300,
352
+ 20: 301,
353
+ 21: 302,
354
+ 22: 321,
355
+ 23: 322,
356
+ 24: 323,
357
+ 25: 324,
358
+ 26: 325,
359
+ 27: 331,
360
+ 28: 332,
361
+ 29: 333,
362
+ 30: 334,
363
+ 31: 335,
364
+ 32: 341,
365
+ 33: 342,
366
+ 34: 398,
367
+ 35: 399,
368
+ 36: 400,
369
+ 37: 401,
370
+ 38: 402,
371
+ 39: 403,
372
+ 40: 404,
373
+ 41: 405,
374
+ 42: 408,
375
+ 43: 498,
376
+ 44: 499,
377
+ 45: 500,
378
+ 46: 501,
379
+ 47: 502,
380
+ 48: 503,
381
+ 49: 504,
382
+ 50: 505,
383
+ 51: 506,
384
+ 52: 529,
385
+ 53: 599,
386
+ 54: 600,
387
+ 55: 601,
388
+ 56: 602,
389
+ 57: 603,
390
+ 58: 604,
391
+ 59: 606,
392
+ 60: 607,
393
+ 61: 698,
394
+ 62: 699,
395
+ 63: 700,
396
+ 64: 701,
397
+ 65: 703,
398
+ 66: 704,
399
+ 67: 705,
400
+ 68: 707,
401
+ 69: 708,
402
+ 70: 709,
403
+ 71: 711,
404
+ 72: 798,
405
+ 73: 799,
406
+ 74: 800,
407
+ 75: 801,
408
+ 76: 802,
409
+ 77: 803,
410
+ 78: 805,
411
+ 79: 806,
412
+ 80: 807,
413
+ 81: 898,
414
+ 82: 899,
415
+ 83: 900,
416
+ 84: 1000,
417
+ 85: 1001,
418
+ 86: 1002,
419
+ 87: 1003,
420
+ 88: 1005,
421
+ 89: 1007,
422
+ 90: 1010,
423
+ 91: 1098,
424
+ 92: 1099,
425
+ 93: 1200,
426
+ 94: 1201,
427
+ 95: 1202,
428
+ 96: 1203,
429
+ 97: 1204,
430
+ 98: 1205,
431
+ 99: 1206,
432
+ 100: 1207,
433
+ 101: 1208,
434
+ 102: 1210,
435
+ 103: 1211,
436
+ 104: 1227,
437
+ 105: 1299,
438
+ 106: 1300,
439
+ 107: 1302,
440
+ 108: 1303,
441
+ 109: 1304,
442
+ 110: 1305,
443
+ 111: 1308,
444
+ 112: 1399,
445
+ 113: 1400,
446
+ 114: 1401,
447
+ 115: 1403,
448
+ 116: 1404,
449
+ 117: 1405,
450
+ 118: 1406,
451
+ 119: 1407,
452
+ 120: 1408,
453
+ 121: 1409,
454
+ 122: 1498,
455
+ 123: 1499,
456
+ 124: 1500,
457
+ 125: 1501,
458
+ 126: 1502,
459
+ 127: 1504,
460
+ 128: 1505,
461
+ 129: 1507,
462
+ 130: 1520,
463
+ 131: 1521,
464
+ 132: 1522,
465
+ 133: 1523,
466
+ 134: 1524,
467
+ 135: 1525,
468
+ 136: 1526,
469
+ 137: 1598,
470
+ 138: 1599,
471
+ 139: 1600,
472
+ 140: 1602,
473
+ 141: 1603,
474
+ 142: 1604,
475
+ 143: 1605,
476
+ 144: 1606,
477
+ 145: 1608,
478
+ 146: 1610,
479
+ 147: 1611,
480
+ 148: 1612,
481
+ 149: 1614,
482
+ 150: 1615,
483
+ 151: 1616,
484
+ 152: 1617,
485
+ 153: 1619,
486
+ 154: 1620,
487
+ 155: 1698,
488
+ 156: 1699,
489
+ 157: 1700,
490
+ 158: 1701,
491
+ 159: 1704,
492
+ 160: 1705,
493
+ 161: 1706,
494
+ 162: 1707,
495
+ 163: 1708,
496
+ 164: 1709,
497
+ 165: 1798,
498
+ 166: 1799,
499
+ 167: 1800,
500
+ 168: 1802,
501
+ 169: 1803,
502
+ 170: 1804,
503
+ 171: 1806,
504
+ 172: 1807,
505
+ 173: 1808,
506
+ 174: 1899,
507
+ 175: 1900,
508
+ 176: 1901,
509
+ 177: 1902,
510
+ 178: 1905,
511
+ 179: 1906,
512
+ 180: 1910,
513
+ 181: 1921,
514
+ 182: 1925,
515
+ 183: 1926,
516
+ 184: 1927,
517
+ 185: 1929,
518
+ 186: 1999,
519
+ 187: 2000,
520
+ 188: 2001,
521
+ 189: 2002,
522
+ 190: 2003,
523
+ 191: 2004,
524
+ 192: 2005,
525
+ 193: 2006,
526
+ 194: 2007,
527
+ 195: 2008,
528
+ 196: 2009,
529
+ 197: 2010,
530
+ 198: 2011,
531
+ 199: 2012,
532
+ 200: 2013,
533
+ 201: 2014,
534
+ 202: 2015,
535
+ 203: 2030,
536
+ 204: 2099,
537
+ 205: 2100,
538
+ 206: 2101,
539
+ 207: 2102,
540
+ 208: 2103,
541
+ 209: 2104,
542
+ 210: 2105,
543
+ 211: 2300,
544
+ 212: 9999,
545
+ 213: 24,
546
+ 214: 26,
547
+ 215: 27,
548
+ 216: 29,
549
+ 217: 30,
550
+ 218: 31,
551
+ 219: 99}
552
+
553
  CAP_MIN_LABEL_NAMES = {
554
  # 1. Macroeconomics
555
  100: "General",
utils.py CHANGED
@@ -13,6 +13,7 @@ from interfaces.illframes import domains as domains_illframes
13
 
14
  from interfaces.cap import build_huggingface_path as hf_cap_path
15
  from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
 
16
  from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
17
  from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
18
  from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
@@ -37,6 +38,9 @@ for language in languages_cap:
37
 
38
  # cap media
39
  models.append(hf_cap_media_path("", ""))
 
 
 
40
 
41
  # emotion9
42
  for language in languages_emotion9:
 
13
 
14
  from interfaces.cap import build_huggingface_path as hf_cap_path
15
  from interfaces.cap_minor import build_huggingface_path as hf_cap_minor_path
16
+ from interfaces.cap_minor_media import build_huggingface_path as hf_cap_minor_media_path
17
  from interfaces.cap_media_demo import build_huggingface_path as hf_cap_media_path # why... just follow the name template the next time pls
18
  from interfaces.manifesto import build_huggingface_path as hf_manifesto_path
19
  from interfaces.sentiment import build_huggingface_path as hf_sentiment_path
 
38
 
39
  # cap media
40
  models.append(hf_cap_media_path("", ""))
41
+
42
+ # cap minor media
43
+ models.append(hf_cap_minor_media_path("", "", False))
44
 
45
  # emotion9
46
  for language in languages_emotion9: