Spaces:
Running
Running
import functools | |
import json | |
import math | |
import os | |
import time | |
from pathlib import Path | |
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union | |
import datasets.distributed | |
import diffusers | |
import torch | |
import torch.backends | |
import transformers | |
import wandb | |
from diffusers import DiffusionPipeline | |
from diffusers.hooks import apply_layerwise_casting | |
from diffusers.training_utils import cast_training_params | |
from diffusers.utils import export_to_video | |
from huggingface_hub import create_repo, upload_folder | |
from peft import LoraConfig, get_peft_model_state_dict | |
from tqdm import tqdm | |
from ... import data, logging, optimizer, parallel, patches, utils | |
from ...config import TrainingType | |
from ...state import State, TrainState | |
if TYPE_CHECKING: | |
from ...args import BaseArgs | |
from ...models import ModelSpecification | |
logger = logging.get_logger() | |
class SFTTrainer: | |
# fmt: off | |
_all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] | |
_condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"] | |
_latent_component_names = ["vae"] | |
_diffusion_component_names = ["transformer", "unet", "scheduler"] | |
# fmt: on | |
def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None: | |
self.args = args | |
self.state = State() | |
self.state.train_state = TrainState() | |
# Tokenizers | |
self.tokenizer = None | |
self.tokenizer_2 = None | |
self.tokenizer_3 = None | |
# Text encoders | |
self.text_encoder = None | |
self.text_encoder_2 = None | |
self.text_encoder_3 = None | |
# Denoisers | |
self.transformer = None | |
self.unet = None | |
# Autoencoders | |
self.vae = None | |
# Scheduler | |
self.scheduler = None | |
# Optimizer & LR scheduler | |
self.optimizer = None | |
self.lr_scheduler = None | |
# Checkpoint manager | |
self.checkpointer = None | |
self._init_distributed() | |
self._init_config_options() | |
# Perform any patches that might be necessary for training to work as expected | |
patches.perform_patches_for_training(self.args, self.state.parallel_backend) | |
self.model_specification = model_specification | |
self._are_condition_models_loaded = False | |
def run(self) -> None: | |
try: | |
self._prepare_models() | |
self._prepare_trainable_parameters() | |
self._prepare_for_training() | |
self._prepare_dataset() | |
self._prepare_checkpointing() | |
self._train() | |
# trainer._evaluate() | |
except Exception as e: | |
logger.error(f"Error during training: {e}") | |
self.state.parallel_backend.destroy() | |
raise e | |
def _prepare_models(self) -> None: | |
logger.info("Initializing models") | |
diffusion_components = self.model_specification.load_diffusion_models() | |
self._set_components(diffusion_components) | |
if self.state.parallel_backend.pipeline_parallel_enabled: | |
raise NotImplementedError( | |
"Pipeline parallelism is not supported yet. This will be supported in the future." | |
) | |
def _prepare_trainable_parameters(self) -> None: | |
logger.info("Initializing trainable parameters") | |
parallel_backend = self.state.parallel_backend | |
if self.args.training_type == TrainingType.FULL_FINETUNE: | |
logger.info("Finetuning transformer with no additional parameters") | |
utils.set_requires_grad([self.transformer], True) | |
else: | |
logger.info("Finetuning transformer with PEFT parameters") | |
utils.set_requires_grad([self.transformer], False) | |
# Layerwise upcasting must be applied before adding the LoRA adapter. | |
# If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on | |
# CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. | |
if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: | |
apply_layerwise_casting( | |
self.transformer, | |
storage_dtype=self.args.layerwise_upcasting_storage_dtype, | |
compute_dtype=self.args.transformer_dtype, | |
skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, | |
non_blocking=True, | |
) | |
transformer_lora_config = None | |
if self.args.training_type == TrainingType.LORA: | |
transformer_lora_config = LoraConfig( | |
r=self.args.rank, | |
lora_alpha=self.args.lora_alpha, | |
init_lora_weights=True, | |
target_modules=self.args.target_modules, | |
) | |
self.transformer.add_adapter(transformer_lora_config) | |
# # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 | |
# # even if layerwise upcasting. Would be nice to have a test as well | |
# self.register_saving_loading_hooks(transformer_lora_config) | |
# Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all | |
# parameters to be of the same dtype. | |
if parallel_backend.data_sharding_enabled: | |
self.transformer.to(dtype=self.args.transformer_dtype) | |
else: | |
if self.args.training_type == TrainingType.LORA: | |
cast_training_params([self.transformer], dtype=torch.float32) | |
def _prepare_for_training(self) -> None: | |
# 1. Apply parallelism | |
parallel_backend = self.state.parallel_backend | |
world_mesh = parallel_backend.get_mesh() | |
model_specification = self.model_specification | |
if parallel_backend.context_parallel_enabled: | |
raise NotImplementedError( | |
"Context parallelism is not supported yet. This will be supported in the future." | |
) | |
if parallel_backend.tensor_parallel_enabled: | |
# TODO(aryan): handle fp8 from TorchAO here | |
model_specification.apply_tensor_parallel( | |
backend=parallel.ParallelBackendEnum.PTD, | |
device_mesh=parallel_backend.get_mesh()["tp"], | |
transformer=self.transformer, | |
) | |
# Enable gradient checkpointing | |
if self.args.gradient_checkpointing: | |
# TODO(aryan): support other checkpointing types | |
utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") | |
# Enable DDP, FSDP or HSDP | |
if parallel_backend.data_sharding_enabled: | |
# TODO(aryan): remove this when supported | |
if self.args.parallel_backend == "accelerate": | |
raise NotImplementedError("Data sharding is not supported with Accelerate yet.") | |
if parallel_backend.data_replication_enabled: | |
logger.info("Applying HSDP to the model") | |
else: | |
logger.info("Applying FSDP to the model") | |
# Apply FSDP or HSDP | |
if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: | |
dp_mesh_names = ("dp_replicate", "dp_shard_cp") | |
else: | |
dp_mesh_names = ("dp_shard_cp",) | |
parallel.apply_fsdp2_ptd( | |
model=self.transformer, | |
dp_mesh=world_mesh[dp_mesh_names], | |
param_dtype=self.args.transformer_dtype, | |
reduce_dtype=torch.float32, | |
output_dtype=None, | |
pp_enabled=parallel_backend.pipeline_parallel_enabled, | |
cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later | |
) | |
elif parallel_backend.data_replication_enabled: | |
logger.info("Applying DDP to the model") | |
if world_mesh.ndim > 1: | |
raise ValueError("DDP not supported for > 1D parallelism") | |
parallel_backend.apply_ddp(self.transformer, world_mesh) | |
self._move_components_to_device() | |
# 2. Prepare optimizer and lr scheduler | |
# For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. | |
# This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer | |
# step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) | |
# but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require | |
# gradients. TODO(aryan): look into it in the future. | |
model_parts = [self.transformer] | |
self.state.num_trainable_parameters = sum( | |
p.numel() for m in model_parts for p in m.parameters() if p.requires_grad | |
) | |
# Setup distributed optimizer and lr scheduler | |
logger.info("Initializing optimizer and lr scheduler") | |
self.state.train_state = TrainState() | |
self.optimizer = optimizer.get_optimizer( | |
parallel_backend=self.args.parallel_backend, | |
name=self.args.optimizer, | |
model_parts=model_parts, | |
learning_rate=self.args.lr, | |
beta1=self.args.beta1, | |
beta2=self.args.beta2, | |
beta3=self.args.beta3, | |
epsilon=self.args.epsilon, | |
weight_decay=self.args.weight_decay, | |
fused=False, | |
) | |
self.lr_scheduler = optimizer.get_lr_scheduler( | |
parallel_backend=self.args.parallel_backend, | |
name=self.args.lr_scheduler, | |
optimizer=self.optimizer, | |
num_warmup_steps=self.args.lr_warmup_steps, | |
num_training_steps=self.args.train_steps, | |
# TODO(aryan): handle last_epoch | |
) | |
self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) | |
# 3. Initialize trackers, directories and repositories | |
self._init_logging() | |
self._init_trackers() | |
self._init_directories_and_repositories() | |
def _prepare_dataset(self) -> None: | |
logger.info("Initializing dataset and dataloader") | |
with open(self.args.dataset_config, "r") as file: | |
dataset_configs = json.load(file)["datasets"] | |
logger.info(f"Training configured to use {len(dataset_configs)} datasets") | |
datasets = [] | |
for config in dataset_configs: | |
data_root = config.pop("data_root", None) | |
dataset_file = config.pop("dataset_file", None) | |
dataset_type = config.pop("dataset_type") | |
caption_options = config.pop("caption_options", {}) | |
if data_root is not None and dataset_file is not None: | |
raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") | |
dataset_name_or_root = data_root or dataset_file | |
dataset = data.initialize_dataset( | |
dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options | |
) | |
if not dataset._precomputable_once and self.args.precomputation_once: | |
raise ValueError( | |
f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." | |
) | |
logger.info(f"Initialized dataset: {dataset_name_or_root}") | |
dataset = self.state.parallel_backend.prepare_dataset(dataset) | |
dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) | |
datasets.append(dataset) | |
dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) | |
dataloader = self.state.parallel_backend.prepare_dataloader( | |
dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory | |
) | |
self.dataset = dataset | |
self.dataloader = dataloader | |
def _prepare_checkpointing(self) -> None: | |
parallel_backend = self.state.parallel_backend | |
def save_model_hook(state_dict: Dict[str, Any]) -> None: | |
if parallel_backend.is_main_process: | |
if self.args.training_type == TrainingType.LORA: | |
state_dict = get_peft_model_state_dict(self.transformer, state_dict) | |
self.model_specification._save_lora_weights(self.args.output_dir, state_dict, self.scheduler) | |
elif self.args.training_type == TrainingType.FULL_FINETUNE: | |
self.model_specification._save_model( | |
self.args.output_dir, self.transformer, state_dict, self.scheduler | |
) | |
parallel_backend.wait_for_everyone() | |
enable_state_checkpointing = self.args.checkpointing_steps > 0 | |
self.checkpointer = utils.PTDCheckpointManager( | |
dataloader=self.dataloader, | |
model_parts=[self.transformer], | |
optimizers=self.optimizer, | |
schedulers=self.lr_scheduler, | |
states={"train_state": self.state.train_state}, | |
checkpointing_steps=self.args.checkpointing_steps, | |
checkpointing_limit=self.args.checkpointing_limit, | |
output_dir=self.args.output_dir, | |
enable=enable_state_checkpointing, | |
_callback_fn=save_model_hook, | |
) | |
resume_from_checkpoint = self.args.resume_from_checkpoint | |
if resume_from_checkpoint == "latest": | |
resume_from_checkpoint = -1 | |
if resume_from_checkpoint is not None: | |
self.checkpointer.load(resume_from_checkpoint) | |
def _train(self) -> None: | |
logger.info("Starting training") | |
parallel_backend = self.state.parallel_backend | |
train_state = self.state.train_state | |
device = parallel_backend.device | |
memory_statistics = utils.get_memory_statistics() | |
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") | |
global_batch_size = self.args.batch_size * parallel_backend._dp_degree | |
info = { | |
"trainable parameters": self.state.num_trainable_parameters, | |
"train steps": self.args.train_steps, | |
"per-replica batch size": self.args.batch_size, | |
"global batch size": global_batch_size, | |
"gradient accumulation steps": self.args.gradient_accumulation_steps, | |
} | |
logger.info(f"Training configuration: {json.dumps(info, indent=4)}") | |
progress_bar = tqdm( | |
range(0, self.args.train_steps), | |
initial=train_state.step, | |
desc="Training steps", | |
disable=not parallel_backend.is_local_main_process, | |
) | |
generator = torch.Generator(device=device) | |
if self.args.seed is not None: | |
generator = generator.manual_seed(self.args.seed) | |
self.state.generator = generator | |
patch_size = 1 | |
if ( | |
getattr(self.transformer.config, "patch_size", None) is not None | |
and getattr(self.transformer.config, "patch_size_t", None) is not None | |
): | |
patch_size = self.transformer.config.patch_size * self.transformer.config.patch_size_t | |
elif isinstance(getattr(self.transformer.config, "patch_size", None), int): | |
patch_size = self.transformer.config.patch_size | |
elif isinstance(getattr(self.transformer.config, "patch_size", None), (list, tuple)): | |
patch_size = math.prod(self.transformer.config.patch_size) | |
scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) | |
scheduler_sigmas = ( | |
scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None | |
) | |
scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) | |
scheduler_alphas = ( | |
scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None | |
) | |
timesteps_buffer = [] | |
self.transformer.train() | |
data_iterator = iter(self.dataloader) | |
preprocessor = data.initialize_preprocessor( | |
rank=parallel_backend.rank, | |
num_items=self.args.precomputation_items if self.args.enable_precomputation else 1, | |
processor_fn={ | |
"condition": self.model_specification.prepare_conditions, | |
"latent": functools.partial( | |
self.model_specification.prepare_latents, compute_posterior=not self.args.precomputation_once | |
), | |
}, | |
save_dir=self.args.precomputation_dir, | |
enable_precomputation=self.args.enable_precomputation, | |
) | |
precomputed_condition_iterator: Iterable[Dict[str, Any]] = None | |
precomputed_latent_iterator: Iterable[Dict[str, Any]] = None | |
sampler = data.ResolutionSampler( | |
batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys | |
) | |
requires_gradient_step = True | |
accumulated_loss = 0.0 | |
while ( | |
train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples | |
): | |
# 1. Load & preprocess data if required | |
if preprocessor.requires_data: | |
# TODO(aryan): We should do the following here: | |
# - Force checkpoint the trainable models, optimizers, schedulers and train state | |
# - Do the precomputation | |
# - Load the checkpointed models, optimizers, schedulers and train state back, and continue training | |
# This way we can be more memory efficient again, since the latest rewrite of precomputation removed | |
# this logic. | |
precomputed_condition_iterator, precomputed_latent_iterator = self._prepare_data( | |
preprocessor, data_iterator | |
) | |
# 2. Prepare batch | |
try: | |
condition_item = next(precomputed_condition_iterator) | |
latent_item = next(precomputed_latent_iterator) | |
sampler.consume(condition_item, latent_item) | |
except StopIteration: | |
if requires_gradient_step: | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
requires_gradient_step = False | |
logger.info("Data exhausted. Exiting training loop.") | |
break | |
if sampler.is_ready: | |
condition_batch, latent_batch = sampler.get_batch() | |
condition_model_conditions = self.model_specification.collate_conditions(condition_batch) | |
latent_model_conditions = self.model_specification.collate_latents(latent_batch) | |
else: | |
continue | |
train_state.step += 1 | |
train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree | |
lmc_latents = latent_model_conditions["latents"] | |
train_state.observed_num_tokens += math.prod(lmc_latents.shape[:-1]) // patch_size | |
logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") | |
utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype) | |
utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype) | |
latent_model_conditions = utils.make_contiguous(latent_model_conditions) | |
condition_model_conditions = utils.make_contiguous(condition_model_conditions) | |
# 3. Forward pass | |
sigmas = utils.prepare_sigmas( | |
scheduler=self.scheduler, | |
sigmas=scheduler_sigmas, | |
batch_size=self.args.batch_size, | |
num_train_timesteps=self.scheduler.config.num_train_timesteps, | |
flow_weighting_scheme=self.args.flow_weighting_scheme, | |
flow_logit_mean=self.args.flow_logit_mean, | |
flow_logit_std=self.args.flow_logit_std, | |
flow_mode_scale=self.args.flow_mode_scale, | |
device=device, | |
generator=self.state.generator, | |
) | |
sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) | |
pred, target, sigmas = self.model_specification.forward( | |
transformer=self.transformer, | |
scheduler=self.scheduler, | |
condition_model_conditions=condition_model_conditions, | |
latent_model_conditions=latent_model_conditions, | |
sigmas=sigmas, | |
compute_posterior=not self.args.precomputation_once, | |
) | |
timesteps = (sigmas * 1000.0).long() | |
weights = utils.prepare_loss_weights( | |
scheduler=self.scheduler, | |
alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, | |
sigmas=sigmas, | |
flow_weighting_scheme=self.args.flow_weighting_scheme, | |
) | |
weights = utils.expand_tensor_dims(weights, pred.ndim) | |
# 4. Compute loss & backward pass | |
loss = weights.float() * (pred.float() - target.float()).pow(2) | |
# Average loss across all but batch dimension | |
loss = loss.mean(list(range(1, loss.ndim))) | |
# Average loss across batch dimension | |
loss = loss.mean() | |
if self.args.gradient_accumulation_steps > 1: | |
loss = loss / self.args.gradient_accumulation_steps | |
loss.backward() | |
accumulated_loss += loss.detach().item() | |
requires_gradient_step = True | |
# 5. Clip gradients | |
model_parts = [self.transformer] | |
grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( | |
[p for m in model_parts for p in m.parameters()], | |
self.args.max_grad_norm, | |
foreach=True, | |
pp_mesh=parallel_backend.get_mesh("pp") if parallel_backend.pipeline_parallel_enabled else None, | |
) | |
# 6. Step optimizer & log metrics | |
logs = {} | |
if train_state.step % self.args.gradient_accumulation_steps == 0: | |
# TODO(aryan): revisit no_sync() for FSDP | |
self.optimizer.step() | |
self.lr_scheduler.step() | |
self.optimizer.zero_grad() | |
if grad_norm is not None: | |
logs["grad_norm"] = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() | |
if ( | |
parallel_backend.data_replication_enabled | |
or parallel_backend.data_sharding_enabled | |
or parallel_backend.context_parallel_enabled | |
): | |
dp_cp_mesh = parallel_backend.get_mesh("dp_cp") | |
global_avg_loss, global_max_loss = ( | |
parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), | |
parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), | |
) | |
else: | |
global_avg_loss = global_max_loss = accumulated_loss | |
logs["global_avg_loss"] = global_avg_loss | |
logs["global_max_loss"] = global_max_loss | |
train_state.global_avg_losses.append(global_avg_loss) | |
train_state.global_max_losses.append(global_max_loss) | |
accumulated_loss = 0.0 | |
requires_gradient_step = False | |
progress_bar.update(1) | |
progress_bar.set_postfix(logs) | |
timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) | |
if train_state.step % self.args.logging_steps == 0: | |
# TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts | |
# TODO(aryan): causes NCCL hang for some reason. look into later | |
# logs.update(self.lr_scheduler.get_last_lr()) | |
# timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) | |
# logs["timesteps"] = wandb.plot.scatter( | |
# timesteps_table, "step", "timesteps", title="Timesteps distribution" | |
# ) | |
timesteps_buffer = [] | |
logs["observed_data_samples"] = train_state.observed_data_samples | |
logs["observed_num_tokens"] = train_state.observed_num_tokens | |
parallel_backend.log(logs, step=train_state.step) | |
train_state.log_steps.append(train_state.step) | |
# 7. Save checkpoint if required | |
self.checkpointer.save( | |
step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process | |
) | |
# 8. Perform validation if required | |
if train_state.step % self.args.validation_steps == 0: | |
self._validate(step=train_state.step, final_validation=False) | |
# 9. Final checkpoint, validation & cleanup | |
self.checkpointer.save( | |
train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process | |
) | |
parallel_backend.wait_for_everyone() | |
self._validate(step=train_state.step, final_validation=True) | |
self._delete_components() | |
memory_statistics = utils.get_memory_statistics() | |
logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") | |
# 10. Upload artifacts to hub | |
if parallel_backend.is_main_process and self.args.push_to_hub: | |
upload_folder( | |
repo_id=self.state.repo_id, | |
folder_path=self.args.output_dir, | |
ignore_patterns=[f"{self.checkpointer._prefix}_*"], | |
) | |
parallel_backend.destroy() | |
def _validate(self, step: int, final_validation: bool = False) -> None: | |
if self.args.validation_dataset_file is None: | |
return | |
logger.info("Starting validation") | |
# 1. Load validation dataset | |
parallel_backend = self.state.parallel_backend | |
dp_mesh = parallel_backend.get_mesh("dp_replicate") | |
if dp_mesh is not None: | |
local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() | |
else: | |
local_rank, dp_world_size = 0, 1 | |
dataset = data.ValidationDataset(self.args.validation_dataset_file) | |
dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size) | |
validation_dataloader = data.DPDataLoader( | |
local_rank, | |
dataset, | |
batch_size=1, | |
num_workers=self.args.dataloader_num_workers, | |
collate_fn=lambda items: items, | |
) | |
data_iterator = iter(validation_dataloader) | |
main_process_prompts_to_filenames = {} # Used to save model card | |
all_processes_artifacts = [] # Used to gather artifacts from all processes | |
memory_statistics = utils.get_memory_statistics() | |
logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") | |
seed = self.args.seed if self.args.seed is not None else 0 | |
generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) | |
pipeline = self._init_pipeline(final_validation=final_validation) | |
# 2. Run validation | |
# TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we | |
# will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset | |
# size is not divisible by dp_shards. | |
self.transformer.eval() | |
while True: | |
validation_data = next(data_iterator, None) | |
if validation_data is None: | |
break | |
logger.debug( | |
f"Validating {validation_data=} on rank={parallel_backend.rank}.", local_main_process_only=False | |
) | |
validation_data = validation_data[0] | |
validation_artifacts = self.model_specification.validation( | |
pipeline=pipeline, generator=generator, **validation_data | |
) | |
PROMPT = validation_data["prompt"] | |
IMAGE = validation_data.get("image", None) | |
VIDEO = validation_data.get("video", None) | |
EXPORT_FPS = validation_data.get("export_fps", 30) | |
# 2.1. If there are any initial images or videos, they will be logged to keep track of them as | |
# conditioning for generation. | |
prompt_filename = utils.string_to_filename(PROMPT)[:25] | |
artifacts = { | |
"input_image": data.ImageArtifact(value=IMAGE), | |
"input_video": data.VideoArtifact(value=VIDEO), | |
} | |
# 2.2. Track the artifacts generated from validation | |
for i, validation_artifact in enumerate(validation_artifacts): | |
if validation_artifact.value is None: | |
continue | |
artifacts.update({f"artifact_{i}": validation_artifact}) | |
# 2.3. Save the artifacts to the output directory and create appropriate logging objects | |
# TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. | |
for index, (key, artifact) in enumerate(list(artifacts.items())): | |
assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) | |
time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension | |
filename = "validation-" if not final_validation else "final-" | |
filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}" | |
output_filename = os.path.join(self.args.output_dir, filename) | |
if parallel_backend.is_main_process and artifact.file_extension == "mp4": | |
main_process_prompts_to_filenames[PROMPT] = filename | |
if artifact.type == "image" and artifact.value is not None: | |
logger.debug( | |
f"Saving image from rank={parallel_backend.rank} to {output_filename}", | |
local_main_process_only=False, | |
) | |
artifact.value.save(output_filename) | |
all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT)) | |
elif artifact.type == "video" and artifact.value is not None: | |
logger.debug( | |
f"Saving video from rank={parallel_backend.rank} to {output_filename}", | |
local_main_process_only=False, | |
) | |
export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) | |
all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT)) | |
# 3. Cleanup & log artifacts | |
parallel_backend.wait_for_everyone() | |
# Remove all hooks that might have been added during pipeline initialization to the models | |
pipeline.remove_all_hooks() | |
del pipeline | |
utils.free_memory() | |
memory_statistics = utils.get_memory_statistics() | |
logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") | |
torch.cuda.reset_peak_memory_stats(parallel_backend.device) | |
# Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. | |
# TODO(aryan): probably should only all gather from dp mesh process group | |
all_artifacts = [None] * parallel_backend.world_size | |
torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) | |
all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] | |
if parallel_backend.is_main_process: | |
tracker_key = "final" if final_validation else "validation" | |
artifact_log_dict = {} | |
image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] | |
if len(image_artifacts) > 0: | |
artifact_log_dict["images"] = image_artifacts | |
video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] | |
if len(video_artifacts) > 0: | |
artifact_log_dict["videos"] = video_artifacts | |
parallel_backend.log({tracker_key: artifact_log_dict}, step=step) | |
if self.args.push_to_hub and final_validation: | |
video_filenames = list(main_process_prompts_to_filenames.values()) | |
prompts = list(main_process_prompts_to_filenames.keys()) | |
utils.save_model_card( | |
args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts | |
) | |
parallel_backend.wait_for_everyone() | |
if not final_validation: | |
self.transformer.train() | |
def _evaluate(self) -> None: | |
raise NotImplementedError("Evaluation has not been implemented yet.") | |
def _init_distributed(self) -> None: | |
# TODO: Accelerate disables native_amp for MPS. Probably need to do the same with implementation. | |
world_size = int(os.environ["WORLD_SIZE"]) | |
# TODO(aryan): handle other backends | |
backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) | |
self.state.parallel_backend = backend_cls( | |
world_size=world_size, | |
pp_degree=self.args.pp_degree, | |
dp_degree=self.args.dp_degree, | |
dp_shards=self.args.dp_shards, | |
cp_degree=self.args.cp_degree, | |
tp_degree=self.args.tp_degree, | |
backend="nccl", | |
timeout=self.args.init_timeout, | |
logging_dir=self.args.logging_dir, | |
output_dir=self.args.output_dir, | |
gradient_accumulation_steps=self.args.gradient_accumulation_steps, | |
) | |
if self.args.seed is not None: | |
world_mesh = self.state.parallel_backend.get_mesh() | |
utils.enable_determinism(self.args.seed, world_mesh) | |
def _init_logging(self) -> None: | |
transformers_log_level = transformers.utils.logging.set_verbosity_error | |
diffusers_log_level = diffusers.utils.logging.set_verbosity_error | |
if self.args.verbose == 0: | |
if self.state.parallel_backend.is_local_main_process: | |
transformers_log_level = transformers.utils.logging.set_verbosity_warning | |
diffusers_log_level = diffusers.utils.logging.set_verbosity_warning | |
elif self.args.verbose == 1: | |
if self.state.parallel_backend.is_local_main_process: | |
transformers_log_level = transformers.utils.logging.set_verbosity_info | |
diffusers_log_level = diffusers.utils.logging.set_verbosity_info | |
elif self.args.verbose == 2: | |
if self.state.parallel_backend.is_local_main_process: | |
transformers_log_level = transformers.utils.logging.set_verbosity_debug | |
diffusers_log_level = diffusers.utils.logging.set_verbosity_debug | |
else: | |
transformers_log_level = transformers.utils.logging.set_verbosity_debug | |
diffusers_log_level = diffusers.utils.logging.set_verbosity_debug | |
transformers_log_level() | |
diffusers_log_level() | |
logging._set_parallel_backend(self.state.parallel_backend) | |
logger.info("Initialized FineTrainers") | |
def _init_trackers(self) -> None: | |
# TODO(aryan): handle multiple trackers | |
trackers = ["wandb"] | |
experiment_name = self.args.tracker_name or "finetrainers-experiment" | |
self.state.parallel_backend.initialize_trackers( | |
trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir | |
) | |
def _init_directories_and_repositories(self) -> None: | |
if self.state.parallel_backend.is_main_process: | |
self.args.output_dir = Path(self.args.output_dir) | |
self.args.output_dir.mkdir(parents=True, exist_ok=True) | |
self.state.output_dir = Path(self.args.output_dir) | |
if self.args.push_to_hub: | |
repo_id = self.args.hub_model_id or Path(self.args.output_dir).name | |
self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id | |
def _init_config_options(self) -> None: | |
# Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | |
if self.args.allow_tf32 and torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
def _move_components_to_device( | |
self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None | |
) -> None: | |
if device is None: | |
device = self.state.parallel_backend.device | |
if components is None: | |
components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae] | |
components = utils.get_non_null_items(components) | |
components = list(filter(lambda x: hasattr(x, "to"), components)) | |
for component in components: | |
component.to(device) | |
def _set_components(self, components: Dict[str, Any]) -> None: | |
for component_name in self._all_component_names: | |
existing_component = getattr(self, component_name, None) | |
new_component = components.get(component_name, existing_component) | |
setattr(self, component_name, new_component) | |
def _delete_components(self, component_names: Optional[List[str]] = None) -> None: | |
if component_names is None: | |
component_names = self._all_component_names | |
for component_name in component_names: | |
setattr(self, component_name, None) | |
utils.free_memory() | |
utils.synchronize_device() | |
def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: | |
parallel_backend = self.state.parallel_backend | |
module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"] | |
if not final_validation: | |
module_names.remove("transformer") | |
pipeline = self.model_specification.load_pipeline( | |
tokenizer=self.tokenizer, | |
tokenizer_2=self.tokenizer_2, | |
tokenizer_3=self.tokenizer_3, | |
text_encoder=self.text_encoder, | |
text_encoder_2=self.text_encoder_2, | |
text_encoder_3=self.text_encoder_3, | |
# TODO(aryan): handle unwrapping for compiled modules | |
# transformer=utils.unwrap_model(accelerator, self.transformer), | |
transformer=self.transformer, | |
vae=self.vae, | |
enable_slicing=self.args.enable_slicing, | |
enable_tiling=self.args.enable_tiling, | |
enable_model_cpu_offload=self.args.enable_model_cpu_offload, | |
training=True, | |
) | |
else: | |
self._delete_components() | |
# Load the transformer weights from the final checkpoint if performing full-finetune | |
transformer = None | |
if self.args.training_type == TrainingType.FULL_FINETUNE: | |
transformer = self.model_specification.load_diffusion_models()["transformer"] | |
pipeline = self.model_specification.load_pipeline( | |
transformer=transformer, | |
enable_slicing=self.args.enable_slicing, | |
enable_tiling=self.args.enable_tiling, | |
enable_model_cpu_offload=self.args.enable_model_cpu_offload, | |
training=False, | |
device=parallel_backend.device, | |
) | |
# Load the LoRA weights if performing LoRA finetuning | |
if self.args.training_type == TrainingType.LORA: | |
pipeline.load_lora_weights(self.args.output_dir) | |
components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} | |
self._set_components(components) | |
self._move_components_to_device(list(components.values())) | |
return pipeline | |
def _prepare_data( | |
self, | |
preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor], | |
data_iterator, | |
): | |
if not self.args.enable_precomputation: | |
if not self._are_condition_models_loaded: | |
logger.info( | |
"Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs." | |
) | |
condition_components = self.model_specification.load_condition_models() | |
latent_components = self.model_specification.load_latent_models() | |
all_components = {**condition_components, **latent_components} | |
self._set_components(all_components) | |
self._move_components_to_device(list(all_components.values())) | |
utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) | |
else: | |
condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))} | |
latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))} | |
condition_iterator = preprocessor.consume( | |
"condition", | |
components=condition_components, | |
data_iterator=data_iterator, | |
generator=self.state.generator, | |
cache_samples=True, | |
) | |
latent_iterator = preprocessor.consume( | |
"latent", | |
components=latent_components, | |
data_iterator=data_iterator, | |
generator=self.state.generator, | |
use_cached_samples=True, | |
drop_samples=True, | |
) | |
self._are_condition_models_loaded = True | |
else: | |
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") | |
# TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer | |
# had become None after this but should have been loaded back from the checkpoint. | |
# parallel_backend = self.state.parallel_backend | |
# train_state = self.state.train_state | |
# self.checkpointer.save( | |
# train_state.step, | |
# force=True, | |
# _device=parallel_backend.device, | |
# _is_main_process=parallel_backend.is_main_process, | |
# ) | |
# self._delete_components(component_names=["transformer", "unet"]) | |
if self.args.precomputation_once: | |
consume_fn = preprocessor.consume_once | |
else: | |
consume_fn = preprocessor.consume | |
# Prepare condition iterators | |
condition_components = self.model_specification.load_condition_models() | |
component_names = list(condition_components.keys()) | |
component_modules = list(condition_components.values()) | |
self._set_components(condition_components) | |
self._move_components_to_device(component_modules) | |
condition_iterator = consume_fn( | |
"condition", | |
components=condition_components, | |
data_iterator=data_iterator, | |
generator=self.state.generator, | |
cache_samples=True, | |
) | |
self._delete_components(component_names) | |
del condition_components, component_names, component_modules | |
# Prepare latent iterators | |
latent_components = self.model_specification.load_latent_models() | |
utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) | |
component_names = list(latent_components.keys()) | |
component_modules = list(latent_components.values()) | |
self._set_components(latent_components) | |
self._move_components_to_device(component_modules) | |
latent_iterator = consume_fn( | |
"latent", | |
components=latent_components, | |
data_iterator=data_iterator, | |
generator=self.state.generator, | |
use_cached_samples=True, | |
drop_samples=True, | |
) | |
self._delete_components(component_names) | |
del latent_components, component_names, component_modules | |
# self.checkpointer.load() | |
# self.transformer = self.checkpointer.states["model"].model[0] | |
return condition_iterator, latent_iterator | |
def _get_training_info(self) -> Dict[str, Any]: | |
info = self.args.to_dict() | |
# Removing flow matching arguments when not using flow-matching objective | |
diffusion_args = info.get("diffusion_arguments", {}) | |
scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" | |
if scheduler_name != "FlowMatchEulerDiscreteScheduler": | |
filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} | |
else: | |
filtered_diffusion_args = diffusion_args | |
info.update({"diffusion_arguments": filtered_diffusion_args}) | |
return info | |