import time import numpy as np import torch from sklearn import metrics from sklearn.metrics import _classification from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import os class Solver(object): def __init__(self, data_loader, test_data_loader, model, criterion, optimizer, lr_scheduler, epochs): self.save_model_path = '/ailab-train/speech/shansizhe/audeo/models/Video2Roll_50_0.4/' # change to your path self.test_loader = test_data_loader self.data_loader = data_loader self.net = model self.criterion = criterion self.optimizer = optimizer self.lr_scheduler = lr_scheduler # Training config self.epochs = epochs # logging self.step = 0 self.global_step = 0 self.writer = SummaryWriter(log_dir='/ailab-train/speech/shansizhe/audeo/log/50_0.4/') # visualizing loss using visdom self.tr_loss = torch.Tensor(self.epochs) self.val_loss = torch.zeros(self.epochs) self.visdom = False self.visdom_epoch = 1 self.visdom_id = 'key classification' if self.visdom: from visdom import Visdom self.vis = Visdom(env=self.visdom_id) self.vis_opts = dict(title=self.visdom_id, ylabel='Loss', xlabel='Epoch', legend=['train loss', 'val loss']) self.vis_window = None self.vis_epochs = torch.arange(1, self.epochs + 1) def train(self): # Train model multi-epoches pre_val_loss = 1e4 for epoch in tqdm(range(self.epochs)): print("Training...") self.net.train() # Turn on BatchNorm & Dropout start = time.time() # training loop tr_avg_loss, tr_avg_precision, tr_avg_recall = self.train_loop() # evaluate self.net.eval() val_avg_loss, val_avg_precision, val_avg_recall, val_avg_acc, val_fscore = self.validate() print('-' * 85) print('Train Summary | Epoch {0} | Time {1:.2f}s | ' 'Train Loss {2:.3f}'.format( epoch+1, time.time() - start, tr_avg_loss, tr_avg_precision, tr_avg_recall)) print("epoch {0} validation loss:{1:.3f} | avg precision:{2:.3f} | avg recall:{3:.3f} | avg acc:{4:.3f} | f1 score:{5:.3f}".format( epoch+1, val_avg_loss, val_avg_precision, val_avg_recall, val_avg_acc, val_fscore)) print('-' * 85) # Log metrics to TensorBoard self.writer.add_scalar('Loss/train', tr_avg_loss, epoch) self.writer.add_scalar('Precision/train', tr_avg_precision, epoch) self.writer.add_scalar('Recall/train', tr_avg_recall, epoch) self.writer.add_scalar('Loss/val', val_avg_loss, epoch) self.writer.add_scalar('Precision/val', val_avg_precision, epoch) self.writer.add_scalar('Recall/val', val_avg_recall, epoch) self.writer.add_scalar('Accuracy/val', val_avg_acc, epoch) self.writer.add_scalar('F1_score/val', val_fscore, epoch) os.makedirs(self.save_model_path, exist_ok=True) model_save_path = f"{self.save_model_path}{epoch}.pth" torch.save(self.net.state_dict(), model_save_path) if val_avg_loss < pre_val_loss: pre_val_loss = val_avg_loss torch.save(self.net.state_dict(), f"{self.save_model_path}best.pth") # Save model each epoch self.val_loss[epoch] = val_avg_loss self.tr_loss[epoch] = tr_avg_loss # visualizing loss using visdom if self.visdom: x_axis = self.vis_epochs[0:epoch + 1] # train_y_axis = self.tr_loss[0:epoch+1] # val_x_axis = self.vis_epochs[0:epoch+1:10] # val_y_axis = self.val_loss[0:epoch//10+1] y_axis = torch.stack( (self.tr_loss[0:epoch + 1], self.val_loss[0:epoch + 1]), dim=1) if self.vis_window is None: self.vis_window = self.vis.line( X=x_axis, Y=y_axis, opts=self.vis_opts, ) else: self.vis.line( X=x_axis.unsqueeze(0).expand(y_axis.size( 1), x_axis.size(0)).transpose(0, 1), # Visdom fix Y=y_axis, win=self.vis_window, update='replace', ) def train_loop(self): data_loader = self.data_loader epoch_loss = 0 epoch_precision = 0 epoch_recall = 0 count = 0 start = time.time() for i, data in tqdm(enumerate(data_loader)): imgs, label = data logits = self.net(imgs) loss = self.criterion(logits,label) # set the threshold of the logits pred_label = torch.sigmoid(logits) >= 0.4 numpy_label = label.cpu().detach().numpy().astype(int) numpy_pre_label = pred_label.cpu().detach().numpy().astype(int) precision = metrics.precision_score(numpy_label,numpy_pre_label, average='samples', zero_division=1) recall = metrics.recall_score(numpy_label,numpy_pre_label, average='samples', zero_division=1) self.writer.add_scalar('loss/step', loss, self.global_step) self.writer.add_scalar('precision/step', precision, self.global_step) self.writer.add_scalar('recall/step', recall, self.global_step) if self.global_step % 100 == 0: end = time.time() print( "step {0} loss:{1:.4f} | precision:{2:.3f} | recall:{3:.3f} | time:{4:.2f}".format(self.global_step, loss.item(), precision, recall,end - start)) start = end epoch_precision += precision epoch_recall += recall epoch_loss += loss.item() self.optimizer.zero_grad() loss.backward() self.optimizer.step() count += 1 self.global_step += 1 self.lr_scheduler.step(epoch_loss / count) return epoch_loss/count, epoch_precision/count, epoch_recall/count def validate(self): epoch_loss = 0 count = 0 all_pred_label = [] all_label = [] with torch.no_grad(): for i, data in enumerate(self.test_loader): imgs, label = data logits = self.net(imgs) loss = self.criterion(logits, label) pred_label = torch.sigmoid(logits) >= 0.4 numpy_label = label.cpu().detach().numpy().astype(int) numpy_pre_label = pred_label.cpu().detach().numpy().astype(int) all_label.append(numpy_label) all_pred_label.append(numpy_pre_label) epoch_loss += loss.item() count += 1 all_label = np.vstack(all_label) all_pred_label = np.vstack(all_pred_label) labels = _classification._check_set_wise_labels(all_label, all_pred_label,labels=None, pos_label=1, average='samples') MCM = metrics.multilabel_confusion_matrix(all_label, all_pred_label,sample_weight=None, labels=labels, samplewise=True) tp_sum = MCM[:, 1, 1] fp_sum = MCM[:, 0, 1] fn_sum = MCM[:, 1, 0] # tn_sum = MCM[:, 0, 0] accuracy = _prf_divide(tp_sum, tp_sum+fp_sum+fn_sum, zero_division=1) accuracy = np.average(accuracy) all_precision = metrics.precision_score(all_label, all_pred_label, average='samples', zero_division=1) all_recall = metrics.recall_score(all_label, all_pred_label, average='samples', zero_division=1) all_f1_score = metrics.f1_score(all_label, all_pred_label, average='samples', zero_division=1) return epoch_loss/count, all_precision, all_recall, accuracy, all_f1_score def _prf_divide(numerator, denominator, zero_division="warn"): """Performs division and handles divide-by-zero. On zero-division, sets the corresponding result elements equal to 0 or 1 (according to ``zero_division``). Plus, if ``zero_division != "warn"`` raises a warning. The metric, modifier and average arguments are used only for determining an appropriate warning. """ mask = denominator == 0.0 denominator = denominator.copy() denominator[mask] = 1 # avoid infs/nans result = numerator / denominator if not np.any(mask): return result # if ``zero_division=1``, set those with denominator == 0 equal to 1 result[mask] = 0.0 if zero_division in ["warn", 0] else 1.0 # the user will be removing warnings if zero_division is set to something # different than its default value. If we are computing only f-score # the warning will be raised only if precision and recall are ill-defined if zero_division != "warn": return result