YOLOv5改进 | 注意力机制 | 理解全局和局部信息的SE注意力机制

在深度学习目标检测领域,YOLOv5成为了备受关注的模型之一。本文给大家带来的是能够理解全局和局部信息的SE注意力机制。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。


专栏地址YOLOv5改进+入门——持续更新各种有效涨点方法

目录

1. 原理

2.YOLOv5添加SE模块

2.1 SE模块代码

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3.总结


1. 原理

论文地址:Squeeze-and-Excitation Networks点击即可跳转

官方代码:SENet官方代码地址点击即可跳转

对于CNN网络来说,核心计算是卷积算子,通过卷积核从输入特征图学习到新特征图。从本质上讲,卷积是对一个局部区域进行特征融合,这包括空间上(H和W维度)以及通道间(C维度)的特征融合。

对于卷积操作,很大一部分改进工作是提高感受野,即空间上融合更多特征,或者是提取多尺度空间信息,如Inception网络的多分支结构。对于channel维度的特征融合,卷积操作基本上默认对输入特征图的所有channel进行融合。而MobileNet网络中的组卷积(Group Convolution)和深度可分离卷积(Depthwise Separable Convolution)对channel进行分组也主要是为了使模型更加轻量级,减少计算量。

SENet网络的创新点在于关注channel之间的关系,希望模型可以自动学习到不同channel特征的重要程度。为此,SENet提出了Squeeze-and-Excitation (SE)模块,如下图所示

SE模块首先对卷积得到的特征图进行Squeeze操作,得到channel级的全局特征(SENet提出Squeeze操作,将一个channel上整个空间特征编码为一个全局特征,采用global average pooling来实现),然后对全局特征进行Excitation操作(这里采用sigmoid形式的gating机制),学习各channel间的关系,得到不同channel的权重,最后乘以原来的特征图得到最终特征。本质上,SE模块是在channel维度上做attention或者gating操作,这种注意力机制让模型可以更加关注信息量最大的channel特征,而抑制不重要的channel特征。另外一点是SE模块是通用的,意味着其可以嵌入到现有的网络架构中。

2.YOLOv5添加SE模块

2.1 SE模块代码

关键步骤一:将下面代码添加到 yolov5/models/common.py中

class AdaptiveAvgPool2d(_AdaptiveAvgPoolNd):

    def forward(self, input: Tensor) -> Tensor:
        return F.adaptive_avg_pool2d(input, self.output_size)

class SE(nn.Module):
    def __init__(self, c1, c2, ratio=16):
        super(SE, self).__init__()
        # c*1*1
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

SE类是一种用于卷积神经网络中的注意力机制模块,旨在增强网络对输入特征的关注。它通过学习得到每个通道的权重,以自适应的方式调整特征图,从而提升网络性能。

构成部分:
1. Squeeze(压缩)阶段:
   通过全局平均池化操作将特征图压缩为一个向量,以捕获每个通道的全局特征信息。
   然后通过一个或多个全连接层对压缩后的特征向量进行变换,以获得通道相关的表示。

2. Excitation(激励)阶段:
   使用激活函数(通常是ReLU)对变换后的特征进行非线性映射,以增强模型的表达能力。
   再通过一个或多个全连接层对特征进行进一步的变换,以生成通道注意力权重。

3. Scale(缩放)阶段:
   使用Sigmoid函数将通道注意力权重归一化到0到1之间,以表示每个通道的重要性。
   将归一化后的权重乘以原始特征图,以对特征进行加权。

流程:

1. 接收输入张量x,该张量的形状为(batch_size, channels, height, width)。
2. 通过全局平均池化操作将每个通道的特征进行压缩,得到一个大小为(batch_size, channels)的全局特征向量。
3. 将全局特征向量通过一个或多个全连接层进行变换,得到每个通道的权重。
4. 对权重进行非线性变换(通常是ReLU激活函数),以增强模型的表达能力。
5. 再次通过全连接层对特征进行变换,得到每个通道的注意力权重。
6. 使用Sigmoid函数将注意力权重归一化到0到1之间,表示每个通道的重要性。
7. 将归一化后的注意力权重与原始特征图相乘,得到加权后的特征表示。
8. 返回加权后的特征张量。

目的:
SE模块的主要目的是通过自适应地调整通道的注意力权重,使模型能够更有效地捕获输入特征的重要信息,从而提高模型的性能和泛化能力。

SE类的实现包括两个主要部分:Squeeze阶段和Excitation阶段,它们共同作用于输入特征,以产生加权后的特征表示,从而增强了模型的表达能力和性能。

2.2 新增yaml文件

