声明:笔记是做项目时根据B站博主视频学习时自己编写,请勿随意转载!
一、站在巨人的肩膀上
SE模块即Squeeze-and-Excitation 模块,这是一种常用于卷积神经网络中的注意力机制!!
借鉴代码的代码链接如下:
注意力机制-SEhttps://github.com/ZhugeKongan/Attention-mechanism-implementation
需要model里面的SE_block.py文件
# -*- coding: UTF-8 -*-
"""
SE structure
"""
import torch.nn as nn # 导入PyTorch的神经网络模块
import torch.nn.functional as F # 导入PyTorch的神经网络功能函数模块
class SE(nn.Module): # 定义一个名为SE的类,该类继承自PyTorch的nn.Module,表示一个神经网络模块
def __init__(self, in_chnls, ratio): # 初始化函数,in_chnls表示输入通道数,ratio表示压缩比率
super(SE, self).__init__() # 调用父类nn.Module的初始化函数
# 使用AdaptiveAvgPool2d将输入的空间维度压缩为1x1,即全局平均池化
self.squeeze = nn.AdaptiveAvgPool2d((1, 1))
# 使用1x1卷积将通道数压缩为原来的1/ratio,实现特征压缩
self.compress = nn.Conv2d(in_chnls, in_chnls // ratio, 1, 1, 0)
# 使用1x1卷积将通道数扩展回原来的in_chnls,实现特征激励
self.excitation = nn.Conv2d(in_chnls // ratio, in_chnls, 1, 1, 0)
def forward(self, x): # 定义前向传播函数
out = self.squeeze(x) # 对输入x进行全局平均池化
out = self.compress(out) # 对池化后的输出进行特征压缩
out = F.relu(out) # 对压缩后的特征进行ReLU激活
out = self.excitation(out) # 对激活后的特征进行特征激励
# 对激励后的特征应用sigmoid函数,然后与原始输入x进行逐元素相乘,实现特征重标定
return x*F.sigmoid(out)
代码后面有附注的注释(GPT解释的,很好用),理解即可。对于使用者来说,重要关注点还是它的输入通道、输出通道、需要传入的参数等!!这个函数整体传入in_chnls, ratio两个参数。
二、开始修改网络结构
与上节的C2f修改基本流程一致,但稍有不同
- model/common.py加入新增的SE网络结构,直接复制粘贴如下,这里加在了上节的C2f之前:
上面说到这个函数整体传入in_chnls, ratio两个参数!!
- model/yolo.py设定网络结构的传参细节
上期的C2f模块之所以可以参照原本存在的C3模块属性,是因为两者相似,但这里的SE模块就不可简单的在C3x后加SE,而是需要在下面加入一段elif代码:
elif m is SE:
c1 = ch[f]
c2 = args[0]
if c2 != no: # if not output
c2 = make_divisible(c2 * gw, 8)
args = [c1, args[1]]
即当新引入的模块中存在输入输出维度时,需要使用gw调整输出维度!!
- model/yolov5s.yaml设定现有模型结构配置文件
老样子,复制一份新的配置文件命名为yolov5s-se.yaml。首先需要在backbone的最后加上SE模块(相当于多了一层为第10层);其次考虑到backbone里多了一层,且在head里的输入层来源不止上一层(-1)一个,所以输入层来源大于等于第10层的都需要改为往后递推+1层。下图左边为原始的yaml配置文件,右侧为修改后的:
即当yaml文件引入新的层后,需要修改模型结构的from参数(上期是将C3替换为C2f模块,所以不涉及这一点)!!
- train.py训练时指定模型结构配置文件
这次将parse_model函数里的第二个参数cfg改为yolov5s-se.yaml即可,运行train.py开始训练!!
可见训练时第10层已经引入了SE注意力机制模块:
100次迭代后结果如下,结果保存在runs\train\exp12文件夹,文件夹里有很多指标曲线可对比分析:
往期精彩
STM32专栏(9.9)http://t.csdnimg.cn/A3BJ2
OpenCV-Python专栏(9.9)http://t.csdnimg.cn/jFJWe
AI底层逻辑专栏(9.9)http://t.csdnimg.cn/6BVhM
机器学习专栏(免费)http://t.csdnimg.cn/ALlLlSimulink专栏(免费)http://t.csdnimg.cn/csDO4电机控制专栏(免费)http://t.csdnimg.cn/FNWM7