import pathlib import random from typing import Any, Dict, List, Optional, Tuple, Union import datasets import datasets.data_files import datasets.distributed import datasets.exceptions import huggingface_hub import huggingface_hub.errors import numpy as np import PIL.Image import torch import torch.distributed.checkpoint.stateful from diffusers.utils import load_image, load_video from huggingface_hub import list_repo_files, repo_exists, snapshot_download from tqdm.auto import tqdm from .. import constants from .. import functional as FF from ..logging import get_logger from . import utils import decord # isort:skip decord.bridge.set_bridge("torch") logger = get_logger() # fmt: off MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024 COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"] COMMON_VIDEO_FILES = ["video.txt", "videos.txt"] COMMON_IMAGE_FILES = ["image.txt", "images.txt"] COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"] # fmt: on class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() self.root = pathlib.Path(root) self.infinite = infinite data = [] caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) for caption_file in caption_files: data_file = self._find_data_file(caption_file) if data_file: data.append( { "caption": (self.root / caption_file).as_posix(), "image": (self.root / data_file).as_posix(), } ) data = datasets.Dataset.from_list(data) data = data.cast_column("image", datasets.Image(mode="RGB")) self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["caption"] = _read_caption_from_file(sample["caption"]) sample["image"] = _preprocess_image(sample["image"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} def _find_data_file(self, caption_file: str) -> str: caption_file = pathlib.Path(caption_file) data_file = None found_data = 0 for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: image_filename = caption_file.with_suffix(f".{extension}") if image_filename.exists(): found_data += 1 data_file = image_filename if found_data == 0: return False elif found_data > 1: raise ValueError( f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " f"file per caption file. The following extensions are supported:\n" f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n" ) return data_file.as_posix() class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() self.root = pathlib.Path(root) self.infinite = infinite data = [] caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0)) for caption_file in caption_files: data_file = self._find_data_file(caption_file) if data_file: data.append( { "caption": (self.root / caption_file).as_posix(), "video": (self.root / data_file).as_posix(), } ) data = datasets.Dataset.from_list(data) data = data.cast_column("video", datasets.Video()) self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["caption"] = _read_caption_from_file(sample["caption"]) sample["video"] = _preprocess_video(sample["video"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} def _find_data_file(self, caption_file: str) -> str: caption_file = pathlib.Path(caption_file) data_file = None found_data = 0 for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: video_filename = caption_file.with_suffix(f".{extension}") if video_filename.exists(): found_data += 1 data_file = video_filename if found_data == 0: return False elif found_data > 1: raise ValueError( f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data " f"file per caption file. The following extensions are supported:\n" f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n" ) return data_file.as_posix() class ImageFileCaptionFileListDataset( torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful ): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] VALID_IMAGE_FILES = ["image.txt", "images.txt"] self.root = pathlib.Path(root) self.infinite = infinite data = [] existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()] if len(existing_caption_files) == 0: raise FileNotFoundError( f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" ) if len(existing_image_files) == 0: raise FileNotFoundError( f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" ) if len(existing_caption_files) > 1: raise ValueError( f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" ) if len(existing_image_files) > 1: raise ValueError( f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}" ) caption_file = existing_caption_files[0] image_file = existing_image_files[0] with open((self.root / caption_file).as_posix(), "r") as f: captions = f.read().splitlines() with open((self.root / image_file).as_posix(), "r") as f: images = f.read().splitlines() images = [(self.root / image).as_posix() for image in images] if len(captions) != len(images): raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})") for caption, image in zip(captions, images): data.append({"caption": caption, "image": image}) data = datasets.Dataset.from_list(data) data = data.cast_column("image", datasets.Image(mode="RGB")) self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["image"] = _preprocess_image(sample["image"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class VideoFileCaptionFileListDataset( torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful ): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"] VALID_VIDEO_FILES = ["video.txt", "videos.txt"] self.root = pathlib.Path(root) self.infinite = infinite data = [] existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()] existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()] if len(existing_caption_files) == 0: raise FileNotFoundError( f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" ) if len(existing_video_files) == 0: raise FileNotFoundError( f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" ) if len(existing_caption_files) > 1: raise ValueError( f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}" ) if len(existing_video_files) > 1: raise ValueError( f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}" ) caption_file = existing_caption_files[0] video_file = existing_video_files[0] with open((self.root / caption_file).as_posix(), "r") as f: captions = f.read().splitlines() with open((self.root / video_file).as_posix(), "r") as f: videos = f.read().splitlines() videos = [(self.root / video).as_posix() for video in videos] if len(captions) != len(videos): raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})") for caption, video in zip(captions, videos): data.append({"caption": caption, "video": video}) data = datasets.Dataset.from_list(data) data = data.cast_column("video", datasets.Video()) self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["video"] = _preprocess_video(sample["video"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() self.root = pathlib.Path(root) self.infinite = infinite data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train") self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["image"] = _preprocess_image(sample["image"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__(self, root: str, infinite: bool = False) -> None: super().__init__() self.root = pathlib.Path(root) self.infinite = infinite data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train") self._data = data.to_iterable_dataset() self._sample_index = 0 self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 sample["video"] = _preprocess_video(sample["video"]) yield sample if not self.infinite: logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data") break else: self._sample_index = 0 def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__( self, dataset_name: str, infinite: bool = False, column_names: Union[str, List[str]] = "__auto__", weights: Dict[str, float] = -1, **kwargs, ) -> None: super().__init__() assert weights == -1 or isinstance( weights, dict ), "`weights` must be a dictionary of probabilities for each caption column" self.dataset_name = dataset_name self.infinite = infinite data = datasets.load_dataset(dataset_name, split="train", streaming=True) if column_names == "__auto__": if weights == -1: caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] if len(caption_columns) == 0: raise ValueError( f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}" ) weights = [1] * len(caption_columns) else: caption_columns = list(weights.keys()) weights = list(weights.values()) if not all(column in data.column_names for column in caption_columns): raise ValueError( f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" ) else: if isinstance(column_names, str): if column_names not in data.column_names: raise ValueError( f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" ) caption_columns = [column_names] weights = [1] if weights == -1 else [weights.get(column_names)] elif isinstance(column_names, list): if not all(column in data.column_names for column in column_names): raise ValueError( f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" ) caption_columns = column_names weights = [1] if weights == -1 else [weights.get(column) for column in column_names] else: raise ValueError(f"Unsupported type for column_name: {type(column_names)}") for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS: if column_names in data.column_names: data = data.cast_column(column_names, datasets.Image(mode="RGB")) data = data.rename_column(column_names, "image") break self._data = data self._sample_index = 0 self._precomputable_once = False self._caption_columns = caption_columns self._weights = weights def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] sample["caption"] = sample[caption_column] sample["image"] = _preprocess_image(sample["image"]) yield sample if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") break else: # Reset offset for the next iteration self._sample_index = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__( self, dataset_name: str, infinite: bool = False, column_names: Union[str, List[str]] = "__auto__", weights: Dict[str, float] = -1, **kwargs, ) -> None: super().__init__() assert weights == -1 or isinstance( weights, dict ), "`weights` must be a dictionary of probabilities for each caption column" self.dataset_name = dataset_name self.infinite = infinite data = datasets.load_dataset(dataset_name, split="train", streaming=True) if column_names == "__auto__": if weights == -1: caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES] if len(caption_columns) == 0: raise ValueError( f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}" ) weights = [1] * len(caption_columns) else: caption_columns = list(weights.keys()) weights = list(weights.values()) if not all(column in data.column_names for column in caption_columns): raise ValueError( f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}" ) else: if isinstance(column_names, str): if column_names not in data.column_names: raise ValueError( f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}" ) caption_columns = [column_names] weights = [1] if weights == -1 else [weights.get(column_names)] elif isinstance(column_names, list): if not all(column in data.column_names for column in column_names): raise ValueError( f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}" ) caption_columns = column_names weights = [1] if weights == -1 else [weights.get(column) for column in column_names] else: raise ValueError(f"Unsupported type for column_name: {type(column_names)}") for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS: if column_names in data.column_names: data = data.cast_column(column_names, datasets.Video()) data = data.rename_column(column_names, "video") break self._data = data self._sample_index = 0 self._precomputable_once = False self._caption_columns = caption_columns self._weights = weights def _get_data_iter(self): if self._sample_index == 0: return iter(self._data) return iter(self._data.skip(self._sample_index)) def __iter__(self): while True: for sample in self._get_data_iter(): self._sample_index += 1 caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0] sample["caption"] = sample[caption_column] sample["video"] = _preprocess_video(sample["video"]) yield sample if not self.infinite: logger.warning(f"Dataset {self.dataset_name} has run out of data") break else: # Reset offset for the next iteration self._sample_index = 0 logger.warning(f"Dataset {self.dataset_name} is being re-looped") def load_state_dict(self, state_dict): self._sample_index = state_dict["sample_index"] def state_dict(self): return {"sample_index": self._sample_index} class ValidationDataset(torch.utils.data.IterableDataset): def __init__(self, filename: str): super().__init__() self.filename = pathlib.Path(filename) if not self.filename.exists(): raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist") if self.filename.suffix == ".csv": data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train") elif self.filename.suffix == ".json": data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data") elif self.filename.suffix == ".parquet": data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train") elif self.filename.suffix == ".arrow": data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train") else: _SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"] raise ValueError( f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}" ) self._data = data.to_iterable_dataset() def __iter__(self): for sample in self._data: # For consistency reasons, we mandate that "caption" is always present in the validation dataset. # However, since the model specifications use "prompt", we create an alias here. sample["prompt"] = sample["caption"] # Load image or video if the path is provided # TODO(aryan): need to handle custom columns here for control conditions sample["image"] = None sample["video"] = None if sample.get("image_path", None) is not None: image_path = pathlib.Path(sample["image_path"]) if not image_path.is_file(): logger.warning(f"Image file {image_path.as_posix()} does not exist.") else: sample["image"] = load_image(sample["image_path"]) if sample.get("video_path", None) is not None: video_path = pathlib.Path(sample["video_path"]) if not video_path.is_file(): logger.warning(f"Video file {video_path.as_posix()} does not exist.") else: sample["video"] = load_video(sample["video_path"]) sample = {k: v for k, v in sample.items() if v is not None} yield sample class IterableDatasetPreprocessingWrapper( torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful ): def __init__( self, dataset: torch.utils.data.IterableDataset, dataset_type: str, id_token: Optional[str] = None, image_resolution_buckets: List[Tuple[int, int]] = None, video_resolution_buckets: List[Tuple[int, int, int]] = None, reshape_mode: str = "bicubic", remove_common_llm_caption_prefixes: bool = False, **kwargs, ): super().__init__() self.dataset = dataset self.dataset_type = dataset_type self.id_token = id_token self.image_resolution_buckets = image_resolution_buckets self.video_resolution_buckets = video_resolution_buckets self.reshape_mode = reshape_mode self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes logger.info( f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n" f" - Dataset Type: {dataset_type}\n" f" - ID Token: {id_token}\n" f" - Image Resolution Buckets: {image_resolution_buckets}\n" f" - Video Resolution Buckets: {video_resolution_buckets}\n" f" - Reshape Mode: {reshape_mode}\n" f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n" ) def __iter__(self): logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset") for sample in iter(self.dataset): if self.dataset_type == "image": if self.image_resolution_buckets: sample["_original_num_frames"] = 1 sample["_original_height"] = sample["image"].size(1) sample["_original_width"] = sample["image"].size(2) sample["image"] = FF.resize_to_nearest_bucket_image( sample["image"], self.image_resolution_buckets, self.reshape_mode ) elif self.dataset_type == "video": if self.video_resolution_buckets: sample["_original_num_frames"] = sample["video"].size(0) sample["_original_height"] = sample["video"].size(2) sample["_original_width"] = sample["video"].size(3) sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video( sample["video"], self.video_resolution_buckets, self.reshape_mode ) if _first_frame_only: msg = ( "The number of frames in the video is less than the minimum bucket size " "specified. The first frame is being used as a single frame video. This " "message is logged at the first occurence and for every 128th occurence " "after that." ) logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128) sample["video"] = sample["video"][0] if self.remove_common_llm_caption_prefixes: sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES) if self.id_token is not None: sample["caption"] = f"{self.id_token} {sample['caption']}" yield sample def load_state_dict(self, state_dict): self.dataset.load_state_dict(state_dict["dataset"]) def state_dict(self): return {"dataset": self.dataset.state_dict()} class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful): def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False): super().__init__() self.datasets = datasets self.buffer_size = buffer_size self.shuffle = shuffle logger.info( f"Initializing IterableCombinedDataset with the following configuration:\n" f" - Number of Datasets: {len(datasets)}\n" f" - Buffer Size: {buffer_size}\n" f" - Shuffle: {shuffle}\n" ) def __iter__(self): logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets") iterators = [iter(dataset) for dataset in self.datasets] buffer = [] per_iter = max(1, self.buffer_size // len(iterators)) for index, it in enumerate(iterators): for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"): try: buffer.append((it, next(it))) except StopIteration: continue while len(buffer) > 0: idx = 0 if self.shuffle: idx = random.randint(0, len(buffer) - 1) current_it, sample = buffer.pop(idx) yield sample try: buffer.append((current_it, next(current_it))) except StopIteration: pass def load_state_dict(self, state_dict): for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]): dataset.load_state_dict(dataset_state_dict) def state_dict(self): return {"datasets": [dataset.state_dict() for dataset in self.datasets]} # TODO(aryan): maybe write a test for this def initialize_dataset( dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None, ) -> torch.utils.data.IterableDataset: assert dataset_type in ["image", "video"] try: does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset") except huggingface_hub.errors.HFValidationError: does_repo_exist_on_hub = False if does_repo_exist_on_hub: return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options) else: return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite) def combine_datasets( datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False ) -> torch.utils.data.IterableDataset: return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle) def wrap_iterable_dataset_for_preprocessing( dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any] ) -> torch.utils.data.IterableDataset: return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config) def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False): root = pathlib.Path(dataset_name_or_root) supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"] metadata_files = [root / metadata_file for metadata_file in supported_metadata_files] metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()] if len(metadata_files) > 1: raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.") if len(metadata_files) == 1: if dataset_type == "image": dataset = ImageFolderDataset(root.as_posix(), infinite=infinite) else: dataset = VideoFolderDataset(root.as_posix(), infinite=infinite) return dataset if _has_data_caption_file_pairs(root, remote=False): if dataset_type == "image": dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite) else: dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite) elif _has_data_file_caption_file_lists(root, remote=False): if dataset_type == "image": dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite) else: dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite) else: raise ValueError( f"Could not find any supported dataset structure in the directory {root}. Please open an issue at " f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will " f"help you set it up." ) return dataset def _initialize_hub_dataset( dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None ): repo_file_list = list_repo_files(dataset_name, repo_type="dataset") if _has_data_caption_file_pairs(repo_file_list, remote=True): return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) elif _has_data_file_caption_file_lists(repo_file_list, remote=True): return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite) has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list) if has_tar_files: return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options) # TODO(aryan): This should be improved caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")] if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT: try: dataset_root = snapshot_download(dataset_name, repo_type="dataset") if dataset_type == "image": dataset = ImageFolderDataset(dataset_root, infinite=infinite) else: dataset = VideoFolderDataset(dataset_root, infinite=infinite) return dataset except Exception: pass raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub") def _initialize_data_caption_file_dataset_from_hub( dataset_name: str, dataset_type: str, infinite: bool = False ) -> torch.utils.data.IterableDataset: logger.info(f"Downloading dataset {dataset_name} from the HF Hub") dataset_root = snapshot_download(dataset_name, repo_type="dataset") if dataset_type == "image": return ImageCaptionFilePairDataset(dataset_root, infinite=infinite) else: return VideoCaptionFilePairDataset(dataset_root, infinite=infinite) def _initialize_data_file_caption_file_dataset_from_hub( dataset_name: str, dataset_type: str, infinite: bool = False ) -> torch.utils.data.IterableDataset: logger.info(f"Downloading dataset {dataset_name} from the HF Hub") dataset_root = snapshot_download(dataset_name, repo_type="dataset") if dataset_type == "image": return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite) else: return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite) def _initialize_webdataset( dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None ) -> torch.utils.data.IterableDataset: logger.info(f"Streaming webdataset {dataset_name} from the HF Hub") _caption_options = _caption_options or {} if dataset_type == "image": return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options) else: return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options) def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: # TODO(aryan): this logic can be improved if not remote: caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0) for caption_file in caption_files: caption_file = pathlib.Path(caption_file) for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: data_filename = caption_file.with_suffix(f".{extension}") if data_filename.exists(): return True return False else: caption_files = [file for file in root if file.endswith(".txt")] for caption_file in caption_files: caption_file = pathlib.Path(caption_file) for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]: data_filename = caption_file.with_suffix(f".{extension}").name if data_filename in root: return True return False def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool: # TODO(aryan): this logic can be improved if not remote: file_list = {x.name for x in root.iterdir()} has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES) has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES) has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES) return has_caption_files and (has_video_files or has_image_files) else: has_caption_files = any(file in root for file in COMMON_CAPTION_FILES) has_video_files = any(file in root for file in COMMON_VIDEO_FILES) has_image_files = any(file in root for file in COMMON_IMAGE_FILES) return has_caption_files and (has_video_files or has_image_files) def _read_caption_from_file(filename: str) -> str: with open(filename, "r") as f: return f.read().strip() def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor: image = image.convert("RGB") image = np.array(image).astype(np.float32) image = torch.from_numpy(image) image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0 return image def _preprocess_video(video: decord.VideoReader) -> torch.Tensor: video = video.get_batch(list(range(len(video)))) video = video.permute(0, 3, 1, 2).contiguous() video = video.float() / 127.5 - 1.0 return video