import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.utils import spectral_norm


class SpectralConv2d(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__()
        self._conv = spectral_norm(
                nn.Conv2d(*args, **kwargs)
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._conv(input)


class SpectralConvTranspose2d(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__()
        self._conv = spectral_norm(
                nn.ConvTranspose2d(*args, **kwargs)
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._conv(input)


class Noise(nn.Module):

    def __init__(self):
        super().__init__()
        self._weight = nn.Parameter(
                torch.zeros(1),
                requires_grad=True,
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        batch_size, _, height, width = input.shape
        noise = torch.randn(batch_size, 1, height, width, device=input.device)
        return self._weight * noise + input


class InitLayer(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers = nn.Sequential(
                SpectralConvTranspose2d(
                        in_channels=in_channels,
                        out_channels=out_channels * 2,
                        kernel_size=4,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels * 2),
                nn.GLU(dim=1),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._layers(input)


class SLEBlock(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers = nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=4),
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=4,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
                nn.SiLU(),
                SpectralConv2d(
                        in_channels=out_channels,
                        out_channels=out_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
                nn.Sigmoid(),
            )

    def forward(self, low_dim: torch.Tensor,
                      high_dim: torch.Tensor) -> torch.Tensor:
        return high_dim * self._layers(low_dim)


class UpsampleBlockT1(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels * 2,
                        kernel_size=3,
                        stride=1,
                        padding='same',
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels * 2),
                nn.GLU(dim=1),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._layers(input)


class UpsampleBlockT2(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers = nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels * 2,
                        kernel_size=3,
                        stride=1,
                        padding='same',
                        bias=False,
                    ),
                Noise(),
                BatchNorm2d(num_features=out_channels * 2),
                nn.GLU(dim=1),
                SpectralConv2d(
                        in_channels=out_channels,
                        out_channels=out_channels * 2,
                        kernel_size=3,
                        stride=1,
                        padding='same',
                        bias=False,
                    ),
                Noise(),
                nn.BatchNorm2d(num_features=out_channels * 2),
                nn.GLU(dim=1),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._layers(input)


class DownsampleBlockT1(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers = nn.Sequential(
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU(negative_slope=0.2),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._layers(input)


class DownsampleBlockT2(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._layers_1 = nn.Sequential(
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=4,
                        stride=2,
                        padding=1,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU(negative_slope=0.2),
                SpectralConv2d(
                        in_channels=out_channels,
                        out_channels=out_channels,
                        kernel_size=3,
                        stride=1,
                        padding='same',
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU(negative_slope=0.2),
            )

        self._layers_2 = nn.Sequential(
                nn.AvgPool2d(
                        kernel_size=2,
                        stride=2,
                    ),
                SpectralConv2d(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=1,
                        stride=1,
                        padding=0,
                        bias=False,
                    ),
                nn.BatchNorm2d(num_features=out_channels),
                nn.LeakyReLU(negative_slope=0.2),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        t1 = self._layers_1(input)
        t2 = self._layers_2(input)
        return (t1 + t2) / 2


class Decoder(nn.Module):

    def __init__(self, in_channels: int,
                       out_channels: int):
        super().__init__()

        self._channels = {
                16:   128,
                32:   64,
                64:   64,
                128:  32,
                256:  16,
                512:  8,
                1024: 4,
            }

        self._layers = nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=8),
                UpsampleBlockT1(in_channels=in_channels,        out_channels=self._channels[16]),
                UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]),
                UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]),
                UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]),
                SpectralConv2d(
                        in_channels=self._channels[128],
                        out_channels=out_channels,
                        kernel_size=3,
                        stride=1,
                        padding='same',
                        bias=False,
                    ),
                nn.Tanh(),
            )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._layers(input)