jbilcke-hf's picture
jbilcke-hf HF Staff
upgrade finetrainers + gradio
ecd5028
raw
history blame
41.1 kB
import pathlib
import random
from typing import Any, Dict, List, Optional, Tuple, Union
import datasets
import datasets.data_files
import datasets.distributed
import datasets.exceptions
import huggingface_hub
import huggingface_hub.errors
import numpy as np
import PIL.Image
import torch
import torch.distributed.checkpoint.stateful
from diffusers.utils import load_image, load_video
from huggingface_hub import list_repo_files, repo_exists, snapshot_download
from tqdm.auto import tqdm
from .. import constants
from .. import functional as FF
from ..logging import get_logger
from . import utils
import decord # isort:skip
decord.bridge.set_bridge("torch")
logger = get_logger()
# fmt: off
MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"]
# fmt: on
class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
self.root = pathlib.Path(root)
self.infinite = infinite
data = []
caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
for caption_file in caption_files:
data_file = self._find_data_file(caption_file)
if data_file:
data.append(
{
"caption": (self.root / caption_file).as_posix(),
"image": (self.root / data_file).as_posix(),
}
)
data = datasets.Dataset.from_list(data)
data = data.cast_column("image", datasets.Image(mode="RGB"))
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["caption"] = _read_caption_from_file(sample["caption"])
sample["image"] = _preprocess_image(sample["image"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
def _find_data_file(self, caption_file: str) -> str:
caption_file = pathlib.Path(caption_file)
data_file = None
found_data = 0
for extension in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
image_filename = caption_file.with_suffix(f".{extension}")
if image_filename.exists():
found_data += 1
data_file = image_filename
if found_data == 0:
return False
elif found_data > 1:
raise ValueError(
f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
f"file per caption file. The following extensions are supported:\n"
f" - Images: {constants.SUPPORTED_IMAGE_FILE_EXTENSIONS}\n"
)
return data_file.as_posix()
class VideoCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
self.root = pathlib.Path(root)
self.infinite = infinite
data = []
caption_files = sorted(utils.find_files(self.root.as_posix(), "*.txt", depth=0))
for caption_file in caption_files:
data_file = self._find_data_file(caption_file)
if data_file:
data.append(
{
"caption": (self.root / caption_file).as_posix(),
"video": (self.root / data_file).as_posix(),
}
)
data = datasets.Dataset.from_list(data)
data = data.cast_column("video", datasets.Video())
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["caption"] = _read_caption_from_file(sample["caption"])
sample["video"] = _preprocess_video(sample["video"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
def _find_data_file(self, caption_file: str) -> str:
caption_file = pathlib.Path(caption_file)
data_file = None
found_data = 0
for extension in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
video_filename = caption_file.with_suffix(f".{extension}")
if video_filename.exists():
found_data += 1
data_file = video_filename
if found_data == 0:
return False
elif found_data > 1:
raise ValueError(
f"Multiple data files found for caption file {caption_file}. Please ensure there is only one data "
f"file per caption file. The following extensions are supported:\n"
f" - Videos: {constants.SUPPORTED_VIDEO_FILE_EXTENSIONS}\n"
)
return data_file.as_posix()
class ImageFileCaptionFileListDataset(
torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
VALID_IMAGE_FILES = ["image.txt", "images.txt"]
self.root = pathlib.Path(root)
self.infinite = infinite
data = []
existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
existing_image_files = [file for file in VALID_IMAGE_FILES if (self.root / file).exists()]
if len(existing_caption_files) == 0:
raise FileNotFoundError(
f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
)
if len(existing_image_files) == 0:
raise FileNotFoundError(
f"No image file found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
)
if len(existing_caption_files) > 1:
raise ValueError(
f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
)
if len(existing_image_files) > 1:
raise ValueError(
f"Multiple image files found in {self.root}. Must have exactly one of {VALID_IMAGE_FILES}"
)
caption_file = existing_caption_files[0]
image_file = existing_image_files[0]
with open((self.root / caption_file).as_posix(), "r") as f:
captions = f.read().splitlines()
with open((self.root / image_file).as_posix(), "r") as f:
images = f.read().splitlines()
images = [(self.root / image).as_posix() for image in images]
if len(captions) != len(images):
raise ValueError(f"Number of captions ({len(captions)}) must match number of images ({len(images)})")
for caption, image in zip(captions, images):
data.append({"caption": caption, "image": image})
data = datasets.Dataset.from_list(data)
data = data.cast_column("image", datasets.Image(mode="RGB"))
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["image"] = _preprocess_image(sample["image"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class VideoFileCaptionFileListDataset(
torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
VALID_CAPTION_FILES = ["caption.txt", "captions.txt", "prompt.txt", "prompts.txt"]
VALID_VIDEO_FILES = ["video.txt", "videos.txt"]
self.root = pathlib.Path(root)
self.infinite = infinite
data = []
existing_caption_files = [file for file in VALID_CAPTION_FILES if (self.root / file).exists()]
existing_video_files = [file for file in VALID_VIDEO_FILES if (self.root / file).exists()]
if len(existing_caption_files) == 0:
raise FileNotFoundError(
f"No caption file found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
)
if len(existing_video_files) == 0:
raise FileNotFoundError(
f"No video file found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
)
if len(existing_caption_files) > 1:
raise ValueError(
f"Multiple caption files found in {self.root}. Must have exactly one of {VALID_CAPTION_FILES}"
)
if len(existing_video_files) > 1:
raise ValueError(
f"Multiple video files found in {self.root}. Must have exactly one of {VALID_VIDEO_FILES}"
)
caption_file = existing_caption_files[0]
video_file = existing_video_files[0]
with open((self.root / caption_file).as_posix(), "r") as f:
captions = f.read().splitlines()
with open((self.root / video_file).as_posix(), "r") as f:
videos = f.read().splitlines()
videos = [(self.root / video).as_posix() for video in videos]
if len(captions) != len(videos):
raise ValueError(f"Number of captions ({len(captions)}) must match number of videos ({len(videos)})")
for caption, video in zip(captions, videos):
data.append({"caption": caption, "video": video})
data = datasets.Dataset.from_list(data)
data = data.cast_column("video", datasets.Video())
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["video"] = _preprocess_video(sample["video"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class ImageFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
self.root = pathlib.Path(root)
self.infinite = infinite
data = datasets.load_dataset("imagefolder", data_dir=self.root.as_posix(), split="train")
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["image"] = _preprocess_image(sample["image"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, root: str, infinite: bool = False) -> None:
super().__init__()
self.root = pathlib.Path(root)
self.infinite = infinite
data = datasets.load_dataset("videofolder", data_dir=self.root.as_posix(), split="train")
self._data = data.to_iterable_dataset()
self._sample_index = 0
self._precomputable_once = len(data) <= MAX_PRECOMPUTABLE_ITEMS_LIMIT
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
sample["video"] = _preprocess_video(sample["video"])
yield sample
if not self.infinite:
logger.warning(f"Dataset ({self.__class__.__name__}={self.root}) has run out of data")
break
else:
self._sample_index = 0
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(
self,
dataset_name: str,
infinite: bool = False,
column_names: Union[str, List[str]] = "__auto__",
weights: Dict[str, float] = -1,
**kwargs,
) -> None:
super().__init__()
assert weights == -1 or isinstance(
weights, dict
), "`weights` must be a dictionary of probabilities for each caption column"
self.dataset_name = dataset_name
self.infinite = infinite
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
if column_names == "__auto__":
if weights == -1:
caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
if len(caption_columns) == 0:
raise ValueError(
f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
)
weights = [1] * len(caption_columns)
else:
caption_columns = list(weights.keys())
weights = list(weights.values())
if not all(column in data.column_names for column in caption_columns):
raise ValueError(
f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
)
else:
if isinstance(column_names, str):
if column_names not in data.column_names:
raise ValueError(
f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
)
caption_columns = [column_names]
weights = [1] if weights == -1 else [weights.get(column_names)]
elif isinstance(column_names, list):
if not all(column in data.column_names for column in column_names):
raise ValueError(
f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
)
caption_columns = column_names
weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
else:
raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
if column_names in data.column_names:
data = data.cast_column(column_names, datasets.Image(mode="RGB"))
data = data.rename_column(column_names, "image")
break
self._data = data
self._sample_index = 0
self._precomputable_once = False
self._caption_columns = caption_columns
self._weights = weights
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
sample["caption"] = sample[caption_column]
sample["image"] = _preprocess_image(sample["image"])
yield sample
if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_index = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(
self,
dataset_name: str,
infinite: bool = False,
column_names: Union[str, List[str]] = "__auto__",
weights: Dict[str, float] = -1,
**kwargs,
) -> None:
super().__init__()
assert weights == -1 or isinstance(
weights, dict
), "`weights` must be a dictionary of probabilities for each caption column"
self.dataset_name = dataset_name
self.infinite = infinite
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
if column_names == "__auto__":
if weights == -1:
caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
if len(caption_columns) == 0:
raise ValueError(
f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
)
weights = [1] * len(caption_columns)
else:
caption_columns = list(weights.keys())
weights = list(weights.values())
if not all(column in data.column_names for column in caption_columns):
raise ValueError(
f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
)
else:
if isinstance(column_names, str):
if column_names not in data.column_names:
raise ValueError(
f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
)
caption_columns = [column_names]
weights = [1] if weights == -1 else [weights.get(column_names)]
elif isinstance(column_names, list):
if not all(column in data.column_names for column in column_names):
raise ValueError(
f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
)
caption_columns = column_names
weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
else:
raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
if column_names in data.column_names:
data = data.cast_column(column_names, datasets.Video())
data = data.rename_column(column_names, "video")
break
self._data = data
self._sample_index = 0
self._precomputable_once = False
self._caption_columns = caption_columns
self._weights = weights
def _get_data_iter(self):
if self._sample_index == 0:
return iter(self._data)
return iter(self._data.skip(self._sample_index))
def __iter__(self):
while True:
for sample in self._get_data_iter():
self._sample_index += 1
caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
sample["caption"] = sample[caption_column]
sample["video"] = _preprocess_video(sample["video"])
yield sample
if not self.infinite:
logger.warning(f"Dataset {self.dataset_name} has run out of data")
break
else:
# Reset offset for the next iteration
self._sample_index = 0
logger.warning(f"Dataset {self.dataset_name} is being re-looped")
def load_state_dict(self, state_dict):
self._sample_index = state_dict["sample_index"]
def state_dict(self):
return {"sample_index": self._sample_index}
class ValidationDataset(torch.utils.data.IterableDataset):
def __init__(self, filename: str):
super().__init__()
self.filename = pathlib.Path(filename)
if not self.filename.exists():
raise FileNotFoundError(f"File {self.filename.as_posix()} does not exist")
if self.filename.suffix == ".csv":
data = datasets.load_dataset("csv", data_files=self.filename.as_posix(), split="train")
elif self.filename.suffix == ".json":
data = datasets.load_dataset("json", data_files=self.filename.as_posix(), split="train", field="data")
elif self.filename.suffix == ".parquet":
data = datasets.load_dataset("parquet", data_files=self.filename.as_posix(), split="train")
elif self.filename.suffix == ".arrow":
data = datasets.load_dataset("arrow", data_files=self.filename.as_posix(), split="train")
else:
_SUPPORTED_FILE_FORMATS = [".csv", ".json", ".parquet", ".arrow"]
raise ValueError(
f"Unsupported file format {self.filename.suffix} for validation dataset. Supported formats are: {_SUPPORTED_FILE_FORMATS}"
)
self._data = data.to_iterable_dataset()
def __iter__(self):
for sample in self._data:
# For consistency reasons, we mandate that "caption" is always present in the validation dataset.
# However, since the model specifications use "prompt", we create an alias here.
sample["prompt"] = sample["caption"]
# Load image or video if the path is provided
# TODO(aryan): need to handle custom columns here for control conditions
sample["image"] = None
sample["video"] = None
if sample.get("image_path", None) is not None:
image_path = pathlib.Path(sample["image_path"])
if not image_path.is_file():
logger.warning(f"Image file {image_path.as_posix()} does not exist.")
else:
sample["image"] = load_image(sample["image_path"])
if sample.get("video_path", None) is not None:
video_path = pathlib.Path(sample["video_path"])
if not video_path.is_file():
logger.warning(f"Video file {video_path.as_posix()} does not exist.")
else:
sample["video"] = load_video(sample["video_path"])
sample = {k: v for k, v in sample.items() if v is not None}
yield sample
class IterableDatasetPreprocessingWrapper(
torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful
):
def __init__(
self,
dataset: torch.utils.data.IterableDataset,
dataset_type: str,
id_token: Optional[str] = None,
image_resolution_buckets: List[Tuple[int, int]] = None,
video_resolution_buckets: List[Tuple[int, int, int]] = None,
reshape_mode: str = "bicubic",
remove_common_llm_caption_prefixes: bool = False,
**kwargs,
):
super().__init__()
self.dataset = dataset
self.dataset_type = dataset_type
self.id_token = id_token
self.image_resolution_buckets = image_resolution_buckets
self.video_resolution_buckets = video_resolution_buckets
self.reshape_mode = reshape_mode
self.remove_common_llm_caption_prefixes = remove_common_llm_caption_prefixes
logger.info(
f"Initializing IterableDatasetPreprocessingWrapper for the dataset with the following configuration:\n"
f" - Dataset Type: {dataset_type}\n"
f" - ID Token: {id_token}\n"
f" - Image Resolution Buckets: {image_resolution_buckets}\n"
f" - Video Resolution Buckets: {video_resolution_buckets}\n"
f" - Reshape Mode: {reshape_mode}\n"
f" - Remove Common LLM Caption Prefixes: {remove_common_llm_caption_prefixes}\n"
)
def __iter__(self):
logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
for sample in iter(self.dataset):
if self.dataset_type == "image":
if self.image_resolution_buckets:
sample["_original_num_frames"] = 1
sample["_original_height"] = sample["image"].size(1)
sample["_original_width"] = sample["image"].size(2)
sample["image"] = FF.resize_to_nearest_bucket_image(
sample["image"], self.image_resolution_buckets, self.reshape_mode
)
elif self.dataset_type == "video":
if self.video_resolution_buckets:
sample["_original_num_frames"] = sample["video"].size(0)
sample["_original_height"] = sample["video"].size(2)
sample["_original_width"] = sample["video"].size(3)
sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
sample["video"], self.video_resolution_buckets, self.reshape_mode
)
if _first_frame_only:
msg = (
"The number of frames in the video is less than the minimum bucket size "
"specified. The first frame is being used as a single frame video. This "
"message is logged at the first occurence and for every 128th occurence "
"after that."
)
logger.log_freq("WARNING", "BUCKET_TEMPORAL_SIZE_UNAVAILABLE", msg, frequency=128)
sample["video"] = sample["video"][0]
if self.remove_common_llm_caption_prefixes:
sample["caption"] = FF.remove_prefix(sample["caption"], constants.COMMON_LLM_START_PHRASES)
if self.id_token is not None:
sample["caption"] = f"{self.id_token} {sample['caption']}"
yield sample
def load_state_dict(self, state_dict):
self.dataset.load_state_dict(state_dict["dataset"])
def state_dict(self):
return {"dataset": self.dataset.state_dict()}
class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
def __init__(self, datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False):
super().__init__()
self.datasets = datasets
self.buffer_size = buffer_size
self.shuffle = shuffle
logger.info(
f"Initializing IterableCombinedDataset with the following configuration:\n"
f" - Number of Datasets: {len(datasets)}\n"
f" - Buffer Size: {buffer_size}\n"
f" - Shuffle: {shuffle}\n"
)
def __iter__(self):
logger.info(f"Starting IterableCombinedDataset with {len(self.datasets)} datasets")
iterators = [iter(dataset) for dataset in self.datasets]
buffer = []
per_iter = max(1, self.buffer_size // len(iterators))
for index, it in enumerate(iterators):
for _ in tqdm(range(per_iter), desc=f"Filling buffer from data iterator {index}"):
try:
buffer.append((it, next(it)))
except StopIteration:
continue
while len(buffer) > 0:
idx = 0
if self.shuffle:
idx = random.randint(0, len(buffer) - 1)
current_it, sample = buffer.pop(idx)
yield sample
try:
buffer.append((current_it, next(current_it)))
except StopIteration:
pass
def load_state_dict(self, state_dict):
for dataset, dataset_state_dict in zip(self.datasets, state_dict["datasets"]):
dataset.load_state_dict(dataset_state_dict)
def state_dict(self):
return {"datasets": [dataset.state_dict() for dataset in self.datasets]}
# TODO(aryan): maybe write a test for this
def initialize_dataset(
dataset_name_or_root: str,
dataset_type: str = "video",
streaming: bool = True,
infinite: bool = False,
*,
_caption_options: Optional[Dict[str, Any]] = None,
) -> torch.utils.data.IterableDataset:
assert dataset_type in ["image", "video"]
try:
does_repo_exist_on_hub = repo_exists(dataset_name_or_root, repo_type="dataset")
except huggingface_hub.errors.HFValidationError:
does_repo_exist_on_hub = False
if does_repo_exist_on_hub:
return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options)
else:
return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
def combine_datasets(
datasets: List[torch.utils.data.IterableDataset], buffer_size: int, shuffle: bool = False
) -> torch.utils.data.IterableDataset:
return IterableCombinedDataset(datasets=datasets, buffer_size=buffer_size, shuffle=shuffle)
def wrap_iterable_dataset_for_preprocessing(
dataset: torch.utils.data.IterableDataset, dataset_type: str, config: Dict[str, Any]
) -> torch.utils.data.IterableDataset:
return IterableDatasetPreprocessingWrapper(dataset, dataset_type, **config)
def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infinite: bool = False):
root = pathlib.Path(dataset_name_or_root)
supported_metadata_files = ["metadata.json", "metadata.jsonl", "metadata.csv"]
metadata_files = [root / metadata_file for metadata_file in supported_metadata_files]
metadata_files = [metadata_file for metadata_file in metadata_files if metadata_file.exists()]
if len(metadata_files) > 1:
raise ValueError("Found multiple metadata files. Please ensure there is only one metadata file.")
if len(metadata_files) == 1:
if dataset_type == "image":
dataset = ImageFolderDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFolderDataset(root.as_posix(), infinite=infinite)
return dataset
if _has_data_caption_file_pairs(root, remote=False):
if dataset_type == "image":
dataset = ImageCaptionFilePairDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoCaptionFilePairDataset(root.as_posix(), infinite=infinite)
elif _has_data_file_caption_file_lists(root, remote=False):
if dataset_type == "image":
dataset = ImageFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
dataset = VideoFileCaptionFileListDataset(root.as_posix(), infinite=infinite)
else:
raise ValueError(
f"Could not find any supported dataset structure in the directory {root}. Please open an issue at "
f"https://github.com/a-r-r-o-w/finetrainers with information about your dataset structure and we will "
f"help you set it up."
)
return dataset
def _initialize_hub_dataset(
dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None
):
repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
if _has_data_caption_file_pairs(repo_file_list, remote=True):
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
if has_tar_files:
return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options)
# TODO(aryan): This should be improved
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
try:
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
dataset = ImageFolderDataset(dataset_root, infinite=infinite)
else:
dataset = VideoFolderDataset(dataset_root, infinite=infinite)
return dataset
except Exception:
pass
raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")
def _initialize_data_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
return ImageCaptionFilePairDataset(dataset_root, infinite=infinite)
else:
return VideoCaptionFilePairDataset(dataset_root, infinite=infinite)
def _initialize_data_file_caption_file_dataset_from_hub(
dataset_name: str, dataset_type: str, infinite: bool = False
) -> torch.utils.data.IterableDataset:
logger.info(f"Downloading dataset {dataset_name} from the HF Hub")
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
if dataset_type == "image":
return ImageFileCaptionFileListDataset(dataset_root, infinite=infinite)
else:
return VideoFileCaptionFileListDataset(dataset_root, infinite=infinite)
def _initialize_webdataset(
dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None
) -> torch.utils.data.IterableDataset:
logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
_caption_options = _caption_options or {}
if dataset_type == "image":
return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options)
else:
return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options)
def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
caption_files = utils.find_files(root.as_posix(), "*.txt", depth=0)
for caption_file in caption_files:
caption_file = pathlib.Path(caption_file)
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}")
if data_filename.exists():
return True
return False
else:
caption_files = [file for file in root if file.endswith(".txt")]
for caption_file in caption_files:
caption_file = pathlib.Path(caption_file)
for extension in [*constants.SUPPORTED_IMAGE_FILE_EXTENSIONS, *constants.SUPPORTED_VIDEO_FILE_EXTENSIONS]:
data_filename = caption_file.with_suffix(f".{extension}").name
if data_filename in root:
return True
return False
def _has_data_file_caption_file_lists(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
# TODO(aryan): this logic can be improved
if not remote:
file_list = {x.name for x in root.iterdir()}
has_caption_files = any(file in file_list for file in COMMON_CAPTION_FILES)
has_video_files = any(file in file_list for file in COMMON_VIDEO_FILES)
has_image_files = any(file in file_list for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)
else:
has_caption_files = any(file in root for file in COMMON_CAPTION_FILES)
has_video_files = any(file in root for file in COMMON_VIDEO_FILES)
has_image_files = any(file in root for file in COMMON_IMAGE_FILES)
return has_caption_files and (has_video_files or has_image_files)
def _read_caption_from_file(filename: str) -> str:
with open(filename, "r") as f:
return f.read().strip()
def _preprocess_image(image: PIL.Image.Image) -> torch.Tensor:
image = image.convert("RGB")
image = np.array(image).astype(np.float32)
image = torch.from_numpy(image)
image = image.permute(2, 0, 1).contiguous() / 127.5 - 1.0
return image
def _preprocess_video(video: decord.VideoReader) -> torch.Tensor:
video = video.get_batch(list(range(len(video))))
video = video.permute(0, 3, 1, 2).contiguous()
video = video.float() / 127.5 - 1.0
return video