thai thong
add file via upload
f4c379b
raw
history blame
23 kB
# Shuffle
# CBAM
# -- GAM ECA SE SK LSK
from models.common import *
class RepNCBAM(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(CBAMBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class RepNSA(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=16, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(SABottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class RepNLSK(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(LSKBottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class RepNECA(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(ECABottleneck(c_, c_, shortcut, g=g) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
# ----------------------- Attention Mechanism ---------------------------
## CBAM ATTENTION
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.act = nn.SiLU()
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.f2(self.act(self.f1(self.avg_pool(x))))
max_out = self.f2(self.act(self.f1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=3):
super().__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 1*h*w
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
#2*h*w
x = self.conv(x)
#1*h*w
return self.sigmoid(x)
class CBAMBottleneck(nn.Module):
def __init__(self,
c1,
c2,
shortcut=True,
g=1,
e=0.5,
ratio=16,
kernel_size=3): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.channel_attention = ChannelAttention(c2, ratio)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x):
x1 = self.cv2(self.cv1(x))
out = self.channel_attention(x1) * x1
# print('outchannels:{}'.format(out.shape))
out = self.spatial_attention(out) * out
return x + out if self.add else out
class CBAMC4(nn.Module):
def __init__(self, c1, c2, c3, c4, c5=1):
super(CBAMC4, self).__init__()
self.c = c3 // 2
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1))
self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)
self.channel_attention = ChannelAttention(c2)
self.spatial_attention = SpatialAttention(kernel_size=3) # Specify kernel_size here
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
y = torch.cat(y, 1)
# Apply channel attention
y = y * self.channel_attention(y)
# Apply spatial attention
y = y * self.spatial_attention(y)
return self.cv4(y)
def forward_split(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
y = torch.cat(y, 1)
# Apply channel attention
y = y * self.channel_attention(y)
# Apply spatial attention
y = y * self.spatial_attention(y)
return self.cv4(y)
class RepNCBAMELAN4(RepNCSPELAN4):
# C3 module with CBAMBottleneck()
def __init__(self, c1, c2, c3, c4, c5=1):
super().__init__(c1, c2, c3, c4, c5)
self.cv2 = nn.Sequential(RepNCBAM(c3//2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepNCBAM(c4, c4, c5), Conv(c4, c4, 3, 1))
# c_ = int(c2 * e) # hidden channels
# self.m = nn.Sequential(*(RepCBAM(c_, c_, shortcut) for _ in range(n)))
## GAM ATTETION
class GAMAttention(nn.Module):
#https://paperswithcode.com/paper/global-attention-mechanism-retain-information
def __init__(self, c1, c2, group=True,rate=4):
super(GAMAttention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(c1, int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(c1 / rate), c1)
)
self.spatial_attention = nn.Sequential(
nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3),
nn.BatchNorm2d(int(c1 /rate)),
nn.ReLU(inplace=True),
nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3),
nn.BatchNorm2d(c2)
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle
out = x * x_spatial_att
return out
def channel_shuffle(x, groups=2):
B, C, H, W = x.size()
out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
out=out.view(B, C, H, W)
return out
## SK ATTENTION
class SKAttention(nn.Module):
def __init__(self, channel=512,out_channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):
super().__init__()
self.d=max(L,channel//reduction)
self.convs=nn.ModuleList([])
for k in kernels:
self.convs.append(
nn.Sequential(OrderedDict([
('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),
('bn',nn.BatchNorm2d(channel)),
('relu',nn.ReLU())
]))
)
self.fc=nn.Linear(channel,self.d)
self.fcs=nn.ModuleList([])
for i in range(len(kernels)):
self.fcs.append(nn.Linear(self.d,channel))
self.softmax=nn.Softmax(dim=0)
def forward(self, x):
bs, c, _, _ = x.size()
conv_outs=[]
### split
for conv in self.convs:
conv_outs.append(conv(x))
feats=torch.stack(conv_outs,0)#k,bs,channel,h,w
### fuse
U=sum(conv_outs) #bs,c,h,w
### reduction channel
S=U.mean(-1).mean(-1) #bs,c
Z=self.fc(S) #bs,d
### calculate attention weight
weights=[]
for fc in self.fcs:
weight=fc(Z)
weights.append(weight.view(bs,c,1,1)) #bs,channel
attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1
attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1
### fuse
V=(attention_weughts*feats).sum(0)
return V
## SHUFFLE ATTENTION
from torch.nn.parameter import Parameter
from torch.nn import init
class sa_layer(nn.Module):
"""Constructs a Channel Spatial Group module.
Args:
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, groups=16):
super(sa_layer, self).__init__()
self.groups = groups
self.channel = channel
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups))
self.cweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1))
self.cbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1))
self.sweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1))
self.sbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1))
self.sigmoid = nn.Sigmoid()
self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups))
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.shape
# group into subfeatures
x = x.reshape(b * self.groups, -1, h, w)
# channel_split
x_0, x_1 = x.chunk(2, dim=1)
# channel attention
xn = self.avg_pool(x_0)
xn = self.cweight * xn + self.cbias
xn = x_0 * self.sigmoid(xn)
# spatial attention
xs = self.gn(x_1)
xs = self.sweight * xs + self.sbias
xs = x_1 * self.sigmoid(xs)
# concatenate along channel axis
out = torch.cat([xn, xs], dim=1)
out = out.reshape(b, -1, h, w)
out = self.channel_shuffle(out, 2)
return out
class SABottleneck(nn.Module):
# expansion = 4
def __init__(self, c1, c2, s=1, shortcut=True, k=(1, 3), e=0.5, g=1):
super(SABottleneck, self).__init__()
c_ = c2 // 2
self.shortcut = shortcut
self.conv1 = Conv(c1, c_, k[0], s)
self.conv2 = Conv(c_, c2, k[1], s, g=g)
self.add = shortcut and c1 == c2
self.sa = sa_layer(c2, g)
def forward(self, x):
x1 = self.conv2(self.conv1(x))
y = self.sa(x1)
out = y
return x + out if self.add else out
class RepNSAELAN4(RepNCSPELAN4):
def __init__(self, c1, c2, c3, c4, c5=1):
super().__init__(c1, c2, c3, c4, c5)
self.cv2 = nn.Sequential(RepNSA(c3//2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepNSA(c4, c4, c5), Conv(c4, c4, 3, 1))
## ECA
class EfficientChannelAttention(nn.Module): # Efficient Channel Attention module
def __init__(self, c, b=1, gamma=2):
super(EfficientChannelAttention, self).__init__()
t = int(abs((math.log(c, 2) + b) / gamma))
k = t if t % 2 else t + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.avg_pool(x)
out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
out = self.sigmoid(out)
return out * x
class ECABottleneck(nn.Module):
# Standard bottleneck
def __init__(self,
c1,
c2,
shortcut=True,
g=1,
e=0.5,
ratio=16,
k_size=3): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.cv2(self.cv1(x))
y = self.avg_pool(x1)
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
y = self.sigmoid(y)
out = x1 * y.expand_as(x1)
return x + out if self.add else out
class RepNECALAN4(RepNCSPELAN4):
def __init__(self, c1, c2, c3, c4, c5=1):
super().__init__(c1, c2, c3, c4, c5)
self.cv2 = nn.Sequential(RepNECA(c3//2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepNECA(c4, c4, c5), Conv(c4, c4, 3, 1))
## LSK Attention
class LSKblock(nn.Module):
def __init__(self, c1):
super().__init__()
self.conv0 = nn.Conv2d(c1, c1, 5, padding=2, groups=c1)
self.conv_spatial = nn.Conv2d(c1, c1, 7, stride=1, padding=9, groups=c1, dilation=3)
self.conv1 = nn.Conv2d(c1, c1//2, 1)
self.conv2 = nn.Conv2d(c1, c1//2, 1)
# self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1))
self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3)
self.conv = nn.Conv2d(c1//2, c1, 1)
def forward(self, x):
attn1 = self.conv0(x)
attn2 = self.conv_spatial(attn1)
attn1 = self.conv1(attn1)
attn2 = self.conv2(attn2)
attn = torch.cat([attn1, attn2], dim=1)
avg_attn = torch.mean(attn, dim=1, keepdim=True)
max_attn, _ = torch.max(attn, dim=1, keepdim=True)
agg = torch.cat([avg_attn, max_attn], dim=1)
sig = self.conv_squeeze(agg).sigmoid()
attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1)
attn = self.conv(attn)
return x * attn
# class LSKAttention(nn.Module):
# def __init__(self, c1, c2, shortcut = True):
# super().__init__()
# self.conv1 = Conv(c1, c1, 1)
# self.spatial_gating_unit = LSKblock(c1)
# self.conv2 = Conv(c1, c2, 1)
# self.add = shortcut and c1 == c2
# def forward(self, x):
# x1 = self.conv1(x)
# x = self.spatial_gating_unit(x)
# x = self.proj_2(x)
# x = x + shorcut
# return x
class LSKBottleneck(nn.Module):
# expansion = 4
def __init__(self, c1, c2, s=1, shortcut=True, g=1):
super(LSKBottleneck, self).__init__()
c_ = c2 // 2
self.shortcut = shortcut
self.add = shortcut and c1 == c2
self.conv1 = Conv(c1, c_, 1)
self.conv2 = Conv(c_, c2, 3, s, g= g)
self.lsk = LSKblock(c2)
def forward(self, x):
x1 = self.conv2(self.conv1(x))
y = self.lsk(x1)
out = y
return x + out if self.add else out
class RepNLSKELAN4(RepNCSPELAN4):
def __init__(self, c1, c2, c3, c4, c5=1):
super().__init__(c1, c2, c3, c4, c5)
self.cv2 = nn.Sequential(RepNLSK(c3//2, c4, c5), Conv(c4, c4, 3, 1))
self.cv3 = nn.Sequential(RepNLSK(c4, c4, c5), Conv(c4, c4, 3, 1))
## SE Attention
class SEBottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, ratio=16): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
# self.se=SE(c1,c2,ratio)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
self.relu = nn.ReLU(inplace=True)
self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
self.sig = nn.Sigmoid()
def forward(self, x):
x1 = self.cv2(self.cv1(x))
b, c, _, _ = x.size()
y = self.avgpool(x1).view(b, c)
y = self.l1(y)
y = self.relu(y)
y = self.l2(y)
y = self.sig(y)
y = y.view(b, c, 1, 1)
out = x1 * y.expand_as(x1)
# out=self.se(x1)*x1
return x + out if self.add else out
## SOCA Attention
from torch.autograd import Function
class Covpool(Function):
@staticmethod
def forward(ctx, input):
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
h = x.data.shape[2]
w = x.data.shape[3]
M = h*w
x = x.reshape(batchSize,dim,M)
I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)
I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)
y = x.bmm(I_hat).bmm(x.transpose(1,2))
ctx.save_for_backward(input,I_hat)
return y
@staticmethod
def backward(ctx, grad_output):
input,I_hat = ctx.saved_tensors
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
h = x.data.shape[2]
w = x.data.shape[3]
M = h*w
x = x.reshape(batchSize,dim,M)
grad_input = grad_output + grad_output.transpose(1,2)
grad_input = grad_input.bmm(x).bmm(I_hat)
grad_input = grad_input.reshape(batchSize,dim,h,w)
return grad_input
class Sqrtm(Function):
@staticmethod
def forward(ctx, input, iterN):
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
dtype = x.dtype
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1)
A = x.div(normA.view(batchSize,1,1).expand_as(x))
Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device)
Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1)
if iterN < 2:
ZY = 0.5*(I3 - A)
Y[:,0,:,:] = A.bmm(ZY)
else:
ZY = 0.5*(I3 - A)
Y[:,0,:,:] = A.bmm(ZY)
Z[:,0,:,:] = ZY
for i in range(1, iterN-1):
ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:]))
Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY)
Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:])
ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]))
y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
ctx.save_for_backward(input, A, ZY, normA, Y, Z)
ctx.iterN = iterN
return y
@staticmethod
def backward(ctx, grad_output):
input, A, ZY, normA, Y, Z = ctx.saved_tensors
iterN = ctx.iterN
x = input
batchSize = x.data.shape[0]
dim = x.data.shape[1]
dtype = x.dtype
der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x)
der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA))
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)
if iterN < 2:
der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace))
else:
dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) -
Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom))
dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:])
for i in range(iterN-3, -1, -1):
YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:])
ZY = Z[:,i,:,:].bmm(Y[:,i,:,:])
dldY_ = 0.5*(dldY.bmm(YZ) -
Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) -
ZY.bmm(dldY))
dldZ_ = 0.5*(YZ.bmm(dldZ) -
Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) -
dldZ.bmm(ZY))
dldY = dldY_
dldZ = dldZ_
der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY))
grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x))
grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1)
for i in range(batchSize):
grad_input[i,:,:] += (der_postComAux[i] \
- grad_aux[i] / (normA[i] * normA[i])) \
*torch.ones(dim,device = x.device).diag()
return grad_input, None
def CovpoolLayer(var):
return Covpool.apply(var)
def SqrtmLayer(var, iterN):
return Sqrtm.apply(var, iterN)
class SOCA(nn.Module):
# Second-order Channel Attention
def __init__(self, c1, c2, reduction=8):
super(SOCA, self).__init__()
self.max_pool = nn.MaxPool2d(kernel_size=2)
self.conv_du = nn.Sequential(
nn.Conv2d(c1, c1 // reduction, 1, padding=0, bias=True),
nn.SiLU(), # SiLU activation
nn.Conv2d(c1 // reduction, c1, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
batch_size, C, h, w = x.shape # x: NxCxHxW
N = int(h * w)
min_h = min(h, w)
h1 = 1000
w1 = 1000
if h < h1 and w < w1:
x_sub = x
elif h < h1 and w > w1:
W = (w - w1) // 2
x_sub = x[:, :, :, W:(W + w1)]
elif w < w1 and h > h1:
H = (h - h1) // 2
x_sub = x[:, :, H:H + h1, :]
else:
H = (h - h1) // 2
W = (w - w1) // 2
x_sub = x[:, :, H:(H + h1), W:(W + w1)]
cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layer
cov_mat_sqrt = SqrtmLayer(cov_mat, 5) # Matrix square root layer (including pre-norm, Newton-Schulz iter. and post-com. with 5 iterations)
cov_mat_sum = torch.mean(cov_mat_sqrt, 1)
cov_mat_sum = cov_mat_sum.view(batch_size, C, 1, 1)
y_cov = self.conv_du(cov_mat_sum)
return y_cov * x