YOLOv8改进 | 注意力机制 | 结合静态和动态上下文信息的注意力机制

秋招面试专栏推荐 :深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有50+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


上下文Transformer(CoT)块是一种新颖的Transformer风格模块,用于视觉识别。它充分利用输入键之间的上下文信息来指导动态注意力矩阵的学习,从而加强了视觉表示的能力。CoT块首先通过3×3卷积对输入键进行上下文化编码,得到输入的静态上下文表示。然后,将编码后的键与输入查询连接起来,通过两个连续的1×1卷积来学习动态的多头注意力矩阵。最后,将静态和动态上下文表示的融合作为输出。CoT块可以轻松替换ResNet架构中的每个3×3卷积,产生一个名为上下文Transformer网络(CoTNet)的Transformer风格的主干网络。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改并将修改后的完整代码放在文章的最后方便大家一键运行小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1.原理

2. 将CoTAttention添加到YOLOv8中

2.1 CoTAttention代码实现

2.2 更改init.py文件

2.3 添加yaml文件

2.4 在task.py中进行注册

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6.总结


1.原理

论文地址:Contextual Transformer Networks for Visual Recognition——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

上下文 Transformer (CoT) 注意力是一种新颖的 Transformer 式模块,旨在增强视觉识别任务。以下是根据提供的文档对其主要原理的解释:

CoT 注意力的主要原理

1. 键的上下文编码

  • CoT 首先使用 3×3 卷积对输入键进行上下文编码。此步骤捕获输入特征图中本地邻居之间的静态上下文,从而产生静态上下文表示。

2. 动态注意力矩阵

  • 然后将上下文化的键与输入查询连接起来。此组合表示通过两个连续的 1×1 卷积来学习动态多头注意力矩阵。此步骤结合了查询-键关系和静态上下文以进行自注意力学习。

3. 动态上下文表示

  • 学习到的注意力矩阵用于加权输入值,从而产生动态上下文表示,从输入中捕获动态上下文。

4. 静态和动态上下文融合

  • 静态和动态上下文表示融合在一起,形成 CoT 块的最终输出。这种组合利用了通过自注意力学习到的局部邻域信息和更广泛的上下文。

优势和实现

  • 与 ResNet 集成

  • CoT 块可以替代 ResNet 架构中的 3×3 卷积,而无需增加参数数量或计算开销,从而创建了一个名为上下文 Transformer 网络 (CoTNet) 的新主干。

  • 性能提升

  • 与传统卷积网络和其他基于 Transformer 的架构相比,CoTNet 在各种任务(包括图像识别、对象检测和实例分割)中表现出色。

与传统自注意力的比较

  • 传统自注意力

  • 根据每个空间位置上的孤立查询键对来测量注意力,通常忽略相邻键之间的丰富上下文。

  • CoT Attention

  • 通过 3×3 卷积整合相邻键的静态上下文,并通过 1×1 卷积考虑组合查询和上下文化键来增强动态上下文学习。

视觉表示

  • 传统自注意力模块

  • 通常涉及使用查询和键之间的成对交互来计算注意力矩阵,而不考虑键之间的空间上下文。

  • CoT 模块

  • 涉及额外的 3×3 卷积步骤以进行键之间的上下文挖掘,然后进行动态注意力矩阵学习和上下文融合。

通过利用静态和动态上下文信息,CoT Attention 可以更全面地理解输入特征图,从而提高视觉识别能力。

2. 将CoTAttention添加到YOLOv8中

2.1 CoTAttention代码实现

关键步骤一: 将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/block.py中,并在该文件的__all__中添加“CoTAttention”

class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.SiLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.SiLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

上下文转换器 (CoT) 注意力机制通过整合输入键之间的上下文信息来增强图像处理。以下是使用 CoT 注意力机制进行图像处理的主要工作流程的详细说明:

使用 CoT 注意力机制进行图像处理的主要工作流程

1. 输入特征图

  • 从大小为 (H \times W \times C) 的输入特征图 (X) 开始,其中 (H) 为高度,(W) 为宽度,(C) 为通道数。

2. 键的上下文编码

  • 对输入键应用 3×3 卷积以捕获本地邻居之间的静态上下文。这会产生一个表示上下文化键的新特征图: K_{contextual} = \text{Conv3x3}(X)

