import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as tvm

class ResNet18(nn.Module):
    def __init__(self, pretrained=False) -> None:
        super().__init__()
        self.net = tvm.resnet18(pretrained=pretrained)
    def forward(self, x):
        self = self.net
        x1 = x
        x = self.conv1(x1)
        x = self.bn1(x)
        x2 = self.relu(x)
        x = self.maxpool(x2)
        x4 = self.layer1(x)
        x8 = self.layer2(x4)
        x16 = self.layer3(x8)
        x32 = self.layer4(x16)
        return {32:x32,16:x16,8:x8,4:x4,2:x2,1:x1}

    def train(self, mode=True):
        super().train(mode)
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
            pass

class ResNet50(nn.Module):
    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False) -> None:
        super().__init__()
        if dilation is None:
            dilation = [False,False,False]
        if anti_aliased:
            pass
        else:
            if weights is not None:
                self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
            else:
                self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)

        del self.net.fc
        self.high_res = high_res
        self.freeze_bn = freeze_bn
    def forward(self, x):
        net = self.net
        feats = {1:x}
        x = net.conv1(x)
        x = net.bn1(x)
        x = net.relu(x)
        feats[2] = x 
        x = net.maxpool(x)
        x = net.layer1(x)
        feats[4] = x 
        x = net.layer2(x)
        feats[8] = x  
        x = net.layer3(x)
        feats[16] = x
        x = net.layer4(x)
        feats[32] = x
        return feats

    def train(self, mode=True):
        super().train(mode)
        if self.freeze_bn:
            for m in self.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                pass




class ResNet101(nn.Module):
    def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
        super().__init__()
        if weights is not None:
            self.net = tvm.resnet101(weights = weights)
        else:
            self.net = tvm.resnet101(pretrained=pretrained)
        self.high_res = high_res
        self.scale_factor = 1 if not high_res else 1.5
    def forward(self, x):
        net = self.net
        feats = {1:x}
        sf = self.scale_factor
        if self.high_res:
            x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
        x = net.conv1(x)
        x = net.bn1(x)
        x = net.relu(x)
        feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.maxpool(x)
        x = net.layer1(x)
        feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer2(x)
        feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer3(x)
        feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer4(x)
        feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        return feats

    def train(self, mode=True):
        super().train(mode)
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
            pass


class WideResNet50(nn.Module):
    def __init__(self, pretrained=False, high_res = False, weights = None) -> None:
        super().__init__()
        if weights is not None:
            self.net = tvm.wide_resnet50_2(weights = weights)
        else:
            self.net = tvm.wide_resnet50_2(pretrained=pretrained)
        self.high_res = high_res
        self.scale_factor = 1 if not high_res else 1.5
    def forward(self, x):
        net = self.net
        feats = {1:x}
        sf = self.scale_factor
        if self.high_res:
            x = F.interpolate(x, scale_factor=sf, align_corners=False, mode="bicubic")
        x = net.conv1(x)
        x = net.bn1(x)
        x = net.relu(x)
        feats[2] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.maxpool(x)
        x = net.layer1(x)
        feats[4] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer2(x)
        feats[8] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer3(x)
        feats[16] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        x = net.layer4(x)
        feats[32] = x if not self.high_res else F.interpolate(x,scale_factor=1/sf,align_corners=False, mode="bilinear")    
        return feats

    def train(self, mode=True):
        super().train(mode)
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
            pass