davidpomerenke's picture
Upload from GitHub Actions: Correlation plot
b0aa389 verified
import asyncio
import os
from datasets import Dataset, load_dataset
from datasets_.util import _get_dataset_config_names, _load_dataset
from langcodes import standardize_tag
from models import google_supported_languages, translate_google
from tqdm import tqdm
from tqdm.asyncio import tqdm_asyncio
slug_mgsm = "juletxara/mgsm"
tags_mgsm = {
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_mgsm)
}
slug_afrimgsm = "masakhane/afrimgsm"
tags_afrimgsm = {
standardize_tag(a, macro=True): a for a in _get_dataset_config_names(slug_afrimgsm)
}
slug_gsm8kx = "Eurolingua/gsm8kx"
tags_gsm8kx = {
standardize_tag(a, macro=True): a
for a in _get_dataset_config_names(slug_gsm8kx, trust_remote_code=True)
}
slug_gsm_autotranslated = "fair-forward/gsm-autotranslated"
tags_gsm_autotranslated = {
standardize_tag(a, macro=True): a
for a in _get_dataset_config_names(slug_gsm_autotranslated)
}
def parse_number(i):
if isinstance(i, int):
return i
try:
return int(i.replace(",", "").replace(".", ""))
except ValueError:
return None
def load_mgsm(language_bcp_47, nr):
if language_bcp_47 in tags_mgsm.keys():
ds = _load_dataset(slug_mgsm, subset=tags_mgsm[language_bcp_47], split="test")
return slug_mgsm, ds[nr]
elif language_bcp_47 in tags_afrimgsm.keys():
ds = _load_dataset(
slug_afrimgsm, subset=tags_afrimgsm[language_bcp_47], split="test"
)
return slug_afrimgsm, ds[nr]
elif language_bcp_47 in tags_gsm_autotranslated.keys():
ds = _load_dataset(
slug_gsm_autotranslated, subset=tags_gsm_autotranslated[language_bcp_47], split="test"
)
return slug_gsm_autotranslated, ds[nr]
elif language_bcp_47 in tags_gsm8kx.keys():
row = _load_dataset(
slug_gsm8kx,
subset=tags_gsm8kx[language_bcp_47],
split="test",
trust_remote_code=True,
)[nr]
row["answer_number"] = row["answer"].split("####")[1].strip()
return slug_gsm8kx, row
else:
return None, None
def translate_mgsm(languages):
human_translated = [*tags_mgsm.keys(), *tags_afrimgsm.keys()]
untranslated = [
lang
for lang in languages["bcp_47"].values[:100]
if lang not in human_translated and lang in google_supported_languages
]
en = _load_dataset(slug_mgsm, subset=tags_mgsm["en"], split="test")
slug = "fair-forward/gsm-autotranslated"
for lang in tqdm(untranslated):
# check if already exists on hub
try:
ds_lang = load_dataset(slug, lang, split="test")
except ValueError:
print(f"Translating {lang}...")
questions_tr = [translate_google(q, "en", lang) for q in en["question"]]
questions_tr = asyncio.run(tqdm_asyncio.gather(*questions_tr))
ds_lang = Dataset.from_dict(
{
"question": questions_tr,
"answer": en["answer"],
"answer_number": en["answer_number"],
"equation_solution": en["equation_solution"],
}
)
ds_lang.push_to_hub(
slug,
split="test",
config_name=lang,
token=os.getenv("HUGGINGFACE_ACCESS_TOKEN"),
)
ds_lang.to_json(
f"data/translations/mgsm/{lang}.json", lines=False, force_ascii=False, indent=2
)