PikaPikaGen / pikapikagen /data_loader.py
Val-2's picture
First commit
66347a3
from torch.utils.data import DataLoader, Subset
import torch
from dataset import PokemonDataset
import math
def create_training_setup(
tokenizer,
test_set_size,
val_set_size,
batch_size,
num_workers=0,
num_viz_samples=4,
random_seed=42,
train_augmentation_pipeline=None,
):
"""
Create a complete setup for training with dataset, dataloaders and fixed batches for visualization.
"""
assert 0 <= test_set_size < 1.0, "test_set_size must be a float between 0 and 1"
assert 0 <= val_set_size < 1.0, "val_set_size must be a float between 0 and 1"
assert (test_set_size + val_set_size) < 1.0, "The sum of test and validation sizes must be less than 1"
train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_transforms=train_augmentation_pipeline)
# Don't use augmentation for test and validation
test_val_full_dataset = PokemonDataset(tokenizer=tokenizer)
dataset_size = len(train_full_dataset)
# Create a random reproducible permutation
generator = torch.Generator().manual_seed(random_seed)
shuffled_indices = torch.randperm(dataset_size, generator=generator)
val_count = math.ceil(val_set_size * dataset_size)
test_count = math.ceil(test_set_size * dataset_size)
train_count = dataset_size - val_count - test_count
# Partition based on the computed splits
train_indices = shuffled_indices[:train_count].tolist()
test_indices = shuffled_indices[train_count : train_count + test_count].tolist()
val_indices = shuffled_indices[train_count + test_count :].tolist()
# Create the subsets based on the indices
train_dataset = Subset(train_full_dataset, train_indices)
test_dataset = Subset(test_val_full_dataset, test_indices)
val_dataset = Subset(test_val_full_dataset, val_indices)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
)
# Batch for visualization
vis_generator = torch.Generator().manual_seed(random_seed)
fixed_train_batch = next(
iter(DataLoader(train_dataset, batch_size=num_viz_samples, shuffle=True, generator=vis_generator))
)
# Since no shuffle, a generator is not needed
fixed_test_batch = next(iter(DataLoader(test_dataset, batch_size=num_viz_samples, shuffle=False)))
fixed_val_batch = next(iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False)))
# Batch (dimensione 1) for attention map visualization
vis_generator.manual_seed(random_seed)
fixed_train_attention_batch = next(
iter(DataLoader(train_dataset, batch_size=1, shuffle=True, generator=vis_generator))
)
fixed_test_attention_batch = next(iter(DataLoader(test_dataset, batch_size=1, shuffle=False)))
fixed_val_attention_batch = next(iter(DataLoader(val_dataset, batch_size=1, shuffle=False)))
return {
'train_loader': train_loader,
'val_loader': val_loader,
'test_loader': test_loader,
'train_dataset': train_dataset,
'val_dataset': val_dataset,
'test_dataset': test_dataset,
'fixed_train_batch': fixed_train_batch,
'fixed_val_batch': fixed_val_batch,
'fixed_test_batch': fixed_test_batch,
'fixed_train_attention_batch': fixed_train_attention_batch,
'fixed_val_attention_batch': fixed_val_attention_batch,
'fixed_test_attention_batch': fixed_test_attention_batch,
}