CBAM
CBAM 模块概述 通道注意力模块(Channel Attention Mechanism)和空间注意力模块(Spatial Attention Mechanism)是注意力机制的两种主要形式,它们分别通过对通道维度和空间维度的特征图进行加权,从而使网络更加关注重要的特征。CBAM模块结合了这两种注意力机制,可以在保留空间信息的同时,有效地提取关键通道特征,提高了网络在处理复杂图像任务上的性能表现。
代码实现
import torch
from torch import nn
from torchsummary import summary
class ChannelModule(nn.Module):
def __init__(self, inputs, ratio=16):
super(ChannelModule, self).__init__()
_, c, _, _ = inputs.size()
self.maxpool = nn.AdaptiveMaxPool2d(1)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.share_liner = nn.Sequential(
nn.Linear(c, c // ratio),
nn.ReLU(),
nn.Linear(c // ratio, c)
)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
x = self.maxpool(inputs).view(inputs.size(0), -1)#nc
maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#nchw
y = self.avgpool(inputs).view(inputs.size(0), -1)
avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)
return self.sigmoid(maxout + avgout)
class SpatialModule(nn.Module):
def __init__(self):
super(SpatialModule, self).__init__()
self.maxpool = torch.max
self.avgpool = torch.mean
self.concat = torch.cat
self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#n1hw
avgout = self.avgpool(inputs, dim=1, keepdim=True)#n1hw
outs = self.concat([maxout, avgout], dim=1)#n2hw
outs = self.conv(outs)#n1hw
return self.sigmoid(outs)
class CBAM(nn.Module):
def __init__(self, inputs):
super(CBAM, self).__init__()
self.channel_out = ChannelModule(inputs)
self.spatial_out = SpatialModule()
def forward(self, inputs):
outs = self.channel_out(inputs) * inputs
return self.spatial_out(outs) * outs