Spaces:
Running
Running
| import copy | |
| import datetime | |
| import os | |
| import random | |
| import time | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from openrec.losses import build_loss | |
| from openrec.metrics import build_metric | |
| from openrec.modeling import build_model | |
| from openrec.optimizer import build_optimizer | |
| from openrec.postprocess import build_post_process | |
| from tools.data import build_dataloader | |
| from tools.utils.ckpt import load_ckpt, save_ckpt | |
| from tools.utils.logging import get_logger | |
| from tools.utils.stats import TrainingStats | |
| from tools.utils.utility import AverageMeter | |
| __all__ = ['Trainer'] | |
| def get_parameter_number(model): | |
| total_num = sum(p.numel() for p in model.parameters()) | |
| trainable_num = sum(p.numel() for p in model.parameters() | |
| if p.requires_grad) | |
| return {'Total': total_num, 'Trainable': trainable_num} | |
| class Trainer(object): | |
| def __init__(self, cfg, mode='train'): | |
| self.cfg = cfg.cfg | |
| self.local_rank = (int(os.environ['LOCAL_RANK']) | |
| if 'LOCAL_RANK' in os.environ else 0) | |
| self.set_device(self.cfg['Global']['device']) | |
| mode = mode.lower() | |
| assert mode in [ | |
| 'train_eval', | |
| 'train', | |
| 'eval', | |
| 'test', | |
| ], 'mode should be train, eval and test' | |
| if torch.cuda.device_count() > 1 and 'train' in mode: | |
| torch.distributed.init_process_group(backend='nccl') | |
| torch.cuda.set_device(self.device) | |
| self.cfg['Global']['distributed'] = True | |
| else: | |
| self.cfg['Global']['distributed'] = False | |
| self.local_rank = 0 | |
| self.cfg['Global']['output_dir'] = self.cfg['Global'].get( | |
| 'output_dir', 'output') | |
| os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True) | |
| self.writer = None | |
| if self.local_rank == 0 and self.cfg['Global'][ | |
| 'use_tensorboard'] and 'train' in mode: | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.writer = SummaryWriter(self.cfg['Global']['output_dir']) | |
| self.logger = get_logger( | |
| 'openrec', | |
| os.path.join(self.cfg['Global']['output_dir'], 'train.log') | |
| if 'train' in mode else None, | |
| ) | |
| cfg.print_cfg(self.logger.info) | |
| if self.cfg['Global']['device'] == 'gpu' and self.device.type == 'cpu': | |
| self.logger.info('cuda is not available, auto switch to cpu') | |
| self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0) | |
| self.all_ema = self.cfg['Global'].get('all_ema', True) | |
| self.use_ema = self.cfg['Global'].get('use_ema', True) | |
| self.set_random_seed(self.cfg['Global'].get('seed', 48)) | |
| # build data loader | |
| self.train_dataloader = None | |
| if 'train' in mode: | |
| cfg.save( | |
| os.path.join(self.cfg['Global']['output_dir'], 'config.yml'), | |
| self.cfg) | |
| self.train_dataloader = build_dataloader(self.cfg, 'Train', | |
| self.logger) | |
| self.logger.info( | |
| f'train dataloader has {len(self.train_dataloader)} iters') | |
| self.valid_dataloader = None | |
| if 'eval' in mode and self.cfg['Eval']: | |
| self.valid_dataloader = build_dataloader(self.cfg, 'Eval', | |
| self.logger) | |
| self.logger.info( | |
| f'valid dataloader has {len(self.valid_dataloader)} iters') | |
| # build post process | |
| self.post_process_class = build_post_process(self.cfg['PostProcess'], | |
| self.cfg['Global']) | |
| # build model | |
| # for rec algorithm | |
| char_num = self.post_process_class.get_character_num() | |
| self.cfg['Architecture']['Decoder']['out_channels'] = char_num | |
| self.model = build_model(self.cfg['Architecture']) | |
| self.logger.info(get_parameter_number(model=self.model)) | |
| self.model = self.model.to(self.device) | |
| if self.local_rank == 0: | |
| ema_model = build_model(self.cfg['Architecture']) | |
| self.ema_model = ema_model.to(self.device) | |
| self.ema_model.eval() | |
| use_sync_bn = self.cfg['Global'].get('use_sync_bn', False) | |
| if use_sync_bn: | |
| self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( | |
| self.model) | |
| self.logger.info('convert_sync_batchnorm') | |
| # build loss | |
| self.loss_class = build_loss(self.cfg['Loss']) | |
| self.optimizer, self.lr_scheduler = None, None | |
| if self.train_dataloader is not None: | |
| # build optim | |
| self.optimizer, self.lr_scheduler = build_optimizer( | |
| self.cfg['Optimizer'], | |
| self.cfg['LRScheduler'], | |
| epochs=self.cfg['Global']['epoch_num'], | |
| step_each_epoch=len(self.train_dataloader), | |
| model=self.model, | |
| ) | |
| self.eval_class = build_metric(self.cfg['Metric']) | |
| self.status = load_ckpt(self.model, self.cfg, self.optimizer, | |
| self.lr_scheduler) | |
| if self.cfg['Global']['distributed']: | |
| self.model = torch.nn.parallel.DistributedDataParallel( | |
| self.model, [self.local_rank], find_unused_parameters=False) | |
| # amp | |
| self.scaler = (torch.cuda.amp.GradScaler() if self.cfg['Global'].get( | |
| 'use_amp', False) else None) | |
| self.logger.info( | |
| f'run with torch {torch.__version__} and device {self.device}') | |
| def load_params(self, params): | |
| self.model.load_state_dict(params) | |
| def set_random_seed(self, seed): | |
| torch.manual_seed(seed) # 为CPU设置随机种子 | |
| if self.device.type == 'cuda': | |
| torch.backends.cudnn.benchmark = True | |
| torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子 | |
| torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子 | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| def set_device(self, device): | |
| if device == 'gpu' and torch.cuda.is_available(): | |
| device = torch.device(f'cuda:{self.local_rank}') | |
| else: | |
| device = torch.device('cpu') | |
| self.device = device | |
| def train(self): | |
| cal_metric_during_train = self.cfg['Global'].get( | |
| 'cal_metric_during_train', False) | |
| log_smooth_window = self.cfg['Global']['log_smooth_window'] | |
| epoch_num = self.cfg['Global']['epoch_num'] | |
| print_batch_step = self.cfg['Global']['print_batch_step'] | |
| eval_epoch_step = self.cfg['Global'].get('eval_epoch_step', 1) | |
| start_eval_epoch = 0 | |
| if self.valid_dataloader is not None: | |
| if type(eval_epoch_step) == list and len(eval_epoch_step) >= 2: | |
| start_eval_epoch = eval_epoch_step[0] | |
| eval_epoch_step = eval_epoch_step[1] | |
| if len(self.valid_dataloader) == 0: | |
| start_eval_epoch = 1e111 | |
| self.logger.info( | |
| 'No Images in eval dataset, evaluation during training will be disabled' | |
| ) | |
| self.logger.info( | |
| f'During the training process, after the {start_eval_epoch}th epoch, ' | |
| f'an evaluation is run every {eval_epoch_step} epoch') | |
| else: | |
| start_eval_epoch = 1e111 | |
| eval_batch_step = self.cfg['Global']['eval_batch_step'] | |
| global_step = self.status.get('global_step', 0) | |
| start_eval_step = 0 | |
| if type(eval_batch_step) == list and len(eval_batch_step) >= 2: | |
| start_eval_step = eval_batch_step[0] | |
| eval_batch_step = eval_batch_step[1] | |
| if len(self.valid_dataloader) == 0: | |
| self.logger.info( | |
| 'No Images in eval dataset, evaluation during training ' | |
| 'will be disabled') | |
| start_eval_step = 1e111 | |
| self.logger.info( | |
| 'During the training process, after the {}th iteration, ' | |
| 'an evaluation is run every {} iterations'.format( | |
| start_eval_step, eval_batch_step)) | |
| start_epoch = self.status.get('epoch', 1) | |
| best_metric = self.status.get('metrics', {}) | |
| if self.eval_class.main_indicator not in best_metric: | |
| best_metric[self.eval_class.main_indicator] = 0 | |
| ema_best_metric = self.status.get('metrics', {}) | |
| ema_best_metric[self.eval_class.main_indicator] = 0 | |
| train_stats = TrainingStats(log_smooth_window, ['lr']) | |
| self.model.train() | |
| total_samples = 0 | |
| train_reader_cost = 0.0 | |
| train_batch_cost = 0.0 | |
| best_iter = 0 | |
| ema_stpe = 1 | |
| ema_eval_iter = 0 | |
| loss_avg = 0. | |
| reader_start = time.time() | |
| eta_meter = AverageMeter() | |
| for epoch in range(start_epoch, epoch_num + 1): | |
| if self.train_dataloader.dataset.need_reset: | |
| self.train_dataloader = build_dataloader( | |
| self.cfg, | |
| 'Train', | |
| self.logger, | |
| epoch=epoch % 20 if epoch % 20 != 0 else 20, | |
| ) | |
| for idx, batch in enumerate(self.train_dataloader): | |
| batch = [t.to(self.device) for t in batch] | |
| self.optimizer.zero_grad() | |
| train_reader_cost += time.time() - reader_start | |
| # use amp | |
| if self.scaler: | |
| with torch.cuda.amp.autocast(): | |
| preds = self.model(batch[0], data=batch[1:]) | |
| loss = self.loss_class(preds, batch) | |
| self.scaler.scale(loss['loss']).backward() | |
| if self.grad_clip_val > 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| max_norm=self.grad_clip_val) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| else: | |
| preds = self.model(batch[0], data=batch[1:]) | |
| loss = self.loss_class(preds, batch) | |
| avg_loss = loss['loss'] | |
| avg_loss.backward() | |
| if self.grad_clip_val > 0: | |
| torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| max_norm=self.grad_clip_val) | |
| self.optimizer.step() | |
| if cal_metric_during_train: # only rec and cls need | |
| post_result = self.post_process_class(preds, | |
| batch, | |
| training=True) | |
| self.eval_class(post_result, batch, training=True) | |
| metric = self.eval_class.get_metric() | |
| train_stats.update(metric) | |
| train_batch_time = time.time() - reader_start | |
| train_batch_cost += train_batch_time | |
| eta_meter.update(train_batch_time) | |
| global_step += 1 | |
| total_samples += len(batch[0]) | |
| self.lr_scheduler.step() | |
| if self.local_rank == 0 and self.use_ema and epoch > ( | |
| epoch_num - epoch_num // 10): | |
| with torch.no_grad(): | |
| loss_currn = loss['loss'].detach().cpu().numpy().mean() | |
| loss_avg = ((loss_avg * | |
| (ema_stpe - 1)) + loss_currn) / (ema_stpe) | |
| if ema_stpe == 1: | |
| # current_weight = copy.deepcopy(self.model.module.state_dict()) | |
| ema_state_dict = copy.deepcopy( | |
| self.model.module.state_dict() if self. | |
| cfg['Global']['distributed'] else self.model. | |
| state_dict()) | |
| self.ema_model.load_state_dict(ema_state_dict) | |
| # if global_step > (epoch_num - epoch_num//10)*max_iter: | |
| elif loss_currn <= loss_avg or self.all_ema: | |
| # eval_batch_step = 500 | |
| current_weight = copy.deepcopy( | |
| self.model.module.state_dict() if self. | |
| cfg['Global']['distributed'] else self.model. | |
| state_dict()) | |
| k1 = 1 / (ema_stpe + 1) | |
| k2 = 1 - k1 | |
| for k, v in ema_state_dict.items(): | |
| # v = (v * (ema_stpe - 1) + current_weight[k])/ema_stpe | |
| v = v * k2 + current_weight[k] * k1 | |
| # v.req = True | |
| ema_state_dict[k] = v | |
| # ema_stpe += 1 | |
| self.ema_model.load_state_dict(ema_state_dict) | |
| ema_stpe += 1 | |
| if global_step > start_eval_step and ( | |
| global_step - | |
| start_eval_step) % eval_batch_step == 0: | |
| ema_cur_metric = self.eval_ema() | |
| ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" | |
| self.logger.info(ema_cur_metric_str) | |
| state = { | |
| 'epoch': epoch, | |
| 'global_step': global_step, | |
| 'state_dict': self.ema_model.state_dict(), | |
| 'optimizer': None, | |
| 'scheduler': None, | |
| 'config': self.cfg, | |
| 'metrics': ema_cur_metric, | |
| } | |
| save_path = os.path.join( | |
| self.cfg['Global']['output_dir'], | |
| 'ema_' + str(ema_eval_iter) + '.pth') | |
| torch.save(state, save_path) | |
| self.logger.info(f'save ema ckpt to {save_path}') | |
| ema_eval_iter += 1 | |
| if ema_cur_metric[self.eval_class. | |
| main_indicator] >= ema_best_metric[ | |
| self.eval_class.main_indicator]: | |
| ema_best_metric.update(ema_cur_metric) | |
| ema_best_metric['best_epoch'] = epoch | |
| best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
| self.logger.info(best_ema_str) | |
| # logger | |
| stats = { | |
| k: float(v) | |
| if v.shape == [] else v.detach().cpu().numpy().mean() | |
| for k, v in loss.items() | |
| } | |
| stats['lr'] = self.lr_scheduler.get_last_lr()[0] | |
| train_stats.update(stats) | |
| if self.writer is not None: | |
| for k, v in train_stats.get().items(): | |
| self.writer.add_scalar(f'TRAIN/{k}', v, global_step) | |
| if self.local_rank == 0 and ( | |
| (global_step > 0 and global_step % print_batch_step == 0) | |
| or (idx >= len(self.train_dataloader) - 1)): | |
| logs = train_stats.log() | |
| eta_sec = ( | |
| (epoch_num + 1 - epoch) * len(self.train_dataloader) - | |
| idx - 1) * eta_meter.avg | |
| eta_sec_format = str( | |
| datetime.timedelta(seconds=int(eta_sec))) | |
| strs = ( | |
| f'epoch: [{epoch}/{epoch_num}], global_step: {global_step}, {logs}, ' | |
| f'avg_reader_cost: {train_reader_cost / print_batch_step:.5f} s, ' | |
| f'avg_batch_cost: {train_batch_cost / print_batch_step:.5f} s, ' | |
| f'avg_samples: {total_samples / print_batch_step}, ' | |
| f'ips: {total_samples / train_batch_cost:.5f} samples/s, ' | |
| f'eta: {eta_sec_format}') | |
| self.logger.info(strs) | |
| total_samples = 0 | |
| train_reader_cost = 0.0 | |
| train_batch_cost = 0.0 | |
| reader_start = time.time() | |
| # eval | |
| if (global_step > start_eval_step and | |
| (global_step - start_eval_step) % eval_batch_step | |
| == 0) and self.local_rank == 0: | |
| cur_metric = self.eval() | |
| cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" | |
| self.logger.info(cur_metric_str) | |
| # logger metric | |
| if self.writer is not None: | |
| for k, v in cur_metric.items(): | |
| if isinstance(v, (float, int)): | |
| self.writer.add_scalar(f'EVAL/{k}', | |
| cur_metric[k], | |
| global_step) | |
| if (cur_metric[self.eval_class.main_indicator] >= | |
| best_metric[self.eval_class.main_indicator]): | |
| best_metric.update(cur_metric) | |
| best_metric['best_epoch'] = epoch | |
| if self.writer is not None: | |
| self.writer.add_scalar( | |
| f'EVAL/best_{self.eval_class.main_indicator}', | |
| best_metric[self.eval_class.main_indicator], | |
| global_step, | |
| ) | |
| if epoch > (epoch_num - epoch_num // 10 - 2): | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=True, | |
| prefix='best_' + str(best_iter)) | |
| best_iter += 1 | |
| # else: | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=True, | |
| prefix=None) | |
| best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
| self.logger.info(best_str) | |
| if self.local_rank == 0 and epoch > start_eval_epoch and ( | |
| epoch - start_eval_epoch) % eval_epoch_step == 0: | |
| cur_metric = self.eval() | |
| cur_metric_str = f"cur metric, {', '.join(['{}: {}'.format(k, v) for k, v in cur_metric.items()])}" | |
| self.logger.info(cur_metric_str) | |
| # logger metric | |
| if self.writer is not None: | |
| for k, v in cur_metric.items(): | |
| if isinstance(v, (float, int)): | |
| self.writer.add_scalar(f'EVAL/{k}', cur_metric[k], | |
| global_step) | |
| if (cur_metric[self.eval_class.main_indicator] >= | |
| best_metric[self.eval_class.main_indicator]): | |
| best_metric.update(cur_metric) | |
| best_metric['best_epoch'] = epoch | |
| if self.writer is not None: | |
| self.writer.add_scalar( | |
| f'EVAL/best_{self.eval_class.main_indicator}', | |
| best_metric[self.eval_class.main_indicator], | |
| global_step, | |
| ) | |
| if epoch > (epoch_num - epoch_num // 10 - 2): | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=True, | |
| prefix='best_' + str(best_iter)) | |
| best_iter += 1 | |
| # else: | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=True, | |
| prefix=None) | |
| best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
| self.logger.info(best_str) | |
| if self.local_rank == 0: | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=False, | |
| prefix=None) | |
| if epoch > (epoch_num - epoch_num // 10 - 2): | |
| save_ckpt(self.model, | |
| self.cfg, | |
| self.optimizer, | |
| self.lr_scheduler, | |
| epoch, | |
| global_step, | |
| best_metric, | |
| is_best=False, | |
| prefix='epoch_' + str(epoch)) | |
| if self.use_ema and epoch > (epoch_num - epoch_num // 10): | |
| # if global_step > start_eval_step and (global_step - start_eval_step) % eval_batch_step == 0: | |
| ema_cur_metric = self.eval_ema() | |
| ema_cur_metric_str = f"cur ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_cur_metric.items()])}" | |
| self.logger.info(ema_cur_metric_str) | |
| state = { | |
| 'epoch': epoch, | |
| 'global_step': global_step, | |
| 'state_dict': self.ema_model.state_dict(), | |
| 'optimizer': None, | |
| 'scheduler': None, | |
| 'config': self.cfg, | |
| 'metrics': ema_cur_metric, | |
| } | |
| save_path = os.path.join( | |
| self.cfg['Global']['output_dir'], | |
| 'ema_' + str(ema_eval_iter) + '.pth') | |
| torch.save(state, save_path) | |
| self.logger.info(f'save ema ckpt to {save_path}') | |
| ema_eval_iter += 1 | |
| if (ema_cur_metric[self.eval_class.main_indicator] >= | |
| ema_best_metric[self.eval_class.main_indicator]): | |
| ema_best_metric.update(ema_cur_metric) | |
| ema_best_metric['best_epoch'] = epoch | |
| # ema_cur_metric_str = f"best ema metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
| best_ema_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in ema_best_metric.items()])}" | |
| self.logger.info(best_ema_str) | |
| best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in best_metric.items()])}" | |
| self.logger.info(best_str) | |
| if self.writer is not None: | |
| self.writer.close() | |
| if torch.cuda.device_count() > 1: | |
| torch.distributed.destroy_process_group() | |
| def eval(self): | |
| self.model.eval() | |
| with torch.no_grad(): | |
| total_frame = 0.0 | |
| total_time = 0.0 | |
| pbar = tqdm( | |
| total=len(self.valid_dataloader), | |
| desc='eval model:', | |
| position=0, | |
| leave=True, | |
| ) | |
| sum_images = 0 | |
| for idx, batch in enumerate(self.valid_dataloader): | |
| batch = [t.to(self.device) for t in batch] | |
| start = time.time() | |
| if self.scaler: | |
| with torch.cuda.amp.autocast(): | |
| preds = self.model(batch[0], data=batch[1:]) | |
| else: | |
| preds = self.model(batch[0], data=batch[1:]) | |
| total_time += time.time() - start | |
| # Obtain usable results from post-processing methods | |
| # Evaluate the results of the current batch | |
| post_result = self.post_process_class(preds, batch) | |
| self.eval_class(post_result, batch) | |
| pbar.update(1) | |
| total_frame += len(batch[0]) | |
| sum_images += 1 | |
| # Get final metric,eg. acc or hmean | |
| metric = self.eval_class.get_metric() | |
| pbar.close() | |
| self.model.train() | |
| metric['fps'] = total_frame / total_time | |
| return metric | |
| def eval_ema(self): | |
| # self.model.eval() | |
| with torch.no_grad(): | |
| total_frame = 0.0 | |
| total_time = 0.0 | |
| pbar = tqdm( | |
| total=len(self.valid_dataloader), | |
| desc='eval ema_model:', | |
| position=0, | |
| leave=True, | |
| ) | |
| sum_images = 0 | |
| for idx, batch in enumerate(self.valid_dataloader): | |
| batch = [t.to(self.device) for t in batch] | |
| start = time.time() | |
| if self.scaler: | |
| with torch.cuda.amp.autocast(): | |
| preds = self.ema_model(batch[0], data=batch[1:]) | |
| else: | |
| preds = self.ema_model(batch[0], data=batch[1:]) | |
| total_time += time.time() - start | |
| # Obtain usable results from post-processing methods | |
| # Evaluate the results of the current batch | |
| post_result = self.post_process_class(preds, batch) | |
| self.eval_class(post_result, batch) | |
| pbar.update(1) | |
| total_frame += len(batch[0]) | |
| sum_images += 1 | |
| # Get final metric,eg. acc or hmean | |
| metric = self.eval_class.get_metric() | |
| pbar.close() | |
| # self.model.train() | |
| metric['fps'] = total_frame / total_time | |
| return metric | |
| def test_dataloader(self): | |
| starttime = time.time() | |
| count = 0 | |
| try: | |
| for data in self.train_dataloader: | |
| count += 1 | |
| if count % 1 == 0: | |
| batch_time = time.time() - starttime | |
| starttime = time.time() | |
| self.logger.info( | |
| f'reader: {count}, {data[0].shape}, {batch_time}') | |
| except: | |
| import traceback | |
| self.logger.info(traceback.format_exc()) | |
| self.logger.info(f'finish reader: {count}, Success!') | |