import os, torch import os.path as osp import cv2 import shutil import numpy as np import copy import torch_fidelity import torch.nn as nn from tqdm.auto import tqdm from collections import OrderedDict from einops import rearrange from accelerate import Accelerator from .util import instantiate_from_config from torchvision.utils import make_grid, save_image from torch.utils.data import DataLoader, random_split, DistributedSampler from paintmind.utils.lr_scheduler import build_scheduler from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs from torch.optim import AdamW from concurrent.futures import ThreadPoolExecutor from torchmetrics.functional.image import ( peak_signal_noise_ratio as psnr, structural_similarity_index_measure as ssim ) def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag def save_img(img, save_path): img = np.clip(img.numpy().transpose([1, 2, 0]) * 255, 0, 255) img = img.astype(np.uint8)[:, :, ::-1] cv2.imwrite(save_path, img) def save_img_batch(imgs, save_paths): """Process and save multiple images at once using a thread pool.""" # Convert to numpy and prepare all images in one go imgs = np.clip(imgs.numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once # Use ProcessPoolExecutor which is generally better for CPU-bound tasks # ThreadPoolExecutor is better for I/O-bound tasks like file saving with ThreadPoolExecutor(max_workers=32) as pool: # Submit all tasks at once futures = [pool.submit(cv2.imwrite, path, img) for path, img in zip(save_paths, imgs)] # Wait for all tasks to complete for future in futures: future.result() # This will raise any exceptions that occurred def get_fid_stats(real_dir, rec_dir, fid_stats): stats = torch_fidelity.calculate_metrics( input1=rec_dir, input2=real_dir, fid_statistics_file=fid_stats, cuda=True, isc=True, fid=True, kid=False, prc=False, verbose=False, ) return stats class EMAModel: """Model Exponential Moving Average.""" def __init__(self, model, device, decay=0.999): self.device = device self.decay = decay self.ema_params = OrderedDict( (name, param.clone().detach().to(device)) for name, param in model.named_parameters() if param.requires_grad ) @torch.no_grad() def update(self, model): for name, param in model.named_parameters(): if param.requires_grad: if name in self.ema_params: self.ema_params[name].lerp_(param.data, 1 - self.decay) else: self.ema_params[name] = param.data.clone().detach() def state_dict(self): return self.ema_params def load_state_dict(self, params): self.ema_params = OrderedDict( (name, param.clone().detach().to(self.device)) for name, param in params.items() ) class DiffusionTrainer(nn.Module): def __init__( self, model, dataset, test_dataset=None, test_only=False, num_epoch=400, valid_size=32, lr=None, blr=1e-4, cosine_lr=True, lr_min=0, warmup_epochs=100, warmup_steps=None, warmup_lr_init=0, decay_steps=None, batch_size=32, eval_bs=32, test_bs=64, num_workers=0, pin_memory=False, max_grad_norm=None, grad_accum_steps=1, precision="bf16", save_every=10000, sample_every=1000, fid_every=50000, result_folder=None, log_dir="./log", steps=0, cfg=1.0, test_num_slots=None, eval_fid=False, fid_stats=None, enable_ema=False, use_multi_epochs_dataloader=False, compile=False, overfit=False, making_cache=False, cache_mode=False, latent_cache_file=None, ): super().__init__() kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) self.accelerator = Accelerator( kwargs_handlers=[kwargs], mixed_precision="bf16", gradient_accumulation_steps=grad_accum_steps, log_with="tensorboard", project_dir=log_dir, ) self.model = instantiate_from_config(model) self.num_slots = model.params.num_slots assert precision in ["bf16", "fp32"] precision = "fp32" if self.accelerator.is_main_process: print("Overlooking specified precision and using autocast bf16...") self.precision = precision if test_dataset is not None: test_dataset = instantiate_from_config(test_dataset) self.test_ds = test_dataset # Calculate padded dataset size to ensure even distribution total_size = len(test_dataset) world_size = self.accelerator.num_processes padding_size = world_size * test_bs - (total_size % (world_size * test_bs)) self.test_dataset_size = total_size # Create a padded dataset wrapper class PaddedDataset(torch.utils.data.Dataset): def __init__(self, dataset, padding_size): self.dataset = dataset self.padding_size = padding_size def __len__(self): return len(self.dataset) + self.padding_size def __getitem__(self, idx): if idx < len(self.dataset): return self.dataset[idx] return self.dataset[0] self.test_ds = PaddedDataset(self.test_ds, padding_size) self.test_dl = DataLoader( self.test_ds, batch_size=test_bs, num_workers=num_workers, pin_memory=pin_memory, shuffle=False, drop_last=True, ) if self.accelerator.is_main_process: print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}") else: self.test_dl = None self.test_only = test_only if not test_only: dataset = instantiate_from_config(dataset) train_size = len(dataset) - valid_size self.train_ds, self.valid_ds = random_split( dataset, [train_size, valid_size], generator=torch.Generator().manual_seed(42), ) if self.accelerator.is_main_process: print(f"train dataset size: {train_size}, valid dataset size: {valid_size}") sampler = DistributedSampler( self.train_ds, num_replicas=self.accelerator.num_processes, rank=self.accelerator.process_index, shuffle=True, ) self.train_dl = DataLoader( self.train_ds, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, ) self.valid_dl = DataLoader( self.valid_ds, batch_size=eval_bs, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, ) effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes if lr is None: lr = blr * effective_bs / 256 if self.accelerator.is_main_process: print(f"Effective batch size is {effective_bs}") params = filter(lambda p: p.requires_grad, self.model.parameters()) self.g_optim = AdamW(params, lr=lr, betas=(0.9, 0.95), weight_decay=0) self.g_sched = self._create_scheduler( cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps ) if self.g_sched is not None: self.accelerator.register_for_checkpointing(self.g_sched) self.steps = steps self.loaded_steps = -1 # Prepare everything together if not test_only: self.model, self.g_optim, self.g_sched = self.accelerator.prepare( self.model, self.g_optim, self.g_sched ) else: self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl) if compile: _model = self.accelerator.unwrap_model(self.model) _model.vae = torch.compile(_model.vae, mode="reduce-overhead") _model.dit = torch.compile(_model.dit, mode="reduce-overhead") # _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why _model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead") self.enable_ema = enable_ema if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device) self.accelerator.register_for_checkpointing(self.ema_model) self._load_checkpoint(model.params.ckpt_path) if self.test_only: self.steps = self.loaded_steps self.num_epoch = num_epoch self.save_every = save_every self.samp_every = sample_every self.fid_every = fid_every self.max_grad_norm = max_grad_norm self.cache_mode = cache_mode self.cfg = cfg self.test_num_slots = test_num_slots if self.test_num_slots is not None: self.test_num_slots = min(self.test_num_slots, self.num_slots) else: self.test_num_slots = self.num_slots eval_fid = eval_fid or model.params.eval_fid # legacy self.eval_fid = eval_fid if eval_fid: if fid_stats is None: fid_stats = model.params.fid_stats # legacy assert fid_stats is not None assert test_dataset is not None self.fid_stats = fid_stats self.use_vq = model.params.use_vq if hasattr(model.params, "use_vq") else False self.vq_beta = model.params.code_beta if hasattr(model.params, "code_beta") else 0.25 self.result_folder = result_folder self.model_saved_dir = os.path.join(result_folder, "models") os.makedirs(self.model_saved_dir, exist_ok=True) self.image_saved_dir = os.path.join(result_folder, "images") os.makedirs(self.image_saved_dir, exist_ok=True) @property def device(self): return self.accelerator.device def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps): if warmup_epochs is not None: warmup_steps = warmup_epochs * len(self.train_dl) else: assert warmup_steps is not None scheduler = build_scheduler( self.g_optim, num_epoch, len(self.train_dl), lr_min, warmup_steps, warmup_lr_init, decay_steps, cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix) ) return scheduler def _load_state_dict(self, state_dict): """Helper to load a state dict with proper prefix handling.""" if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Remove '_orig_mod' prefix if present state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} missing, unexpected = self.accelerator.unwrap_model(self.model).load_state_dict( state_dict, strict=False ) if self.accelerator.is_main_process: print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") def _load_safetensors(self, path): """Helper to load a safetensors checkpoint.""" from safetensors.torch import safe_open with safe_open(path, framework="pt", device="cpu") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} self._load_state_dict(state_dict) def _load_checkpoint(self, ckpt_path=None): if ckpt_path is None or not osp.exists(ckpt_path): return if osp.isdir(ckpt_path): # ckpt_path is something like 'path/to/models/step10/' self.loaded_steps = int( ckpt_path.split("step")[-1].split("/")[0] ) if not self.test_only: self.accelerator.load_state(ckpt_path) else: if self.enable_ema: model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl") if osp.exists(model_path): state_dict = torch.load(model_path, map_location="cpu") self._load_state_dict(state_dict) if self.accelerator.is_main_process: print(f"Loaded ema model from {model_path}") else: model_path = osp.join(ckpt_path, "model.safetensors") if osp.exists(model_path): self._load_safetensors(model_path) else: # ckpt_path is something like 'path/to/models/step10.pt' if ckpt_path.endswith(".safetensors"): self._load_safetensors(ckpt_path) else: state_dict = torch.load(ckpt_path) self._load_state_dict(state_dict) if self.accelerator.is_main_process: print(f"Loaded checkpoint from {ckpt_path}") def train(self, config=None): n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad) if self.accelerator.is_main_process: print(f"number of learnable parameters: {n_parameters//1e6}M") if config is not None: # save the config import shutil from omegaconf import OmegaConf if isinstance(config, str) and osp.exists(config): # If it's a path, copy the file to config.yaml shutil.copy(config, osp.join(self.result_folder, "config.yaml")) else: # If it's an OmegaConf object, dump it config_save_path = osp.join(self.result_folder, "config.yaml") OmegaConf.save(config, config_save_path) self.accelerator.init_trackers("vqgan") if self.test_only: empty_cache() self.evaluate() self.accelerator.wait_for_everyone() empty_cache() return for epoch in range(self.num_epoch): if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps: if self.accelerator.is_main_process: print(f"Epoch {epoch} is skipped because it is loaded from ckpt") self.steps += len(self.train_dl) continue if self.steps < self.loaded_steps: for _ in self.train_dl: self.steps += 1 if self.steps >= self.loaded_steps: break self.accelerator.unwrap_model(self.model).current_epoch = epoch self.model.train() # Set model to training mode logger = MetricLogger(delimiter=" ") logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch) print_freq = 20 for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)): # Move batch to device once if isinstance(batch, tuple) or isinstance(batch, list): batch = tuple(b.to(self.device, non_blocking=True) for b in batch) if self.cache_mode: img, latent, targets = batch[0], batch[1], batch[2] img = img.to(self.device, non_blocking=True) latent = latent.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) else: latent = None img, targets = batch[0], batch[1] img = img.to(self.device, non_blocking=True) targets = targets.to(self.device, non_blocking=True) else: img = batch latent = None self.steps += 1 with self.accelerator.accumulate(self.model): with self.accelerator.autocast(): if self.steps == 1: print(f"Training batch size: {img.size(0)}") print(f"Hello from index {self.accelerator.local_process_index}") losses = self.model(img, targets, latents=latent, epoch=epoch) # combine loss = sum([v for _, v in losses.items()]) diff_loss = losses["diff_loss"] if self.use_vq: vq_loss = losses["vq_loss"] self.accelerator.backward(loss) if self.accelerator.sync_gradients and self.max_grad_norm is not None: self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch) self.g_optim.step() if self.g_sched is not None: self.g_sched.step_update(self.steps) self.g_optim.zero_grad() # synchronize_processes() # update ema with state dict if self.enable_ema: self.ema_model.update(self.accelerator.unwrap_model(self.model)) logger.update(diff_loss=diff_loss.item()) if self.use_vq: logger.update(vq_loss=vq_loss.item() / self.vq_beta) if 'kl_loss' in losses: logger.update(kl_loss=losses["kl_loss"].item()) if 'repa_loss' in losses: logger.update(repa_loss=losses["repa_loss"].item()) logger.update(lr=self.g_optim.param_groups[0]["lr"]) if self.steps % self.save_every == 0: self.save() if (self.steps % self.samp_every == 0) or (self.steps % self.fid_every == 0): empty_cache() self.evaluate() self.accelerator.wait_for_everyone() empty_cache() # omitted all_gather here # write_dict = dict(epoch=epoch) # write_dict.update(diff_loss=diff_loss.item()) # if "kl_loss" in losses: # write_dict.update(kl_loss=losses["kl_loss"].item()) # if self.use_vq: # write_dict.update(vq_loss=vq_loss.item() / self.vq_beta) # write_dict.update(lr=self.g_optim.param_groups[0]["lr"]) # self.accelerator.log(write_dict, step=self.steps) logger.synchronize_between_processes() if self.accelerator.is_main_process: print("Averaged stats:", logger) self.accelerator.end_training() self.save() if self.accelerator.is_main_process: print("Train finished!") def save(self): self.accelerator.wait_for_everyone() self.accelerator.save_state( os.path.join(self.model_saved_dir, f"step{self.steps}") ) @torch.no_grad() def evaluate(self, use_ema=True): self.model.eval() # switch to ema params, only when eval_fid is True use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only # use_ema = False if use_ema: if hasattr(self, "ema_model"): model_without_ddp = self.accelerator.unwrap_model(self.model) model_state_dict = copy.deepcopy(model_without_ddp.state_dict()) ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): if "nested_sampler" in name: continue if name in self.ema_model.state_dict(): ema_state_dict[name] = self.ema_model.state_dict()[name] if self.accelerator.is_main_process: print("Switch to ema") msg = model_without_ddp.load_state_dict(ema_state_dict) if self.accelerator.is_main_process: print(msg) else: print("EMA model not found, using original model") use_ema = False if not self.test_only: with tqdm( self.valid_dl, dynamic_ncols=True, disable=not self.accelerator.is_main_process, ) as valid_dl: for batch_i, batch in enumerate(valid_dl): if isinstance(batch, tuple) or isinstance(batch, list): img, targets = batch[0], batch[1] else: img = batch with self.accelerator.autocast(): rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0) imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0) imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...") imgs_and_recs = imgs_and_recs.detach().cpu().float() grid = make_grid( imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1) ) if self.accelerator.is_main_process: save_image( grid, os.path.join( self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg" ), ) if self.cfg != 1.0: with self.accelerator.autocast(): rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg) imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0) imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...") imgs_and_recs = imgs_and_recs.detach().cpu().float() grid = make_grid( imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1) ) if self.accelerator.is_main_process: save_image( grid, os.path.join( self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg" ), ) if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)): # Create output directories if self.test_dataset_size > 10000: real_dir = "./dataset/imagenet/val256" else: real_dir = "./dataset/coco/val2017_256" rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}") os.makedirs(rec_dir, exist_ok=True) if self.cfg != 1.0: rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}") os.makedirs(rec_cfg_dir, exist_ok=True) def process_batch(cfg_value, save_dir, header): logger = MetricLogger(delimiter=" ") print_freq = 5 psnr_values = [] ssim_values = [] total_processed = 0 for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)): imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None) # Skip processing if we've already processed all real samples if total_processed >= self.test_dataset_size: break imgs = imgs.to(self.device, non_blocking=True) if targets is not None: targets = targets.to(self.device, non_blocking=True) with self.accelerator.autocast(): recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value) psnr_val = psnr(recs, imgs, data_range=1.0) ssim_val = ssim(recs, imgs, data_range=1.0) recs = concat_all_gather(recs).detach() psnr_val = concat_all_gather(psnr_val.view(1)) ssim_val = concat_all_gather(ssim_val.view(1)) # Remove padding after gathering from all GPUs samples_in_batch = min( recs.size(0), # Always use the gathered size self.test_dataset_size - total_processed ) recs = recs[:samples_in_batch] psnr_val = psnr_val[:samples_in_batch] ssim_val = ssim_val[:samples_in_batch] psnr_values.append(psnr_val) ssim_values.append(ssim_val) if self.accelerator.is_main_process: rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png") for j in range(recs.size(0))] save_img_batch(recs.cpu(), rec_paths) total_processed += samples_in_batch self.accelerator.wait_for_everyone() return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean() # Helper function to calculate and log metrics def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val): if self.accelerator.is_main_process: metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats) fid = metrics_dict["frechet_inception_distance"] inception_score = metrics_dict["inception_score_mean"] metric_prefix = "fid_ema" if use_ema else "fid" isc_prefix = "isc_ema" if use_ema else "isc" self.accelerator.log({ metric_prefix: fid, isc_prefix: inception_score, f"psnr_{'ema' if use_ema else 'test'}": psnr_val, f"ssim_{'ema' if use_ema else 'test'}": ssim_val, "cfg": cfg_value }, step=self.steps) print(f"{'EMA ' if use_ema else ''}{f'CFG: {cfg_value}'} " f"FID: {fid:.2f}, ISC: {inception_score:.2f}, " f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}") # Process without CFG if self.cfg == 1.0 or not self.test_only: psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG') calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val) # Process with CFG if needed if self.cfg != 1.0: psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG') calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val) # Cleanup if self.accelerator.is_main_process: shutil.rmtree(rec_dir) if self.cfg != 1.0: shutil.rmtree(rec_cfg_dir) # back to no ema if use_ema: if self.accelerator.is_main_process: print("Switch back from ema") model_without_ddp.load_state_dict(model_state_dict) self.model.train()