sisr2onnx / test_scripts /test_discriminator_backward.py
Zarxrax's picture
Upload 823 files
62dbcfb verified
raw
history blame
3.13 kB
import copy
import random
import torch
from torch import Tensor, nn
class ToyDiscriminator(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv0 = nn.Conv2d(3, 4, 3, 1, 1, bias=True)
self.bn0 = nn.BatchNorm2d(4, affine=True)
self.conv1 = nn.Conv2d(4, 4, 3, 1, 1, bias=True)
self.bn1 = nn.BatchNorm2d(4, affine=True)
self.linear = nn.Linear(4 * 6 * 6, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x: Tensor) -> Tensor:
feat = self.lrelu(self.bn0(self.conv0(x)))
feat = self.lrelu(self.bn1(self.conv1(feat)))
feat = feat.view(feat.size(0), -1)
out = torch.sigmoid(self.linear(feat))
return out
def main() -> None:
# use fixed random seed
manual_seed = 999
random.seed(manual_seed)
torch.manual_seed(manual_seed)
img_real = torch.rand((1, 3, 6, 6))
img_fake = torch.rand((1, 3, 6, 6))
net_d_1 = ToyDiscriminator()
net_d_2 = copy.deepcopy(net_d_1)
net_d_1.train()
net_d_2.train()
criterion = nn.BCELoss()
real_label = 1
fake_label = 0
for k, v in net_d_1.named_parameters():
print(k, v.size())
###########################
# (1) Backward D network twice as the official tutorial does:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
###########################
net_d_1.zero_grad()
# real
output = net_d_1(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
loss_real.backward()
# fake
output = net_d_1(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss_fake.backward()
###########################
# (2) Backward D network once
###########################
net_d_2.zero_grad()
# real
output = net_d_2(img_real).view(-1)
label = output.new_ones(output.size()) * real_label
loss_real = criterion(output, label)
# fake
output = net_d_2(img_fake).view(-1)
label = output.new_ones(output.size()) * fake_label
loss_fake = criterion(output, label)
loss = loss_real + loss_fake
loss.backward()
###########################
# Compare differences
###########################
for k1, k2 in zip(net_d_1.parameters(), net_d_2.parameters(), strict=False):
assert k1.grad is not None
assert k2.grad is not None
print(torch.sum(torch.abs(k1.grad - k2.grad)))
if __name__ == "__main__":
main()
r"""Output:
conv0.weight torch.Size([4, 3, 3, 3])
conv0.bias torch.Size([4])
bn0.weight torch.Size([4])
bn0.bias torch.Size([4])
conv1.weight torch.Size([4, 4, 3, 3])
conv1.bias torch.Size([4])
bn1.weight torch.Size([4])
bn1.bias torch.Size([4])
linear.weight torch.Size([1, 144])
linear.bias torch.Size([1])
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
tensor(0.)
"""