Video-to-Audio-and-Piano / src /audeo /Roll2MidiNet_enhance.py
lshzhm's picture
Upload 141 files
1991049 verified
raw
history blame
5.34 kB
import torch.nn as nn
import torch.nn.functional as F
import torch
##############################
# U-NET
##############################
class UNetDown(nn.Module):
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
super(UNetDown, self).__init__()
model = [nn.Conv2d(in_size, out_size, 3, stride=1, padding=1, bias=False)]
if normalize:
model.append(nn.BatchNorm2d(out_size, 0.8))
model.append(nn.LeakyReLU(0.2))
if dropout:
model.append(nn.Dropout(dropout))
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class UNetUp(nn.Module):
def __init__(self, in_size, out_size, dropout=0.0):
super(UNetUp, self).__init__()
model = [
nn.ConvTranspose2d(in_size, out_size, 3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_size, 0.8),
nn.ReLU(inplace=True),
]
if dropout:
model.append(nn.Dropout(dropout))
self.model = nn.Sequential(*model)
def forward(self, x, skip_input):
x = self.model(x)
out = torch.cat((x, skip_input), 1)
return out
class AttentionGate(nn.Module):
def __init__(self, in_channels, g_channels, out_channels):
super(AttentionGate, self).__init__()
self.theta_x = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.phi_g = nn.Conv2d(g_channels, out_channels, kernel_size=1)
self.psi = nn.Conv2d(out_channels, 1, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x, g):
theta_x = self.theta_x(x)
phi_g = self.phi_g(g)
f = theta_x + phi_g
f = self.psi(f)
alpha = self.sigmoid(f)
return x * alpha
class Generator(nn.Module):
def __init__(self, input_shape):
super(Generator, self).__init__()
channels, _ , _ = input_shape
self.down1 = UNetDown(channels, 64, normalize=False)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256, dropout=0.5)
self.down4 = UNetDown(256, 512, dropout=0.5)
self.down5 = UNetDown(512, 1024, dropout=0.5)
self.down6 = UNetDown(1024, 1024, dropout=0.5)
# Attention Gates
self.att1 = AttentionGate(2048, 1024, 512)
self.att2 = AttentionGate(1024, 512, 256)
self.att3 = AttentionGate(512, 256, 128)
self.att4 = AttentionGate(256, 128, 64)
self.up1 = UNetUp(1024, 1024, dropout=0.5)
self.up2 = UNetUp(2048, 512, dropout=0.5)
self.up3 = UNetUp(1024, 256, dropout=0.5)
self.up4 = UNetUp(512, 128)
self.up5 = UNetUp(256, 64)
self.conv1d = nn.Conv2d(128, 1, kernel_size=1)
def forward(self, x):
# U-Net generator with skip connections from encoder to decoder
d1 = self.down1(x)
d2 = self.down2(d1)
d3 = self.down3(d2)
d4 = self.down4(d3)
d5 = self.down5(d4)
d6 = self.down6(d5)
u1 = self.up1(d6, d5)
u1 = self.att1(u1, d5)
u2 = self.up2(u1, d4)
u2 = self.att2(u2, d4)
u3 = self.up3(u2, d3)
u3 = self.att3(u3, d3)
u4 = self.up4(u3, d2)
u4 = self.att4(u4, d2)
u5 = self.up5(u4, d1)
out = self.conv1d(u5)
out = F.sigmoid(out)
return out
class Discriminator(nn.Module):
def __init__(self, input_shape):
super(Discriminator, self).__init__()
channels, height, width = input_shape #1 51 50
# Calculate output of image discriminator (PatchGAN)
patch_h, patch_w = int(height / 2 ** 3)+1, int(width / 2 ** 3)+1
self.output_shape = (1, patch_h, patch_w)
def discriminator_block(in_filters, out_filters, stride, normalize):
"""Returns layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
if normalize:
layers.append(nn.InstanceNorm2d(out_filters))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
layers = []
in_filters = channels
for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
in_filters = out_filters
layers.append(nn.Conv2d(out_filters, 1, 3, 1, 1))
self.model = nn.Sequential(*layers)
def forward(self, img):
return self.model(img)
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm2d") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
if __name__ == "__main__":
input_shape = (1,51, 100)
gnet = Generator(input_shape)
dnet = Discriminator(input_shape)
print(dnet.output_shape)
imgs = torch.rand((64,1,51,100))
gen = gnet(imgs)
print(gen.shape)
dis = dnet(gen)
print(dis.shape)