David Pomerenke commited on
Commit
55406ba
·
1 Parent(s): 7283eaa

Move functions for sharing them

Browse files
Files changed (2) hide show
  1. evals/datasets_/mmlu.py +1 -15
  2. evals/datasets_/util.py +14 -0
evals/datasets_/mmlu.py CHANGED
@@ -1,24 +1,10 @@
1
  import random
2
  from collections import Counter, defaultdict
3
 
4
- from datasets import get_dataset_config_names, load_dataset
5
- from joblib.memory import Memory
6
  from langcodes import Language, standardize_tag
7
  from rich import print
8
 
9
- cache = Memory(location=".cache", verbose=0).cache
10
-
11
-
12
- @cache
13
- def _get_dataset_config_names(dataset):
14
- return get_dataset_config_names(dataset)
15
-
16
-
17
- @cache
18
- def _load_dataset(dataset, subset, **kwargs):
19
- return load_dataset(dataset, subset, **kwargs)
20
-
21
-
22
  def print_counts(slug, subjects_dev, subjects_test):
23
  print(
24
  f"{slug:<25} {len(list(set(subjects_test))):>3} test categories, {len(subjects_test):>6} samples, {len(list(set(subjects_dev))):>3} dev categories, {len(subjects_dev):>6} dev samples"
 
1
  import random
2
  from collections import Counter, defaultdict
3
 
 
 
4
  from langcodes import Language, standardize_tag
5
  from rich import print
6
 
7
+ from .util import _get_dataset_config_names, _load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
8
  def print_counts(slug, subjects_dev, subjects_test):
9
  print(
10
  f"{slug:<25} {len(list(set(subjects_test))):>3} test categories, {len(subjects_test):>6} samples, {len(list(set(subjects_dev))):>3} dev categories, {len(subjects_dev):>6} dev samples"
evals/datasets_/util.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import get_dataset_config_names, load_dataset
2
+ from joblib.memory import Memory
3
+
4
+ cache = Memory(location=".cache", verbose=0).cache
5
+
6
+
7
+ @cache
8
+ def _get_dataset_config_names(dataset):
9
+ return get_dataset_config_names(dataset)
10
+
11
+
12
+ @cache
13
+ def _load_dataset(dataset, subset, **kwargs):
14
+ return load_dataset(dataset, subset, **kwargs)