3. 与查询连接

  • 将上下文化键 (K{contextual}) 与输入查询 (Q) 连接起来。这种组合表示结合了原始输入和上下文信息: Q{concat} = \text{Concat}(Q, K_{contextual})

4. 动态注意矩阵学习

  • 将连接表示 (Q{concat}) 传递到两个连续的 1×1 卷积,以学习动态多头注意矩阵: A{dynamic} = \text{Conv1x1}(\text{Conv1x1}(Q_{concat}))

5. 动态上下文表示

  • 使用学习到的注意矩阵 (A{dynamic}) 加权输入值 (V),产生动态上下文表示。此步骤根据查询和键之间的关系捕获动态上下文:V{dynamic} = A_{dynamic} \cdot V

6. 静态和动态上下文融合

  • 将静态上下文表示 (K{contextual}) 与动态上下文表示 (V{dynamic}) 相结合以形成最终输出。此融合利用了局部和更广泛的上下文信息:\text{Output} = \text{Fuse}(K{contextual}, V{dynamic})

详细步骤

3×3 卷积用于上下文编码

  • 3×3 卷积扫描输入特征图以捕获相邻键之间的空间关系,从而创建反映局部依赖关系的静态上下文。

1×1 卷积用于注意力矩阵:

  • 两个连续的 1×1 卷积对连接的查询和上下文化键进行操作,以学习动态注意力矩阵,这有助于根据上下文相关性对输入值进行加权。

注意力机制:

  • CoT 中的注意力机制与传统的自注意力不同,它将静态上下文纳入动态注意力计算中,从而产生更强大、更能感知上下文的注意力矩阵。

融合机制:

  • 最后的融合步骤结合了静态和动态表示,确保模型既能从局部上下文(通过 3×3 卷积)中受益,也能从动态交互(通过学习注意力)中受益。

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_CoTA.yaml文件,粘贴下面的内容

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [ 0.33, 0.25, 1024 ]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, Conv, [ 64, 3, 2 ] ]  # 0-P1/2
  - [ -1, 1, Conv, [ 128, 3, 2 ] ]  # 1-P2/4
  - [ -1, 3, C2f, [ 128, True ] ]
  - [ -1, 1, Conv, [ 256, 3, 2 ] ]  # 3-P3/8
  - [ -1, 6, C2f, [ 256, True ] ]
  - [ -1, 1, Conv, [ 512, 3, 2 ] ]  # 5-P4/16
  - [ -1, 6, C2f, [ 512, True ] ]
  - [ -1, 1, Conv, [ 1024, 3, 2 ] ]  # 7-P5/32
  - [ -1, 3, C2f, [ 1024, True ] ]
  - [ -1, 1, SPPF, [ 1024, 5 ] ]  # 9

# YOLOv8.0n head
head:
  - [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]
  - [ [ -1, 6 ], 1, Concat, [ 1 ] ]  # cat backbone P4
  - [ -1, 3, C2f, [ 512 ] ]  # 12

  - [ -1, 1, nn.Upsample, [ None, 2, 'nearest' ] ]
  - [ [ -1, 4 ], 1, Concat, [ 1 ] ]  # cat backbone P3
  - [ -1, 3, C2f, [ 256 ] ]  # 15 (P3/8-small)

  - [ -1, 1, Conv, [ 256, 3, 2 ] ]
  - [ [ -1, 12 ], 1, Concat, [ 1 ] ]  # cat head P4
  - [ -1, 3, C2f, [ 512 ] ]  # 18 (P4/16-medium)
  - [ -1, 1, CoTAttention, [ 512 ] ]    # CoTAttention https://arxiv.org/abs/2107.12292

  - [ -1, 1, Conv, [ 512, 3, 2 ] ]
  - [ [ -1, 9 ], 1, Concat, [ 1 ] ]  # cat head P5
  - [ -1, 3, C2f, [ 1024 ] ]  # 21 (P5/32-large)
  - [ -1, 1, CoTAttention, [ 1024 ] ]

  - [ [ 15, 19, 23 ], 1, Detect, [ nc ] ]  # Detect(P3, P4, P5)

温馨提示:本文只是对yolov8基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv8n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
max_channels: 512 # max_channels
 
