|
""" |
|
An image-caption dataset dataloader. |
|
Luke Melas-Kyriazi, 2021 |
|
""" |
|
import warnings |
|
from typing import Optional, Callable |
|
from pathlib import Path |
|
import numpy as np |
|
import torch |
|
import pandas as pd |
|
from torch.utils.data import Dataset |
|
from torchvision.datasets.folder import default_loader |
|
from PIL import ImageFile |
|
from PIL.Image import DecompressionBombWarning |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
warnings.filterwarnings("ignore", category=UserWarning) |
|
warnings.filterwarnings("ignore", category=DecompressionBombWarning) |
|
|
|
|
|
class CaptionDataset(Dataset): |
|
""" |
|
A PyTorch Dataset class for (image, texts) tasks. Note that this dataset |
|
returns the raw text rather than tokens. This is done on purpose, because |
|
it's easy to tokenize a batch of text after loading it from this dataset. |
|
""" |
|
|
|
def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None, |
|
image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision', |
|
include_captions: bool = True): |
|
""" |
|
:param images_root: folder where images are stored |
|
:param captions_path: path to csv that maps image filenames to captions |
|
:param image_transform: image transform pipeline |
|
:param text_transform: image transform pipeline |
|
:param image_transform_type: image transform type, either `torchvision` or `albumentations` |
|
:param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images. |
|
""" |
|
|
|
|
|
self.images_root = Path(images_root) |
|
|
|
|
|
self.captions = pd.read_csv(captions_path, delimiter='\t', header=0) |
|
self.captions['image_file'] = self.captions['image_file'].astype(str) |
|
|
|
|
|
self.text_transform = text_transform |
|
self.image_transform = image_transform |
|
self.image_transform_type = image_transform_type.lower() |
|
assert self.image_transform_type in ['torchvision', 'albumentations'] |
|
|
|
|
|
self.size = len(self.captions) |
|
|
|
|
|
self.include_captions = include_captions |
|
|
|
def verify_that_all_images_exist(self): |
|
for image_file in self.captions['image_file']: |
|
p = self.images_root / image_file |
|
if not p.is_file(): |
|
print(f'file does not exist: {p}') |
|
|
|
def _get_raw_image(self, i): |
|
image_file = self.captions.iloc[i]['image_file'] |
|
image_path = self.images_root / image_file |
|
image = default_loader(image_path) |
|
return image |
|
|
|
def _get_raw_text(self, i): |
|
return self.captions.iloc[i]['caption'] |
|
|
|
def __getitem__(self, i): |
|
image = self._get_raw_image(i) |
|
caption = self._get_raw_text(i) |
|
if self.image_transform is not None: |
|
if self.image_transform_type == 'torchvision': |
|
image = self.image_transform(image) |
|
elif self.image_transform_type == 'albumentations': |
|
image = self.image_transform(image=np.array(image))['image'] |
|
else: |
|
raise NotImplementedError(f"{self.image_transform_type=}") |
|
return {'image': image, 'text': caption} if self.include_captions else image |
|
|
|
def __len__(self): |
|
return self.size |
|
|
|
|
|
if __name__ == "__main__": |
|
import albumentations as A |
|
from albumentations.pytorch import ToTensorV2 |
|
from transformers import AutoTokenizer |
|
|
|
|
|
images_root = './images' |
|
captions_path = './images-list-clean.tsv' |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained('distilroberta-base') |
|
def tokenize(text): |
|
return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length') |
|
image_transform = A.Compose([ |
|
A.Resize(256, 256), A.CenterCrop(256, 256), |
|
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()]) |
|
|
|
|
|
dataset = CaptionDataset( |
|
images_root=images_root, |
|
captions_path=captions_path, |
|
image_transform=image_transform, |
|
text_transform=tokenize, |
|
image_transform_type='albumentations') |
|
|
|
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2) |
|
batch = next(iter(dataloader)) |
|
print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}) |
|
|
|
|
|
|
|
|
|
|
|
|