|
import asyncio |
|
import os |
|
|
|
from datasets import Dataset, load_dataset |
|
from datasets_.util import _get_dataset_config_names, _load_dataset |
|
from langcodes import standardize_tag |
|
from models import google_supported_languages, translate_google |
|
from tqdm import tqdm |
|
from tqdm.asyncio import tqdm_asyncio |
|
|
|
slug_mgsm = "juletxara/mgsm" |
|
tags_mgsm = { |
|
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_mgsm) |
|
} |
|
slug_afrimgsm = "masakhane/afrimgsm" |
|
tags_afrimgsm = { |
|
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_afrimgsm) |
|
} |
|
slug_gsm8kx = "Eurolingua/gsm8kx" |
|
tags_gsm8kx = { |
|
standardize_tag(a, macro=True): a |
|
for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True) |
|
} |
|
slug_gsm_autotranslated = "fair-forward/gsm-autotranslated" |
|
tags_gsm_autotranslated = { |
|
standardize_tag(a, macro=True): a |
|
for a in _get_dataset_config_names(slug_gsm_autotranslated) |
|
} |
|
|
|
|
|
def parse_number(i): |
|
if isinstance(i, int): |
|
return i |
|
try: |
|
return int(i.replace(",", "").replace(".", "")) |
|
except ValueError: |
|
return None |
|
|
|
|
|
def load_mgsm(language_bcp_47, nr): |
|
if language_bcp_47 in tags_mgsm.keys(): |
|
ds = _load_dataset(slug_mgsm, subset=tags_mgsm[language_bcp_47], split="test") |
|
return slug_mgsm, ds[nr] |
|
elif language_bcp_47 in tags_afrimgsm.keys(): |
|
ds = _load_dataset( |
|
slug_afrimgsm, subset=tags_afrimgsm[language_bcp_47], split="test" |
|
) |
|
return slug_afrimgsm, ds[nr] |
|
elif language_bcp_47 in tags_gsm_autotranslated.keys(): |
|
ds = _load_dataset( |
|
slug_gsm_autotranslated, subset=tags_gsm_autotranslated[language_bcp_47], split="test" |
|
) |
|
return slug_gsm_autotranslated, ds[nr] |
|
elif language_bcp_47 in tags_gsm8kx.keys(): |
|
row = _load_dataset( |
|
slug_gsm8kx, |
|
subset=tags_gsm8kx[language_bcp_47], |
|
split="test", |
|
trust_remote_code=True, |
|
)[nr] |
|
row["answer_number"] = row["answer"].split("####")[1].strip() |
|
return slug_gsm8kx, row |
|
else: |
|
return None, None |
|
|
|
|
|
def translate_mgsm(languages): |
|
human_translated = [*tags_mgsm.keys(), *tags_afrimgsm.keys()] |
|
untranslated = [ |
|
lang |
|
for lang in languages["bcp_47"].values[:100] |
|
if lang not in human_translated and lang in google_supported_languages |
|
] |
|
en = _load_dataset(slug_mgsm, subset=tags_mgsm["en"], split="test") |
|
slug = "fair-forward/gsm-autotranslated" |
|
for lang in tqdm(untranslated): |
|
|
|
try: |
|
ds_lang = load_dataset(slug, lang, split="test") |
|
except ValueError: |
|
print(f"Translating {lang}...") |
|
questions_tr = [translate_google(q, "en", lang) for q in en["question"]] |
|
questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr)) |
|
ds_lang = Dataset.from_dict( |
|
{ |
|
"question": questions_tr, |
|
"answer": en["answer"], |
|
"answer_number": en["answer_number"], |
|
"equation_solution": en["equation_solution"], |
|
} |
|
) |
|
ds_lang.push_to_hub( |
|
slug, |
|
split="test", |
|
config_name=lang, |
|
token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"), |
|
) |
|
ds_lang.to_json( |
|
f"data/translations/mgsm/{lang}.json", lines=False, force_ascii=False, indent=2 |
|
) |
|
|