IF3D / code /lib /datasets /__init__.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
713 Bytes
from .dataset import Dataset, ValDataset, TestDataset
from torch.utils.data import DataLoader
def find_dataset_using_name(name):
mapping = {
"Video": Dataset,
"VideoVal": ValDataset,
"VideoTest": TestDataset,
}
cls = mapping.get(name, None)
if cls is None:
raise ValueError(f"Fail to find dataset {name}")
return cls
def create_dataset(metainfo, split):
dataset_cls = find_dataset_using_name(split.type)
dataset = dataset_cls(metainfo, split)
return DataLoader(
dataset,
batch_size=split.batch_size,
drop_last=split.drop_last,
shuffle=split.shuffle,
num_workers=split.worker,
pin_memory=True
)