关键步骤二:在 /yolov5/models/ 下新建文件 yolov5_se.yaml文件并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SE, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

 温馨提示:因为本文只是对yolov5s基础上添加SE模块,如果要对yolov5n/l/m/x进行添加则只需要修改对应的depth_multiple 和 width_multiple。


2.3 注册模块

关键步骤三:在yolov5/models/yolo.py中注册

 大概在250行左右添加 ‘SE’

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_se.yaml的路径,如下图所示

 运行程序,如果出现下面的内容则说明添加成功🚀

我修改后的代码:链接: https://pan.baidu.com/s/1FzumcKWlHWHxJGWNQNepyQ?pwd=wikd 提取码: wikd

3.总结

SE(Squeeze-and-Excitation)模块是一种用于增强卷积神经网络特征表示的重要机制。这一模块通过引入通道注意力机制,使得网络能够自适应地调整每个通道的重要性,从而提高了模型的性能和泛化能力。

SE模块包含两个关键阶段:Squeeze阶段和Excitation阶段。在Squeeze阶段,通过全局平均池化操作将每个通道的特征压缩为一个全局特征向量,以捕获全局特征信息。在Excitation阶段,通过一系列线性变换和非线性激活函数,计算每个通道的注意力权重。这些权重用于调整原始特征图中每个通道的重要性,使得网络能够更加关注对任务有用的特征。

SE模块的工作流程如下:首先,接收输入特征张量;然后,在Squeeze阶段,通过全局平均池化操作将每个通道的特征压缩为一个全局特征向量;接着,在Excitation阶段,通过线性变换和激活函数计算每个通道的注意力权重;最后,使用这些注意力权重对原始特征进行加权,并返回加权后的特征表示。

SE模块的主要目的是通过自适应地调整每个通道的注意力权重,增强模型对输入特征的表达能力。这使得模型能够更好地捕获输入数据中的有用信息,从而提高了模型的性能和泛化能力。SE模块可以轻松地集成到各种卷积神经网络结构中,例如ResNet、Inception等,以提升它们在图像分类、目标检测、语义分割等任务中的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/611556.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RAG查询改写方法概述

在RAG系统中,用户的查询是丰富多样的,可能存在措辞不准确和缺乏语义信息的问题。这导致使用原始的查询可能无法有效检索到目标文档。 因此,将用户查询的语义空间与文档的语义空间对齐至关重要,目前主要有查询改写和嵌入转换两种方…

代码随想录算法训练营第六十天| LeetCode647. 回文子串 、516.最长回文子序列

一、LeetCode647. 回文子串 题目链接/文章讲解/视频讲解:https://programmercarl.com/0647.%E5%9B%9E%E6%96%87%E5%AD%90%E4%B8%B2.html 状态:已解决 1.思路 这道题我只想出来了暴力解法,动规解法并没有想出来。根据视频讲解才把它想出来。…

【hackmyvm】 Animetronic靶机

靶机测试 arp-scanporturl枚举exiftool套中套passwordsudo 提权 arp-scan arp-scan 检测局域网中活动的主机 192.168.9.203 靶机IP地址port 通过nmap扫描,获取目标主机的端口信息 ┌──(root㉿kali)-[/usr/share/seclists] └─# nmap -sT -sV -O 192.16…

基于springboot+vue+Mysql的体质测试数据分析及可视化设计

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

C语言/数据结构——(相交链表)

一.前言 今天在力扣上刷到了一道题,想着和大家一起分享一下这道题——相交链表https://leetcode.cn/problems/intersection-of-two-linked-lists废话不多说,让我们开始今天的分享吧。 二.正文 1.1题目描述 是不是感觉好长,我也这么觉得。哈…

【智能优化算法】蜜獾优化算法(Honey Badger Algorithm,HBA)

蜜獾优化算法(Honey Badger Algorithm,HBA)是期刊“MATHEMATICS AND COMPUTERS IN SIMULATION”(IF 3.6)的2022年智能优化算法 01.引言 蜜獾优化算法(Honey Badger Algorithm,HBA)受蜜獾智能觅食行为的启发,从数学上发展出一种求解优化问题的…

Linux修炼之路之初识操作系统+基础指令(1)

目录 引言 一:对操作系统(OS)的简单了解 1.操作系统(OS) 是什么 2.操作系统好坏的衡量标准 3.操作系统存在的重要性 4.理解所有在计算机上的操作 二:Linux与windows操作的特点区别 三:基础指令 1.ls 指令 1.使用 2.常用选项 2.…

【17-Ⅰ】Head First Java 学习笔记

HeadFirst Java 本人有C语言基础,通过阅读Java廖雪峰网站,简单速成了java,但对其中一些入门概念有所疏漏,阅读本书以弥补。 第一章 Java入门 第二章 面向对象 第三章 变量 第四章 方法操作实例变量 第五章 程序实战 第六章 Java…

