Spaces:
Running
Running
| import random | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| from torchvision.transforms import Compose | |
| from .abinet_aug import CVColorJitter, CVDeterioration, CVGeometry, SVTRDeterioration, SVTRGeometry | |
| from .parseq_aug import rand_augment_transform | |
| class PARSeqAugPIL(object): | |
| def __init__(self, **kwargs): | |
| self.transforms = rand_augment_transform() | |
| def __call__(self, data): | |
| img = data['image'] | |
| img_aug = self.transforms(img) | |
| data['image'] = img_aug | |
| return data | |
| class PARSeqAug(object): | |
| def __init__(self, **kwargs): | |
| self.transforms = rand_augment_transform() | |
| def __call__(self, data): | |
| img = data['image'] | |
| img = np.array(self.transforms(Image.fromarray(img))) | |
| data['image'] = img | |
| return data | |
| class ABINetAug(object): | |
| def __init__(self, | |
| geometry_p=0.5, | |
| deterioration_p=0.25, | |
| colorjitter_p=0.25, | |
| **kwargs): | |
| self.transforms = Compose([ | |
| CVGeometry( | |
| degrees=45, | |
| translate=(0.0, 0.0), | |
| scale=(0.5, 2.0), | |
| shear=(45, 15), | |
| distortion=0.5, | |
| p=geometry_p, | |
| ), | |
| CVDeterioration(var=20, degrees=6, factor=4, p=deterioration_p), | |
| CVColorJitter( | |
| brightness=0.5, | |
| contrast=0.5, | |
| saturation=0.5, | |
| hue=0.1, | |
| p=colorjitter_p, | |
| ), | |
| ]) | |
| def __call__(self, data): | |
| img = data['image'] | |
| img = self.transforms(img) | |
| data['image'] = img | |
| return data | |
| class SVTRAug(object): | |
| def __init__(self, | |
| aug_type=0, | |
| geometry_p=0.5, | |
| deterioration_p=0.25, | |
| colorjitter_p=0.25, | |
| **kwargs): | |
| self.transforms = Compose([ | |
| SVTRGeometry( | |
| aug_type=aug_type, | |
| degrees=45, | |
| translate=(0.0, 0.0), | |
| scale=(0.5, 2.0), | |
| shear=(45, 15), | |
| distortion=0.5, | |
| p=geometry_p, | |
| ), | |
| SVTRDeterioration(var=20, degrees=6, factor=4, p=deterioration_p), | |
| CVColorJitter( | |
| brightness=0.5, | |
| contrast=0.5, | |
| saturation=0.5, | |
| hue=0.1, | |
| p=colorjitter_p, | |
| ), | |
| ]) | |
| def __call__(self, data): | |
| img = data['image'] | |
| img = self.transforms(img) | |
| data['image'] = img | |
| return data | |
| class BaseDataAugmentation(object): | |
| def __init__(self, | |
| crop_prob=0.4, | |
| reverse_prob=0.4, | |
| noise_prob=0.4, | |
| jitter_prob=0.4, | |
| blur_prob=0.4, | |
| hsv_aug_prob=0.4, | |
| **kwargs): | |
| self.crop_prob = crop_prob | |
| self.reverse_prob = reverse_prob | |
| self.noise_prob = noise_prob | |
| self.jitter_prob = jitter_prob | |
| self.blur_prob = blur_prob | |
| self.hsv_aug_prob = hsv_aug_prob | |
| # for GaussianBlur | |
| self.fil = cv2.getGaussianKernel(ksize=5, sigma=1, ktype=cv2.CV_32F) | |
| def __call__(self, data): | |
| img = data['image'] | |
| h, w, _ = img.shape | |
| if random.random() <= self.crop_prob and h >= 20 and w >= 20: | |
| img = get_crop(img) | |
| if random.random() <= self.blur_prob: | |
| # GaussianBlur | |
| img = cv2.sepFilter2D(img, -1, self.fil, self.fil) | |
| if random.random() <= self.hsv_aug_prob: | |
| img = hsv_aug(img) | |
| if random.random() <= self.jitter_prob: | |
| img = jitter(img) | |
| if random.random() <= self.noise_prob: | |
| img = add_gasuss_noise(img) | |
| if random.random() <= self.reverse_prob: | |
| img = 255 - img | |
| data['image'] = img | |
| return data | |
| def hsv_aug(img): | |
| """cvtColor.""" | |
| hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) | |
| delta = 0.001 * random.random() * flag() | |
| hsv[:, :, 2] = hsv[:, :, 2] * (1 + delta) | |
| new_img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) | |
| return new_img | |
| def blur(img): | |
| """blur.""" | |
| h, w, _ = img.shape | |
| if h > 10 and w > 10: | |
| return cv2.GaussianBlur(img, (5, 5), 1) | |
| else: | |
| return img | |
| def jitter(img): | |
| """jitter.""" | |
| w, h, _ = img.shape | |
| if h > 10 and w > 10: | |
| thres = min(w, h) | |
| s = int(random.random() * thres * 0.01) | |
| src_img = img.copy() | |
| for i in range(s): | |
| img[i:, i:, :] = src_img[:w - i, :h - i, :] | |
| return img | |
| else: | |
| return img | |
| def add_gasuss_noise(image, mean=0, var=0.1): | |
| """Gasuss noise.""" | |
| noise = np.random.normal(mean, var**0.5, image.shape) | |
| out = image + 0.5 * noise | |
| out = np.clip(out, 0, 255) | |
| out = np.uint8(out) | |
| return out | |
| def get_crop(image): | |
| """random crop.""" | |
| h, w, _ = image.shape | |
| top_min = 1 | |
| top_max = 8 | |
| top_crop = int(random.randint(top_min, top_max)) | |
| top_crop = min(top_crop, h - 1) | |
| crop_img = image.copy() | |
| ratio = random.randint(0, 1) | |
| if ratio: | |
| crop_img = crop_img[top_crop:h, :, :] | |
| else: | |
| crop_img = crop_img[0:h - top_crop, :, :] | |
| return crop_img | |
| def flag(): | |
| """flag.""" | |
| return 1 if random.random() > 0.5000001 else -1 | |