import os #yaml, pickle, shutil, tarfile, from glob import glob import cv2 import albumentations import PIL import numpy as np import torchvision.transforms.functional as TF # from omegaconf import OmegaConf from functools import partial from PIL import Image from torch.utils.data import Dataset #, Subset import pandas as pd from torchvision import transforms from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light from skimage import io from tqdm import tqdm import base64 from io import BytesIO from ldm.data.base import Txt2ImgIterableBaseDataset import multiprocessing as mp from bisect import bisect_left, bisect_right import omegaconf import time import json from torch.utils.data.dataloader import _get_distributed_settings class LAIONBase(Dataset): def __init__(self, img_folder, caption_folder=None, recollect_data_info = False, # indices_file = None, first_stage_key = "jpg", cond_stage_key = "txt", do_flip = False, size=None, degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., flip_p=0.5, random_crop=True): """ LAION Dataloader Performs following ops in order: 1. crops a crop of size s from image either as random or center crop 2. resizes crop to size with cv2.area_interpolation # 3. degrades resized crop with degradation_fn :param size: resizing to size after cropping :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light :param downscale_f: Low Resolution Downsample factor :param min_crop_f: determines crop size s, where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) :param max_crop_f: "" :param data_root: :param random_crop: """ self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key self.root = img_folder.split("/laion")[0] self.images = [] self.texts = [] self.paired_data = [] self.parquet_info = {} # self.load_from_origin_data(caption_folder) # self.load_data(img_folder, caption_folder) # self.load_from_parquet(img_folder) data_info_file = os.path.join(img_folder, "data_info.json") # if os.path.exists(data_info_file): collect_data_info = True if not recollect_data_info: try: with open(data_info_file, "r") as f: # f.write(json.dump(self.data_info)) self.data_info = json.loads(f.read()) collect_data_info = False except: print( "fail to load data info from {}".format(data_info_file) ) if collect_data_info: print( "start to collect data info to {}".format(data_info_file) ) self.data_info = [] self.load_data_par(img_folder) with open(data_info_file, "w") as f: f.write(json.dumps(self.data_info)) # if indices_file is None or not os.path.exists(indices_file): # self.data_info = self.data_info[:50000] self.data_info = self.data_info[:5000] self.indices = range(self.__len__()) # else: # with open(indices_file, "r") as f: # self.indices = [int(s.strip()) for s in f.readlines()] # return self.do_flip = do_flip if self.do_flip: self.flip = transforms.RandomHorizontalFlip(p=flip_p) # self.base = self.get_base() assert size self.size = size self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f assert(max_crop_f <= 1.) self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) # assert (size / downscale_f).is_integer() # self.LR_size = int(size / downscale_f) # self.pil_interpolation = False # gets reset later if incase interp_op is from pillow # if degradation == "bsrgan": # self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) # elif degradation == "bsrgan_light": # self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) # else: # interpolation_fn = { # "cv_nearest": cv2.INTER_NEAREST, # "cv_bilinear": cv2.INTER_LINEAR, # "cv_bicubic": cv2.INTER_CUBIC, # "cv_area": cv2.INTER_AREA, # "cv_lanczos": cv2.INTER_LANCZOS4, # "pil_nearest": PIL.Image.NEAREST, # "pil_bilinear": PIL.Image.BILINEAR, # "pil_bicubic": PIL.Image.BICUBIC, # "pil_box": PIL.Image.BOX, # "pil_hamming": PIL.Image.HAMMING, # "pil_lanczos": PIL.Image.LANCZOS, # }[degradation] # self.pil_interpolation = degradation.startswith("pil_") # if self.pil_interpolation: # self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) # else: # self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, # interpolation=interpolation_fn) def __len__(self): return len(self.data_info) # if len(self.images): # return len(self.images) # elif len(self.paired_data): # self.ranges = [] # num = 0 # for imgs, _ in self.paired_data: # num += len(imgs) # self.ranges.append(num) # return num def load_from_origin_data(self, folder): # folder = "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/laion_aesthetics/laion_aesthetics_6.25+" # folder = "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/nl/output_part-00000/" valid_num = 0 store_processed_folder = os.path.join("./data", os.path.basename(folder)) if not os.path.exists(store_processed_folder): os.makedirs(store_processed_folder) text_list = [] image_list = [] image_path = os.path.join(store_processed_folder, "images.npy") text_path = os.path.join(store_processed_folder, "prompts.txt") if not os.path.exists(image_path) or not os.path.exists(text_path): for file_name in glob(folder + "/*"): if file_name.endswith(".parquet"): data = pd.read_parquet(file_name) # elif file_name.endswith(".tsv"): # data = pd.read_csv(file_name,sep='\t') else: continue for idx, row in tqdm(data.iterrows()): try: img= row.URL #IMAGEPATH #assumes that the df has the column IMAGEPATH txt = row.TEXT image = io.imread(img) except: continue if len(image.shape) == 2: image = Image.fromarray(image) image = image.convert("RGB") image = np.array(image).astype(np.uint8) image_list.append(image) text_list.append(txt) valid_num += 1 # if idx == 128: # break del data self.images = np.array(image_list) self.texts = text_list with open(text_path, "w") as f: f.writelines([text + "\n" for text in text_list]) np.save(image_path, self.images) else: self.images = np.load(image_path,allow_pickle=True) with open(text_path, "r") as f: self.texts = [line.rstrip() for line in f.readlines()] def load_from_tsv(self, image_path, original_data): idx_list = [] with open(image_path, "r") as f: for line_ in tqdm(f.readlines()): list_ = line_.split("\t") # if not list_[1].startswith("/"): # continue img = list_[1] idx = int(list_[0]) idx_list.append(idx) code_ = base64.b64decode(img) #.decode() image = Image.open(BytesIO(code_)).convert("RGB") image = np.array(image).astype(np.uint8) text = original_data.iloc[idx].TEXT self.images.append(image) self.texts.append(text) def load_data(self, img_folder, caption_folder): # par_data = pd.read_parquet(caption_folder) # faster for subfolder in glob(img_folder + "/*"): if os.path.isdir(subfolder): caption_path = os.path.join( caption_folder, os.path.basename(subfolder).lstrip("output_") + ".parquet" ) par_data = pd.read_parquet(caption_path) # faster # num_items = par_data.num_rows imgstr_list = [] # for img_file in glob(subfolder + "/*.tsv"): # # self.load_from_tsv(img_file, par_data) # with open(img_file, "r") as f: # imgstr_list.extend(f.readlines()) tsv_paths = glob(subfolder + "/*.tsv") def merge_(item): imgstr_list.extend(item) def load_(path): with open(path, "r") as f: return f.readlines() p = mp.Pool(30) p.map_async(load_, tsv_paths, callback=merge_) p.close() p.join() self.paired_data.append([imgstr_list, par_data]) del par_data def load_from_parquet(self, parquet_path): df = pd.read_parquet(parquet_path) # rows = df.num_rows print(parquet_path + " is successfully loaded") valid_inds = list(df[df.jpg.notnull()].index) info_lists = list(zip( [parquet_path] * len(valid_inds), # image path [parquet_path] * len(valid_inds), # text path valid_inds )) # return (list(df.jpg), list(df.caption)) # return (parquet_path, len(df)) return info_lists # self.images = [] # self.texts = [] # for idx in range(rows): # img = df.iloc[idx].jpg # if img: # image = Image.open(BytesIO(img)).convert("RGB") # image = np.array(image).astype(np.uint8) # text = df.iloc[idx].caption # self.images.append(image) # self.texts.append(text) # del df def merge_data(self, items): for item in items: # self.images.extend(item[0]) # self.texts.extend(item[1]) self.data_info.extend(item) # self.data_info.update(dict(items)) def load_data_par(self, folder): # if isinstance(folder, list) or isinstance(folder, omegaconf.listconfig.ListConfig): # parquet_paths = [] # for f_ in folder: # parquet_paths.extend(glob(f_ + "/*.parquet")) # else: # parquet_paths = glob(folder + "/*.parquet") parquet_paths = [] for root, _, files in os.walk(os.path.abspath(folder)): for file in files: if file.endswith(".parquet"): parquet_paths.append(os.path.join(root, file)) parquet_paths = glob(folder + "/*/*.parquet") # parquet_paths = parquet_paths[:40] # for parquet_path in tqdm(parquet_paths): # df = pd.read_parquet(parquet_path) # # self.images.extend(list(df.jpg)) # # self.texts.extend(list(df.caption)) # del df bs = 20 iterables = [ parquet_paths[i:i + bs] for i in range(0, len(parquet_paths), bs) ] # results = p.map_async(read_imgs, # ["/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/000001.tsv", "/home/v-yukangyang/data/stable-diffusion-v2/v-yukangyang/data/data/000000.tsv"]) for iterable_ in tqdm(iterables): p = mp.Pool(20) p.map_async(self.load_from_parquet, iterable_, callback=self.merge_data) # p.map_async(pd.read_parquet, iterable_) # p.join() p.close() p.join() # time.sleep(2) # def collect_par_info(): def __getitem__(self, i): example = dict() # example[self.first_stage_key] = np.random.randn(self.size, self.size, 3) # example[self.cond_stage_key] = "diffusion model" # return example # example = self.base[i] # # open image file # image = Image.open(example["file_path_"]) index_ = self.indices[i] imgfile_name, textfile_name, file_idx = self.data_info[index_] imgfile_name = imgfile_name.replace("/scratch", self.root) textfile_name = textfile_name.replace("/scratch", self.root) pre_t = time.time() if imgfile_name.endswith(".parquet"): df = pd.read_parquet(imgfile_name) # print("get image byte", time.time() - pre_t) img = df.jpg.iloc[file_idx] # print("get image byte", time.time() - pre_t) elif imgfile_name.endswith(".tsv"): with open(imgfile_name, "r") as f: line_ = f.readlines()[file_idx] file_idx, img = line_.split("\t") img = base64.b64decode(img) file_idx = int(file_idx) try: image = Image.open(BytesIO(img)).convert("RGB") image = np.array(image).astype(np.uint8) # print("image load", time.time() - pre_t) except: return self.__getitem__(np.random.randint(0, len(self.indices))) # if isinstance() # if self.images: # img = self.images[index_] # try: # image = Image.open(BytesIO(img)).convert("RGB") # image = np.array(image).astype(np.uint8) # except: # return self.__getitem__(np.random.randint(0, len(self.indices))) # example[self.cond_stage_key] = self.texts[index_] # elif self.paired_data: # sec_ = bisect_right(self.ranges, index_) # imgs, texts = self.paired_data[sec_] # index_sec = index_ - self.ranges[sec_-1] if sec_ != 0 else index_ # line_ = imgs[index_sec].strip() # list_ = line_.split("\t") # img = list_[1] # idx_ = int(list_[0]) # try: # code_ = base64.b64decode(img) #.decode() # image = Image.open(BytesIO(code_)).convert("RGB") # image = np.array(image).astype(np.uint8) # except: # return self.__getitem__(np.random.randint(0, len(self.indices))) # example[self.cond_stage_key] = texts[idx_] # if len(image.shape) == 2: # image = Image.fromarray(image) # # if not image.mode == "RGB": # image = image.convert("RGB") # image = np.array(image).astype(np.uint8) if image.shape[0] < self.size or image.shape[1] < self.size: return self.__getitem__(np.random.randint(0, len(self.indices))) # random crop min_side_len = min(image.shape[:2]) # if min_side_len == 0: # return self.__getitem__(np.random.randint(0, len(self.indices))) crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) crop_side_len = int(crop_side_len) if self.center_crop: self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) else: self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) image = self.cropper(image=image)["image"] # ? # if min(image.shape[:2]) == 0: # aa = 1 # rescale image = self.image_rescaler(image=image)["image"] # flip if self.do_flip: image = self.flip(Image.fromarray(image)) image = np.array(image).astype(np.uint8) # # degradation to get the low resolution images # if self.pil_interpolation: # image_pil = PIL.Image.fromarray(image) # LR_image = self.degradation_process(image_pil) # LR_image = np.array(LR_image).astype(np.uint8) # else: # LR_image = self.degradation_process(image=image)["image"] # # store to example # example["image"] = (image/127.5 - 1.0).astype(np.float32) #[-1, 1] # example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) #[-1, 1] example[self.first_stage_key] = (image/127.5 - 1.0).astype(np.float32) # print("image", time.time() - pre_t) pre_t = time.time() if imgfile_name != textfile_name: if textfile_name.endswith(".parquet"): df = pd.read_parquet(textfile_name) else: print( "the format {} of the text file is not supported".format( os.path.splitext(imgfile_name)[1] ) ) raise ValueError try: text = df.TEXT.iloc[file_idx] except: try: text = df.caption.iloc[file_idx] except: raise ValueError example[self.cond_stage_key] = text # Sprint("text (text load)", time.time() - pre_t) return example class LAIONTrain(LAIONBase): def __init__(self, store_folder, *args, ratio=0.7, **kwargs): super().__init__(*args, **kwargs) # store_folder = os.path.join("./data", os.path.basename(folder)) if not os.path.exists(os.path.join(store_folder, "train.txt")): rand_inds = np.random.permutation(self.__len__()) self.indices = rand_inds[:int(len(rand_inds) * ratio)] rand_inds = [str(i) + "\n" for i in rand_inds] with open(os.path.join(store_folder, "train.txt"), "w") as f: f.writelines(rand_inds[:int(len(rand_inds) * ratio)]) with open(os.path.join(store_folder, "val.txt"), "w") as f: f.writelines(rand_inds[int(len(rand_inds) * ratio):]) else: with open(os.path.join(store_folder, "train.txt"), "r") as f: self.indices = [int(s.strip()) for s in f.readlines()] # def get_base(self): # with open("data/imagenet_train_hr_indices.p", "rb") as f: # indices = pickle.load(f) # dset = ImageNetTrain(process_images=False,) # return Subset(dset, indices) class LAIONValidation(LAIONBase): def __init__(self, store_folder, *args, **kwargs): super().__init__(*args, **kwargs) # store_folder = os.path.join("./data", os.path.basename(folder)) if not os.path.exists(os.path.join(store_folder, "val.txt")): raise ValueError else: with open(os.path.join(store_folder, "val.txt"), "r") as f: self.indices = [int(s.strip()) for s in f.readlines()] # def get_base(self): # with open("data/imagenet_val_hr_indices.p", "rb") as f: # indices = pickle.load(f) # dset = ImageNetValidation(process_images=False,) # return Subset(dset, indices) class LAIONIterableBaseDataset(Txt2ImgIterableBaseDataset): ''' Define an interface to make the IterableDatasets for text2img data chainable ''' def __init__(self, img_folder, caption_folder=None, size=256, first_stage_key = "jpg", cond_stage_key = "txt", do_flip = False, min_crop_f=0.5, max_crop_f=1., flip_p=0.5, random_crop=True): assert size super().__init__(size=size) self.caption_folder = caption_folder if self.caption_folder: # self.origin_folders = glob(img_folder + "/*/") # "output_part_000000" self.valid_ids = glob(img_folder + "/*/") self.origin_tsv_paths = { subfolder: glob(subfolder + "/*.tsv") for subfolder in self.valid_ids #self.origin_folders } num = 0 self.tsv_folder_idx = {} # self.tsv_nums = [] for key, value in self.origin_tsv_paths.items(): num += len(value) self.tsv_folder_idx[num] = key # self.tsv_nums.append(num) # self.num_records = len(self.origin_folders) # self.folders = self.origin_folders self.num_records = len(self.valid_ids) self.sample_ids = self.valid_ids self.tsv_paths = self.origin_tsv_paths # to be deprecated self.max_num = self.num_records * 100000 else: parquet_paths = [] for root, _, files in os.walk(os.path.abspath(img_folder)): for file in files: if file.endswith(".parquet"): parquet_paths.append(os.path.join(root, file)) parquet_paths = parquet_paths[:170] # self.origin_parquet_paths = parquet_paths # self.parquet_paths = self.origin_parquet_paths # self.num_records = len(parquet_paths) self.valid_ids = parquet_paths self.sample_ids = self.valid_ids self.num_records = len(self.valid_ids) self.max_num = self.num_records * 1000 self.first_stage_key = first_stage_key self.cond_stage_key = cond_stage_key # self.num_records = len(self.folders) # self.num_records = np.sum([ # len(value_) for value_ in self.tsv_paths.values() # ]) self.do_flip = do_flip if self.do_flip: self.flip = transforms.RandomHorizontalFlip(p=flip_p) # self.base = self.get_base() # self.size = size self.min_crop_f = min_crop_f self.max_crop_f = max_crop_f assert(max_crop_f <= 1.) self.center_crop = not random_crop self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) # self.num_records = num_records # self.valid_ids = valid_ids # print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') # def __len__(self): # # return self.num_records # return self.max_num def __iter__(self): if self.caption_folder: return self.parquet_tsv_iter() else: return self.parquet_iter() def parquet_iter(self): print("this shard on GPU {}: {}".format(_get_distributed_settings()[1], len(self.sample_ids))) idx = 0 while idx >= 0: for parqut_path in self.sample_ids: #parquet_paths: df = pd.read_parquet(parqut_path) for file_idx in range(len(df)): img_code = df.jpg.iloc[file_idx] if img_code: try: image = self.generate_img(img_code) except: # print("can' t open") continue if image is None: continue # except: # continue try: text = df.caption.iloc[file_idx] except: try: text = df.TEXT.iloc[file_idx] except: continue if text is None: continue example = {} example[self.first_stage_key] = image example[self.cond_stage_key] = text yield example del df print("has gone over the whole dataset, need to start next round") idx += 1 def parquet_tsv_iter(self): for subfolder in self.sample_ids: #folders: caption_path = os.path.join( self.caption_folder, os.path.basename(subfolder).lstrip("output_") + ".parquet" ) par_data = pd.read_parquet(caption_path) # faster for image_path in self.tsv_paths[subfolder]: with open(image_path, "r") as f: for line_ in tqdm(f.readlines()): # shuffle could be done # line_ = f.readline() idx, img_code = line_.split("\t") # if not list_[1].startswith("/"): # continue try: img_code = base64.b64decode(img_code) #.decode() image = self.generate_img(img_code) if not image: continue except: continue example = dict() example[self.first_stage_key] = image idx = int(idx) text = par_data.iloc[idx].TEXT example[self.cond_stage_key] = text yield example del par_data def generate_img(self, img_code): image = Image.open(BytesIO(img_code)).convert("RGB") image = np.array(image).astype(np.uint8) if image.shape[0] < self.size or image.shape[1] < self.size: return None # crop min_side_len = min(image.shape[:2]) crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) crop_side_len = int(crop_side_len) if self.center_crop: self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) else: self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) image = self.cropper(image=image)["image"] # ? # rescale image = self.image_rescaler(image=image)["image"] # flip if self.do_flip: image = self.flip(Image.fromarray(image)) image = np.array(image).astype(np.uint8) return (image/127.5 - 1.0).astype(np.float32) # pass