Spaces:
Running
on
Zero
Running
on
Zero
# Shuffle | |
# CBAM | |
# -- GAM ECA SE SK LSK | |
from models.common import * | |
class RepNCBAM(nn.Module): | |
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c1, c_, 1, 1) | |
self.cv3 = Conv(2 * c_, c2, 1) | |
self.m = nn.Sequential(*(CBAMBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) | |
def forward(self, x): | |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) | |
class RepNSA(nn.Module): | |
def __init__(self, c1, c2, n=1, shortcut=True, g=16, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c1, c_, 1, 1) | |
self.cv3 = Conv(2 * c_, c2, 1) | |
self.m = nn.Sequential(*(SABottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n))) | |
def forward(self, x): | |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) | |
class RepNLSK(nn.Module): | |
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c1, c_, 1, 1) | |
self.cv3 = Conv(2 * c_, c2, 1) | |
self.m = nn.Sequential(*(LSKBottleneck(c_, c_, 1, shortcut, g=g) for _ in range(n))) | |
def forward(self, x): | |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) | |
class RepNECA(nn.Module): | |
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c1, c_, 1, 1) | |
self.cv3 = Conv(2 * c_, c2, 1) | |
self.m = nn.Sequential(*(ECABottleneck(c_, c_, shortcut, g=g) for _ in range(n))) | |
def forward(self, x): | |
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) | |
# ----------------------- Attention Mechanism --------------------------- | |
## CBAM ATTENTION | |
class ChannelAttention(nn.Module): | |
def __init__(self, in_planes, ratio=16): | |
super().__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.max_pool = nn.AdaptiveMaxPool2d(1) | |
self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) | |
self.act = nn.SiLU() | |
self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
avg_out = self.f2(self.act(self.f1(self.avg_pool(x)))) | |
max_out = self.f2(self.act(self.f1(self.max_pool(x)))) | |
out = self.sigmoid(avg_out + max_out) | |
return out | |
class SpatialAttention(nn.Module): | |
def __init__(self, kernel_size=3): | |
super().__init__() | |
assert kernel_size in (3, 7), 'kernel size must be 3 or 7' | |
padding = 3 if kernel_size == 7 else 1 | |
self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
# 1*h*w | |
avg_out = torch.mean(x, dim=1, keepdim=True) | |
max_out, _ = torch.max(x, dim=1, keepdim=True) | |
x = torch.cat([avg_out, max_out], dim=1) | |
#2*h*w | |
x = self.conv(x) | |
#1*h*w | |
return self.sigmoid(x) | |
class CBAMBottleneck(nn.Module): | |
def __init__(self, | |
c1, | |
c2, | |
shortcut=True, | |
g=1, | |
e=0.5, | |
ratio=16, | |
kernel_size=3): # ch_in, ch_out, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c_, c2, 3, 1, g=g) | |
self.add = shortcut and c1 == c2 | |
self.channel_attention = ChannelAttention(c2, ratio) | |
self.spatial_attention = SpatialAttention(kernel_size) | |
def forward(self, x): | |
x1 = self.cv2(self.cv1(x)) | |
out = self.channel_attention(x1) * x1 | |
# print('outchannels:{}'.format(out.shape)) | |
out = self.spatial_attention(out) * out | |
return x + out if self.add else out | |
class CBAMC4(nn.Module): | |
def __init__(self, c1, c2, c3, c4, c5=1): | |
super(CBAMC4, self).__init__() | |
self.c = c3 // 2 | |
self.cv1 = Conv(c1, c3, 1, 1) | |
self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1) | |
self.channel_attention = ChannelAttention(c2) | |
self.spatial_attention = SpatialAttention(kernel_size=3) # Specify kernel_size here | |
def forward(self, x): | |
y = list(self.cv1(x).chunk(2, 1)) | |
y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) | |
y = torch.cat(y, 1) | |
# Apply channel attention | |
y = y * self.channel_attention(y) | |
# Apply spatial attention | |
y = y * self.spatial_attention(y) | |
return self.cv4(y) | |
def forward_split(self, x): | |
y = list(self.cv1(x).split((self.c, self.c), 1)) | |
y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) | |
y = torch.cat(y, 1) | |
# Apply channel attention | |
y = y * self.channel_attention(y) | |
# Apply spatial attention | |
y = y * self.spatial_attention(y) | |
return self.cv4(y) | |
class RepNCBAMELAN4(RepNCSPELAN4): | |
# C3 module with CBAMBottleneck() | |
def __init__(self, c1, c2, c3, c4, c5=1): | |
super().__init__(c1, c2, c3, c4, c5) | |
self.cv2 = nn.Sequential(RepNCBAM(c3//2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv3 = nn.Sequential(RepNCBAM(c4, c4, c5), Conv(c4, c4, 3, 1)) | |
# c_ = int(c2 * e) # hidden channels | |
# self.m = nn.Sequential(*(RepCBAM(c_, c_, shortcut) for _ in range(n))) | |
## GAM ATTETION | |
class GAMAttention(nn.Module): | |
#https://paperswithcode.com/paper/global-attention-mechanism-retain-information | |
def __init__(self, c1, c2, group=True,rate=4): | |
super(GAMAttention, self).__init__() | |
self.channel_attention = nn.Sequential( | |
nn.Linear(c1, int(c1 / rate)), | |
nn.ReLU(inplace=True), | |
nn.Linear(int(c1 / rate), c1) | |
) | |
self.spatial_attention = nn.Sequential( | |
nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), | |
nn.BatchNorm2d(int(c1 /rate)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), | |
nn.BatchNorm2d(c2) | |
) | |
def forward(self, x): | |
b, c, h, w = x.shape | |
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c) | |
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c) | |
x_channel_att = x_att_permute.permute(0, 3, 1, 2) | |
x = x * x_channel_att | |
x_spatial_att = self.spatial_attention(x).sigmoid() | |
x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle | |
out = x * x_spatial_att | |
return out | |
def channel_shuffle(x, groups=2): | |
B, C, H, W = x.size() | |
out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous() | |
out=out.view(B, C, H, W) | |
return out | |
## SK ATTENTION | |
class SKAttention(nn.Module): | |
def __init__(self, channel=512,out_channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32): | |
super().__init__() | |
self.d=max(L,channel//reduction) | |
self.convs=nn.ModuleList([]) | |
for k in kernels: | |
self.convs.append( | |
nn.Sequential(OrderedDict([ | |
('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)), | |
('bn',nn.BatchNorm2d(channel)), | |
('relu',nn.ReLU()) | |
])) | |
) | |
self.fc=nn.Linear(channel,self.d) | |
self.fcs=nn.ModuleList([]) | |
for i in range(len(kernels)): | |
self.fcs.append(nn.Linear(self.d,channel)) | |
self.softmax=nn.Softmax(dim=0) | |
def forward(self, x): | |
bs, c, _, _ = x.size() | |
conv_outs=[] | |
### split | |
for conv in self.convs: | |
conv_outs.append(conv(x)) | |
feats=torch.stack(conv_outs,0)#k,bs,channel,h,w | |
### fuse | |
U=sum(conv_outs) #bs,c,h,w | |
### reduction channel | |
S=U.mean(-1).mean(-1) #bs,c | |
Z=self.fc(S) #bs,d | |
### calculate attention weight | |
weights=[] | |
for fc in self.fcs: | |
weight=fc(Z) | |
weights.append(weight.view(bs,c,1,1)) #bs,channel | |
attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1 | |
attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1 | |
### fuse | |
V=(attention_weughts*feats).sum(0) | |
return V | |
## SHUFFLE ATTENTION | |
from torch.nn.parameter import Parameter | |
from torch.nn import init | |
class sa_layer(nn.Module): | |
"""Constructs a Channel Spatial Group module. | |
Args: | |
k_size: Adaptive selection of kernel size | |
""" | |
def __init__(self, channel, groups=16): | |
super(sa_layer, self).__init__() | |
self.groups = groups | |
self.channel = channel | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups)) | |
self.cweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1)) | |
self.cbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1)) | |
self.sweight = Parameter(torch.zeros(1, self.channel // (2 * self.groups), 1, 1)) | |
self.sbias = Parameter(torch.ones(1, self.channel // (2 * self.groups), 1, 1)) | |
self.sigmoid = nn.Sigmoid() | |
self.gn = nn.GroupNorm(self.channel // (2 * self.groups), self.channel // (2 * self.groups)) | |
def channel_shuffle(x, groups): | |
b, c, h, w = x.shape | |
x = x.reshape(b, groups, -1, h, w) | |
x = x.permute(0, 2, 1, 3, 4) | |
# flatten | |
x = x.reshape(b, -1, h, w) | |
return x | |
def forward(self, x): | |
b, c, h, w = x.shape | |
# group into subfeatures | |
x = x.reshape(b * self.groups, -1, h, w) | |
# channel_split | |
x_0, x_1 = x.chunk(2, dim=1) | |
# channel attention | |
xn = self.avg_pool(x_0) | |
xn = self.cweight * xn + self.cbias | |
xn = x_0 * self.sigmoid(xn) | |
# spatial attention | |
xs = self.gn(x_1) | |
xs = self.sweight * xs + self.sbias | |
xs = x_1 * self.sigmoid(xs) | |
# concatenate along channel axis | |
out = torch.cat([xn, xs], dim=1) | |
out = out.reshape(b, -1, h, w) | |
out = self.channel_shuffle(out, 2) | |
return out | |
class SABottleneck(nn.Module): | |
# expansion = 4 | |
def __init__(self, c1, c2, s=1, shortcut=True, k=(1, 3), e=0.5, g=1): | |
super(SABottleneck, self).__init__() | |
c_ = c2 // 2 | |
self.shortcut = shortcut | |
self.conv1 = Conv(c1, c_, k[0], s) | |
self.conv2 = Conv(c_, c2, k[1], s, g=g) | |
self.add = shortcut and c1 == c2 | |
self.sa = sa_layer(c2, g) | |
def forward(self, x): | |
x1 = self.conv2(self.conv1(x)) | |
y = self.sa(x1) | |
out = y | |
return x + out if self.add else out | |
class RepNSAELAN4(RepNCSPELAN4): | |
def __init__(self, c1, c2, c3, c4, c5=1): | |
super().__init__(c1, c2, c3, c4, c5) | |
self.cv2 = nn.Sequential(RepNSA(c3//2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv3 = nn.Sequential(RepNSA(c4, c4, c5), Conv(c4, c4, 3, 1)) | |
## ECA | |
class EfficientChannelAttention(nn.Module): # Efficient Channel Attention module | |
def __init__(self, c, b=1, gamma=2): | |
super(EfficientChannelAttention, self).__init__() | |
t = int(abs((math.log(c, 2) + b) / gamma)) | |
k = t if t % 2 else t + 1 | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
out = self.avg_pool(x) | |
out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) | |
out = self.sigmoid(out) | |
return out * x | |
class ECABottleneck(nn.Module): | |
# Standard bottleneck | |
def __init__(self, | |
c1, | |
c2, | |
shortcut=True, | |
g=1, | |
e=0.5, | |
ratio=16, | |
k_size=3): # ch_in, ch_out, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c_, c2, 3, 1, g=g) | |
self.add = shortcut and c1 == c2 | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) | |
self.sigmoid = nn.Sigmoid() | |
def forward(self, x): | |
x1 = self.cv2(self.cv1(x)) | |
y = self.avg_pool(x1) | |
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) | |
y = self.sigmoid(y) | |
out = x1 * y.expand_as(x1) | |
return x + out if self.add else out | |
class RepNECALAN4(RepNCSPELAN4): | |
def __init__(self, c1, c2, c3, c4, c5=1): | |
super().__init__(c1, c2, c3, c4, c5) | |
self.cv2 = nn.Sequential(RepNECA(c3//2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv3 = nn.Sequential(RepNECA(c4, c4, c5), Conv(c4, c4, 3, 1)) | |
## LSK Attention | |
class LSKblock(nn.Module): | |
def __init__(self, c1): | |
super().__init__() | |
self.conv0 = nn.Conv2d(c1, c1, 5, padding=2, groups=c1) | |
self.conv_spatial = nn.Conv2d(c1, c1, 7, stride=1, padding=9, groups=c1, dilation=3) | |
self.conv1 = nn.Conv2d(c1, c1//2, 1) | |
self.conv2 = nn.Conv2d(c1, c1//2, 1) | |
# self.cv2 = nn.Sequential(RepNCSP(c3 // 2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.conv_squeeze = nn.Conv2d(2, 2, 7, padding=3) | |
self.conv = nn.Conv2d(c1//2, c1, 1) | |
def forward(self, x): | |
attn1 = self.conv0(x) | |
attn2 = self.conv_spatial(attn1) | |
attn1 = self.conv1(attn1) | |
attn2 = self.conv2(attn2) | |
attn = torch.cat([attn1, attn2], dim=1) | |
avg_attn = torch.mean(attn, dim=1, keepdim=True) | |
max_attn, _ = torch.max(attn, dim=1, keepdim=True) | |
agg = torch.cat([avg_attn, max_attn], dim=1) | |
sig = self.conv_squeeze(agg).sigmoid() | |
attn = attn1 * sig[:,0,:,:].unsqueeze(1) + attn2 * sig[:,1,:,:].unsqueeze(1) | |
attn = self.conv(attn) | |
return x * attn | |
# class LSKAttention(nn.Module): | |
# def __init__(self, c1, c2, shortcut = True): | |
# super().__init__() | |
# self.conv1 = Conv(c1, c1, 1) | |
# self.spatial_gating_unit = LSKblock(c1) | |
# self.conv2 = Conv(c1, c2, 1) | |
# self.add = shortcut and c1 == c2 | |
# def forward(self, x): | |
# x1 = self.conv1(x) | |
# x = self.spatial_gating_unit(x) | |
# x = self.proj_2(x) | |
# x = x + shorcut | |
# return x | |
class LSKBottleneck(nn.Module): | |
# expansion = 4 | |
def __init__(self, c1, c2, s=1, shortcut=True, g=1): | |
super(LSKBottleneck, self).__init__() | |
c_ = c2 // 2 | |
self.shortcut = shortcut | |
self.add = shortcut and c1 == c2 | |
self.conv1 = Conv(c1, c_, 1) | |
self.conv2 = Conv(c_, c2, 3, s, g= g) | |
self.lsk = LSKblock(c2) | |
def forward(self, x): | |
x1 = self.conv2(self.conv1(x)) | |
y = self.lsk(x1) | |
out = y | |
return x + out if self.add else out | |
class RepNLSKELAN4(RepNCSPELAN4): | |
def __init__(self, c1, c2, c3, c4, c5=1): | |
super().__init__(c1, c2, c3, c4, c5) | |
self.cv2 = nn.Sequential(RepNLSK(c3//2, c4, c5), Conv(c4, c4, 3, 1)) | |
self.cv3 = nn.Sequential(RepNLSK(c4, c4, c5), Conv(c4, c4, 3, 1)) | |
## SE Attention | |
class SEBottleneck(nn.Module): | |
# Standard bottleneck | |
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5, ratio=16): # ch_in, ch_out, shortcut, groups, expansion | |
super().__init__() | |
c_ = int(c2 * e) # hidden channels | |
self.cv1 = Conv(c1, c_, 1, 1) | |
self.cv2 = Conv(c_, c2, 3, 1, g=g) | |
self.add = shortcut and c1 == c2 | |
# self.se=SE(c1,c2,ratio) | |
self.avgpool = nn.AdaptiveAvgPool2d(1) | |
self.l1 = nn.Linear(c1, c1 // ratio, bias=False) | |
self.relu = nn.ReLU(inplace=True) | |
self.l2 = nn.Linear(c1 // ratio, c1, bias=False) | |
self.sig = nn.Sigmoid() | |
def forward(self, x): | |
x1 = self.cv2(self.cv1(x)) | |
b, c, _, _ = x.size() | |
y = self.avgpool(x1).view(b, c) | |
y = self.l1(y) | |
y = self.relu(y) | |
y = self.l2(y) | |
y = self.sig(y) | |
y = y.view(b, c, 1, 1) | |
out = x1 * y.expand_as(x1) | |
# out=self.se(x1)*x1 | |
return x + out if self.add else out | |
## SOCA Attention | |
from torch.autograd import Function | |
class Covpool(Function): | |
def forward(ctx, input): | |
x = input | |
batchSize = x.data.shape[0] | |
dim = x.data.shape[1] | |
h = x.data.shape[2] | |
w = x.data.shape[3] | |
M = h*w | |
x = x.reshape(batchSize,dim,M) | |
I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device) | |
I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype) | |
y = x.bmm(I_hat).bmm(x.transpose(1,2)) | |
ctx.save_for_backward(input,I_hat) | |
return y | |
def backward(ctx, grad_output): | |
input,I_hat = ctx.saved_tensors | |
x = input | |
batchSize = x.data.shape[0] | |
dim = x.data.shape[1] | |
h = x.data.shape[2] | |
w = x.data.shape[3] | |
M = h*w | |
x = x.reshape(batchSize,dim,M) | |
grad_input = grad_output + grad_output.transpose(1,2) | |
grad_input = grad_input.bmm(x).bmm(I_hat) | |
grad_input = grad_input.reshape(batchSize,dim,h,w) | |
return grad_input | |
class Sqrtm(Function): | |
def forward(ctx, input, iterN): | |
x = input | |
batchSize = x.data.shape[0] | |
dim = x.data.shape[1] | |
dtype = x.dtype | |
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) | |
normA = (1.0/3.0)*x.mul(I3).sum(dim=1).sum(dim=1) | |
A = x.div(normA.view(batchSize,1,1).expand_as(x)) | |
Y = torch.zeros(batchSize, iterN, dim, dim, requires_grad = False, device = x.device) | |
Z = torch.eye(dim,dim,device = x.device).view(1,dim,dim).repeat(batchSize,iterN,1,1) | |
if iterN < 2: | |
ZY = 0.5*(I3 - A) | |
Y[:,0,:,:] = A.bmm(ZY) | |
else: | |
ZY = 0.5*(I3 - A) | |
Y[:,0,:,:] = A.bmm(ZY) | |
Z[:,0,:,:] = ZY | |
for i in range(1, iterN-1): | |
ZY = 0.5*(I3 - Z[:,i-1,:,:].bmm(Y[:,i-1,:,:])) | |
Y[:,i,:,:] = Y[:,i-1,:,:].bmm(ZY) | |
Z[:,i,:,:] = ZY.bmm(Z[:,i-1,:,:]) | |
ZY = 0.5*Y[:,iterN-2,:,:].bmm(I3 - Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:])) | |
y = ZY*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) | |
ctx.save_for_backward(input, A, ZY, normA, Y, Z) | |
ctx.iterN = iterN | |
return y | |
def backward(ctx, grad_output): | |
input, A, ZY, normA, Y, Z = ctx.saved_tensors | |
iterN = ctx.iterN | |
x = input | |
batchSize = x.data.shape[0] | |
dim = x.data.shape[1] | |
dtype = x.dtype | |
der_postCom = grad_output*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(x) | |
der_postComAux = (grad_output*ZY).sum(dim=1).sum(dim=1).div(2*torch.sqrt(normA)) | |
I3 = 3.0*torch.eye(dim,dim,device = x.device).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) | |
if iterN < 2: | |
der_NSiter = 0.5*(der_postCom.bmm(I3 - A) - A.bmm(der_sacleTrace)) | |
else: | |
dldY = 0.5*(der_postCom.bmm(I3 - Y[:,iterN-2,:,:].bmm(Z[:,iterN-2,:,:])) - | |
Z[:,iterN-2,:,:].bmm(Y[:,iterN-2,:,:]).bmm(der_postCom)) | |
dldZ = -0.5*Y[:,iterN-2,:,:].bmm(der_postCom).bmm(Y[:,iterN-2,:,:]) | |
for i in range(iterN-3, -1, -1): | |
YZ = I3 - Y[:,i,:,:].bmm(Z[:,i,:,:]) | |
ZY = Z[:,i,:,:].bmm(Y[:,i,:,:]) | |
dldY_ = 0.5*(dldY.bmm(YZ) - | |
Z[:,i,:,:].bmm(dldZ).bmm(Z[:,i,:,:]) - | |
ZY.bmm(dldY)) | |
dldZ_ = 0.5*(YZ.bmm(dldZ) - | |
Y[:,i,:,:].bmm(dldY).bmm(Y[:,i,:,:]) - | |
dldZ.bmm(ZY)) | |
dldY = dldY_ | |
dldZ = dldZ_ | |
der_NSiter = 0.5*(dldY.bmm(I3 - A) - dldZ - A.bmm(dldY)) | |
grad_input = der_NSiter.div(normA.view(batchSize,1,1).expand_as(x)) | |
grad_aux = der_NSiter.mul(x).sum(dim=1).sum(dim=1) | |
for i in range(batchSize): | |
grad_input[i,:,:] += (der_postComAux[i] \ | |
- grad_aux[i] / (normA[i] * normA[i])) \ | |
*torch.ones(dim,device = x.device).diag() | |
return grad_input, None | |
def CovpoolLayer(var): | |
return Covpool.apply(var) | |
def SqrtmLayer(var, iterN): | |
return Sqrtm.apply(var, iterN) | |
class SOCA(nn.Module): | |
# Second-order Channel Attention | |
def __init__(self, c1, c2, reduction=8): | |
super(SOCA, self).__init__() | |
self.max_pool = nn.MaxPool2d(kernel_size=2) | |
self.conv_du = nn.Sequential( | |
nn.Conv2d(c1, c1 // reduction, 1, padding=0, bias=True), | |
nn.SiLU(), # SiLU activation | |
nn.Conv2d(c1 // reduction, c1, 1, padding=0, bias=True), | |
nn.Sigmoid() | |
) | |
def forward(self, x): | |
batch_size, C, h, w = x.shape # x: NxCxHxW | |
N = int(h * w) | |
min_h = min(h, w) | |
h1 = 1000 | |
w1 = 1000 | |
if h < h1 and w < w1: | |
x_sub = x | |
elif h < h1 and w > w1: | |
W = (w - w1) // 2 | |
x_sub = x[:, :, :, W:(W + w1)] | |
elif w < w1 and h > h1: | |
H = (h - h1) // 2 | |
x_sub = x[:, :, H:H + h1, :] | |
else: | |
H = (h - h1) // 2 | |
W = (w - w1) // 2 | |
x_sub = x[:, :, H:(H + h1), W:(W + w1)] | |
cov_mat = CovpoolLayer(x_sub) # Global Covariance pooling layer | |
cov_mat_sqrt = SqrtmLayer(cov_mat, 5) # Matrix square root layer (including pre-norm, Newton-Schulz iter. and post-com. with 5 iterations) | |
cov_mat_sum = torch.mean(cov_mat_sqrt, 1) | |
cov_mat_sum = cov_mat_sum.view(batch_size, C, 1, 1) | |
y_cov = self.conv_du(cov_mat_sum) | |
return y_cov * x | |