# YOLOv8m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
max_channels: 768 # max_channels
 
# YOLOv8x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
max_channels: 512 # max_channels

2.4 在task.py中进行注册

关键步骤四:在task.py的parse_model函数中进行注册,

elif m is CoTAttention:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

2.5 执行程序

关键步骤五:在ultralytics文件中新建train.py,将model的参数路径设置为yolov8_CoTA.yaml的路径即可

from ultralytics import YOLO
 
# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
 
model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_CoTA.yaml')  # build from YAML and transfer weights
 
# Train the model
model.train( batch=16)

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1wyZnAKNVMhMhuaH2lDJ7jA?pwd=yzye 

 提取码:yzye 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

img

改进后的GFLOPs

5. 进阶

可以结合损失函数或者卷积模块进行多重改进

6.总结

上下文变换注意 (CoTAttention) 是一种新颖的机制,旨在通过整合静态和动态上下文信息来增强视觉识别任务。它首先对输入键应用 3×3 卷积,以捕获本地邻居之间的静态上下文。然后将上下文化的键与输入查询连接起来,并将此组合表示通过两个连续的 1×1 卷积来学习动态多头注意矩阵。此矩阵用于加权输入值,从而产生动态上下文表示。最后,将静态和动态上下文表示融合以形成最终输出。此过程使 CoTAttention 能够利用通过自注意力学习到的局部邻域信息和更广泛的上下文,从而提高图像识别、对象检测和实例分割任务的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/778148.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

207 课程表

题目 你这个学期必须选修 numCourses 门课程,记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出,其中 prerequisites[i] [ai, bi] ,表示如果要学习课程 ai 则 必须 先学习课程 bi 。 …

跨越语言的界限:Vue I18n 国际化指南

前言 📫 大家好,我是南木元元,热爱技术和分享,欢迎大家交流,一起学习进步! 🍅 个人主页:南木元元 目录 国际化简介 vue-i18n 安装和配置 创建语言包 基本使用 切换语言 动态翻…

使用Python绘制堆积柱形图

使用Python绘制堆积柱形图 堆积柱形图效果代码 堆积柱形图 堆积柱形图(Stacked Bar Chart)是一种数据可视化图表,用于显示不同类别的数值在某一变量上的累积情况。每一个柱状条显示多个子类别的数值,子类别的数值在柱状条上堆积在…

电商视角如何理解动态IP与静态IP

在电子商务的蓬勃发展中,网络基础设施的稳定性和安全性是至关重要的。其中,IP地址作为网络设备间通信的基础,扮演着举足轻重的角色。从电商的视角出发,我们可以将动态IP和静态IP比作电商平台上不同类型的店铺安排,以此…

数据结构1:C++实现边长数组

数组作为线性表的一种,具有内存连续这一特点,可以通过下标访问元素,并且下标访问的时间复杂的是O(1),在数组的末尾插入和删除元素的时间复杂度同样是O(1),我们使用C实现一个简单的边长数组。 数据结构定义 class Arr…

