在mmdetection/mmdet/models/layers/目录下增加attention_layers.py
import torch.nn as nn
from mmdet.registry import MODELS
#自定义注意力机制算法
from .attention.CBAM import CBAMBlock as _CBAMBlock
from .attention.BAM import BAMBlock as _BAMBlock
from .attention.SEAttention import SEAttention as _SEAttention
from .attention.ECAAttention import ECAAttention as _ECAAttention
from .attention.ShuffleAttention import ShuffleAttention as _ShuffleAttention
from .attention.SGE import SpatialGroupEnhance as _SpatialGroupEnhance
from .attention.A2Atttention import DoubleAttention as _DoubleAttention
from .attention.PolarizedSelfAttention import SequentialPolarizedSelfAttention as _SequentialPolarizedSelfAttention
from .attention.CoTAttention import CoTAttention as _CoTAttention
from .attention.TripletAttention import TripletAttention as _TripletAttention
from .attention.CoordAttention import CoordAtt as _CoordAtt
from .attention.ParNetAttention import ParNetAttention as _ParNetAttention
@MODELS.register_module()
class CBAMBlock(nn.Module):
def __init__(self, in_channels, **kwargs):
super(CBAMBlock, self).__init__()
print("======激活注意力机制模块【CBAMBlock】======")
self.module = _CBAMBlock(channel = in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class BAMBlock(nn.Module):
def __init__(self, in_channels, **kwargs):
super(BAMBlock, self).__init__()
print("======激活注意力机制模块【BAMBlock】======")
self.module = _BAMBlock(channel = in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class SEAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(SEAttention, self).__init__()
print("======激活注意力机制模块【SEAttention】======")
self.module = _SEAttention(channel = in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class ECAAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(ECAAttention, self).__init__()
print("======激活注意力机制模块【ECAAttention】======")
self.module = _ECAAttention(**kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class ShuffleAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(ShuffleAttention, self).__init__()
print("======激活注意力机制模块【ShuffleAttention】======")
self.module = _ShuffleAttention(channel = in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class SpatialGroupEnhance(nn.Module):
def __init__(self, in_channels, **kwargs):
super(SpatialGroupEnhance, self).__init__()
print("======激活注意力机制模块【SpatialGroupEnhance】======")
self.module = _SpatialGroupEnhance(**kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class DoubleAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(DoubleAttention, self).__init__()
print("======激活注意力机制模块【DoubleAttention】======")
self.module = _DoubleAttention(in_channels, 128, 128,True)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class SequentialPolarizedSelfAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(SequentialPolarizedSelfAttention, self).__init__()
print("======激活注意力机制模块【Polarized Self-Attention】======")
self.module = _SequentialPolarizedSelfAttention(channel=in_channels)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class CoTAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(CoTAttention, self).__init__()
print("======激活注意力机制模块【CoTAttention】======")
self.module = _CoTAttention(dim=in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class TripletAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(TripletAttention, self).__init__()
print("======激活注意力机制模块【TripletAttention】======")
self.module = _TripletAttention()
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class CoordAtt(nn.Module):
def __init__(self, in_channels, **kwargs):
super(CoordAtt, self).__init__()
print("======激活注意力机制模块【CoordAtt】======")
self.module = _CoordAtt(in_channels, in_channels, **kwargs)
def forward(self, x):
return self.module(x)
@MODELS.register_module()
class ParNetAttention(nn.Module):
def __init__(self, in_channels, **kwargs):
super(ParNetAttention, self).__init__()
print("======激活注意力机制模块【ParNetAttention】======")
self.module = _ParNetAttention(channel=in_channels)
def forward(self, x):
return self.module(x)
与attention_layers.py同级目录下创建attention文件夹,在attention文件中放12种注意力机制算法文件。
下载地址:mmdetection3的12种注意力机制资源-CSDN文库https://download.csdn.net/download/lanyan90/89513979
使用方法:
以faster-rcnn_r50为例,创建faster-rcnn_r50_fpn_1x_coco_attention.py
_base_ = 'configs/detection/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'
custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)
model = dict(
backbone=dict(
plugins = [
dict(
position='after_conv3',
#cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
#cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
#cfg = dict(type='SEAttention', reduction=8)
#cfg = dict(type='ECAAttention', kernel_size=3)
#cfg = dict(type='ShuffleAttention', G=8)
#cfg = dict(type='SpatialGroupEnhance', groups=8)
#cfg = dict(type='DoubleAttention')
#cfg = dict(type='SequentialPolarizedSelfAttention')
#cfg = dict(type='CoTAttention', kernel_size=3)
#cfg = dict(type='TripletAttention')
#cfg = dict(type='CoordAtt', reduction=32)
#cfg = dict(type='ParNetAttention')
)
]
)
)
想使用哪种注意力机制,放开plugins中的注释即可。
以mask-rcnn_r50为例,创建mask-rcnn_r50_fpn_1x_coco_attention.py
_base_ = 'configs/segmentation/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py'
custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)
model = dict(
backbone=dict(
plugins = [
dict(
position='after_conv3',
#cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
#cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
#cfg = dict(type='SEAttention', reduction=8)
#cfg = dict(type='ECAAttention', kernel_size=3)
#cfg = dict(type='ShuffleAttention', G=8)
#cfg = dict(type='SpatialGroupEnhance', groups=8)
#cfg = dict(type='DoubleAttention')
#cfg = dict(type='SequentialPolarizedSelfAttention')
#cfg = dict(type='CoTAttention', kernel_size=3)
#cfg = dict(type='TripletAttention')
#cfg = dict(type='CoordAtt', reduction=32)
#cfg = dict(type='ParNetAttention')
)
]
)
)
用法一样!