|
import PIL
|
|
import torchvision.transforms as T
|
|
|
|
def pair(t):
|
|
return t if isinstance(t, tuple) else (t, t)
|
|
|
|
def stage1_transform(img_size=256, is_train=True, scale=0.8):
|
|
|
|
resize = pair(int(img_size/scale))
|
|
t = []
|
|
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
|
if is_train:
|
|
t.append(T.RandomCrop(img_size))
|
|
t.append(T.RandomHorizontalFlip(p=0.5))
|
|
else:
|
|
t.append(T.CenterCrop(img_size))
|
|
|
|
t.append(T.ToTensor())
|
|
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
|
|
|
return T.Compose(t)
|
|
|
|
def stage2_transform(img_size=256, is_train=True, scale=0.8):
|
|
resize = pair(int(img_size/scale))
|
|
t = []
|
|
t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
|
|
if is_train:
|
|
t.append(T.RandomCrop(img_size))
|
|
else:
|
|
t.append(T.CenterCrop(img_size))
|
|
|
|
t.append(T.ToTensor())
|
|
t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
|
|
|
|
return T.Compose(t)
|
|
|