C++(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例

C(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例 文章目录 C(Qt)-GIS开发-QGraphicsView显示瓦片地图简单示例1、概述2、实现效果3、主要代码4、源码地址 更多精彩内容👉个人内容分类汇总 👈👉GIS开发 👈 1、概述 支持多线程加…

系统安全与应用

目录 1. 系统账户清理 2. 密码安全性控制 2.1 密码复杂性 2.2 密码时限 3 命令历史查看限制 4. 终端自动注销 5. su权限以及sudo提权 5.1 su权限 5.2 sudo提权 6. 限制更改GRUB引导 7. 网络端口扫描 那天不知道为什么,心血来潮看了一下passwd配置文件&am…

在 PostgreSQL 中,如何处理大规模的文本数据以提高查询性能?

文章目录 一、引言二、理解 PostgreSQL 中的文本数据类型三、数据建模策略四、索引选择与优化五、查询优化技巧六、示例场景与性能对比七、分区表八、数据压缩九、定期维护十、总结 在 PostgreSQL 中处理大规模文本数据以提高查询性能 一、引言 在当今的数据驱动的世界中&…

Android 集成OpenCV

记录自己在学习使用OpenCV的过程 我使用的是4.10.0 版本 Android 集成OpenCV 步骤 下载OpenCV新建工程依赖OpenCV初始化及逻辑处理 1、下载OpenCV 并解压到自己的电脑 官网 地址:https://opencv.org/releases/ 个人地址:https://pan.baidu.com/s/19f…

前端必修技能:高手进阶核心知识分享 - CSS mix-blend-mode 图片混合模式详解

标签定义及使用说明 mix-blend-mode 属性描述了元素的内容应该与元素的直系父元素的内容和元素的背景如何混合。 语法 mix-blend-mod: 使用mix-blend-mode 各种混合模式实例 注意: Internet Explorer 或 Edge 浏览器不支持 mix-blend-mode 属性。 (还是那个熟…

收银系统源码-千呼新零售2.0

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货、宠物等连锁店使用。 详细介绍请…

24-7-6-读书笔记(八)-《蒙田随笔集》[法]蒙田 [译]潘丽珍

文章目录 《蒙田随笔集》阅读笔记记录总结 《蒙田随笔集》 《蒙田随笔集》蒙田(1533-1592),是个大神人,这本书就是250页的样子,但是却看了好长好长时间,体会还是挺深的,但看的也是不大仔细&…

【Oracle】Oracle常用函数

目录 聚合函数数字函数1. ABS函数:返回一个数的绝对值。2. CEIL函数:返回大于等于给定数的最小整数。3. FLOOR函数:返回小于等于给定数的最大整数。4. ROUND函数:将一个数四舍五入到指定的小数位。5. MOD函数:返回两个…

Ubuntu固定虚拟机的ip地址

1、由于虚拟机网络是桥接,所以ip地址会不停地变化,接下来我们就讲述ip如何固定 2、如果apt安装时报错W: Target CNF (multiverse/cnf/Commands-all) is configured multiple times in /etc/apt/sources.list:10, 检查 /etc/apt/sources.list…

SpringBoot新手快速入门系列教程二:MySql5.7.44的免安装版本下载和配置,以及简单的Mysql生存指令指南。

我们要如何选择MySql 目前主流的Mysql有5.0、8.0、9.0 主要区别 MySQL 5.0 发布年份:2005年特性: 基础事务支持存储过程、触发器、视图基础存储引擎(如MyISAM、InnoDB)外键支持基本的全文搜索性能和扩展性: 相对较…

HTML+CSS+JavaScript入门学习

目录 1. 前言2. HTML2.1 HTML简介2.2 HTML标签 3. CSS3.1 CSS知识整理及总结3.2 CSS之flex布局 4. JavaScript4.1 JavaScript知识整理及总结1-基础篇4.2 JavaScript知识整理及总结2-进阶篇 1. 前言 本文主要采用转载的形式,偶尔发现了一个比较不错的博客站点&#…

华为ENSP防火墙+路由器+交换机的常规配置

(防火墙区域DHCP基于接口DHCP中继服务器区域有线区域无线区域)配置 一、适用场景: 1、普通企业级网络无冗余网络环境,防火墙作为边界安全设备,分trust(内部网络信任区域)、untrust(外部网络非信…

计算机网络-IP组播基础

一、概述 在前面的学习交换机和路由协议,二层通信是数据链路层间通信,在同一个广播域间通过源MAC地址和目的MAC地址进行通信,当两台主机第一次通信由于不清楚目的MAC地址需要进行广播泛洪,目的主机回复自身MAC地址,然后…

JSP WEB开发(一) JSP语言基础

目录 JSP JSP简介: JSP页面 JSP运行原理 JSP脚本元素 JAVA程序片 局部变量 全局变量和方法的声明 全局变量 方法的声明 程序片执行特点 synchronized关键字 表达式 JSP指令标记 page指令 include指令 JSP动作标记 JSP动作元素include和include指令的…

【C++】B树及其实现

写目录 一、B树的基本概念1.引入2.B树的概念 二、B树的实现1.B树的定义2.B树的查找3.B树的插入操作4.B树的删除5.B树的遍历6.B树的高度7.整体代码 三、B树和B*树1.B树2.B*树3.总结 一、B树的基本概念 1.引入 我们已经学习过二叉排序树、AVL树和红黑树三种树形查找结构&#x…