import numpy as np
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform

class KAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, nb_hidden=300, nb_active=16):
        super().__init__()
        self.nb_hidden = nb_hidden
        self.nb_active = nb_active
        self.encode = nn.Sequential(
            nn.Linear(w*h*c, nb_hidden, bias=False)
        )
        self.bias = nn.Parameter(torch.zeros(w*h*c))
        self.params = nn.ParameterList([self.bias])
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        X = X.view(X.size(0), -1)
        h = self.encode(X)
        Xr, _ = self.decode(h)
        Xr = Xr.view(size)
        return Xr
    
    def decode(self, h):
        thetas, _ = torch.sort(h, dim=1, descending=True)
        thetas = thetas[:, self.nb_active:self.nb_active+1]
        h = h * (h > thetas).float()
        Xr = torch.matmul(h, self.encode[0].weight) + self.bias
        Xr = nn.Sigmoid()(Xr)
        return Xr, h


class ZAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, nb_hidden=300, theta=1):
        super().__init__()
        self.nb_hidden = nb_hidden
        self.theta = theta
        self.encode = nn.Sequential(
            nn.Linear(w*h*c, nb_hidden, bias=False)
        )
        self.bias = nn.Parameter(torch.zeros(w*h*c))
        self.params = nn.ParameterList([self.bias])
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        X = X.view(X.size(0), -1)
        h = self.encode(X)
        Xr, _ = self.decode(h)
        Xr = Xr.view(size)
        return Xr
    
    def decode(self, h):
        h  = h * (h > self.theta).float()
        Xr = torch.matmul(h, self.encode[0].weight) + self.bias
        Xr = nn.Sigmoid()(Xr)
        return Xr, h



class DenseAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, encode_hidden=(300,), decode_hidden=(300,), ksparse=True, nb_active=10, denoise=None):
        super().__init__()
        self.encode_hidden = encode_hidden
        self.decode_hidden = decode_hidden
        self.ksparse = ksparse
        self.nb_active = nb_active
        self.denoise = denoise
        
        # encode layers
        layers = []
        hid_prev = w * h * c
        for hid in encode_hidden:
            layers.extend([
                nn.Linear(hid_prev, hid),
                nn.ReLU(True)
            ])
            hid_prev = hid
        self.encode = nn.Sequential(*layers)
        
        # decode layers
        layers = []
        for hid in decode_hidden:
            layers.extend([
                nn.Linear(hid_prev, hid),
                nn.ReLU(True)
            ])
            hid_prev = hid
        layers.extend([
            nn.Linear(hid_prev, w * h * c),
            nn.Sigmoid()
        ])
        self.decode = nn.Sequential(*layers)
        
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        if self.denoise is not None:
            X = X * ((torch.rand(X.size()) <= self.denoise).float()).to(X.device)
        X = X.view(X.size(0), -1)
        h = self.encode(X)
        if self.ksparse:
            h = ksparse(h, nb_active=self.nb_active)
        Xr = self.decode(h)
        Xr = Xr.view(size)
        return Xr



def ksparse(x, nb_active=10):
    mask = torch.ones(x.size())
    for i, xi in enumerate(x.data.tolist()):
        inds = np.argsort(xi)
        inds = inds[::-1]
        inds = inds[nb_active:]
        if len(inds):
            inds = np.array(inds)
            inds = torch.from_numpy(inds).long()
            mask[i][inds] = 0
    return x * (mask).float().to(x.device)


class ConvAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
        super().__init__()
        self.spatial = spatial
        self.channel = channel
        self.channel_stride = channel_stride
        self.encode = nn.Sequential(
            nn.Conv2d(c, nb_filters, 5, 1, 0),
            nn.ReLU(True),
            nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
            nn.ReLU(True),
            nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
        )
        self.decode = nn.Sequential(
            nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
            nn.Sigmoid()
        )
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        h = self.encode(X)
        h = self.sparsify(h)
        Xr = self.decode(h)
        return Xr
    
    def sparsify(self, h):
        if self.spatial:
            h = spatial_sparsity(h)
        if self.channel:
            h = strided_channel_sparsity(h, stride=self.channel_stride)
        return h

class SimpleConvAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, nb_filters=64, spatial=True, channel=True, channel_stride=4):
        super().__init__()
        self.spatial = spatial
        self.channel = channel
        self.channel_stride = channel_stride
        self.encode = nn.Sequential(
            nn.Conv2d(c, nb_filters, 13, 1, 0),
            nn.ReLU(True),
        )
        self.decode = nn.Sequential(
            nn.ConvTranspose2d(nb_filters, c, 13, 1, 0),
            nn.Sigmoid()
        )
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        h = self.encode(X)
        h = self.sparsify(h)
        Xr = self.decode(h)
        return Xr
    
    def sparsify(self, h):
        if self.spatial:
            h = spatial_sparsity(h)
        if self.channel:
            h = strided_channel_sparsity(h, stride=self.channel_stride)
        return h

class DeepConvAE(nn.Module):

    def __init__(self, w=32, h=32, c=1, nb_filters=64, nb_layers=3, spatial=True, channel=True, channel_stride=4):
        super().__init__()
        self.spatial = spatial
        self.channel = channel
        self.channel_stride = channel_stride

        layers = [
            nn.Conv2d(c, nb_filters, 5, 1, 0),
            nn.ReLU(True),
        ]
        for _ in range(nb_layers - 1):
            layers.extend([
                nn.Conv2d(nb_filters, nb_filters, 5, 1, 0),
                nn.ReLU(True),
            ])
        self.encode = nn.Sequential(*layers)
        layers = []
        for _ in range(nb_layers - 1):
            layers.extend([
                nn.ConvTranspose2d(nb_filters, nb_filters, 5, 1, 0),
                nn.ReLU(True),
            ])
        layers.extend([
            nn.ConvTranspose2d(nb_filters, c, 5, 1, 0),
            nn.Sigmoid()
        ])
        self.decode = nn.Sequential(*layers)
        self.apply(_weights_init)

    def forward(self, X):
        size = X.size()
        h = self.encode(X)
        h = self.sparsify(h)
        Xr = self.decode(h)
        return Xr
    
    def sparsify(self, h):
        if self.spatial:
            h = spatial_sparsity(h)
        if self.channel:
            h = strided_channel_sparsity(h, stride=self.channel_stride)
        return h


def spatial_sparsity(x):
    maxes = x.amax(dim=(2,3), keepdims=True)
    return x * equals(x, maxes)

def equals(x, y, eps=1e-8):
    return torch.abs(x-y) <= eps

def strided_channel_sparsity(x, stride=1):
    B, F = x.shape[0:2]
    h, w = x.shape[2:]
    x_ = x.view(B, F, h // stride, stride, w // stride, stride)
    mask = equals(x_, x_.amax(axis=(1, 3, 5), keepdims=True))
    mask = mask.view(x.shape).float()
    return x * mask


def _weights_init(m):
    if hasattr(m, 'weight'):
        xavier_uniform(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0)