old_tok / paintmind /engine /gpt_trainer.py
tennant's picture
upload
af7c0ce
raw
history blame
37.3 kB
import os, torch
import os.path as osp
import cv2
import shutil
import numpy as np
import copy
import torch_fidelity
import torch.nn as nn
from tqdm.auto import tqdm
from collections import OrderedDict
from einops import rearrange
from accelerate import Accelerator
from .util import instantiate_from_config
from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader, random_split, DistributedSampler, Sampler
from paintmind.utils.lr_scheduler import build_scheduler
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
from torch.optim import AdamW
from concurrent.futures import ThreadPoolExecutor
from paintmind.stage2.gpt import GPT_models
from paintmind.stage2.causaldit import CausalDiT_models
from paintmind.stage2.generate import generate, generate_causal_dit
from pathlib import Path
import time
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def save_img(img, save_path):
img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
img = img.astype(np.uint8)[:, :, ::-1]
cv2.imwrite(save_path, img)
def save_img_batch(imgs, save_paths):
"""Process and save multiple images at once using a thread pool."""
# Convert to numpy and prepare all images in one go
imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
# Use ProcessPoolExecutor which is generally better for CPU-bound tasks
# ThreadPoolExecutor is better for I/O-bound tasks like file saving
with ThreadPoolExecutor(max_workers=32) as pool:
# Submit all tasks at once
futures = [pool.submit(cv2.imwrite, path, img)
for path, img in zip(save_paths, imgs)]
# Wait for all tasks to complete
for future in futures:
future.result() # This will raise any exceptions that occurred
def get_fid_stats(real_dir, rec_dir, fid_stats):
stats = torch_fidelity.calculate_metrics(
input1=real_dir,
input2=rec_dir,
fid_statistics_file=fid_stats,
cuda=True,
isc=True,
fid=True,
kid=False,
prc=False,
verbose=False,
)
return stats
class EMAModel:
"""Model Exponential Moving Average."""
def __init__(self, model, device, decay=0.999):
self.device = device
self.decay = decay
self.ema_params = OrderedDict(
(name, param.clone().detach().to(device))
for name, param in model.named_parameters()
if param.requires_grad
)
@torch.no_grad()
def update(self, model):
for name, param in model.named_parameters():
if param.requires_grad:
if name in self.ema_params:
self.ema_params[name].lerp_(param.data, 1 - self.decay)
else:
self.ema_params[name] = param.data.clone().detach()
def state_dict(self):
return self.ema_params
def load_state_dict(self, params):
self.ema_params = OrderedDict(
(name, param.clone().detach().to(self.device))
for name, param in params.items()
)
class CacheDataLoader:
"""DataLoader-like interface for cached data with epoch-based shuffling."""
def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
self.slots = slots
self.targets = targets
self.batch_size = batch_size
self.num_augs = num_augs
self.seed = seed
self.epoch = 0
# Original dataset size (before augmentations)
self.num_samples = len(slots) // num_augs
def set_epoch(self, epoch):
"""Set epoch for deterministic shuffling."""
self.epoch = epoch
def __len__(self):
"""Return number of batches based on original dataset size."""
return self.num_samples // self.batch_size
def __iter__(self):
"""Return random indices for current epoch."""
g = torch.Generator()
g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
# Randomly sample indices from the entire augmented dataset
indices = torch.randint(
0, len(self.slots),
(self.num_samples,),
generator=g
).numpy()
# Yield batches of indices
for start in range(0, self.num_samples, self.batch_size):
end = min(start + self.batch_size, self.num_samples)
batch_indices = indices[start:end]
yield (
torch.from_numpy(self.slots[batch_indices]),
torch.from_numpy(self.targets[batch_indices])
)
class GPTTrainer(nn.Module):
def __init__(
self,
ae_model,
gpt_model,
dataset,
test_dataset=None,
test_only=False,
num_test_images=50000,
num_epoch=400,
eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
lr=None,
blr=1e-4,
cosine_lr=False,
lr_min=0,
warmup_epochs=100,
warmup_steps=None,
warmup_lr_init=0,
decay_steps=None,
batch_size=32,
cache_bs=8,
test_bs=100,
num_workers=0,
pin_memory=False,
max_grad_norm=None,
grad_accum_steps=1,
precision="bf16",
save_every=10000,
sample_every=1000,
fid_every=50000,
result_folder=None,
log_dir="./log",
steps=0,
cfg=1.75,
ae_cfg=1.5,
diff_cfg=2.0,
temperature=1.0,
cfg_schedule="constant",
diff_cfg_schedule="inv_linear",
train_num_slots=None,
test_num_slots=None,
eval_fid=False,
fid_stats=None,
enable_ema=False,
compile=False,
enable_cache_latents=True,
cache_dir='/dev/shm/slot_cache',
seed=42
):
super().__init__()
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
self.accelerator = Accelerator(
kwargs_handlers=[kwargs],
mixed_precision="bf16",
gradient_accumulation_steps=grad_accum_steps,
log_with="tensorboard",
project_dir=log_dir,
)
self.ae_model = instantiate_from_config(ae_model)
if hasattr(ae_model.params, "ema_path") and ae_model.params.ema_path is not None:
ae_model_path = ae_model.params.ema_path
else:
ae_model_path = ae_model.params.ckpt_path
assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
self._load_checkpoint(ae_model_path, self.ae_model)
self.ae_model.to(self.device)
for param in self.ae_model.parameters():
param.requires_grad = False
self.ae_model.eval()
self.model_name = gpt_model.target
if 'GPT' in gpt_model.target:
self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
elif 'CausalDiT' in gpt_model.target:
self.gpt_model = CausalDiT_models[gpt_model.target](**gpt_model.params)
else:
raise ValueError(f"Unknown model type: {gpt_model.target}")
self.num_slots = ae_model.params.num_slots
self.slot_dim = ae_model.params.slot_dim
assert precision in ["bf16", "fp32"]
precision = "fp32"
if self.accelerator.is_main_process:
print("Overlooking specified precision and using autocast bf16...")
self.precision = precision
self.test_only = test_only
self.test_bs = test_bs
self.num_test_images = num_test_images
self.num_classes = gpt_model.params.num_classes
self.batch_size = batch_size
if not test_only:
self.train_ds = instantiate_from_config(dataset)
train_size = len(self.train_ds)
if self.accelerator.is_main_process:
print(f"train dataset size: {train_size}")
sampler = DistributedSampler(
self.train_ds,
num_replicas=self.accelerator.num_processes,
rank=self.accelerator.process_index,
shuffle=True,
)
self.train_dl = DataLoader(
self.train_ds,
batch_size=batch_size if not enable_cache_latents else cache_bs,
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=True,
)
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
if lr is None:
lr = blr * effective_bs / 256
if self.accelerator.is_main_process:
print(f"Effective batch size is {effective_bs}")
self.g_optim = self._creat_optimizer(weight_decay=0.05, learning_rate=lr, betas=(0.9, 0.95))
self.g_sched = self._create_scheduler(
cosine_lr, warmup_epochs, warmup_steps, num_epoch,
lr_min, warmup_lr_init, decay_steps
)
self.accelerator.register_for_checkpointing(self.g_sched)
self.steps = steps
self.loaded_steps = -1
# Prepare everything together
if not test_only:
self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(
self.gpt_model, self.g_optim, self.g_sched
)
else:
self.gpt_model = self.accelerator.prepare(self.gpt_model)
# assume _ori_model does not exist in checkpoints
if compile:
_model = self.accelerator.unwrap_model(self.gpt_model)
self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
_model = torch.compile(_model, mode="reduce-overhead")
self.enable_ema = enable_ema
if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
self.accelerator.register_for_checkpointing(self.ema_model)
self._load_checkpoint(gpt_model.params.ckpt_path)
if self.test_only:
self.steps = self.loaded_steps
self.num_epoch = num_epoch
self.save_every = save_every
self.samp_every = sample_every
self.fid_every = fid_every
self.max_grad_norm = max_grad_norm
self.eval_classes = eval_classes
self.cfg = cfg
self.ae_cfg = ae_cfg
self.diff_cfg = diff_cfg
self.cfg_schedule = cfg_schedule
self.diff_cfg_schedule = diff_cfg_schedule
self.temperature = temperature
self.train_num_slots = train_num_slots
self.test_num_slots = test_num_slots
if self.train_num_slots is not None:
self.train_num_slots = min(self.train_num_slots, self.num_slots)
else:
self.train_num_slots = self.num_slots
if self.test_num_slots is not None:
self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
else:
self.num_slots_to_gen = self.train_num_slots
self.eval_fid = eval_fid
if eval_fid:
assert fid_stats is not None
self.fid_stats = fid_stats
self.result_folder = result_folder
self.model_saved_dir = os.path.join(result_folder, "models")
os.makedirs(self.model_saved_dir, exist_ok=True)
self.image_saved_dir = os.path.join(result_folder, "images")
os.makedirs(self.image_saved_dir, exist_ok=True)
self.cache_dir = Path(cache_dir)
self.enable_cache_latents = enable_cache_latents
self.seed = seed
self.cache_loader = None
@property
def device(self):
return self.accelerator.device
def _creat_optimizer(self, weight_decay, learning_rate, betas):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
if self.accelerator.is_main_process:
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
return optimizer
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
if warmup_epochs is not None:
warmup_steps = warmup_epochs * len(self.train_dl)
else:
assert warmup_steps is not None
scheduler = build_scheduler(
self.g_optim,
num_epoch,
len(self.train_dl),
lr_min,
warmup_steps,
warmup_lr_init,
decay_steps,
cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
)
return scheduler
def _load_state_dict(self, state_dict, model):
"""Helper to load a state dict with proper prefix handling."""
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
# Remove '_orig_mod' prefix if present
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
missing, unexpected = model.load_state_dict(
state_dict, strict=False
)
if self.accelerator.is_main_process:
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
def _load_safetensors(self, path, model):
"""Helper to load a safetensors checkpoint."""
from safetensors.torch import safe_open
with safe_open(path, framework="pt", device="cpu") as f:
state_dict = {k: f.get_tensor(k) for k in f.keys()}
self._load_state_dict(state_dict, model)
def _load_checkpoint(self, ckpt_path=None, model=None):
if ckpt_path is None or not osp.exists(ckpt_path):
return
if model is None:
model = self.accelerator.unwrap_model(self.gpt_model)
if osp.isdir(ckpt_path):
# ckpt_path is something like 'path/to/models/step10/'
self.loaded_steps = int(
ckpt_path.split("step")[-1].split("/")[0]
)
if not self.test_only:
self.accelerator.load_state(ckpt_path)
else:
if self.enable_ema:
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
if osp.exists(model_path):
state_dict = torch.load(model_path, map_location="cpu")
self._load_state_dict(state_dict, model)
if self.accelerator.is_main_process:
print(f"Loaded ema model from {model_path}")
else:
model_path = osp.join(ckpt_path, "model.safetensors")
if osp.exists(model_path):
self._load_safetensors(model_path, model)
else:
# ckpt_path is something like 'path/to/models/step10.pt'
if ckpt_path.endswith(".safetensors"):
self._load_safetensors(ckpt_path, model)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
self._load_state_dict(state_dict, model)
if self.accelerator.is_main_process:
print(f"Loaded checkpoint from {ckpt_path}")
def _build_cache(self):
"""Build cache for slots and targets."""
rank = self.accelerator.process_index
world_size = self.accelerator.num_processes
# Clean up any existing cache files first
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
if slots_file.exists():
os.remove(slots_file)
if targets_file.exists():
os.remove(targets_file)
dataset_size = len(self.train_dl.dataset)
shard_size = dataset_size // world_size
# Detect number of augmentations from first batch
with torch.no_grad():
sample_batch = next(iter(self.train_dl))
img, _ = sample_batch
num_augs = img.shape[1] if len(img.shape) == 5 else 1
print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
os.makedirs(self.cache_dir, exist_ok=True)
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
# Create memory-mapped files
slots_mmap = np.memmap(
slots_file,
dtype='float32',
mode='w+',
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
)
targets_mmap = np.memmap(
targets_file,
dtype='int64',
mode='w+',
shape=(shard_size * num_augs,)
)
# Cache data
with torch.no_grad():
for i, batch in enumerate(tqdm(
self.train_dl,
desc=f"Rank {rank}: Caching data",
disable=not self.accelerator.is_local_main_process
)):
imgs, targets = batch
if len(imgs.shape) == 5: # [B, num_augs, C, H, W]
B, A, C, H, W = imgs.shape
imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W]
targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs]
# Split imgs into n chunks
num_splits = num_augs
split_size = imgs.shape[0] // num_splits
imgs_splits = torch.split(imgs, split_size)
targets_splits = torch.split(targets, split_size)
start_idx = i * self.train_dl.batch_size * num_augs
for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
img_split = img_split.to(self.device, non_blocking=True)
slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
split_start = start_idx + (split_idx * split_size)
split_end = split_start + img_split.shape[0]
# Write directly to mmap files
slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
targets_mmap[split_start:split_end] = targets_split.numpy()
# Close the mmap files
del slots_mmap
del targets_mmap
# Reopen in read mode
self.cached_latents = np.memmap(
slots_file,
dtype='float32',
mode='r',
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
)
self.cached_targets = np.memmap(
targets_file,
dtype='int64',
mode='r',
shape=(shard_size * num_augs,)
)
# Store the number of augmentations for the cache loader
self.num_augs = num_augs
def _setup_cache(self):
"""Setup cache if enabled."""
self._build_cache()
self.accelerator.wait_for_everyone()
# Initialize cache loader if cache exists
if self.cached_latents is not None:
self.cache_loader = CacheDataLoader(
slots=self.cached_latents,
targets=self.cached_targets,
batch_size=self.batch_size,
num_augs=self.num_augs,
seed=self.seed + self.accelerator.process_index
)
def __del__(self):
"""Cleanup cache files."""
if self.enable_cache_latents:
rank = self.accelerator.process_index
world_size = self.accelerator.num_processes
# Clean up slots cache
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
if slots_file.exists():
os.remove(slots_file)
# Clean up targets cache
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
if targets_file.exists():
os.remove(targets_file)
def _train_step(self, slots, targets=None):
"""Execute single training step."""
with self.accelerator.accumulate(self.gpt_model):
with self.accelerator.autocast():
loss = self.gpt_model(slots, targets)
self.accelerator.backward(loss)
if self.accelerator.sync_gradients and self.max_grad_norm is not None:
self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
self.g_optim.step()
if self.g_sched is not None:
self.g_sched.step_update(self.steps)
self.g_optim.zero_grad()
# Update EMA model if enabled
if self.enable_ema:
self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
return loss
def _train_epoch_cached(self, epoch, logger):
"""Train one epoch using cached data."""
self.cache_loader.set_epoch(epoch)
header = f'Epoch: [{epoch}/{self.num_epoch}]'
for batch in logger.log_every(self.cache_loader, 20, header):
slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
self.steps += 1
if self.steps == 1:
print(f"Training batch size: {len(slots)}")
print(f"Hello from index {self.accelerator.local_process_index}")
loss = self._train_step(slots, targets)
self._handle_periodic_ops(loss, logger)
def _train_epoch_uncached(self, epoch, logger):
"""Train one epoch using raw data."""
header = f'Epoch: [{epoch}/{self.num_epoch}]'
for batch in logger.log_every(self.train_dl, 20, header):
img, targets = (b.to(self.device, non_blocking=True) for b in batch)
self.steps += 1
if self.steps == 1:
print(f"Training batch size: {img.size(0)}")
print(f"Hello from index {self.accelerator.local_process_index}")
slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
loss = self._train_step(slots, targets)
self._handle_periodic_ops(loss, logger)
def _handle_periodic_ops(self, loss, logger):
"""Handle periodic operations and logging."""
logger.update(loss=loss.item())
logger.update(lr=self.g_optim.param_groups[0]["lr"])
if self.steps % self.save_every == 0:
self.save()
if (self.steps % self.samp_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
def _save_config(self, config):
"""Save configuration file."""
if config is not None and self.accelerator.is_main_process:
import shutil
from omegaconf import OmegaConf
if isinstance(config, str) and osp.exists(config):
shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
else:
config_save_path = osp.join(self.result_folder, "config.yaml")
OmegaConf.save(config, config_save_path)
def _should_skip_epoch(self, epoch):
"""Check if epoch should be skipped due to loaded checkpoint."""
loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
if ((epoch + 1) * len(loader)) <= self.loaded_steps:
if self.accelerator.is_main_process:
print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
self.steps += len(loader)
return True
if self.steps < self.loaded_steps:
for _ in loader:
self.steps += 1
if self.steps >= self.loaded_steps:
break
return False
def train(self, config=None):
"""Main training loop."""
# Initial setup
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
if self.accelerator.is_main_process:
print(f"number of learnable parameters: {n_parameters//1e6}M")
self._save_config(config)
self.accelerator.init_trackers("gpt")
# Handle test-only mode
if self.test_only:
empty_cache()
self.evaluate()
self.accelerator.wait_for_everyone()
empty_cache()
return
# Setup cache if enabled
if self.enable_cache_latents:
self._setup_cache()
# Training loop
for epoch in range(self.num_epoch):
if self._should_skip_epoch(epoch):
continue
self.gpt_model.train()
logger = MetricLogger(delimiter=" ")
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
# Choose training path based on cache availability
if self.enable_cache_latents:
self._train_epoch_cached(epoch, logger)
else:
self._train_epoch_uncached(epoch, logger)
# Synchronize and log epoch stats
# logger.synchronize_between_processes()
# if self.accelerator.is_main_process:
# print("Averaged stats:", logger)
# Finish training
self.accelerator.end_training()
self.save()
if self.accelerator.is_main_process:
print("Train finished!")
def save(self):
self.accelerator.wait_for_everyone()
self.accelerator.save_state(
os.path.join(self.model_saved_dir, f"step{self.steps}")
)
@torch.no_grad()
def evaluate(self, use_ema=True):
self.gpt_model.eval()
unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
# switch to ema params, only when eval_fid is True
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
if use_ema:
if hasattr(self, "ema_model"):
model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
if "nested_sampler" in name:
continue
ema_state_dict[name] = self.ema_model.state_dict()[name]
if self.accelerator.is_main_process:
print("Switch to ema")
model_without_ddp.load_state_dict(ema_state_dict)
else:
print("EMA model not found, using original model")
use_ema = False
generate_fn = generate if 'GPT' in self.model_name else generate_causal_dit
if not self.test_only:
classes = torch.tensor(self.eval_classes, device=self.device)
with self.accelerator.autocast():
slots = generate_fn(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature)
if self.num_slots_to_gen < self.num_slots:
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
null_slots = null_slots[:, self.num_slots_to_gen:, :]
slots = torch.cat([slots, null_slots], dim=1)
imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now
imgs = concat_all_gather(imgs)
if self.accelerator.num_processes > 16:
imgs = imgs[:16*len(self.eval_classes)]
imgs = imgs.detach().cpu()
grid = make_grid(
imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
)
if self.accelerator.is_main_process:
save_image(
grid,
os.path.join(
self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
),
)
if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
# Create output directory (only on main process)
save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
if self.accelerator.is_main_process:
os.makedirs(save_folder, exist_ok=True)
# Setup for distributed generation
world_size = self.accelerator.num_processes
local_rank = self.accelerator.process_index
batch_size = self.test_bs
# Create balanced class distribution
num_classes = self.num_classes
images_per_class = self.num_test_images // num_classes
class_labels = np.repeat(np.arange(num_classes), images_per_class)
# Shuffle the class labels to ensure random ordering
np.random.shuffle(class_labels)
total_images = len(class_labels)
padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
class_labels = np.pad(class_labels, (0, padding_size), 'constant')
padded_total_images = len(class_labels)
# Distribute workload across GPUs
images_per_gpu = padded_total_images // world_size
start_idx = local_rank * images_per_gpu
end_idx = min(start_idx + images_per_gpu, padded_total_images)
local_class_labels = class_labels[start_idx:end_idx]
local_num_steps = len(local_class_labels) // batch_size
if self.accelerator.is_main_process:
print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
used_time = 0
gen_img_cnt = 0
for i in range(local_num_steps):
if self.accelerator.is_main_process and i % 10 == 0:
print(f"Generation step {i}/{local_num_steps}")
# Get and pad labels for current batch
batch_start = i * batch_size
batch_end = batch_start + batch_size
labels = local_class_labels[batch_start:batch_end]
# Convert to tensors and track real vs padding
labels = torch.tensor(labels, device=self.device)
# Generate images
self.accelerator.wait_for_everyone()
start_time = time.time()
with torch.no_grad():
with self.accelerator.autocast():
slots = generate_fn(unwraped_gpt_model, labels, self.num_slots_to_gen,
cfg_scale=self.cfg, diff_cfg=self.diff_cfg,
cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule,
temperature=self.temperature)
if self.num_slots_to_gen < self.num_slots:
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
null_slots = null_slots[:, self.num_slots_to_gen:, :]
slots = torch.cat([slots, null_slots], dim=1)
imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)
samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)
# Update timing stats
used_time += time.time() - start_time
gen_img_cnt += samples_in_batch
if self.accelerator.is_main_process and i % 10 == 0:
print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")
gathered_imgs = concat_all_gather(imgs)
gathered_imgs = gathered_imgs[:samples_in_batch]
# Save images (only on main process)
if self.accelerator.is_main_process:
real_imgs = gathered_imgs.detach().cpu()
save_paths = [
os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
]
save_img_batch(real_imgs, save_paths)
# Calculate metrics (only on main process)
self.accelerator.wait_for_everyone()
if self.accelerator.is_main_process:
generated_files = len(os.listdir(save_folder))
print(f"Generated {generated_files} images out of {total_images} expected")
metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
fid = metrics_dict["frechet_inception_distance"]
inception_score = metrics_dict["inception_score_mean"]
metric_prefix = "fid_ema" if use_ema else "fid"
isc_prefix = "isc_ema" if use_ema else "isc"
self.accelerator.log({
metric_prefix: fid,
isc_prefix: inception_score,
"gpt_cfg": self.cfg,
"ae_cfg": self.ae_cfg,
"diff_cfg": self.diff_cfg,
"cfg_schedule": self.cfg_schedule,
"diff_cfg_schedule": self.diff_cfg_schedule,
"temperature": self.temperature,
"num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
}, step=self.steps)
# Print comprehensive CFG information
cfg_info = (
f"{'EMA ' if use_ema else ''}CFG params: "
f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, diff_cfg={self.diff_cfg}, "
f"cfg_schedule={self.cfg_schedule}, diff_cfg_schedule={self.diff_cfg_schedule}, "
f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
f"temperature={self.temperature}"
)
print(cfg_info)
print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")
# Cleanup
shutil.rmtree(save_folder)
# back to no ema
if use_ema:
if self.accelerator.is_main_process:
print("Switch back from ema")
model_without_ddp.load_state_dict(model_state_dict)
self.gpt_model.train()