import os, torch import os.path as osp import cv2 import shutil import numpy as np import copy import torch_fidelity import torch.nn as nn from tqdm.auto import tqdm from collections import OrderedDict from einops import rearrange from accelerate import Accelerator from .util import instantiate_from_config from torchvision.utils import make_grid, save_image from torch.utils.data import DataLoader, random_split, DistributedSampler, Sampler from paintmind.utils.lr_scheduler import build_scheduler from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs from torch.optim import AdamW from concurrent.futures import ThreadPoolExecutor from paintmind.stage2.gpt import GPT_models from paintmind.stage2.causaldit import CausalDiT_models from paintmind.stage2.generate import generate, generate_causal_dit from pathlib import Path import time def requires_grad(model, flag=True): for p in model.parameters(): p.requires_grad = flag def save_img(img, save_path): img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255) img = img.astype(np.uint8)[:, :, ::-1] cv2.imwrite(save_path, img) def save_img_batch(imgs, save_paths): """Process and save multiple images at once using a thread pool.""" # Convert to numpy and prepare all images in one go imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once # Use ProcessPoolExecutor which is generally better for CPU-bound tasks # ThreadPoolExecutor is better for I/O-bound tasks like file saving with ThreadPoolExecutor(max_workers=32) as pool: # Submit all tasks at once futures = [pool.submit(cv2.imwrite, path, img) for path, img in zip(save_paths, imgs)] # Wait for all tasks to complete for future in futures: future.result() # This will raise any exceptions that occurred def get_fid_stats(real_dir, rec_dir, fid_stats): stats = torch_fidelity.calculate_metrics( input1=real_dir, input2=rec_dir, fid_statistics_file=fid_stats, cuda=True, isc=True, fid=True, kid=False, prc=False, verbose=False, ) return stats class EMAModel: """Model Exponential Moving Average.""" def __init__(self, model, device, decay=0.999): self.device = device self.decay = decay self.ema_params = OrderedDict( (name, param.clone().detach().to(device)) for name, param in model.named_parameters() if param.requires_grad ) @torch.no_grad() def update(self, model): for name, param in model.named_parameters(): if param.requires_grad: if name in self.ema_params: self.ema_params[name].lerp_(param.data, 1 - self.decay) else: self.ema_params[name] = param.data.clone().detach() def state_dict(self): return self.ema_params def load_state_dict(self, params): self.ema_params = OrderedDict( (name, param.clone().detach().to(self.device)) for name, param in params.items() ) class CacheDataLoader: """DataLoader-like interface for cached data with epoch-based shuffling.""" def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None): self.slots = slots self.targets = targets self.batch_size = batch_size self.num_augs = num_augs self.seed = seed self.epoch = 0 # Original dataset size (before augmentations) self.num_samples = len(slots) // num_augs def set_epoch(self, epoch): """Set epoch for deterministic shuffling.""" self.epoch = epoch def __len__(self): """Return number of batches based on original dataset size.""" return self.num_samples // self.batch_size def __iter__(self): """Return random indices for current epoch.""" g = torch.Generator() g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch) # Randomly sample indices from the entire augmented dataset indices = torch.randint( 0, len(self.slots), (self.num_samples,), generator=g ).numpy() # Yield batches of indices for start in range(0, self.num_samples, self.batch_size): end = min(start + self.batch_size, self.num_samples) batch_indices = indices[start:end] yield ( torch.from_numpy(self.slots[batch_indices]), torch.from_numpy(self.targets[batch_indices]) ) class GPTTrainer(nn.Module): def __init__( self, ae_model, gpt_model, dataset, test_dataset=None, test_only=False, num_test_images=50000, num_epoch=400, eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear lr=None, blr=1e-4, cosine_lr=False, lr_min=0, warmup_epochs=100, warmup_steps=None, warmup_lr_init=0, decay_steps=None, batch_size=32, cache_bs=8, test_bs=100, num_workers=0, pin_memory=False, max_grad_norm=None, grad_accum_steps=1, precision="bf16", save_every=10000, sample_every=1000, fid_every=50000, result_folder=None, log_dir="./log", steps=0, cfg=1.75, ae_cfg=1.5, diff_cfg=2.0, temperature=1.0, cfg_schedule="constant", diff_cfg_schedule="inv_linear", train_num_slots=None, test_num_slots=None, eval_fid=False, fid_stats=None, enable_ema=False, compile=False, enable_cache_latents=True, cache_dir='/dev/shm/slot_cache', seed=42 ): super().__init__() kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) self.accelerator = Accelerator( kwargs_handlers=[kwargs], mixed_precision="bf16", gradient_accumulation_steps=grad_accum_steps, log_with="tensorboard", project_dir=log_dir, ) self.ae_model = instantiate_from_config(ae_model) if hasattr(ae_model.params, "ema_path") and ae_model.params.ema_path is not None: ae_model_path = ae_model.params.ema_path else: ae_model_path = ae_model.params.ckpt_path assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl") assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist" self._load_checkpoint(ae_model_path, self.ae_model) self.ae_model.to(self.device) for param in self.ae_model.parameters(): param.requires_grad = False self.ae_model.eval() self.model_name = gpt_model.target if 'GPT' in gpt_model.target: self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params) elif 'CausalDiT' in gpt_model.target: self.gpt_model = CausalDiT_models[gpt_model.target](**gpt_model.params) else: raise ValueError(f"Unknown model type: {gpt_model.target}") self.num_slots = ae_model.params.num_slots self.slot_dim = ae_model.params.slot_dim assert precision in ["bf16", "fp32"] precision = "fp32" if self.accelerator.is_main_process: print("Overlooking specified precision and using autocast bf16...") self.precision = precision self.test_only = test_only self.test_bs = test_bs self.num_test_images = num_test_images self.num_classes = gpt_model.params.num_classes self.batch_size = batch_size if not test_only: self.train_ds = instantiate_from_config(dataset) train_size = len(self.train_ds) if self.accelerator.is_main_process: print(f"train dataset size: {train_size}") sampler = DistributedSampler( self.train_ds, num_replicas=self.accelerator.num_processes, rank=self.accelerator.process_index, shuffle=True, ) self.train_dl = DataLoader( self.train_ds, batch_size=batch_size if not enable_cache_latents else cache_bs, sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=True, ) effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes if lr is None: lr = blr * effective_bs / 256 if self.accelerator.is_main_process: print(f"Effective batch size is {effective_bs}") self.g_optim = self._creat_optimizer(weight_decay=0.05, learning_rate=lr, betas=(0.9, 0.95)) self.g_sched = self._create_scheduler( cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps ) self.accelerator.register_for_checkpointing(self.g_sched) self.steps = steps self.loaded_steps = -1 # Prepare everything together if not test_only: self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare( self.gpt_model, self.g_optim, self.g_sched ) else: self.gpt_model = self.accelerator.prepare(self.gpt_model) # assume _ori_model does not exist in checkpoints if compile: _model = self.accelerator.unwrap_model(self.gpt_model) self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead") _model = torch.compile(_model, mode="reduce-overhead") self.enable_ema = enable_ema if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device) self.accelerator.register_for_checkpointing(self.ema_model) self._load_checkpoint(gpt_model.params.ckpt_path) if self.test_only: self.steps = self.loaded_steps self.num_epoch = num_epoch self.save_every = save_every self.samp_every = sample_every self.fid_every = fid_every self.max_grad_norm = max_grad_norm self.eval_classes = eval_classes self.cfg = cfg self.ae_cfg = ae_cfg self.diff_cfg = diff_cfg self.cfg_schedule = cfg_schedule self.diff_cfg_schedule = diff_cfg_schedule self.temperature = temperature self.train_num_slots = train_num_slots self.test_num_slots = test_num_slots if self.train_num_slots is not None: self.train_num_slots = min(self.train_num_slots, self.num_slots) else: self.train_num_slots = self.num_slots if self.test_num_slots is not None: self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots) else: self.num_slots_to_gen = self.train_num_slots self.eval_fid = eval_fid if eval_fid: assert fid_stats is not None self.fid_stats = fid_stats self.result_folder = result_folder self.model_saved_dir = os.path.join(result_folder, "models") os.makedirs(self.model_saved_dir, exist_ok=True) self.image_saved_dir = os.path.join(result_folder, "images") os.makedirs(self.image_saved_dir, exist_ok=True) self.cache_dir = Path(cache_dir) self.enable_cache_latents = enable_cache_latents self.seed = seed self.cache_loader = None @property def device(self): return self.accelerator.device def _creat_optimizer(self, weight_decay, learning_rate, betas): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) if self.accelerator.is_main_process: print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas) return optimizer def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps): if warmup_epochs is not None: warmup_steps = warmup_epochs * len(self.train_dl) else: assert warmup_steps is not None scheduler = build_scheduler( self.g_optim, num_epoch, len(self.train_dl), lr_min, warmup_steps, warmup_lr_init, decay_steps, cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix) ) return scheduler def _load_state_dict(self, state_dict, model): """Helper to load a state dict with proper prefix handling.""" if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] # Remove '_orig_mod' prefix if present state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} missing, unexpected = model.load_state_dict( state_dict, strict=False ) if self.accelerator.is_main_process: print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") def _load_safetensors(self, path, model): """Helper to load a safetensors checkpoint.""" from safetensors.torch import safe_open with safe_open(path, framework="pt", device="cpu") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} self._load_state_dict(state_dict, model) def _load_checkpoint(self, ckpt_path=None, model=None): if ckpt_path is None or not osp.exists(ckpt_path): return if model is None: model = self.accelerator.unwrap_model(self.gpt_model) if osp.isdir(ckpt_path): # ckpt_path is something like 'path/to/models/step10/' self.loaded_steps = int( ckpt_path.split("step")[-1].split("/")[0] ) if not self.test_only: self.accelerator.load_state(ckpt_path) else: if self.enable_ema: model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl") if osp.exists(model_path): state_dict = torch.load(model_path, map_location="cpu") self._load_state_dict(state_dict, model) if self.accelerator.is_main_process: print(f"Loaded ema model from {model_path}") else: model_path = osp.join(ckpt_path, "model.safetensors") if osp.exists(model_path): self._load_safetensors(model_path, model) else: # ckpt_path is something like 'path/to/models/step10.pt' if ckpt_path.endswith(".safetensors"): self._load_safetensors(ckpt_path, model) else: state_dict = torch.load(ckpt_path, map_location="cpu") self._load_state_dict(state_dict, model) if self.accelerator.is_main_process: print(f"Loaded checkpoint from {ckpt_path}") def _build_cache(self): """Build cache for slots and targets.""" rank = self.accelerator.process_index world_size = self.accelerator.num_processes # Clean up any existing cache files first slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" if slots_file.exists(): os.remove(slots_file) if targets_file.exists(): os.remove(targets_file) dataset_size = len(self.train_dl.dataset) shard_size = dataset_size // world_size # Detect number of augmentations from first batch with torch.no_grad(): sample_batch = next(iter(self.train_dl)) img, _ = sample_batch num_augs = img.shape[1] if len(img.shape) == 5 else 1 print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...") os.makedirs(self.cache_dir, exist_ok=True) slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" # Create memory-mapped files slots_mmap = np.memmap( slots_file, dtype='float32', mode='w+', shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) ) targets_mmap = np.memmap( targets_file, dtype='int64', mode='w+', shape=(shard_size * num_augs,) ) # Cache data with torch.no_grad(): for i, batch in enumerate(tqdm( self.train_dl, desc=f"Rank {rank}: Caching data", disable=not self.accelerator.is_local_main_process )): imgs, targets = batch if len(imgs.shape) == 5: # [B, num_augs, C, H, W] B, A, C, H, W = imgs.shape imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W] targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs] # Split imgs into n chunks num_splits = num_augs split_size = imgs.shape[0] // num_splits imgs_splits = torch.split(imgs, split_size) targets_splits = torch.split(targets, split_size) start_idx = i * self.train_dl.batch_size * num_augs for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)): img_split = img_split.to(self.device, non_blocking=True) slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :] split_start = start_idx + (split_idx * split_size) split_end = split_start + img_split.shape[0] # Write directly to mmap files slots_mmap[split_start:split_end] = slots_split.cpu().numpy() targets_mmap[split_start:split_end] = targets_split.numpy() # Close the mmap files del slots_mmap del targets_mmap # Reopen in read mode self.cached_latents = np.memmap( slots_file, dtype='float32', mode='r', shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) ) self.cached_targets = np.memmap( targets_file, dtype='int64', mode='r', shape=(shard_size * num_augs,) ) # Store the number of augmentations for the cache loader self.num_augs = num_augs def _setup_cache(self): """Setup cache if enabled.""" self._build_cache() self.accelerator.wait_for_everyone() # Initialize cache loader if cache exists if self.cached_latents is not None: self.cache_loader = CacheDataLoader( slots=self.cached_latents, targets=self.cached_targets, batch_size=self.batch_size, num_augs=self.num_augs, seed=self.seed + self.accelerator.process_index ) def __del__(self): """Cleanup cache files.""" if self.enable_cache_latents: rank = self.accelerator.process_index world_size = self.accelerator.num_processes # Clean up slots cache slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" if slots_file.exists(): os.remove(slots_file) # Clean up targets cache targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" if targets_file.exists(): os.remove(targets_file) def _train_step(self, slots, targets=None): """Execute single training step.""" with self.accelerator.accumulate(self.gpt_model): with self.accelerator.autocast(): loss = self.gpt_model(slots, targets) self.accelerator.backward(loss) if self.accelerator.sync_gradients and self.max_grad_norm is not None: self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm) self.g_optim.step() if self.g_sched is not None: self.g_sched.step_update(self.steps) self.g_optim.zero_grad() # Update EMA model if enabled if self.enable_ema: self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model)) return loss def _train_epoch_cached(self, epoch, logger): """Train one epoch using cached data.""" self.cache_loader.set_epoch(epoch) header = f'Epoch: [{epoch}/{self.num_epoch}]' for batch in logger.log_every(self.cache_loader, 20, header): slots, targets = (b.to(self.device, non_blocking=True) for b in batch) self.steps += 1 if self.steps == 1: print(f"Training batch size: {len(slots)}") print(f"Hello from index {self.accelerator.local_process_index}") loss = self._train_step(slots, targets) self._handle_periodic_ops(loss, logger) def _train_epoch_uncached(self, epoch, logger): """Train one epoch using raw data.""" header = f'Epoch: [{epoch}/{self.num_epoch}]' for batch in logger.log_every(self.train_dl, 20, header): img, targets = (b.to(self.device, non_blocking=True) for b in batch) self.steps += 1 if self.steps == 1: print(f"Training batch size: {img.size(0)}") print(f"Hello from index {self.accelerator.local_process_index}") slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :] loss = self._train_step(slots, targets) self._handle_periodic_ops(loss, logger) def _handle_periodic_ops(self, loss, logger): """Handle periodic operations and logging.""" logger.update(loss=loss.item()) logger.update(lr=self.g_optim.param_groups[0]["lr"]) if self.steps % self.save_every == 0: self.save() if (self.steps % self.samp_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0): empty_cache() self.evaluate() self.accelerator.wait_for_everyone() empty_cache() def _save_config(self, config): """Save configuration file.""" if config is not None and self.accelerator.is_main_process: import shutil from omegaconf import OmegaConf if isinstance(config, str) and osp.exists(config): shutil.copy(config, osp.join(self.result_folder, "config.yaml")) else: config_save_path = osp.join(self.result_folder, "config.yaml") OmegaConf.save(config, config_save_path) def _should_skip_epoch(self, epoch): """Check if epoch should be skipped due to loaded checkpoint.""" loader = self.train_dl if not self.enable_cache_latents else self.cache_loader if ((epoch + 1) * len(loader)) <= self.loaded_steps: if self.accelerator.is_main_process: print(f"Epoch {epoch} is skipped because it is loaded from ckpt") self.steps += len(loader) return True if self.steps < self.loaded_steps: for _ in loader: self.steps += 1 if self.steps >= self.loaded_steps: break return False def train(self, config=None): """Main training loop.""" # Initial setup n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad) if self.accelerator.is_main_process: print(f"number of learnable parameters: {n_parameters//1e6}M") self._save_config(config) self.accelerator.init_trackers("gpt") # Handle test-only mode if self.test_only: empty_cache() self.evaluate() self.accelerator.wait_for_everyone() empty_cache() return # Setup cache if enabled if self.enable_cache_latents: self._setup_cache() # Training loop for epoch in range(self.num_epoch): if self._should_skip_epoch(epoch): continue self.gpt_model.train() logger = MetricLogger(delimiter=" ") logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) # Choose training path based on cache availability if self.enable_cache_latents: self._train_epoch_cached(epoch, logger) else: self._train_epoch_uncached(epoch, logger) # Synchronize and log epoch stats # logger.synchronize_between_processes() # if self.accelerator.is_main_process: # print("Averaged stats:", logger) # Finish training self.accelerator.end_training() self.save() if self.accelerator.is_main_process: print("Train finished!") def save(self): self.accelerator.wait_for_everyone() self.accelerator.save_state( os.path.join(self.model_saved_dir, f"step{self.steps}") ) @torch.no_grad() def evaluate(self, use_ema=True): self.gpt_model.eval() unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model) # switch to ema params, only when eval_fid is True use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only if use_ema: if hasattr(self, "ema_model"): model_without_ddp = self.accelerator.unwrap_model(self.gpt_model) model_state_dict = copy.deepcopy(model_without_ddp.state_dict()) ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): if "nested_sampler" in name: continue ema_state_dict[name] = self.ema_model.state_dict()[name] if self.accelerator.is_main_process: print("Switch to ema") model_without_ddp.load_state_dict(ema_state_dict) else: print("EMA model not found, using original model") use_ema = False generate_fn = generate if 'GPT' in self.model_name else generate_causal_dit if not self.test_only: classes = torch.tensor(self.eval_classes, device=self.device) with self.accelerator.autocast(): slots = generate_fn(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature) if self.num_slots_to_gen < self.num_slots: null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) null_slots = null_slots[:, self.num_slots_to_gen:, :] slots = torch.cat([slots, null_slots], dim=1) imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now imgs = concat_all_gather(imgs) if self.accelerator.num_processes > 16: imgs = imgs[:16*len(self.eval_classes)] imgs = imgs.detach().cpu() grid = make_grid( imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1) ) if self.accelerator.is_main_process: save_image( grid, os.path.join( self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg" ), ) if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)): # Create output directory (only on main process) save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}") if self.accelerator.is_main_process: os.makedirs(save_folder, exist_ok=True) # Setup for distributed generation world_size = self.accelerator.num_processes local_rank = self.accelerator.process_index batch_size = self.test_bs # Create balanced class distribution num_classes = self.num_classes images_per_class = self.num_test_images // num_classes class_labels = np.repeat(np.arange(num_classes), images_per_class) # Shuffle the class labels to ensure random ordering np.random.shuffle(class_labels) total_images = len(class_labels) padding_size = world_size * batch_size - (total_images % (world_size * batch_size)) class_labels = np.pad(class_labels, (0, padding_size), 'constant') padded_total_images = len(class_labels) # Distribute workload across GPUs images_per_gpu = padded_total_images // world_size start_idx = local_rank * images_per_gpu end_idx = min(start_idx + images_per_gpu, padded_total_images) local_class_labels = class_labels[start_idx:end_idx] local_num_steps = len(local_class_labels) // batch_size if self.accelerator.is_main_process: print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs") used_time = 0 gen_img_cnt = 0 for i in range(local_num_steps): if self.accelerator.is_main_process and i % 10 == 0: print(f"Generation step {i}/{local_num_steps}") # Get and pad labels for current batch batch_start = i * batch_size batch_end = batch_start + batch_size labels = local_class_labels[batch_start:batch_end] # Convert to tensors and track real vs padding labels = torch.tensor(labels, device=self.device) # Generate images self.accelerator.wait_for_everyone() start_time = time.time() with torch.no_grad(): with self.accelerator.autocast(): slots = generate_fn(unwraped_gpt_model, labels, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature) if self.num_slots_to_gen < self.num_slots: null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) null_slots = null_slots[:, self.num_slots_to_gen:, :] slots = torch.cat([slots, null_slots], dim=1) imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg) samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt) # Update timing stats used_time += time.time() - start_time gen_img_cnt += samples_in_batch if self.accelerator.is_main_process and i % 10 == 0: print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image") gathered_imgs = concat_all_gather(imgs) gathered_imgs = gathered_imgs[:samples_in_batch] # Save images (only on main process) if self.accelerator.is_main_process: real_imgs = gathered_imgs.detach().cpu() save_paths = [ os.path.join(save_folder, f"{str(idx).zfill(5)}.png") for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt) ] save_img_batch(real_imgs, save_paths) # Calculate metrics (only on main process) self.accelerator.wait_for_everyone() if self.accelerator.is_main_process: generated_files = len(os.listdir(save_folder)) print(f"Generated {generated_files} images out of {total_images} expected") metrics_dict = get_fid_stats(save_folder, None, self.fid_stats) fid = metrics_dict["frechet_inception_distance"] inception_score = metrics_dict["inception_score_mean"] metric_prefix = "fid_ema" if use_ema else "fid" isc_prefix = "isc_ema" if use_ema else "isc" self.accelerator.log({ metric_prefix: fid, isc_prefix: inception_score, "gpt_cfg": self.cfg, "ae_cfg": self.ae_cfg, "diff_cfg": self.diff_cfg, "cfg_schedule": self.cfg_schedule, "diff_cfg_schedule": self.diff_cfg_schedule, "temperature": self.temperature, "num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots }, step=self.steps) # Print comprehensive CFG information cfg_info = ( f"{'EMA ' if use_ema else ''}CFG params: " f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, diff_cfg={self.diff_cfg}, " f"cfg_schedule={self.cfg_schedule}, diff_cfg_schedule={self.diff_cfg_schedule}, " f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, " f"temperature={self.temperature}" ) print(cfg_info) print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}") # Cleanup shutil.rmtree(save_folder) # back to no ema if use_ema: if self.accelerator.is_main_process: print("Switch back from ema") model_without_ddp.load_state_dict(model_state_dict) self.gpt_model.train()