|
import os, torch |
|
import os.path as osp |
|
import cv2 |
|
import shutil |
|
import numpy as np |
|
import copy |
|
import torch_fidelity |
|
import torch.nn as nn |
|
from tqdm.auto import tqdm |
|
from collections import OrderedDict |
|
from einops import rearrange |
|
from accelerate import Accelerator |
|
from .util import instantiate_from_config |
|
from torchvision.utils import make_grid, save_image |
|
from torch.utils.data import DataLoader, random_split, DistributedSampler, Sampler |
|
from paintmind.utils.lr_scheduler import build_scheduler |
|
from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache |
|
from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather |
|
from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs |
|
from torch.optim import AdamW |
|
from concurrent.futures import ThreadPoolExecutor |
|
from paintmind.stage2.gpt import GPT_models |
|
from paintmind.stage2.causaldit import CausalDiT_models |
|
from paintmind.stage2.generate import generate, generate_causal_dit |
|
from pathlib import Path |
|
import time |
|
|
|
|
|
def requires_grad(model, flag=True): |
|
for p in model.parameters(): |
|
p.requires_grad = flag |
|
|
|
|
|
def save_img(img, save_path): |
|
img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255) |
|
img = img.astype(np.uint8)[:, :, ::-1] |
|
cv2.imwrite(save_path, img) |
|
|
|
def save_img_batch(imgs, save_paths): |
|
"""Process and save multiple images at once using a thread pool.""" |
|
|
|
imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8) |
|
imgs = imgs[:, :, :, ::-1] |
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=32) as pool: |
|
|
|
futures = [pool.submit(cv2.imwrite, path, img) |
|
for path, img in zip(save_paths, imgs)] |
|
|
|
for future in futures: |
|
future.result() |
|
|
|
def get_fid_stats(real_dir, rec_dir, fid_stats): |
|
stats = torch_fidelity.calculate_metrics( |
|
input1=real_dir, |
|
input2=rec_dir, |
|
fid_statistics_file=fid_stats, |
|
cuda=True, |
|
isc=True, |
|
fid=True, |
|
kid=False, |
|
prc=False, |
|
verbose=False, |
|
) |
|
return stats |
|
|
|
|
|
class EMAModel: |
|
"""Model Exponential Moving Average.""" |
|
def __init__(self, model, device, decay=0.999): |
|
self.device = device |
|
self.decay = decay |
|
self.ema_params = OrderedDict( |
|
(name, param.clone().detach().to(device)) |
|
for name, param in model.named_parameters() |
|
if param.requires_grad |
|
) |
|
|
|
@torch.no_grad() |
|
def update(self, model): |
|
for name, param in model.named_parameters(): |
|
if param.requires_grad: |
|
if name in self.ema_params: |
|
self.ema_params[name].lerp_(param.data, 1 - self.decay) |
|
else: |
|
self.ema_params[name] = param.data.clone().detach() |
|
|
|
def state_dict(self): |
|
return self.ema_params |
|
|
|
def load_state_dict(self, params): |
|
self.ema_params = OrderedDict( |
|
(name, param.clone().detach().to(self.device)) |
|
for name, param in params.items() |
|
) |
|
|
|
class CacheDataLoader: |
|
"""DataLoader-like interface for cached data with epoch-based shuffling.""" |
|
def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None): |
|
self.slots = slots |
|
self.targets = targets |
|
self.batch_size = batch_size |
|
self.num_augs = num_augs |
|
self.seed = seed |
|
self.epoch = 0 |
|
|
|
self.num_samples = len(slots) // num_augs |
|
|
|
def set_epoch(self, epoch): |
|
"""Set epoch for deterministic shuffling.""" |
|
self.epoch = epoch |
|
|
|
def __len__(self): |
|
"""Return number of batches based on original dataset size.""" |
|
return self.num_samples // self.batch_size |
|
|
|
def __iter__(self): |
|
"""Return random indices for current epoch.""" |
|
g = torch.Generator() |
|
g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch) |
|
|
|
|
|
indices = torch.randint( |
|
0, len(self.slots), |
|
(self.num_samples,), |
|
generator=g |
|
).numpy() |
|
|
|
|
|
for start in range(0, self.num_samples, self.batch_size): |
|
end = min(start + self.batch_size, self.num_samples) |
|
batch_indices = indices[start:end] |
|
yield ( |
|
torch.from_numpy(self.slots[batch_indices]), |
|
torch.from_numpy(self.targets[batch_indices]) |
|
) |
|
|
|
class GPTTrainer(nn.Module): |
|
def __init__( |
|
self, |
|
ae_model, |
|
gpt_model, |
|
dataset, |
|
test_dataset=None, |
|
test_only=False, |
|
num_test_images=50000, |
|
num_epoch=400, |
|
eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], |
|
lr=None, |
|
blr=1e-4, |
|
cosine_lr=False, |
|
lr_min=0, |
|
warmup_epochs=100, |
|
warmup_steps=None, |
|
warmup_lr_init=0, |
|
decay_steps=None, |
|
batch_size=32, |
|
cache_bs=8, |
|
test_bs=100, |
|
num_workers=0, |
|
pin_memory=False, |
|
max_grad_norm=None, |
|
grad_accum_steps=1, |
|
precision="bf16", |
|
save_every=10000, |
|
sample_every=1000, |
|
fid_every=50000, |
|
result_folder=None, |
|
log_dir="./log", |
|
steps=0, |
|
cfg=1.75, |
|
ae_cfg=1.5, |
|
diff_cfg=2.0, |
|
temperature=1.0, |
|
cfg_schedule="constant", |
|
diff_cfg_schedule="inv_linear", |
|
train_num_slots=None, |
|
test_num_slots=None, |
|
eval_fid=False, |
|
fid_stats=None, |
|
enable_ema=False, |
|
compile=False, |
|
enable_cache_latents=True, |
|
cache_dir='/dev/shm/slot_cache', |
|
seed=42 |
|
): |
|
super().__init__() |
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) |
|
self.accelerator = Accelerator( |
|
kwargs_handlers=[kwargs], |
|
mixed_precision="bf16", |
|
gradient_accumulation_steps=grad_accum_steps, |
|
log_with="tensorboard", |
|
project_dir=log_dir, |
|
) |
|
|
|
self.ae_model = instantiate_from_config(ae_model) |
|
if hasattr(ae_model.params, "ema_path") and ae_model.params.ema_path is not None: |
|
ae_model_path = ae_model.params.ema_path |
|
else: |
|
ae_model_path = ae_model.params.ckpt_path |
|
assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl") |
|
assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist" |
|
self._load_checkpoint(ae_model_path, self.ae_model) |
|
|
|
self.ae_model.to(self.device) |
|
for param in self.ae_model.parameters(): |
|
param.requires_grad = False |
|
self.ae_model.eval() |
|
|
|
self.model_name = gpt_model.target |
|
if 'GPT' in gpt_model.target: |
|
self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params) |
|
elif 'CausalDiT' in gpt_model.target: |
|
self.gpt_model = CausalDiT_models[gpt_model.target](**gpt_model.params) |
|
else: |
|
raise ValueError(f"Unknown model type: {gpt_model.target}") |
|
self.num_slots = ae_model.params.num_slots |
|
self.slot_dim = ae_model.params.slot_dim |
|
|
|
assert precision in ["bf16", "fp32"] |
|
precision = "fp32" |
|
if self.accelerator.is_main_process: |
|
print("Overlooking specified precision and using autocast bf16...") |
|
self.precision = precision |
|
|
|
self.test_only = test_only |
|
self.test_bs = test_bs |
|
self.num_test_images = num_test_images |
|
self.num_classes = gpt_model.params.num_classes |
|
|
|
self.batch_size = batch_size |
|
if not test_only: |
|
self.train_ds = instantiate_from_config(dataset) |
|
train_size = len(self.train_ds) |
|
if self.accelerator.is_main_process: |
|
print(f"train dataset size: {train_size}") |
|
|
|
sampler = DistributedSampler( |
|
self.train_ds, |
|
num_replicas=self.accelerator.num_processes, |
|
rank=self.accelerator.process_index, |
|
shuffle=True, |
|
) |
|
self.train_dl = DataLoader( |
|
self.train_ds, |
|
batch_size=batch_size if not enable_cache_latents else cache_bs, |
|
sampler=sampler, |
|
num_workers=num_workers, |
|
pin_memory=pin_memory, |
|
drop_last=True, |
|
) |
|
|
|
effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes |
|
if lr is None: |
|
lr = blr * effective_bs / 256 |
|
if self.accelerator.is_main_process: |
|
print(f"Effective batch size is {effective_bs}") |
|
|
|
self.g_optim = self._creat_optimizer(weight_decay=0.05, learning_rate=lr, betas=(0.9, 0.95)) |
|
self.g_sched = self._create_scheduler( |
|
cosine_lr, warmup_epochs, warmup_steps, num_epoch, |
|
lr_min, warmup_lr_init, decay_steps |
|
) |
|
self.accelerator.register_for_checkpointing(self.g_sched) |
|
|
|
self.steps = steps |
|
self.loaded_steps = -1 |
|
|
|
|
|
if not test_only: |
|
self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare( |
|
self.gpt_model, self.g_optim, self.g_sched |
|
) |
|
else: |
|
self.gpt_model = self.accelerator.prepare(self.gpt_model) |
|
|
|
|
|
if compile: |
|
_model = self.accelerator.unwrap_model(self.gpt_model) |
|
self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead") |
|
_model = torch.compile(_model, mode="reduce-overhead") |
|
|
|
self.enable_ema = enable_ema |
|
if self.enable_ema and not self.test_only: |
|
self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device) |
|
self.accelerator.register_for_checkpointing(self.ema_model) |
|
|
|
self._load_checkpoint(gpt_model.params.ckpt_path) |
|
if self.test_only: |
|
self.steps = self.loaded_steps |
|
|
|
self.num_epoch = num_epoch |
|
self.save_every = save_every |
|
self.samp_every = sample_every |
|
self.fid_every = fid_every |
|
self.max_grad_norm = max_grad_norm |
|
|
|
self.eval_classes = eval_classes |
|
self.cfg = cfg |
|
self.ae_cfg = ae_cfg |
|
self.diff_cfg = diff_cfg |
|
self.cfg_schedule = cfg_schedule |
|
self.diff_cfg_schedule = diff_cfg_schedule |
|
self.temperature = temperature |
|
self.train_num_slots = train_num_slots |
|
self.test_num_slots = test_num_slots |
|
if self.train_num_slots is not None: |
|
self.train_num_slots = min(self.train_num_slots, self.num_slots) |
|
else: |
|
self.train_num_slots = self.num_slots |
|
if self.test_num_slots is not None: |
|
self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots) |
|
else: |
|
self.num_slots_to_gen = self.train_num_slots |
|
self.eval_fid = eval_fid |
|
if eval_fid: |
|
assert fid_stats is not None |
|
self.fid_stats = fid_stats |
|
|
|
self.result_folder = result_folder |
|
self.model_saved_dir = os.path.join(result_folder, "models") |
|
os.makedirs(self.model_saved_dir, exist_ok=True) |
|
|
|
self.image_saved_dir = os.path.join(result_folder, "images") |
|
os.makedirs(self.image_saved_dir, exist_ok=True) |
|
|
|
self.cache_dir = Path(cache_dir) |
|
self.enable_cache_latents = enable_cache_latents |
|
self.seed = seed |
|
self.cache_loader = None |
|
|
|
@property |
|
def device(self): |
|
return self.accelerator.device |
|
|
|
def _creat_optimizer(self, weight_decay, learning_rate, betas): |
|
|
|
param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()} |
|
|
|
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} |
|
|
|
|
|
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] |
|
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] |
|
optim_groups = [ |
|
{'params': decay_params, 'weight_decay': weight_decay}, |
|
{'params': nodecay_params, 'weight_decay': 0.0} |
|
] |
|
num_decay_params = sum(p.numel() for p in decay_params) |
|
num_nodecay_params = sum(p.numel() for p in nodecay_params) |
|
if self.accelerator.is_main_process: |
|
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") |
|
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") |
|
optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas) |
|
return optimizer |
|
|
|
def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps): |
|
if warmup_epochs is not None: |
|
warmup_steps = warmup_epochs * len(self.train_dl) |
|
else: |
|
assert warmup_steps is not None |
|
|
|
scheduler = build_scheduler( |
|
self.g_optim, |
|
num_epoch, |
|
len(self.train_dl), |
|
lr_min, |
|
warmup_steps, |
|
warmup_lr_init, |
|
decay_steps, |
|
cosine_lr, |
|
) |
|
return scheduler |
|
|
|
def _load_state_dict(self, state_dict, model): |
|
"""Helper to load a state dict with proper prefix handling.""" |
|
if 'state_dict' in state_dict: |
|
state_dict = state_dict['state_dict'] |
|
|
|
state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()} |
|
missing, unexpected = model.load_state_dict( |
|
state_dict, strict=False |
|
) |
|
if self.accelerator.is_main_process: |
|
print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}") |
|
|
|
def _load_safetensors(self, path, model): |
|
"""Helper to load a safetensors checkpoint.""" |
|
from safetensors.torch import safe_open |
|
with safe_open(path, framework="pt", device="cpu") as f: |
|
state_dict = {k: f.get_tensor(k) for k in f.keys()} |
|
self._load_state_dict(state_dict, model) |
|
|
|
def _load_checkpoint(self, ckpt_path=None, model=None): |
|
if ckpt_path is None or not osp.exists(ckpt_path): |
|
return |
|
|
|
if model is None: |
|
model = self.accelerator.unwrap_model(self.gpt_model) |
|
|
|
if osp.isdir(ckpt_path): |
|
|
|
self.loaded_steps = int( |
|
ckpt_path.split("step")[-1].split("/")[0] |
|
) |
|
if not self.test_only: |
|
self.accelerator.load_state(ckpt_path) |
|
else: |
|
if self.enable_ema: |
|
model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl") |
|
if osp.exists(model_path): |
|
state_dict = torch.load(model_path, map_location="cpu") |
|
self._load_state_dict(state_dict, model) |
|
if self.accelerator.is_main_process: |
|
print(f"Loaded ema model from {model_path}") |
|
else: |
|
model_path = osp.join(ckpt_path, "model.safetensors") |
|
if osp.exists(model_path): |
|
self._load_safetensors(model_path, model) |
|
else: |
|
|
|
if ckpt_path.endswith(".safetensors"): |
|
self._load_safetensors(ckpt_path, model) |
|
else: |
|
state_dict = torch.load(ckpt_path, map_location="cpu") |
|
self._load_state_dict(state_dict, model) |
|
|
|
if self.accelerator.is_main_process: |
|
print(f"Loaded checkpoint from {ckpt_path}") |
|
|
|
def _build_cache(self): |
|
"""Build cache for slots and targets.""" |
|
rank = self.accelerator.process_index |
|
world_size = self.accelerator.num_processes |
|
|
|
|
|
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" |
|
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" |
|
|
|
if slots_file.exists(): |
|
os.remove(slots_file) |
|
if targets_file.exists(): |
|
os.remove(targets_file) |
|
|
|
dataset_size = len(self.train_dl.dataset) |
|
shard_size = dataset_size // world_size |
|
|
|
|
|
with torch.no_grad(): |
|
sample_batch = next(iter(self.train_dl)) |
|
img, _ = sample_batch |
|
num_augs = img.shape[1] if len(img.shape) == 5 else 1 |
|
|
|
print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...") |
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" |
|
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" |
|
|
|
|
|
slots_mmap = np.memmap( |
|
slots_file, |
|
dtype='float32', |
|
mode='w+', |
|
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) |
|
) |
|
|
|
targets_mmap = np.memmap( |
|
targets_file, |
|
dtype='int64', |
|
mode='w+', |
|
shape=(shard_size * num_augs,) |
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
for i, batch in enumerate(tqdm( |
|
self.train_dl, |
|
desc=f"Rank {rank}: Caching data", |
|
disable=not self.accelerator.is_local_main_process |
|
)): |
|
imgs, targets = batch |
|
if len(imgs.shape) == 5: |
|
B, A, C, H, W = imgs.shape |
|
imgs = imgs.view(-1, C, H, W) |
|
targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) |
|
|
|
|
|
num_splits = num_augs |
|
split_size = imgs.shape[0] // num_splits |
|
imgs_splits = torch.split(imgs, split_size) |
|
targets_splits = torch.split(targets, split_size) |
|
|
|
start_idx = i * self.train_dl.batch_size * num_augs |
|
|
|
for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)): |
|
img_split = img_split.to(self.device, non_blocking=True) |
|
slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :] |
|
|
|
split_start = start_idx + (split_idx * split_size) |
|
split_end = split_start + img_split.shape[0] |
|
|
|
|
|
slots_mmap[split_start:split_end] = slots_split.cpu().numpy() |
|
targets_mmap[split_start:split_end] = targets_split.numpy() |
|
|
|
|
|
del slots_mmap |
|
del targets_mmap |
|
|
|
|
|
self.cached_latents = np.memmap( |
|
slots_file, |
|
dtype='float32', |
|
mode='r', |
|
shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim) |
|
) |
|
|
|
self.cached_targets = np.memmap( |
|
targets_file, |
|
dtype='int64', |
|
mode='r', |
|
shape=(shard_size * num_augs,) |
|
) |
|
|
|
|
|
self.num_augs = num_augs |
|
|
|
def _setup_cache(self): |
|
"""Setup cache if enabled.""" |
|
self._build_cache() |
|
self.accelerator.wait_for_everyone() |
|
|
|
|
|
if self.cached_latents is not None: |
|
self.cache_loader = CacheDataLoader( |
|
slots=self.cached_latents, |
|
targets=self.cached_targets, |
|
batch_size=self.batch_size, |
|
num_augs=self.num_augs, |
|
seed=self.seed + self.accelerator.process_index |
|
) |
|
|
|
def __del__(self): |
|
"""Cleanup cache files.""" |
|
if self.enable_cache_latents: |
|
rank = self.accelerator.process_index |
|
world_size = self.accelerator.num_processes |
|
|
|
|
|
slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap" |
|
if slots_file.exists(): |
|
os.remove(slots_file) |
|
|
|
|
|
targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap" |
|
if targets_file.exists(): |
|
os.remove(targets_file) |
|
|
|
def _train_step(self, slots, targets=None): |
|
"""Execute single training step.""" |
|
|
|
with self.accelerator.accumulate(self.gpt_model): |
|
with self.accelerator.autocast(): |
|
loss = self.gpt_model(slots, targets) |
|
|
|
self.accelerator.backward(loss) |
|
if self.accelerator.sync_gradients and self.max_grad_norm is not None: |
|
self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm) |
|
self.g_optim.step() |
|
if self.g_sched is not None: |
|
self.g_sched.step_update(self.steps) |
|
self.g_optim.zero_grad() |
|
|
|
|
|
if self.enable_ema: |
|
self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model)) |
|
|
|
return loss |
|
|
|
def _train_epoch_cached(self, epoch, logger): |
|
"""Train one epoch using cached data.""" |
|
self.cache_loader.set_epoch(epoch) |
|
header = f'Epoch: [{epoch}/{self.num_epoch}]' |
|
|
|
for batch in logger.log_every(self.cache_loader, 20, header): |
|
slots, targets = (b.to(self.device, non_blocking=True) for b in batch) |
|
|
|
self.steps += 1 |
|
|
|
if self.steps == 1: |
|
print(f"Training batch size: {len(slots)}") |
|
print(f"Hello from index {self.accelerator.local_process_index}") |
|
|
|
loss = self._train_step(slots, targets) |
|
self._handle_periodic_ops(loss, logger) |
|
|
|
def _train_epoch_uncached(self, epoch, logger): |
|
"""Train one epoch using raw data.""" |
|
header = f'Epoch: [{epoch}/{self.num_epoch}]' |
|
|
|
for batch in logger.log_every(self.train_dl, 20, header): |
|
img, targets = (b.to(self.device, non_blocking=True) for b in batch) |
|
|
|
self.steps += 1 |
|
|
|
if self.steps == 1: |
|
print(f"Training batch size: {img.size(0)}") |
|
print(f"Hello from index {self.accelerator.local_process_index}") |
|
|
|
slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :] |
|
loss = self._train_step(slots, targets) |
|
self._handle_periodic_ops(loss, logger) |
|
|
|
def _handle_periodic_ops(self, loss, logger): |
|
"""Handle periodic operations and logging.""" |
|
logger.update(loss=loss.item()) |
|
logger.update(lr=self.g_optim.param_groups[0]["lr"]) |
|
|
|
if self.steps % self.save_every == 0: |
|
self.save() |
|
|
|
if (self.steps % self.samp_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0): |
|
empty_cache() |
|
self.evaluate() |
|
self.accelerator.wait_for_everyone() |
|
empty_cache() |
|
|
|
def _save_config(self, config): |
|
"""Save configuration file.""" |
|
if config is not None and self.accelerator.is_main_process: |
|
import shutil |
|
from omegaconf import OmegaConf |
|
|
|
if isinstance(config, str) and osp.exists(config): |
|
shutil.copy(config, osp.join(self.result_folder, "config.yaml")) |
|
else: |
|
config_save_path = osp.join(self.result_folder, "config.yaml") |
|
OmegaConf.save(config, config_save_path) |
|
|
|
def _should_skip_epoch(self, epoch): |
|
"""Check if epoch should be skipped due to loaded checkpoint.""" |
|
loader = self.train_dl if not self.enable_cache_latents else self.cache_loader |
|
if ((epoch + 1) * len(loader)) <= self.loaded_steps: |
|
if self.accelerator.is_main_process: |
|
print(f"Epoch {epoch} is skipped because it is loaded from ckpt") |
|
self.steps += len(loader) |
|
return True |
|
|
|
if self.steps < self.loaded_steps: |
|
for _ in loader: |
|
self.steps += 1 |
|
if self.steps >= self.loaded_steps: |
|
break |
|
return False |
|
|
|
def train(self, config=None): |
|
"""Main training loop.""" |
|
|
|
n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
if self.accelerator.is_main_process: |
|
print(f"number of learnable parameters: {n_parameters//1e6}M") |
|
|
|
self._save_config(config) |
|
self.accelerator.init_trackers("gpt") |
|
|
|
|
|
if self.test_only: |
|
empty_cache() |
|
self.evaluate() |
|
self.accelerator.wait_for_everyone() |
|
empty_cache() |
|
return |
|
|
|
|
|
if self.enable_cache_latents: |
|
self._setup_cache() |
|
|
|
|
|
for epoch in range(self.num_epoch): |
|
if self._should_skip_epoch(epoch): |
|
continue |
|
|
|
self.gpt_model.train() |
|
logger = MetricLogger(delimiter=" ") |
|
logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) |
|
|
|
|
|
if self.enable_cache_latents: |
|
self._train_epoch_cached(epoch, logger) |
|
else: |
|
self._train_epoch_uncached(epoch, logger) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.accelerator.end_training() |
|
self.save() |
|
if self.accelerator.is_main_process: |
|
print("Train finished!") |
|
|
|
def save(self): |
|
self.accelerator.wait_for_everyone() |
|
self.accelerator.save_state( |
|
os.path.join(self.model_saved_dir, f"step{self.steps}") |
|
) |
|
|
|
@torch.no_grad() |
|
def evaluate(self, use_ema=True): |
|
self.gpt_model.eval() |
|
unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model) |
|
|
|
use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only |
|
if use_ema: |
|
if hasattr(self, "ema_model"): |
|
model_without_ddp = self.accelerator.unwrap_model(self.gpt_model) |
|
model_state_dict = copy.deepcopy(model_without_ddp.state_dict()) |
|
ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) |
|
for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): |
|
if "nested_sampler" in name: |
|
continue |
|
ema_state_dict[name] = self.ema_model.state_dict()[name] |
|
if self.accelerator.is_main_process: |
|
print("Switch to ema") |
|
model_without_ddp.load_state_dict(ema_state_dict) |
|
else: |
|
print("EMA model not found, using original model") |
|
use_ema = False |
|
|
|
generate_fn = generate if 'GPT' in self.model_name else generate_causal_dit |
|
if not self.test_only: |
|
classes = torch.tensor(self.eval_classes, device=self.device) |
|
with self.accelerator.autocast(): |
|
slots = generate_fn(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature) |
|
if self.num_slots_to_gen < self.num_slots: |
|
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) |
|
null_slots = null_slots[:, self.num_slots_to_gen:, :] |
|
slots = torch.cat([slots, null_slots], dim=1) |
|
imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) |
|
|
|
imgs = concat_all_gather(imgs) |
|
if self.accelerator.num_processes > 16: |
|
imgs = imgs[:16*len(self.eval_classes)] |
|
imgs = imgs.detach().cpu() |
|
grid = make_grid( |
|
imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1) |
|
) |
|
if self.accelerator.is_main_process: |
|
save_image( |
|
grid, |
|
os.path.join( |
|
self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg" |
|
), |
|
) |
|
if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)): |
|
|
|
save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}") |
|
if self.accelerator.is_main_process: |
|
os.makedirs(save_folder, exist_ok=True) |
|
|
|
|
|
world_size = self.accelerator.num_processes |
|
local_rank = self.accelerator.process_index |
|
batch_size = self.test_bs |
|
|
|
|
|
num_classes = self.num_classes |
|
images_per_class = self.num_test_images // num_classes |
|
class_labels = np.repeat(np.arange(num_classes), images_per_class) |
|
|
|
|
|
np.random.shuffle(class_labels) |
|
|
|
total_images = len(class_labels) |
|
|
|
padding_size = world_size * batch_size - (total_images % (world_size * batch_size)) |
|
class_labels = np.pad(class_labels, (0, padding_size), 'constant') |
|
padded_total_images = len(class_labels) |
|
|
|
|
|
images_per_gpu = padded_total_images // world_size |
|
start_idx = local_rank * images_per_gpu |
|
end_idx = min(start_idx + images_per_gpu, padded_total_images) |
|
local_class_labels = class_labels[start_idx:end_idx] |
|
local_num_steps = len(local_class_labels) // batch_size |
|
|
|
if self.accelerator.is_main_process: |
|
print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs") |
|
|
|
used_time = 0 |
|
gen_img_cnt = 0 |
|
|
|
for i in range(local_num_steps): |
|
if self.accelerator.is_main_process and i % 10 == 0: |
|
print(f"Generation step {i}/{local_num_steps}") |
|
|
|
|
|
batch_start = i * batch_size |
|
batch_end = batch_start + batch_size |
|
labels = local_class_labels[batch_start:batch_end] |
|
|
|
|
|
labels = torch.tensor(labels, device=self.device) |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
start_time = time.time() |
|
with torch.no_grad(): |
|
with self.accelerator.autocast(): |
|
slots = generate_fn(unwraped_gpt_model, labels, self.num_slots_to_gen, |
|
cfg_scale=self.cfg, diff_cfg=self.diff_cfg, |
|
cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, |
|
temperature=self.temperature) |
|
if self.num_slots_to_gen < self.num_slots: |
|
null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1) |
|
null_slots = null_slots[:, self.num_slots_to_gen:, :] |
|
slots = torch.cat([slots, null_slots], dim=1) |
|
imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg) |
|
|
|
samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt) |
|
|
|
|
|
used_time += time.time() - start_time |
|
gen_img_cnt += samples_in_batch |
|
if self.accelerator.is_main_process and i % 10 == 0: |
|
print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image") |
|
|
|
gathered_imgs = concat_all_gather(imgs) |
|
gathered_imgs = gathered_imgs[:samples_in_batch] |
|
|
|
|
|
if self.accelerator.is_main_process: |
|
real_imgs = gathered_imgs.detach().cpu() |
|
|
|
save_paths = [ |
|
os.path.join(save_folder, f"{str(idx).zfill(5)}.png") |
|
for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt) |
|
] |
|
save_img_batch(real_imgs, save_paths) |
|
|
|
|
|
self.accelerator.wait_for_everyone() |
|
if self.accelerator.is_main_process: |
|
generated_files = len(os.listdir(save_folder)) |
|
print(f"Generated {generated_files} images out of {total_images} expected") |
|
|
|
metrics_dict = get_fid_stats(save_folder, None, self.fid_stats) |
|
fid = metrics_dict["frechet_inception_distance"] |
|
inception_score = metrics_dict["inception_score_mean"] |
|
|
|
metric_prefix = "fid_ema" if use_ema else "fid" |
|
isc_prefix = "isc_ema" if use_ema else "isc" |
|
|
|
self.accelerator.log({ |
|
metric_prefix: fid, |
|
isc_prefix: inception_score, |
|
"gpt_cfg": self.cfg, |
|
"ae_cfg": self.ae_cfg, |
|
"diff_cfg": self.diff_cfg, |
|
"cfg_schedule": self.cfg_schedule, |
|
"diff_cfg_schedule": self.diff_cfg_schedule, |
|
"temperature": self.temperature, |
|
"num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots |
|
}, step=self.steps) |
|
|
|
|
|
cfg_info = ( |
|
f"{'EMA ' if use_ema else ''}CFG params: " |
|
f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, diff_cfg={self.diff_cfg}, " |
|
f"cfg_schedule={self.cfg_schedule}, diff_cfg_schedule={self.diff_cfg_schedule}, " |
|
f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, " |
|
f"temperature={self.temperature}" |
|
) |
|
print(cfg_info) |
|
print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}") |
|
|
|
|
|
shutil.rmtree(save_folder) |
|
|
|
|
|
if use_ema: |
|
if self.accelerator.is_main_process: |
|
print("Switch back from ema") |
|
model_without_ddp.load_state_dict(model_state_dict) |
|
|
|
self.gpt_model.train() |
|
|
|
|