|
import copy
|
|
import random
|
|
|
|
import torch
|
|
from torch import Tensor, nn
|
|
|
|
|
|
class ToyDiscriminator(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv0 = nn.Conv2d(3, 4, 3, 1, 1, bias=True)
|
|
self.bn0 = nn.BatchNorm2d(4, affine=True)
|
|
self.conv1 = nn.Conv2d(4, 4, 3, 1, 1, bias=True)
|
|
self.bn1 = nn.BatchNorm2d(4, affine=True)
|
|
self.linear = nn.Linear(4 * 6 * 6, 1)
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
feat = self.lrelu(self.bn0(self.conv0(x)))
|
|
feat = self.lrelu(self.bn1(self.conv1(feat)))
|
|
feat = feat.view(feat.size(0), -1)
|
|
out = torch.sigmoid(self.linear(feat))
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
|
|
manual_seed = 999
|
|
random.seed(manual_seed)
|
|
torch.manual_seed(manual_seed)
|
|
|
|
img_real = torch.rand((1, 3, 6, 6))
|
|
img_fake = torch.rand((1, 3, 6, 6))
|
|
net_d_1 = ToyDiscriminator()
|
|
net_d_2 = copy.deepcopy(net_d_1)
|
|
net_d_1.train()
|
|
net_d_2.train()
|
|
|
|
criterion = nn.BCELoss()
|
|
real_label = 1
|
|
fake_label = 0
|
|
|
|
for k, v in net_d_1.named_parameters():
|
|
print(k, v.size())
|
|
|
|
|
|
|
|
|
|
|
|
net_d_1.zero_grad()
|
|
|
|
output = net_d_1(img_real).view(-1)
|
|
label = output.new_ones(output.size()) * real_label
|
|
loss_real = criterion(output, label)
|
|
loss_real.backward()
|
|
|
|
output = net_d_1(img_fake).view(-1)
|
|
label = output.new_ones(output.size()) * fake_label
|
|
loss_fake = criterion(output, label)
|
|
loss_fake.backward()
|
|
|
|
|
|
|
|
|
|
net_d_2.zero_grad()
|
|
|
|
output = net_d_2(img_real).view(-1)
|
|
label = output.new_ones(output.size()) * real_label
|
|
loss_real = criterion(output, label)
|
|
|
|
output = net_d_2(img_fake).view(-1)
|
|
label = output.new_ones(output.size()) * fake_label
|
|
loss_fake = criterion(output, label)
|
|
|
|
loss = loss_real + loss_fake
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
for k1, k2 in zip(net_d_1.parameters(), net_d_2.parameters(), strict=False):
|
|
assert k1.grad is not None
|
|
assert k2.grad is not None
|
|
print(torch.sum(torch.abs(k1.grad - k2.grad)))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
r"""Output:
|
|
conv0.weight torch.Size([4, 3, 3, 3])
|
|
conv0.bias torch.Size([4])
|
|
bn0.weight torch.Size([4])
|
|
bn0.bias torch.Size([4])
|
|
conv1.weight torch.Size([4, 4, 3, 3])
|
|
conv1.bias torch.Size([4])
|
|
bn1.weight torch.Size([4])
|
|
bn1.bias torch.Size([4])
|
|
linear.weight torch.Size([1, 144])
|
|
linear.bias torch.Size([1])
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
tensor(0.)
|
|
"""
|
|
|