linux进阶篇:Nginx反向代理原理与案例详解

Linux服务搭建篇:Nginx反向代理原理与案例详解 一、什么是正向代理 举个栗子: 我们在校外、公司外,是访问不到学校、公司的内网的,但是我们想要访问内网资源时,会用到VPN。而一般内网会存在一个VPN服务器&#xff0c…

安装vmware station记录

想学一下linux,花了3个多小时,才配置好了,记录一下 安装vm12,已配置linux系统 报错,VMware Workstation 与 Device/Credential Guard 不兼容解决方案,网上说有不成功的,电脑蓝屏,选择装vm16试试 vm16 在…

运维开发工程师教程之MongoDB单机版设置

MongoDB单机版设置 一、创建虚拟机 在VMware Workstation软件中新建一个虚拟机,具体操作步骤如下: ①运行VMware Workstation软件,进入到主界面,单击“创建新的虚拟机”来创建新的虚拟机,如图3-1所示。 图3-1 VMware…

算法设计与分析 动态规划/回溯

1.最大子段和 int a[N]; int maxn(int n) {int tempa[0];int ans0;ansmax(temp,ans);for(int i1;i<n;i){if(temp>0){tempa[i];}else tempa[i];ansmax(temp,ans);}return ans; } int main() {int n,ans0;cin>>n;for(int i0;i<n;i) cin>>a[i];ansmaxn(n);co…

智慧之巅:大数据与算力中心的融合演进

智慧之巅&#xff1a;大数据与算力中心的融合演进 1 引言 在这个数据驱动的时代&#xff0c;我们站在了一个前所未有的历史节点上。大数据和算力中心&#xff0c;这两个曾经各自为政的领域&#xff0c;如今正以一种前所未有的方式交织在一起&#xff0c;共同推动着数字经济的蓬…

PDF高效编辑:一键批量,PDF转图片的快速解决方案

在数字化时代&#xff0c;PDF文件已成为工作和学习中不可或缺的一部分。然而&#xff0c;有时我们可能需要将PDF转换为图片&#xff0c;以便更轻松地编辑、共享或处理。为了满足这一需求&#xff0c;许多高效的PDF编辑工具应运而生&#xff0c;其中“办公提效工具”一键批量PDF…

C++对象的拷贝构造函数

如果一个构造函数的第一个参数是类本身的引用,且没有其它参数(或者其它的参数都有默认值),则该构造函数为拷贝构造函数。 拷贝(复制)构造函数:利用同类对象构造一个新的对象 ●1.函数名和类同名 (构造函数) ●2.没有返回值 (构造函数) ●3.第一个参数必…

【甲辰雜俎】世界上最不可靠的就是人

"世界上最不可靠的就是人" 人是一個多元的複變函數, 今天經受住考驗, 明天你就有可能叛變。 過去是戰場上的仇敵, 明天就有可能成為政治上的盟友。 —— 擷取自電視劇《黑冰》 人的不可預測性, 的確是一個普遍的現象。 每個人都是一個獨特的個體, 受到不同的…

基于WPF的DynamicDataDisplay曲线显示

一、DynamicDataDisplay下载和引用 1.新建项目,下载DynamicDataDisplay引用: 如下图: 二、前端开发: <Border Grid.Row="0" Grid.Column="2" BorderBrush="Purple" BorderThickness="1" Margin="2"><Grid>…

5.11学习记录

20长安杯部分 检材 1 的操作系统版本 CentOS Linux 7.6.1810 (Core) 检材 1 中&#xff0c;操作系统的内核版本是 3.10.0-957.el7.x86_64 检材 1 中磁盘包含一个 LVM 逻辑卷&#xff0c;该 LVM 开始的逻辑区块地址&#xff08;LBA&#xff09;是 2099200 物理卷&#xff…

Web服务器--虚拟主机配置

实验1&#xff1a;建立两个基于ip地址访问的网站&#xff0c;要求如下 该网站ip地址的主机位为100&#xff0c;设置DocumentRoot为/www/ip/100&#xff0c;网页内容为&#xff1a;this is 100。 该网站ip地址主机位为200&#xff0c;设置DocumentRoot为/www/ip/200&#xff0c…

EmploLeaks:一款针对企业安全的组织员工信息收集OSINT工具

关于EmploLeaks EmploLeaks是一款针对企业安全的组织员工信息收集OSINT工具&#xff0c;在该工具的帮助下&#xff0c;企业内部的安全人员和管理员可以有效地收集组织内员工的各种信息&#xff0c;并以此来判断组织内部的网络安全态势。 工作机制 首先&#xff0c;该工具会在…