Video-to-Audio-and-Piano / src /audeo /Video2Roll_solver.py
lshzhm's picture
Upload 141 files
1991049 verified
raw
history blame
9.45 kB
import time
import numpy as np
import torch
from sklearn import metrics
from sklearn.metrics import _classification
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import os
class Solver(object):
def __init__(self, data_loader, test_data_loader, model, criterion, optimizer, lr_scheduler, epochs):
self.save_model_path = '/ailab-train/speech/shansizhe/audeo/models/Video2Roll_50_0.4/' # change to your path
self.test_loader = test_data_loader
self.data_loader = data_loader
self.net = model
self.criterion = criterion
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
# Training config
self.epochs = epochs
# logging
self.step = 0
self.global_step = 0
self.writer = SummaryWriter(log_dir='/ailab-train/speech/shansizhe/audeo/log/50_0.4/')
# visualizing loss using visdom
self.tr_loss = torch.Tensor(self.epochs)
self.val_loss = torch.zeros(self.epochs)
self.visdom = False
self.visdom_epoch = 1
self.visdom_id = 'key classification'
if self.visdom:
from visdom import Visdom
self.vis = Visdom(env=self.visdom_id)
self.vis_opts = dict(title=self.visdom_id,
ylabel='Loss', xlabel='Epoch',
legend=['train loss', 'val loss'])
self.vis_window = None
self.vis_epochs = torch.arange(1, self.epochs + 1)
def train(self):
# Train model multi-epoches
pre_val_loss = 1e4
for epoch in tqdm(range(self.epochs)):
print("Training...")
self.net.train() # Turn on BatchNorm & Dropout
start = time.time()
# training loop
tr_avg_loss, tr_avg_precision, tr_avg_recall = self.train_loop()
# evaluate
self.net.eval()
val_avg_loss, val_avg_precision, val_avg_recall, val_avg_acc, val_fscore = self.validate()
print('-' * 85)
print('Train Summary | Epoch {0} | Time {1:.2f}s | '
'Train Loss {2:.3f}'.format(
epoch+1, time.time() - start, tr_avg_loss, tr_avg_precision, tr_avg_recall))
print("epoch {0} validation loss:{1:.3f} | avg precision:{2:.3f} | avg recall:{3:.3f} | avg acc:{4:.3f} | f1 score:{5:.3f}".format(
epoch+1, val_avg_loss, val_avg_precision, val_avg_recall, val_avg_acc, val_fscore))
print('-' * 85)
# Log metrics to TensorBoard
self.writer.add_scalar('Loss/train', tr_avg_loss, epoch)
self.writer.add_scalar('Precision/train', tr_avg_precision, epoch)
self.writer.add_scalar('Recall/train', tr_avg_recall, epoch)
self.writer.add_scalar('Loss/val', val_avg_loss, epoch)
self.writer.add_scalar('Precision/val', val_avg_precision, epoch)
self.writer.add_scalar('Recall/val', val_avg_recall, epoch)
self.writer.add_scalar('Accuracy/val', val_avg_acc, epoch)
self.writer.add_scalar('F1_score/val', val_fscore, epoch)
os.makedirs(self.save_model_path, exist_ok=True)
model_save_path = f"{self.save_model_path}{epoch}.pth"
torch.save(self.net.state_dict(), model_save_path)
if val_avg_loss < pre_val_loss:
pre_val_loss = val_avg_loss
torch.save(self.net.state_dict(), f"{self.save_model_path}best.pth")
# Save model each epoch
self.val_loss[epoch] = val_avg_loss
self.tr_loss[epoch] = tr_avg_loss
# visualizing loss using visdom
if self.visdom:
x_axis = self.vis_epochs[0:epoch + 1]
# train_y_axis = self.tr_loss[0:epoch+1]
# val_x_axis = self.vis_epochs[0:epoch+1:10]
# val_y_axis = self.val_loss[0:epoch//10+1]
y_axis = torch.stack(
(self.tr_loss[0:epoch + 1], self.val_loss[0:epoch + 1]), dim=1)
if self.vis_window is None:
self.vis_window = self.vis.line(
X=x_axis,
Y=y_axis,
opts=self.vis_opts,
)
else:
self.vis.line(
X=x_axis.unsqueeze(0).expand(y_axis.size(
1), x_axis.size(0)).transpose(0, 1), # Visdom fix
Y=y_axis,
win=self.vis_window,
update='replace',
)
def train_loop(self):
data_loader = self.data_loader
epoch_loss = 0
epoch_precision = 0
epoch_recall = 0
count = 0
start = time.time()
for i, data in tqdm(enumerate(data_loader)):
imgs, label = data
logits = self.net(imgs)
loss = self.criterion(logits,label)
# set the threshold of the logits
pred_label = torch.sigmoid(logits) >= 0.4
numpy_label = label.cpu().detach().numpy().astype(int)
numpy_pre_label = pred_label.cpu().detach().numpy().astype(int)
precision = metrics.precision_score(numpy_label,numpy_pre_label, average='samples', zero_division=1)
recall = metrics.recall_score(numpy_label,numpy_pre_label, average='samples', zero_division=1)
self.writer.add_scalar('loss/step', loss, self.global_step)
self.writer.add_scalar('precision/step', precision, self.global_step)
self.writer.add_scalar('recall/step', recall, self.global_step)
if self.global_step % 100 == 0:
end = time.time()
print(
"step {0} loss:{1:.4f} | precision:{2:.3f} | recall:{3:.3f} | time:{4:.2f}".format(self.global_step, loss.item(), precision,
recall,end - start))
start = end
epoch_precision += precision
epoch_recall += recall
epoch_loss += loss.item()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
count += 1
self.global_step += 1
self.lr_scheduler.step(epoch_loss / count)
return epoch_loss/count, epoch_precision/count, epoch_recall/count
def validate(self):
epoch_loss = 0
count = 0
all_pred_label = []
all_label = []
with torch.no_grad():
for i, data in enumerate(self.test_loader):
imgs, label = data
logits = self.net(imgs)
loss = self.criterion(logits, label)
pred_label = torch.sigmoid(logits) >= 0.4
numpy_label = label.cpu().detach().numpy().astype(int)
numpy_pre_label = pred_label.cpu().detach().numpy().astype(int)
all_label.append(numpy_label)
all_pred_label.append(numpy_pre_label)
epoch_loss += loss.item()
count += 1
all_label = np.vstack(all_label)
all_pred_label = np.vstack(all_pred_label)
labels = _classification._check_set_wise_labels(all_label, all_pred_label,labels=None, pos_label=1, average='samples')
MCM = metrics.multilabel_confusion_matrix(all_label, all_pred_label,sample_weight=None, labels=labels, samplewise=True)
tp_sum = MCM[:, 1, 1]
fp_sum = MCM[:, 0, 1]
fn_sum = MCM[:, 1, 0]
# tn_sum = MCM[:, 0, 0]
accuracy = _prf_divide(tp_sum, tp_sum+fp_sum+fn_sum, zero_division=1)
accuracy = np.average(accuracy)
all_precision = metrics.precision_score(all_label, all_pred_label, average='samples', zero_division=1)
all_recall = metrics.recall_score(all_label, all_pred_label, average='samples', zero_division=1)
all_f1_score = metrics.f1_score(all_label, all_pred_label, average='samples', zero_division=1)
return epoch_loss/count, all_precision, all_recall, accuracy, all_f1_score
def _prf_divide(numerator, denominator, zero_division="warn"):
"""Performs division and handles divide-by-zero.
On zero-division, sets the corresponding result elements equal to
0 or 1 (according to ``zero_division``). Plus, if
``zero_division != "warn"`` raises a warning.
The metric, modifier and average arguments are used only for determining
an appropriate warning.
"""
mask = denominator == 0.0
denominator = denominator.copy()
denominator[mask] = 1 # avoid infs/nans
result = numerator / denominator
if not np.any(mask):
return result
# if ``zero_division=1``, set those with denominator == 0 equal to 1
result[mask] = 0.0 if zero_division in ["warn", 0] else 1.0
# the user will be removing warnings if zero_division is set to something
# different than its default value. If we are computing only f-score
# the warning will be raised only if precision and recall are ill-defined
if zero_division != "warn":
return result