import socket import os, sys, pdb from torch import inf import os.path as osp from pathlib import Path import builtins, datetime import torch.distributed as dist import os, sys, time, torch, copy, pdb from collections import defaultdict, deque def print_available_port(): return _find_free_port() def _find_free_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Binding to port 0 will cause the OS to find an available port for us sock.bind(("", 0)) port = sock.getsockname()[1] sock.close() # NOTE: there is still a chance the port could be taken by other processes. return port def ensure_dir(dirpath): if not osp.exists(dirpath): os.makedirs(dirpath, exist_ok=True) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop('force', False) force = force or (get_world_size() > 8) if is_master or force: now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def concat_all_gather(tensor): """ Performs all_gather operation on the provided tensors. *** Warning ***: torch.distributed.all_gather has no gradient. """ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] torch.distributed.all_gather(tensors_gather, tensor, async_op=False) output = torch.cat(tensors_gather, dim=0) return output def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if args.dist_on_itp: args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) assert isinstance(args.port, int) & (args.port > 0) & (args.port < 1<<30) port = _find_free_port() # args.dist_url = "tcp://%s:%s" % (port, os.environ['MASTER_PORT']) args.dist_url = f'tcp://127.0.0.1:{port}' os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') setup_for_distributed(is_master=True) # hack args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] # ema if ema_params is not None: ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): assert name in ema_state_dict ema_state_dict[name] = ema_params[i] else: ema_state_dict = None for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'model_ema': ema_state_dict, 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def save_model_last(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None): output_dir = Path(args.output_dir) epoch_name = 'last' if loss_scaler is not None: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] # ema if ema_params is not None: ema_state_dict = copy.deepcopy(model_without_ddp.state_dict()) for i, (name, _value) in enumerate(model_without_ddp.named_parameters()): assert name in ema_state_dict ema_state_dict[name] = ema_params[i] else: ema_state_dict = None for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'model_ema': ema_state_dict, 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def load_model(args, model_without_ddp, optimizer, loss_scaler): if osp.exists(osp.join(args.resume, "checkpoint-last.pth")): resume_path = osp.join(args.resume, "checkpoint-last.pth") else: resume_path = args.resume if args.resume: checkpoint = torch.load(resume_path, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) print("Resume checkpoint %s" % resume_path) if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate): optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") def all_reduce_mean(x): world_size = get_world_size() if world_size > 1: x_reduce = torch.tensor(x).cuda() dist.all_reduce(x_reduce) x_reduce /= world_size return x_reduce.item() else: return x