Spaces:
Running
Running
from torch.utils.data import DataLoader, Subset | |
import torch | |
from dataset import PokemonDataset | |
import math | |
def create_training_setup( | |
tokenizer, | |
test_set_size, | |
val_set_size, | |
batch_size, | |
num_workers=0, | |
num_viz_samples=4, | |
random_seed=42, | |
train_augmentation_pipeline=None, | |
): | |
""" | |
Create a complete setup for training with dataset, dataloaders and fixed batches for visualization. | |
""" | |
assert 0 <= test_set_size < 1.0, "test_set_size must be a float between 0 and 1" | |
assert 0 <= val_set_size < 1.0, "val_set_size must be a float between 0 and 1" | |
assert (test_set_size + val_set_size) < 1.0, "The sum of test and validation sizes must be less than 1" | |
train_full_dataset = PokemonDataset(tokenizer=tokenizer, augmentation_transforms=train_augmentation_pipeline) | |
# Don't use augmentation for test and validation | |
test_val_full_dataset = PokemonDataset(tokenizer=tokenizer) | |
dataset_size = len(train_full_dataset) | |
# Create a random reproducible permutation | |
generator = torch.Generator().manual_seed(random_seed) | |
shuffled_indices = torch.randperm(dataset_size, generator=generator) | |
val_count = math.ceil(val_set_size * dataset_size) | |
test_count = math.ceil(test_set_size * dataset_size) | |
train_count = dataset_size - val_count - test_count | |
# Partition based on the computed splits | |
train_indices = shuffled_indices[:train_count].tolist() | |
test_indices = shuffled_indices[train_count : train_count + test_count].tolist() | |
val_indices = shuffled_indices[train_count + test_count :].tolist() | |
# Create the subsets based on the indices | |
train_dataset = Subset(train_full_dataset, train_indices) | |
test_dataset = Subset(test_val_full_dataset, test_indices) | |
val_dataset = Subset(test_val_full_dataset, val_indices) | |
train_loader = DataLoader( | |
train_dataset, | |
batch_size=batch_size, | |
shuffle=True, | |
num_workers=num_workers, | |
pin_memory=True, | |
) | |
test_loader = DataLoader( | |
test_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=True, | |
) | |
val_loader = DataLoader( | |
val_dataset, | |
batch_size=batch_size, | |
shuffle=False, | |
num_workers=num_workers, | |
pin_memory=True, | |
) | |
# Batch for visualization | |
vis_generator = torch.Generator().manual_seed(random_seed) | |
fixed_train_batch = next( | |
iter(DataLoader(train_dataset, batch_size=num_viz_samples, shuffle=True, generator=vis_generator)) | |
) | |
# Since no shuffle, a generator is not needed | |
fixed_test_batch = next(iter(DataLoader(test_dataset, batch_size=num_viz_samples, shuffle=False))) | |
fixed_val_batch = next(iter(DataLoader(val_dataset, batch_size=num_viz_samples, shuffle=False))) | |
# Batch (dimensione 1) for attention map visualization | |
vis_generator.manual_seed(random_seed) | |
fixed_train_attention_batch = next( | |
iter(DataLoader(train_dataset, batch_size=1, shuffle=True, generator=vis_generator)) | |
) | |
fixed_test_attention_batch = next(iter(DataLoader(test_dataset, batch_size=1, shuffle=False))) | |
fixed_val_attention_batch = next(iter(DataLoader(val_dataset, batch_size=1, shuffle=False))) | |
return { | |
'train_loader': train_loader, | |
'val_loader': val_loader, | |
'test_loader': test_loader, | |
'train_dataset': train_dataset, | |
'val_dataset': val_dataset, | |
'test_dataset': test_dataset, | |
'fixed_train_batch': fixed_train_batch, | |
'fixed_val_batch': fixed_val_batch, | |
'fixed_test_batch': fixed_test_batch, | |
'fixed_train_attention_batch': fixed_train_attention_batch, | |
'fixed_val_attention_batch': fixed_val_attention_batch, | |
'fixed_test_attention_batch': fixed_test_attention_batch, | |
} | |