|
from dataclasses import dataclass, field |
|
from datasets import load_dataset, Dataset |
|
from functools import partial |
|
import numpy as np |
|
import jax |
|
import jax.numpy as jnp |
|
from flax.training.common_utils import shard |
|
from .text import TextNormalizer |
|
|
|
|
|
@dataclass |
|
class Dataset: |
|
dataset_repo_or_path: str |
|
train_file: str = None |
|
validation_file: str = None |
|
dataset_type: str = "dataset" |
|
streaming: bool = True |
|
use_auth_token: bool = False |
|
text_column: str = "caption" |
|
encoding_column: str = "encoding" |
|
max_source_length: int = 128 |
|
max_train_samples: int = None |
|
max_eval_samples: int = None |
|
preprocessing_num_workers: int = None |
|
overwrite_cache: bool = False |
|
do_train: bool = False |
|
do_eval: bool = True |
|
seed_dataset: int = None |
|
train_dataset: Dataset = field(init=False) |
|
eval_dataset: Dataset = field(init=False) |
|
rng_dataset: jnp.ndarray = field(init=False) |
|
|
|
def __post_init__(self): |
|
|
|
if self.train_file is not None or self.validation_file is not None: |
|
data_files = { |
|
"train": self.train_file, |
|
"validation": self.validation_file, |
|
} |
|
else: |
|
data_files = None |
|
|
|
|
|
dataset = load_dataset( |
|
self.dataset_repo_or_path, |
|
data_files=data_files, |
|
streaming=self.streaming, |
|
use_auth_token=self.use_auth_token, |
|
) |
|
if self.do_train: |
|
if "train" not in dataset: |
|
raise ValueError("Training requires a training dataset") |
|
self.train_dataset = dataset["train"] |
|
if self.max_train_samples is not None: |
|
self.train_dataset = ( |
|
self.train_dataset.take(self.max_train_samples) |
|
if self.streaming |
|
else self.train_dataset.select(range(self.max_train_samples)) |
|
) |
|
if self.do_eval: |
|
if "validation" not in dataset: |
|
raise ValueError("Evaluating requires a validation dataset") |
|
self.eval_dataset = dataset["validation"] |
|
if self.max_eval_samples is not None: |
|
self.eval_dataset = ( |
|
self.eval_dataset.take(self.max_eval_samples) |
|
if self.streaming |
|
else self.eval_dataset.select(range(self.max_eval_samples)) |
|
) |
|
|
|
def preprocess(self, tokenizer, decoder_start_token_id, normalize_text): |
|
if self.streaming: |
|
|
|
if hasattr(self, "train_dataset"): |
|
self.train_dataset = self.train_dataset.shuffle(1000, self.seed_dataset) |
|
else: |
|
|
|
if self.seed_dataset is None: |
|
self.seed_dataset = np.random.get_state()[1][0] |
|
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) |
|
|
|
|
|
if normalize_text: |
|
text_normalizer = TextNormalizer() |
|
partial_normalize_function = partial( |
|
normalize_function, |
|
text_column=self.text_column, |
|
text_normalizer=text_normalizer, |
|
) |
|
for ds in ["train_dataset", "eval_dataset"]: |
|
if hasattr(self, ds): |
|
setattr( |
|
self, |
|
ds, |
|
( |
|
getattr(self, ds).map(partial_normalize_function) |
|
if self.streaming |
|
else getattr(self, ds).map( |
|
partial_normalize_function, |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Normalizing datasets", |
|
) |
|
), |
|
) |
|
|
|
|
|
partial_preprocess_function = partial( |
|
preprocess_function, |
|
tokenizer=tokenizer, |
|
text_column=self.text_column, |
|
encoding_column=self.encoding_column, |
|
max_source_length=self.max_source_length, |
|
decoder_start_token_id=decoder_start_token_id, |
|
) |
|
for ds in ["train_dataset", "eval_dataset"]: |
|
if hasattr(self, ds): |
|
setattr( |
|
self, |
|
ds, |
|
( |
|
getattr(self, ds).map( |
|
partial_preprocess_function, |
|
batched=True, |
|
) |
|
if self.streaming |
|
else getattr(self, ds).map( |
|
partial_preprocess_function, |
|
batched=True, |
|
remove_columns=getattr(ds, "column_names"), |
|
num_proc=self.preprocessing_num_workers, |
|
load_from_cache_file=not self.overwrite_cache, |
|
desc="Preprocessing datasets", |
|
) |
|
), |
|
) |
|
|
|
def dataloader(self, split, batch_size, epoch=None): |
|
def _dataloader_datasets_non_streaming( |
|
dataset: Dataset, |
|
batch_size: int, |
|
rng: jax.random.PRNGKey = None, |
|
): |
|
""" |
|
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
|
Shuffle batches if `shuffle` is `True`. |
|
""" |
|
steps_per_epoch = len(dataset) // batch_size |
|
|
|
if rng is not None: |
|
batch_idx = jax.random.permutation(rng, len(dataset)) |
|
else: |
|
batch_idx = jnp.arange(len(dataset)) |
|
|
|
batch_idx = batch_idx[ |
|
: steps_per_epoch * batch_size |
|
] |
|
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
|
for idx in batch_idx: |
|
batch = dataset[idx] |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
batch = shard(batch) |
|
yield batch |
|
|
|
def _dataloader_datasets_streaming(dataset: Dataset, batch_size: int): |
|
keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"] |
|
batch = {k: [] for k in keys} |
|
for item in dataset: |
|
for k, v in item.items(): |
|
batch[k].append(v) |
|
if len(batch[keys[0]]) == batch_size: |
|
batch = {k: jnp.array(v) for k, v in batch.items()} |
|
batch = shard(batch) |
|
yield batch |
|
batch = {k: [] for k in keys} |
|
|
|
if split == "train": |
|
ds = self.train_dataset |
|
elif split == "eval": |
|
ds = self.eval_dataset |
|
else: |
|
raise ValueError(f'split must be "train" or "eval", got {split}') |
|
|
|
if self.streaming: |
|
if split == "train": |
|
ds.set_epoch(epoch) |
|
return _dataloader_datasets_streaming(ds, batch_size) |
|
else: |
|
if split == "train": |
|
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) |
|
return _dataloader_datasets_non_streaming(ds, batch_size, input_rng) |
|
|
|
@property |
|
def length(self): |
|
len_train_dataset, len_eval_dataset = None, None |
|
if self.streaming: |
|
|
|
if self.max_train_samples is not None: |
|
len_train_dataset = self.max_train_samples |
|
if self.max_eval_samples is not None: |
|
len_eval_dataset = self.max_eval_samples |
|
else: |
|
len_train_dataset = ( |
|
len(self.train_dataset) if hasattr(self, "train_dataset") else None |
|
) |
|
len_eval_dataset = ( |
|
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None |
|
) |
|
return len_train_dataset, len_eval_dataset |
|
|
|
|
|
def shift_tokens_right(input_ids: np.array, decoder_start_token_id: int): |
|
""" |
|
Shift input ids one token to the right. |
|
""" |
|
shifted_input_ids = np.zeros(input_ids.shape) |
|
shifted_input_ids[:, 1:] = input_ids[:, :-1] |
|
shifted_input_ids[:, 0] = decoder_start_token_id |
|
return shifted_input_ids |
|
|
|
|
|
def normalize_function(example, text_column, text_normalizer): |
|
example[text_column] = text_normalizer(example[text_column]) |
|
return example |
|
|
|
|
|
def preprocess_function( |
|
examples, |
|
tokenizer, |
|
text_column, |
|
encoding_column, |
|
max_source_length, |
|
decoder_start_token_id, |
|
): |
|
inputs = examples[text_column] |
|
|
|
model_inputs = tokenizer( |
|
inputs, |
|
max_length=max_source_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="np", |
|
) |
|
|
|
|
|
|
|
|
|
labels = examples[encoding_column] |
|
labels = np.asarray(labels) |
|
|
|
|
|
model_inputs["labels"] = labels |
|
|
|
|
|
decoder_input_ids = shift_tokens_right(labels, decoder_start_token_id) |
|
model_inputs["decoder_input_ids"] = decoder_input_ids |
|
|
|
return model_inputs |
|
|