import functools import json import math import os import time from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union import datasets.distributed import diffusers import torch import torch.backends import transformers import wandb from diffusers import DiffusionPipeline from diffusers.hooks import apply_layerwise_casting from diffusers.training_utils import cast_training_params from diffusers.utils import export_to_video from huggingface_hub import create_repo, upload_folder from peft import LoraConfig, get_peft_model_state_dict from tqdm import tqdm from ... import data, logging, optimizer, parallel, patches, utils from ...config import TrainingType from ...state import State, TrainState if TYPE_CHECKING: from ...args import BaseArgs from ...models import ModelSpecification logger = logging.get_logger() class SFTTrainer: # fmt: off _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"] _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"] _latent_component_names = ["vae"] _diffusion_component_names = ["transformer", "unet", "scheduler"] # fmt: on def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None: self.args = args self.state = State() self.state.train_state = TrainState() # Tokenizers self.tokenizer = None self.tokenizer_2 = None self.tokenizer_3 = None # Text encoders self.text_encoder = None self.text_encoder_2 = None self.text_encoder_3 = None # Denoisers self.transformer = None self.unet = None # Autoencoders self.vae = None # Scheduler self.scheduler = None # Optimizer & LR scheduler self.optimizer = None self.lr_scheduler = None # Checkpoint manager self.checkpointer = None self._init_distributed() self._init_config_options() # Perform any patches that might be necessary for training to work as expected patches.perform_patches_for_training(self.args, self.state.parallel_backend) self.model_specification = model_specification self._are_condition_models_loaded = False def run(self) -> None: try: self._prepare_models() self._prepare_trainable_parameters() self._prepare_for_training() self._prepare_dataset() self._prepare_checkpointing() self._train() # trainer._evaluate() except Exception as e: logger.error(f"Error during training: {e}") self.state.parallel_backend.destroy() raise e def _prepare_models(self) -> None: logger.info("Initializing models") diffusion_components = self.model_specification.load_diffusion_models() self._set_components(diffusion_components) if self.state.parallel_backend.pipeline_parallel_enabled: raise NotImplementedError( "Pipeline parallelism is not supported yet. This will be supported in the future." ) def _prepare_trainable_parameters(self) -> None: logger.info("Initializing trainable parameters") parallel_backend = self.state.parallel_backend if self.args.training_type == TrainingType.FULL_FINETUNE: logger.info("Finetuning transformer with no additional parameters") utils.set_requires_grad([self.transformer], True) else: logger.info("Finetuning transformer with PEFT parameters") utils.set_requires_grad([self.transformer], False) # Layerwise upcasting must be applied before adding the LoRA adapter. # If we don't perform this before moving to device, we might OOM on the GPU. So, best to do it on # CPU for now, before support is added in Diffusers for loading and enabling layerwise upcasting directly. if self.args.training_type == "lora" and "transformer" in self.args.layerwise_upcasting_modules: apply_layerwise_casting( self.transformer, storage_dtype=self.args.layerwise_upcasting_storage_dtype, compute_dtype=self.args.transformer_dtype, skip_modules_pattern=self.args.layerwise_upcasting_skip_modules_pattern, non_blocking=True, ) transformer_lora_config = None if self.args.training_type == TrainingType.LORA: transformer_lora_config = LoraConfig( r=self.args.rank, lora_alpha=self.args.lora_alpha, init_lora_weights=True, target_modules=self.args.target_modules, ) self.transformer.add_adapter(transformer_lora_config) # # TODO(aryan): it might be nice to add some assertions here to make sure that lora parameters are still in fp32 # # even if layerwise upcasting. Would be nice to have a test as well # self.register_saving_loading_hooks(transformer_lora_config) # Make sure the trainable params are in float32 if data sharding is not enabled. For FSDP, we need all # parameters to be of the same dtype. if parallel_backend.data_sharding_enabled: self.transformer.to(dtype=self.args.transformer_dtype) else: if self.args.training_type == TrainingType.LORA: cast_training_params([self.transformer], dtype=torch.float32) def _prepare_for_training(self) -> None: # 1. Apply parallelism parallel_backend = self.state.parallel_backend world_mesh = parallel_backend.get_mesh() model_specification = self.model_specification if parallel_backend.context_parallel_enabled: raise NotImplementedError( "Context parallelism is not supported yet. This will be supported in the future." ) if parallel_backend.tensor_parallel_enabled: # TODO(aryan): handle fp8 from TorchAO here model_specification.apply_tensor_parallel( backend=parallel.ParallelBackendEnum.PTD, device_mesh=parallel_backend.get_mesh()["tp"], transformer=self.transformer, ) # Enable gradient checkpointing if self.args.gradient_checkpointing: # TODO(aryan): support other checkpointing types utils.apply_activation_checkpointing(self.transformer, checkpointing_type="full") # Enable DDP, FSDP or HSDP if parallel_backend.data_sharding_enabled: # TODO(aryan): remove this when supported if self.args.parallel_backend == "accelerate": raise NotImplementedError("Data sharding is not supported with Accelerate yet.") if parallel_backend.data_replication_enabled: logger.info("Applying HSDP to the model") else: logger.info("Applying FSDP to the model") # Apply FSDP or HSDP if parallel_backend.data_replication_enabled or parallel_backend.context_parallel_enabled: dp_mesh_names = ("dp_replicate", "dp_shard_cp") else: dp_mesh_names = ("dp_shard_cp",) parallel.apply_fsdp2_ptd( model=self.transformer, dp_mesh=world_mesh[dp_mesh_names], param_dtype=self.args.transformer_dtype, reduce_dtype=torch.float32, output_dtype=None, pp_enabled=parallel_backend.pipeline_parallel_enabled, cpu_offload=False, # TODO(aryan): needs to be tested and allowed for enabling later ) elif parallel_backend.data_replication_enabled: logger.info("Applying DDP to the model") if world_mesh.ndim > 1: raise ValueError("DDP not supported for > 1D parallelism") parallel_backend.apply_ddp(self.transformer, world_mesh) self._move_components_to_device() # 2. Prepare optimizer and lr scheduler # For training LoRAs, we can be a little more optimal. Currently, the OptimizerWrapper only accepts torch::nn::Module. # This causes us to loop over all the parameters (even ones that don't require gradients, as in LoRA) at each optimizer # step. This is OK (see https://github.com/pytorch/pytorch/blob/2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4/torch/optim/sgd.py#L85-L99) # but can be optimized a bit by maybe creating a simple wrapper module encompassing the actual parameters that require # gradients. TODO(aryan): look into it in the future. model_parts = [self.transformer] self.state.num_trainable_parameters = sum( p.numel() for m in model_parts for p in m.parameters() if p.requires_grad ) # Setup distributed optimizer and lr scheduler logger.info("Initializing optimizer and lr scheduler") self.state.train_state = TrainState() self.optimizer = optimizer.get_optimizer( parallel_backend=self.args.parallel_backend, name=self.args.optimizer, model_parts=model_parts, learning_rate=self.args.lr, beta1=self.args.beta1, beta2=self.args.beta2, beta3=self.args.beta3, epsilon=self.args.epsilon, weight_decay=self.args.weight_decay, fused=False, ) self.lr_scheduler = optimizer.get_lr_scheduler( parallel_backend=self.args.parallel_backend, name=self.args.lr_scheduler, optimizer=self.optimizer, num_warmup_steps=self.args.lr_warmup_steps, num_training_steps=self.args.train_steps, # TODO(aryan): handle last_epoch ) self.optimizer, self.lr_scheduler = parallel_backend.prepare_optimizer(self.optimizer, self.lr_scheduler) # 3. Initialize trackers, directories and repositories self._init_logging() self._init_trackers() self._init_directories_and_repositories() def _prepare_dataset(self) -> None: logger.info("Initializing dataset and dataloader") with open(self.args.dataset_config, "r") as file: dataset_configs = json.load(file)["datasets"] logger.info(f"Training configured to use {len(dataset_configs)} datasets") datasets = [] for config in dataset_configs: data_root = config.pop("data_root", None) dataset_file = config.pop("dataset_file", None) dataset_type = config.pop("dataset_type") caption_options = config.pop("caption_options", {}) if data_root is not None and dataset_file is not None: raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.") dataset_name_or_root = data_root or dataset_file dataset = data.initialize_dataset( dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options ) if not dataset._precomputable_once and self.args.precomputation_once: raise ValueError( f"Dataset {dataset_name_or_root} does not support precomputing all embeddings at once." ) logger.info(f"Initialized dataset: {dataset_name_or_root}") dataset = self.state.parallel_backend.prepare_dataset(dataset) dataset = data.wrap_iterable_dataset_for_preprocessing(dataset, dataset_type, config) datasets.append(dataset) dataset = data.combine_datasets(datasets, buffer_size=self.args.dataset_shuffle_buffer_size, shuffle=True) dataloader = self.state.parallel_backend.prepare_dataloader( dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.pin_memory ) self.dataset = dataset self.dataloader = dataloader def _prepare_checkpointing(self) -> None: parallel_backend = self.state.parallel_backend def save_model_hook(state_dict: Dict[str, Any]) -> None: if parallel_backend.is_main_process: if self.args.training_type == TrainingType.LORA: state_dict = get_peft_model_state_dict(self.transformer, state_dict) self.model_specification._save_lora_weights(self.args.output_dir, state_dict, self.scheduler) elif self.args.training_type == TrainingType.FULL_FINETUNE: self.model_specification._save_model( self.args.output_dir, self.transformer, state_dict, self.scheduler ) parallel_backend.wait_for_everyone() enable_state_checkpointing = self.args.checkpointing_steps > 0 self.checkpointer = utils.PTDCheckpointManager( dataloader=self.dataloader, model_parts=[self.transformer], optimizers=self.optimizer, schedulers=self.lr_scheduler, states={"train_state": self.state.train_state}, checkpointing_steps=self.args.checkpointing_steps, checkpointing_limit=self.args.checkpointing_limit, output_dir=self.args.output_dir, enable=enable_state_checkpointing, _callback_fn=save_model_hook, ) resume_from_checkpoint = self.args.resume_from_checkpoint if resume_from_checkpoint == "latest": resume_from_checkpoint = -1 if resume_from_checkpoint is not None: self.checkpointer.load(resume_from_checkpoint) def _train(self) -> None: logger.info("Starting training") parallel_backend = self.state.parallel_backend train_state = self.state.train_state device = parallel_backend.device memory_statistics = utils.get_memory_statistics() logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") global_batch_size = self.args.batch_size * parallel_backend._dp_degree info = { "trainable parameters": self.state.num_trainable_parameters, "train steps": self.args.train_steps, "per-replica batch size": self.args.batch_size, "global batch size": global_batch_size, "gradient accumulation steps": self.args.gradient_accumulation_steps, } logger.info(f"Training configuration: {json.dumps(info, indent=4)}") progress_bar = tqdm( range(0, self.args.train_steps), initial=train_state.step, desc="Training steps", disable=not parallel_backend.is_local_main_process, ) generator = torch.Generator(device=device) if self.args.seed is not None: generator = generator.manual_seed(self.args.seed) self.state.generator = generator patch_size = 1 if ( getattr(self.transformer.config, "patch_size", None) is not None and getattr(self.transformer.config, "patch_size_t", None) is not None ): patch_size = self.transformer.config.patch_size * self.transformer.config.patch_size_t elif isinstance(getattr(self.transformer.config, "patch_size", None), int): patch_size = self.transformer.config.patch_size elif isinstance(getattr(self.transformer.config, "patch_size", None), (list, tuple)): patch_size = math.prod(self.transformer.config.patch_size) scheduler_sigmas = utils.get_scheduler_sigmas(self.scheduler) scheduler_sigmas = ( scheduler_sigmas.to(device=device, dtype=torch.float32) if scheduler_sigmas is not None else None ) scheduler_alphas = utils.get_scheduler_alphas(self.scheduler) scheduler_alphas = ( scheduler_alphas.to(device=device, dtype=torch.float32) if scheduler_alphas is not None else None ) timesteps_buffer = [] self.transformer.train() data_iterator = iter(self.dataloader) preprocessor = data.initialize_preprocessor( rank=parallel_backend.rank, num_items=self.args.precomputation_items if self.args.enable_precomputation else 1, processor_fn={ "condition": self.model_specification.prepare_conditions, "latent": functools.partial( self.model_specification.prepare_latents, compute_posterior=not self.args.precomputation_once ), }, save_dir=self.args.precomputation_dir, enable_precomputation=self.args.enable_precomputation, ) precomputed_condition_iterator: Iterable[Dict[str, Any]] = None precomputed_latent_iterator: Iterable[Dict[str, Any]] = None sampler = data.ResolutionSampler( batch_size=self.args.batch_size, dim_keys=self.model_specification._resolution_dim_keys ) requires_gradient_step = True accumulated_loss = 0.0 while ( train_state.step < self.args.train_steps and train_state.observed_data_samples < self.args.max_data_samples ): # 1. Load & preprocess data if required if preprocessor.requires_data: # TODO(aryan): We should do the following here: # - Force checkpoint the trainable models, optimizers, schedulers and train state # - Do the precomputation # - Load the checkpointed models, optimizers, schedulers and train state back, and continue training # This way we can be more memory efficient again, since the latest rewrite of precomputation removed # this logic. precomputed_condition_iterator, precomputed_latent_iterator = self._prepare_data( preprocessor, data_iterator ) # 2. Prepare batch try: condition_item = next(precomputed_condition_iterator) latent_item = next(precomputed_latent_iterator) sampler.consume(condition_item, latent_item) except StopIteration: if requires_gradient_step: self.optimizer.step() self.lr_scheduler.step() requires_gradient_step = False logger.info("Data exhausted. Exiting training loop.") break if sampler.is_ready: condition_batch, latent_batch = sampler.get_batch() condition_model_conditions = self.model_specification.collate_conditions(condition_batch) latent_model_conditions = self.model_specification.collate_latents(latent_batch) else: continue train_state.step += 1 train_state.observed_data_samples += self.args.batch_size * parallel_backend._dp_degree lmc_latents = latent_model_conditions["latents"] train_state.observed_num_tokens += math.prod(lmc_latents.shape[:-1]) // patch_size logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype) utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype) latent_model_conditions = utils.make_contiguous(latent_model_conditions) condition_model_conditions = utils.make_contiguous(condition_model_conditions) # 3. Forward pass sigmas = utils.prepare_sigmas( scheduler=self.scheduler, sigmas=scheduler_sigmas, batch_size=self.args.batch_size, num_train_timesteps=self.scheduler.config.num_train_timesteps, flow_weighting_scheme=self.args.flow_weighting_scheme, flow_logit_mean=self.args.flow_logit_mean, flow_logit_std=self.args.flow_logit_std, flow_mode_scale=self.args.flow_mode_scale, device=device, generator=self.state.generator, ) sigmas = utils.expand_tensor_dims(sigmas, latent_model_conditions["latents"].ndim) pred, target, sigmas = self.model_specification.forward( transformer=self.transformer, scheduler=self.scheduler, condition_model_conditions=condition_model_conditions, latent_model_conditions=latent_model_conditions, sigmas=sigmas, compute_posterior=not self.args.precomputation_once, ) timesteps = (sigmas * 1000.0).long() weights = utils.prepare_loss_weights( scheduler=self.scheduler, alphas=scheduler_alphas[timesteps] if scheduler_alphas is not None else None, sigmas=sigmas, flow_weighting_scheme=self.args.flow_weighting_scheme, ) weights = utils.expand_tensor_dims(weights, pred.ndim) # 4. Compute loss & backward pass loss = weights.float() * (pred.float() - target.float()).pow(2) # Average loss across all but batch dimension loss = loss.mean(list(range(1, loss.ndim))) # Average loss across batch dimension loss = loss.mean() if self.args.gradient_accumulation_steps > 1: loss = loss / self.args.gradient_accumulation_steps loss.backward() accumulated_loss += loss.detach().item() requires_gradient_step = True # 5. Clip gradients model_parts = [self.transformer] grad_norm = utils.torch._clip_grad_norm_while_handling_failing_dtensor_cases( [p for m in model_parts for p in m.parameters()], self.args.max_grad_norm, foreach=True, pp_mesh=parallel_backend.get_mesh("pp") if parallel_backend.pipeline_parallel_enabled else None, ) # 6. Step optimizer & log metrics logs = {} if train_state.step % self.args.gradient_accumulation_steps == 0: # TODO(aryan): revisit no_sync() for FSDP self.optimizer.step() self.lr_scheduler.step() self.optimizer.zero_grad() if grad_norm is not None: logs["grad_norm"] = grad_norm if isinstance(grad_norm, float) else grad_norm.detach().item() if ( parallel_backend.data_replication_enabled or parallel_backend.data_sharding_enabled or parallel_backend.context_parallel_enabled ): dp_cp_mesh = parallel_backend.get_mesh("dp_cp") global_avg_loss, global_max_loss = ( parallel.dist_mean(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), parallel.dist_max(torch.tensor([accumulated_loss], device=device), dp_cp_mesh), ) else: global_avg_loss = global_max_loss = accumulated_loss logs["global_avg_loss"] = global_avg_loss logs["global_max_loss"] = global_max_loss train_state.global_avg_losses.append(global_avg_loss) train_state.global_max_losses.append(global_max_loss) accumulated_loss = 0.0 requires_gradient_step = False progress_bar.update(1) progress_bar.set_postfix(logs) timesteps_buffer.extend([(train_state.step, t) for t in timesteps.detach().cpu().numpy().tolist()]) if train_state.step % self.args.logging_steps == 0: # TODO(aryan): handle non-SchedulerWrapper schedulers (probably not required eventually) since they might not be dicts # TODO(aryan): causes NCCL hang for some reason. look into later # logs.update(self.lr_scheduler.get_last_lr()) # timesteps_table = wandb.Table(data=timesteps_buffer, columns=["step", "timesteps"]) # logs["timesteps"] = wandb.plot.scatter( # timesteps_table, "step", "timesteps", title="Timesteps distribution" # ) timesteps_buffer = [] logs["observed_data_samples"] = train_state.observed_data_samples logs["observed_num_tokens"] = train_state.observed_num_tokens parallel_backend.log(logs, step=train_state.step) train_state.log_steps.append(train_state.step) # 7. Save checkpoint if required self.checkpointer.save( step=train_state.step, _device=device, _is_main_process=parallel_backend.is_main_process ) # 8. Perform validation if required if train_state.step % self.args.validation_steps == 0: self._validate(step=train_state.step, final_validation=False) # 9. Final checkpoint, validation & cleanup self.checkpointer.save( train_state.step, force=True, _device=device, _is_main_process=parallel_backend.is_main_process ) parallel_backend.wait_for_everyone() self._validate(step=train_state.step, final_validation=True) self._delete_components() memory_statistics = utils.get_memory_statistics() logger.info(f"Memory after training end: {json.dumps(memory_statistics, indent=4)}") # 10. Upload artifacts to hub if parallel_backend.is_main_process and self.args.push_to_hub: upload_folder( repo_id=self.state.repo_id, folder_path=self.args.output_dir, ignore_patterns=[f"{self.checkpointer._prefix}_*"], ) parallel_backend.destroy() def _validate(self, step: int, final_validation: bool = False) -> None: if self.args.validation_dataset_file is None: return logger.info("Starting validation") # 1. Load validation dataset parallel_backend = self.state.parallel_backend dp_mesh = parallel_backend.get_mesh("dp_replicate") if dp_mesh is not None: local_rank, dp_world_size = dp_mesh.get_local_rank(), dp_mesh.size() else: local_rank, dp_world_size = 0, 1 dataset = data.ValidationDataset(self.args.validation_dataset_file) dataset._data = datasets.distributed.split_dataset_by_node(dataset._data, local_rank, dp_world_size) validation_dataloader = data.DPDataLoader( local_rank, dataset, batch_size=1, num_workers=self.args.dataloader_num_workers, collate_fn=lambda items: items, ) data_iterator = iter(validation_dataloader) main_process_prompts_to_filenames = {} # Used to save model card all_processes_artifacts = [] # Used to gather artifacts from all processes memory_statistics = utils.get_memory_statistics() logger.info(f"Memory before validation start: {json.dumps(memory_statistics, indent=4)}") seed = self.args.seed if self.args.seed is not None else 0 generator = torch.Generator(device=parallel_backend.device).manual_seed(seed) pipeline = self._init_pipeline(final_validation=final_validation) # 2. Run validation # TODO(aryan): when running validation with FSDP, if the number of data points is not divisible by dp_shards, we # will hang indefinitely. Either pad the dataset or raise an error early on during initialization if the dataset # size is not divisible by dp_shards. self.transformer.eval() while True: validation_data = next(data_iterator, None) if validation_data is None: break logger.debug( f"Validating {validation_data=} on rank={parallel_backend.rank}.", local_main_process_only=False ) validation_data = validation_data[0] validation_artifacts = self.model_specification.validation( pipeline=pipeline, generator=generator, **validation_data ) PROMPT = validation_data["prompt"] IMAGE = validation_data.get("image", None) VIDEO = validation_data.get("video", None) EXPORT_FPS = validation_data.get("export_fps", 30) # 2.1. If there are any initial images or videos, they will be logged to keep track of them as # conditioning for generation. prompt_filename = utils.string_to_filename(PROMPT)[:25] artifacts = { "input_image": data.ImageArtifact(value=IMAGE), "input_video": data.VideoArtifact(value=VIDEO), } # 2.2. Track the artifacts generated from validation for i, validation_artifact in enumerate(validation_artifacts): if validation_artifact.value is None: continue artifacts.update({f"artifact_{i}": validation_artifact}) # 2.3. Save the artifacts to the output directory and create appropriate logging objects # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited. for index, (key, artifact) in enumerate(list(artifacts.items())): assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact)) time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension filename = "validation-" if not final_validation else "final-" filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}" output_filename = os.path.join(self.args.output_dir, filename) if parallel_backend.is_main_process and artifact.file_extension == "mp4": main_process_prompts_to_filenames[PROMPT] = filename if artifact.type == "image" and artifact.value is not None: logger.debug( f"Saving image from rank={parallel_backend.rank} to {output_filename}", local_main_process_only=False, ) artifact.value.save(output_filename) all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT)) elif artifact.type == "video" and artifact.value is not None: logger.debug( f"Saving video from rank={parallel_backend.rank} to {output_filename}", local_main_process_only=False, ) export_to_video(artifact.value, output_filename, fps=EXPORT_FPS) all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT)) # 3. Cleanup & log artifacts parallel_backend.wait_for_everyone() # Remove all hooks that might have been added during pipeline initialization to the models pipeline.remove_all_hooks() del pipeline utils.free_memory() memory_statistics = utils.get_memory_statistics() logger.info(f"Memory after validation end: {json.dumps(memory_statistics, indent=4)}") torch.cuda.reset_peak_memory_stats(parallel_backend.device) # Gather artifacts from all processes. We also need to flatten them since each process returns a list of artifacts. # TODO(aryan): probably should only all gather from dp mesh process group all_artifacts = [None] * parallel_backend.world_size torch.distributed.all_gather_object(all_artifacts, all_processes_artifacts) all_artifacts = [artifact for artifacts in all_artifacts for artifact in artifacts] if parallel_backend.is_main_process: tracker_key = "final" if final_validation else "validation" artifact_log_dict = {} image_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Image)] if len(image_artifacts) > 0: artifact_log_dict["images"] = image_artifacts video_artifacts = [artifact for artifact in all_artifacts if isinstance(artifact, wandb.Video)] if len(video_artifacts) > 0: artifact_log_dict["videos"] = video_artifacts parallel_backend.log({tracker_key: artifact_log_dict}, step=step) if self.args.push_to_hub and final_validation: video_filenames = list(main_process_prompts_to_filenames.values()) prompts = list(main_process_prompts_to_filenames.keys()) utils.save_model_card( args=self.args, repo_id=self.state.repo_id, videos=video_filenames, validation_prompts=prompts ) parallel_backend.wait_for_everyone() if not final_validation: self.transformer.train() def _evaluate(self) -> None: raise NotImplementedError("Evaluation has not been implemented yet.") def _init_distributed(self) -> None: # TODO: Accelerate disables native_amp for MPS. Probably need to do the same with implementation. world_size = int(os.environ["WORLD_SIZE"]) # TODO(aryan): handle other backends backend_cls: parallel.ParallelBackendType = parallel.get_parallel_backend_cls(self.args.parallel_backend) self.state.parallel_backend = backend_cls( world_size=world_size, pp_degree=self.args.pp_degree, dp_degree=self.args.dp_degree, dp_shards=self.args.dp_shards, cp_degree=self.args.cp_degree, tp_degree=self.args.tp_degree, backend="nccl", timeout=self.args.init_timeout, logging_dir=self.args.logging_dir, output_dir=self.args.output_dir, gradient_accumulation_steps=self.args.gradient_accumulation_steps, ) if self.args.seed is not None: world_mesh = self.state.parallel_backend.get_mesh() utils.enable_determinism(self.args.seed, world_mesh) def _init_logging(self) -> None: transformers_log_level = transformers.utils.logging.set_verbosity_error diffusers_log_level = diffusers.utils.logging.set_verbosity_error if self.args.verbose == 0: if self.state.parallel_backend.is_local_main_process: transformers_log_level = transformers.utils.logging.set_verbosity_warning diffusers_log_level = diffusers.utils.logging.set_verbosity_warning elif self.args.verbose == 1: if self.state.parallel_backend.is_local_main_process: transformers_log_level = transformers.utils.logging.set_verbosity_info diffusers_log_level = diffusers.utils.logging.set_verbosity_info elif self.args.verbose == 2: if self.state.parallel_backend.is_local_main_process: transformers_log_level = transformers.utils.logging.set_verbosity_debug diffusers_log_level = diffusers.utils.logging.set_verbosity_debug else: transformers_log_level = transformers.utils.logging.set_verbosity_debug diffusers_log_level = diffusers.utils.logging.set_verbosity_debug transformers_log_level() diffusers_log_level() logging._set_parallel_backend(self.state.parallel_backend) logger.info("Initialized FineTrainers") def _init_trackers(self) -> None: # TODO(aryan): handle multiple trackers trackers = ["wandb"] experiment_name = self.args.tracker_name or "finetrainers-experiment" self.state.parallel_backend.initialize_trackers( trackers, experiment_name=experiment_name, config=self._get_training_info(), log_dir=self.args.logging_dir ) def _init_directories_and_repositories(self) -> None: if self.state.parallel_backend.is_main_process: self.args.output_dir = Path(self.args.output_dir) self.args.output_dir.mkdir(parents=True, exist_ok=True) self.state.output_dir = Path(self.args.output_dir) if self.args.push_to_hub: repo_id = self.args.hub_model_id or Path(self.args.output_dir).name self.state.repo_id = create_repo(token=self.args.hub_token, repo_id=repo_id, exist_ok=True).repo_id def _init_config_options(self) -> None: # Enable TF32 for faster training on Ampere GPUs: https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if self.args.allow_tf32 and torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True def _move_components_to_device( self, components: Optional[List[torch.nn.Module]] = None, device: Optional[Union[str, torch.device]] = None ) -> None: if device is None: device = self.state.parallel_backend.device if components is None: components = [self.text_encoder, self.text_encoder_2, self.text_encoder_3, self.transformer, self.vae] components = utils.get_non_null_items(components) components = list(filter(lambda x: hasattr(x, "to"), components)) for component in components: component.to(device) def _set_components(self, components: Dict[str, Any]) -> None: for component_name in self._all_component_names: existing_component = getattr(self, component_name, None) new_component = components.get(component_name, existing_component) setattr(self, component_name, new_component) def _delete_components(self, component_names: Optional[List[str]] = None) -> None: if component_names is None: component_names = self._all_component_names for component_name in component_names: setattr(self, component_name, None) utils.free_memory() utils.synchronize_device() def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: parallel_backend = self.state.parallel_backend module_names = ["text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "vae"] if not final_validation: module_names.remove("transformer") pipeline = self.model_specification.load_pipeline( tokenizer=self.tokenizer, tokenizer_2=self.tokenizer_2, tokenizer_3=self.tokenizer_3, text_encoder=self.text_encoder, text_encoder_2=self.text_encoder_2, text_encoder_3=self.text_encoder_3, # TODO(aryan): handle unwrapping for compiled modules # transformer=utils.unwrap_model(accelerator, self.transformer), transformer=self.transformer, vae=self.vae, enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, training=True, ) else: self._delete_components() # Load the transformer weights from the final checkpoint if performing full-finetune transformer = None if self.args.training_type == TrainingType.FULL_FINETUNE: transformer = self.model_specification.load_diffusion_models()["transformer"] pipeline = self.model_specification.load_pipeline( transformer=transformer, enable_slicing=self.args.enable_slicing, enable_tiling=self.args.enable_tiling, enable_model_cpu_offload=self.args.enable_model_cpu_offload, training=False, device=parallel_backend.device, ) # Load the LoRA weights if performing LoRA finetuning if self.args.training_type == TrainingType.LORA: pipeline.load_lora_weights(self.args.output_dir) components = {module_name: getattr(pipeline, module_name, None) for module_name in module_names} self._set_components(components) self._move_components_to_device(list(components.values())) return pipeline def _prepare_data( self, preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor], data_iterator, ): if not self.args.enable_precomputation: if not self._are_condition_models_loaded: logger.info( "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs." ) condition_components = self.model_specification.load_condition_models() latent_components = self.model_specification.load_latent_models() all_components = {**condition_components, **latent_components} self._set_components(all_components) self._move_components_to_device(list(all_components.values())) utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) else: condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))} latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))} condition_iterator = preprocessor.consume( "condition", components=condition_components, data_iterator=data_iterator, generator=self.state.generator, cache_samples=True, ) latent_iterator = preprocessor.consume( "latent", components=latent_components, data_iterator=data_iterator, generator=self.state.generator, use_cached_samples=True, drop_samples=True, ) self._are_condition_models_loaded = True else: logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.") # TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer # had become None after this but should have been loaded back from the checkpoint. # parallel_backend = self.state.parallel_backend # train_state = self.state.train_state # self.checkpointer.save( # train_state.step, # force=True, # _device=parallel_backend.device, # _is_main_process=parallel_backend.is_main_process, # ) # self._delete_components(component_names=["transformer", "unet"]) if self.args.precomputation_once: consume_fn = preprocessor.consume_once else: consume_fn = preprocessor.consume # Prepare condition iterators condition_components = self.model_specification.load_condition_models() component_names = list(condition_components.keys()) component_modules = list(condition_components.values()) self._set_components(condition_components) self._move_components_to_device(component_modules) condition_iterator = consume_fn( "condition", components=condition_components, data_iterator=data_iterator, generator=self.state.generator, cache_samples=True, ) self._delete_components(component_names) del condition_components, component_names, component_modules # Prepare latent iterators latent_components = self.model_specification.load_latent_models() utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling) component_names = list(latent_components.keys()) component_modules = list(latent_components.values()) self._set_components(latent_components) self._move_components_to_device(component_modules) latent_iterator = consume_fn( "latent", components=latent_components, data_iterator=data_iterator, generator=self.state.generator, use_cached_samples=True, drop_samples=True, ) self._delete_components(component_names) del latent_components, component_names, component_modules # self.checkpointer.load() # self.transformer = self.checkpointer.states["model"].model[0] return condition_iterator, latent_iterator def _get_training_info(self) -> Dict[str, Any]: info = self.args.to_dict() # Removing flow matching arguments when not using flow-matching objective diffusion_args = info.get("diffusion_arguments", {}) scheduler_name = self.scheduler.__class__.__name__ if self.scheduler is not None else "" if scheduler_name != "FlowMatchEulerDiscreteScheduler": filtered_diffusion_args = {k: v for k, v in diffusion_args.items() if "flow" not in k} else: filtered_diffusion_args = diffusion_args info.update({"diffusion_arguments": filtered_diffusion_args}) return info