|
from .dataset import Dataset, ValDataset, TestDataset |
|
from torch.utils.data import DataLoader |
|
|
|
def find_dataset_using_name(name): |
|
mapping = { |
|
"Video": Dataset, |
|
"VideoVal": ValDataset, |
|
"VideoTest": TestDataset, |
|
} |
|
cls = mapping.get(name, None) |
|
if cls is None: |
|
raise ValueError(f"Fail to find dataset {name}") |
|
return cls |
|
|
|
|
|
def create_dataset(metainfo, split): |
|
dataset_cls = find_dataset_using_name(split.type) |
|
dataset = dataset_cls(metainfo, split) |
|
return DataLoader( |
|
dataset, |
|
batch_size=split.batch_size, |
|
drop_last=split.drop_last, |
|
shuffle=split.shuffle, |
|
num_workers=split.worker, |
|
pin_memory=True |
|
) |