一、本文介绍
本文给大家带来的改进是Triplet Attention三重注意力机制。这个机制,它通过三个不同的视角来分析输入的数据,就好比三个人从不同的角度来观察同一幅画,然后共同决定哪些部分最值得注意。三重注意力机制的主要思想是在网络中引入了一种新的注意力模块,这个模块包含三个分支,分别关注图像的不同维度。比如说,一个分支可能专注于图像的宽度,另一个分支专注于高度,第三个分支则聚焦于图像的深度,即色彩和纹理等特征。这样一来,网络就能够更全面地理解图像内容,就像是得到了一副三维眼镜,能够看到图片的立体效果一样。
推荐指数:⭐⭐⭐⭐
专栏回顾:YOLOv5改进专栏——持续复现各种顶会内容——内含100+创新
训练结果对比图->
二、Triplet Attention机制原理
论文地址:官方论文地址
代码地址:官方代码地址
2.1 Triplet Attention的基本原理
三重注意力(Triplet Attention)的基本原理是利用三支结构捕获输入数据的跨维度交互,从而计算注意力权重。这个方法能够构建输入通道或空间位置之间的相互依赖性,而且计算代价小。三重注意力由三个分支组成,每个分支负责捕获空间维度H或W与通道维度C之间的交互特征。通过对每个分支中的输入张量进行排列变换,然后通过Z池操作和一个大小为k×k的卷积层,生成注意力权重。这些权重是通过一个S形激活层生成的,然后应用于排列变换后的输入张量,再变换回原来的输入形状
三重注意力(Triplet Attention)的主要改进点包括:
-
跨维度的注意力权重计算: 通过一个创新的三支结构捕获通道、高度、宽度三个维度之间的交互关系来计算注意力权重。
-
旋转操作和残差变换: 通过旋转输入张量和应用残差变换来建立不同维度间的依赖,这是三重注意力机制中的关键步骤。
-
维度间依赖性的重要性: 强调在计算注意力权重时,捕获跨维度依赖性的重要性,这是三重注意力的核心直觉和设计理念。
下面的图片是三重注意力的一个抽象表示图,展示了三个分支如何捕获跨维度交互。图中的每个子图表示三重注意力中的一个分支:
1. 分支(a): 这个分支直接处理输入张量,没有进行旋转,然后通过残差变换来提取特征。
2. 分支(b): 这个分支首先沿着宽度(W)和通道(C)的维度旋转输入张量,然后进行残差变换。
3. 分支(c): 这个分支沿着高度(H)和通道(C)的维度旋转输入张量,之后同样进行残差变换。
总结:通过这样的设计,三重注意力模型能够有效地捕获输入张量中的空间和通道维度之间的交互关系。这种方法使模型能够构建通道与空间位置之间的相互依赖性,提高模型对特征的理解能力。
2.2 Triplet Attention和其它简单注意力机制的对比
下面的图片是论文中三重注意力机制和其它注意力机制的一个对比大家有兴趣可以看看,横向扩展以下自己的知识库。
这张图片是一幅对比不同注意力模块的图示,其中包括:
1.Squeeze Excitation (SE) Module:
这个模块使用全局平均池化 (Global Avg Pool) 生成通道描述符,接着通过两个全连接层(1x1 Conv),中间使用ReLU激活函数,最后通过Sigmoid函数生成每个通道的权重。
2. Convolutional Block Attention Module (CBAM):
首先使用全局平均池化和全局最大池化(GAP + GMP)结合,再通过一个卷积层和ReLU激活函数,最后经过另一个卷积层和Sigmoid函数生成注意力权重。
3. Global Context (GC) Module:
从一个1x1卷积层开始,经过Softmax函数进行归一化,接着进行另一个1x1卷积,然后使用LayerNorm和最终的1x1卷积,通过广播加法结合原始特征图。
4. Triplet Attention (我们的方法):
分为三个分支,每个分支进行不同的处理:通道池化后的7x7卷积,Z池化,再接一个7x7卷积,然后是批量归一化和Sigmoid函数。每个分支都有一个Permute操作来调整维度。最后,三个分支的结果通过平均池化聚合起来生成最终的注意力权重。
每种模块都设计用于处理特征图(C x H x W),其中C是通道数,H是高度,W是宽度。这些模块通过不同方式计算注意力权重,增强网络对特征的重要部分的关注度,从而在各种视觉任务中提高性能。图片中的符号⊗代表矩阵乘法,⊕代表广播元素级加法。
2.3 Triplet Attention的实现流程
下面的图片是三重注意力(Triplet Attention)的具体实现流程图。图中详细展示了三个分支如何处理输入张量,并最终合成三重注意力。下面是对这个过程的描述:
-
上部分支: 负责计算通道维度C和空间维度W的注意力权重。这个分支对输入张量进行Z池化(Z-Pool)操作,然后通过一个卷积层(Conv),接着用Sigmoid函数生成注意力权重。
-
中部分支: 负责捕获通道维度C与空间维度H和W之间的依赖性。这个分支首先进行相同的Z池化和卷积操作,然后同样通过Sigmoid函数生成注意力权重。
-
下部分支: 用于捕获空间维度之间的依赖性。这个分支保持输入的身份(Identity,即不改变输入),执行Z池化和卷积操作,之后也通过Sigmoid函数生成注意力权重。
每个分支在生成注意力权重后,会对输入进行排列(Permutation),然后将三个分支的输出进行平均聚合(Avg),最终得到三重注意力输出。
这种结构通过不同的旋转和排列操作,能够综合不同维度上的信息,更好地捕获数据的内在特征,同时这种方法在计算上是高效的,并且可以作为一个模块加入到现有的网络架构中,增强网络对复杂数据结构的理解和处理能力。
三、Triplet Attention的核心代码
我们找到如下的目录'yolov5-master/models'在这个目录下创建一个文件目录(注意是目录,因为我这个专栏会出很多的更新,这里用一种一劳永逸的方法)文件目录起名modules,然后在下面新建一个文件,将我们的代码复制粘贴进去。
import torch
import torch.nn as nn
from ..common import Conv
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ZPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)
class AttentionGate(nn.Module):
def __init__(self):
super(AttentionGate, self).__init__()
kernel_size = 7
self.compress = ZPool()
self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.conv(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
def __init__(self, no_spatial=False):
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial = no_spatial
if not no_spatial:
self.hw = AttentionGate()
def forward(self, x):
x_perm1 = x.permute(0, 2, 1, 3).contiguous()
x_out1 = self.cw(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
x_perm2 = x.permute(0, 3, 2, 1).contiguous()
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
if not self.no_spatial:
x_out = self.hw(x)
x_out = 1 / 3 * (x_out + x_out11 + x_out21)
else:
x_out = 1 / 2 * (x_out11 + x_out21)
return x_out
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.Dattention = TripletAttention()
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.Dattention(self.cv2(self.cv1(x))) if self.add else self.Dattention(self.cv2(self.cv1(x)))
class C3_TripleA(nn.Module):
# CSP Bottleneck with 3 convolutions
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
c_ = int(c2 * e) # hidden channels
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
四、手把手教你添加Triplet Attention
4.1 细节修改教程
4.1.1 修改一
我们找到如下的目录'yolov5-master/models'在这个目录下创建一个文件目录(注意是目录,因为我这个专栏会出很多的更新,这里用一种一劳永逸的方法)文件目录起名modules,然后在下面新建一个文件,将我们的代码复制粘贴进去。
4.1.2 修改二
然后新建一个__init__.py文件,然后我们在里面添加一行代码。注意标记一个'.'其作用是标记当前目录。
4.1.3 修改三
然后我们找到如下文件''models/yolo.py''在开头的地方导入我们的模块按照如下修改->
(如果你看了我多个改进机制此处只需要添加一个即可,无需重复添加)
4.1.4 修改四
然后我们找到parse_model方法,按照如下修改->
到此就修改完成了,复制下面的ymal文件即可运行。
4.2 Triplet Attention的yaml文件
4.2.1 Triplet Attention的yaml文件一
下面的配置文件为我修改的Triplet Attention的位置,参数的位置里面什么都不用添加空着就行,大家复制粘贴我的就可以运行,同时我提供多个版本给大家,根据我的经验可能涨点的位置。
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3_TripleA, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3_TripleA, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3_TripleA, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3_TripleA, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3_TripleA, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3_TripleA, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3_TripleA, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3_TripleA, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
4.2.2 Triplet Attention的yaml文件二
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3_TripleA, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3_TripleA, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3_TripleA, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
4.2.3 Triplet Attention的yaml文件三
注意此版本的我再大目标,小目标,中目标三个曾的后面添加了一个注意力机制,此版本需要显存较大,可以根据自己的需求增删,如果修改大家要注意修改Detect里面的检测层数。
# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license
# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.25 # layer channel multiple
anchors:
- [10,13, 16,30, 33,23] # P3/8
- [30,61, 62,45, 59,119] # P4/16
- [116,90, 156,198, 373,326] # P5/32
# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, TripletAttention, []], # 18
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 21 (P4/16-medium)
[-1, 1, TripletAttention, []], # 22
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 25 (P5/32-large)
[-1, 1, TripletAttention, []], # 26
[[18, 22, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]
4.3 Triplet Attention运行成功截图
附上我的运行记录确保我的教程是可用的。
4.4 推荐Triplet Attention可添加的位置
Triplet Attention是一种即插即用的可替换注意力机制的模块,其可以添加的位置有很多,添加的位置不同效果也不同,所以我下面推荐几个添加的位,置大家可以进行参考,当然不一定要按照我推荐的地方添加。
残差连接中:在残差网络的残差连接中加入Triplet Attention(yaml文件一)。
Neck部分:YOLOv8的Neck部分负责特征融合,这里添加修改后的C3_TripletA可以帮助模型更有效地融合不同层次的特征(yaml文件二)。
检测头:可以再检测头前面添加(yaml文件三)
检测头中:可以再检测头的内部添加该机制(未提供因为需要修改检测头比较麻烦,后期专栏收费后大家购买专栏之后大家会得到一个包含上百个机制的v5文件里面包含所有的改进机制)
五、本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv5改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~
专栏回顾:YOLOv5改进专栏——持续复现各种顶会内容——内含100+创新