YOLOv8改进 | 注意力机制 | 在主干网络中添加MHSA模块【原理+附完整代码】

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

多头自注意力机制(Multi-Head Self-Attention)是Transformer模型中的一个核心概念,它允许模型在处理序列数据时同时关注不同的位置和表示子空间。这种机制是“自注意力”(Self-Attention)的一种扩展,自注意力又称为内部注意力(Intra-Attention),是一种注意力机制,用于对序列中的每个位置进行加权,以便在编码每个位置时能够考虑到序列中的其他位置。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1. 原理

2. 多头自注意力机制代码实现

2.1 将MHSA添加到YOLOv8代码中

2.2 更改init.py文件

2.3 添加yaml文件

2.4 在task.py中进行注册

2.5 执行程序

3. 完整代码分享

4. GFLOPs

5. 总结


1. 原理

多头自注意力机制(Multi-Head Self-Attention, MHSA)是深度学习中的一种机制,主要用于提升模型捕捉复杂关系和不同尺度特征的能力。它是自注意力机制的扩展和增强版本,广泛应用于Transformer模型中,如BERT和GPT等。以下是多头自注意力机制的主要原理:

自注意力机制

首先,了解自注意力机制(Self-Attention Mechanism)的基础原理非常重要。在自注意力机制中,输入序列的每个元素(通常是词或词向量)都会根据其与其他元素的相关性进行加权。具体步骤如下:

  1. 输入表示:假设输入序列为x_1, x_2, \ldots, x_n,每个元素是一个d维的向量。

  2. 线性变换:将输入向量通过三个不同的线性变换,得到查询(Query)、键(Key)和值(Value)矩阵。 Q = XW_Q, \quad K = XW_K, \quad V = XW_V其中,W_QW_KW_V是可训练的权重矩阵。

  3. 计算注意力得分:通过点积来计算查询和键之间的相似性(即注意力得分),并应用缩放处理。 \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V其中,(d_k)是键向量的维度,\text{softmax}用于将得分转化为概率分布。

  4. 加权求和:用注意力得分对值向量进行加权求和,得到最终的输出。

多头自注意力机制

多头自注意力机制通过引入多个注意力头,扩展了自注意力机制的表示能力。具体步骤如下:

  1. 多头注意力:将输入向量通过不同的线性变换,得到多个不同的查询、键和值。 \text{head}_i = \text{Attention}(Q_i, K_i, V_i)其中,每个头都有独立的权重矩阵(W{Q_i})(W{K_i})(W_{V_i})

  2. 并行计算:对每个头并行计算注意力,并得到多个不同的输出。 \text{head}_i = \text{Attention}(XW_{Q_i}, XW_{K_i}, XW_{V_i})

  3. 拼接和线性变换:将所有头的输出拼接在一起,并通过一个线性变换得到最终输出。 \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, \ldots, \text{head}_h)W_O 其中,是用于将拼接后的向量映射回原始维度的权重矩阵。

主要优点

  • 捕捉多种特征:多头机制允许模型在不同的子空间中捕捉输入的多种特征和关系。

  • 增强表示能力:通过多头注意力,模型可以同时关注输入序列的不同部分,提高表示的多样性和丰富性。

  • 稳定训练:多头机制还可以缓解单头注意力可能出现的不稳定性问题。

总之,多头自注意力机制通过并行计算多个注意力头,有效增强了模型的表示能力,使其能够更好地捕捉序列数据中的复杂模式和关系。这一机制在自然语言处理和其他序列数据任务中表现出色,是Transformer模型成功的关键组件之一。

2. 多头自注意力机制代码实现

2.1 将MHSA添加到YOLOv8代码中

关键步骤一: 将下面代码粘贴到在/ultralytics/ultralytics/nn/modules/conv.py中,并在该文件的__all__中添加“MHSA”

class MHSA(nn.Module):
    def __init__(self, n_dims, width=14, height=14, heads=4, pos_emb=False):
        super(MHSA, self).__init__()

        self.heads = heads
        self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
        self.pos = pos_emb
        if self.pos:
            self.rel_h_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, 1, int(height)]),
                                             requires_grad=True)
            self.rel_w_weight = nn.Parameter(torch.randn([1, heads, (n_dims) // heads, int(width), 1]),
                                             requires_grad=True)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        n_batch, C, width, height = x.size()
        q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
        k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
        v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
        content_content = torch.matmul(q.permute(0, 1, 3, 2), k)  # 1,C,h*w,h*w
        c1, c2, c3, c4 = content_content.size()
        if self.pos:
            content_position = (self.rel_h_weight + self.rel_w_weight).view(1, self.heads, C // self.heads, -1).permute(
                0, 1, 3, 2)  # 1,4,1024,64

            content_position = torch.matmul(content_position, q)  # ([1, 4, 1024, 256])
            content_position = content_position if (
                    content_content.shape == content_position.shape) else content_position[:, :, :c3, ]
            assert (content_content.shape == content_position.shape)
            energy = content_content + content_position
        else:
            energy = content_content
        attention = self.softmax(energy)
        out = torch.matmul(v, attention.permute(0, 1, 3, 2))  # 1,4,256,64
        out = out.view(n_batch, C, width, height)
        return out

多头自注意力机制在处理图片时,主要应用于视觉Transformer(Vision Transformer, ViT)等模型。这些模型将图片处理流程分为几个关键步骤,下面详细解释这些步骤:

1. 图像分块

输入图片:假设输入图像的尺寸为H \times W \times C,其中H是高度,W是宽度,C是通道数(例如RGB图像的C=3)。

分块:将输入图像划分为多个固定大小的小块(patch)。假设每个小块的尺寸为P \times P \times C,则图片会被划分成N = \frac{H \times W}{P \times P}个小块。

2. 小块向量化

线性变换:将每个小块展平(flatten)为一个向量,然后通过线性变换将其映射到固定的维度,得到小块的嵌入表示。假设映射后的维度为D,则每个小块被表示为一个D维向量。

3. 加入位置编码

位置编码:由于Transformer不具备内建的位置感知能力,需加入位置编码(positional encoding)以保留小块在图像中的位置信息。位置编码与小块嵌入相加,形成带有位置信息的输入序列。

4. 多头自注意力机制

输入准备:将带有位置编码的小块嵌入向量作为Transformer的输入序列。

多头自注意力

  • 查询、键和值:将每个小块嵌入向量通过不同的线性变换,得到查询(Query)、键(Key)和值(Value)。

  • 并行计算多个注意力头:每个注意力头独立地计算注意力得分,然后通过加权求和值向量,得到每个头的输出。

  • 拼接和线性变换:将所有头的输出拼接在一起,并通过一个线性变换得到最终输出。

5. 叠加Transformer层

多层堆叠:重复上述的多头自注意力和前馈神经网络结构,通常会堆叠多个这样的Transformer层。每一层都会进一步处理和组合小块的表示,提取更高级别的特征。

6. 分类头

CLS标记:在输入序列中添加一个特殊的分类标记(CLS token),它的嵌入表示会被用于最终的分类任务。

全连接层:通过一个全连接层(或多层感知机)将CLS标记的最终表示映射到目标类别数。

优点

  • 全局信息捕捉:多头自注意力机制可以捕捉图像中远距离像素之间的关系。

  • 灵活性强:不依赖于卷积操作,可以更灵活地处理不同分辨率的图像。

  • 表示能力强:多头机制允许模型在不同的子空间中捕捉多种特征。

单头的注意力机制

 

单头注意力机制

 

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_MHSA.yaml文件,粘贴下面的内容

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
  - [-1, 1, MHSA, [14,14,4]]  #10

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

温馨提示:本文只是对yolov8基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。


# YOLOv8n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
 
# YOLOv8s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
 
# YOLOv8l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
 
# YOLOv8m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
 
# YOLOv8x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple

2.4 在task.py中进行注册

关键步骤四:在parse_model函数中进行注册,添加MHSA,

        elif m in {MHSA}:
            args=[ch[f],*args]

2.5 执行程序

关键步骤五: 在ultralytics文件中新建train.py,将model的参数路径设置为yolov8_MHSA.yaml的路径即可

建议大家写绝对路径,确保一定能找到

from ultralytics import YOLO
 
# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
 
model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_MHSA.yaml')  # build from YAML and transfer weights
 
# Train the model
model.train(device = [3], batch=16)

 🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

img

改进后的GFLOPs

5. 总结

自注意力机制是一种在处理序列数据时,通过计算序列中每个元素与其他元素之间的相关性来加权组合输入元素的方法。具体来说,自注意力机制首先对输入序列进行线性变换,生成查询(Query)、键(Key)和值(Value)向量。然后,通过计算查询和键的点积并进行缩放处理,再应用softmax函数,将相关性转换为概率分布,得到注意力得分。接着,用这些得分对值向量进行加权求和,得到每个元素的自注意力表示。这种机制允许模型在全局范围内捕捉输入序列中各个位置之间的依赖关系和特征。多头自注意力机制通过并行计算多个独立的自注意力,并将这些注意力头的输出拼接起来,再通过线性变换,进一步增强了模型捕捉多种特征和复杂关系的能力。这种机制在自然语言处理、计算机视觉等领域表现出色,如在Transformer模型中广泛应用,使得模型能够处理长距离依赖和捕捉全局上下文信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/702221.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024-6-12-IXI(npy存储)应用SR的数据处理代码解读

数据集:https://drive.google.com/drive/folders/1i2nj-xnv0zBRC-jOtu079Owav12WIpDEhttps://drive.google.com/drive/folders/1i2nj-xnv0zBRC-jOtu079Owav12WIpDE import numpy as np from torch.utils.data import DataLoader, Dataset import torch from skimage.measur…

nodejs湖北省智慧乡村旅游平台-计算机毕业设计源码 00232

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,旅游行业当然也不能排除在外。智慧乡村旅游平台是以实际运用为开发背景,运用软件工程开发方法,采…

电子招投标系统:企业战略布局下的采购寻源利器

在当今商业环境中,企业对于采购管理的效率和透明度要求日益提高。鸿鹄电子招投标系统,一款基于Java技术的电子招标采购软件,旨在为企业提供一个公平、公开、公正的采购平台,同时降低成本,提升采购质量和速度。 项目说…

MySQL系列-语法说明以及基本操作(二)

1、MySQL数据表的约束 1.1、MySQL主键 “主键(PRIMARY KEY)”的完整称呼是“主键约束”。 MySQL 主键约束是一个列或者列的组合,其值能唯一地标识表中的每一行。这样的一列或多列称为表的主键,通过它可以强制表的实体完整性。 …

VMware Workstation虚拟机固定IP配置(主机互通、外网可访问)

VMware Workstation虚拟机固定IP配置 环境问题配置过程配置虚拟机网络适配器配置虚拟机网络配置虚拟网卡网络适配器配置虚拟机固定IP 结果验证结束语参考 环境 主机:Windows 11 VMware Workstation: 17.5.2 虚拟机:Ubuntu 24.02 LTS 注: 主…

抖音和快手都做电商,到底有啥区别?

抖音的电商叫兴趣电商,而快手的电商叫信任电商。 有什么区别呢?电商的交易逻辑完全不一样。 兴趣电商是小姐妹聚会逛街的。逛街是什么?是对时尚流行内容的消费,出门之前不一定有购买需求,但逛着逛着,看到商…

tim定时器 输入捕获模式下 TIM–ICStructinit(TIM–ICStructinit) 这个值 解析

主要需要看着图来理解 1.这是stm中文手册的图 2.这是解析 我觉得写的不错 注:有个很坑的地方 我觉得是stm32中文手册的问题 他写的解释只写了tim输入2 3 4和ic1 2 3 4,少写了一个输入1 第一次看见很不好理解

SortTable.js + vxe-table 实现多条批量排序

环境: vue3+vxe-table+sorttable.js 功能: 实现表格拖动排序,支持单条排序,多条排序 实现思路: sorttable.js官网只有单条排序的例子,网上也都是简单的使用,想要实现多条排序,就要结合着表格的复选框功能,在对其勾选的行统一计算! 最终效果: 实现代码 <template>…

技术干货分享:初识分布式版本控制系统Git

初识Git版本控制 自动化测试代码反复执行&#xff0c;如果借用持续集成工具会提高测试效率&#xff0c;那么需要我们把自动化测试代码发布到正式环境中&#xff0c;这时候用Git版本控制工具高效、稳定、便捷。 分布式版本控制 Git可以把代码仓库完整地镜像下来&#xff0c;有…

心理咨询系统源码|心理咨询系统开发|心理咨询系统

心理咨询系统&#xff0c;作为一种集现代化科技与专业心理服务于一体的工具&#xff0c;正逐渐渗透到我们生活的各个角落。它不仅为个人提供了便捷的心理支持&#xff0c;还为企业和组织带来了全新的管理方式。下面&#xff0c;我们将深入探讨心理咨询系统的可应用范围及其带来…

Visual Studio扩展开发

对于Roslyn(编译平台)的扩展 内容来源:https://learn.microsoft.com/zh-cn/dotnet/csharp/roslyn-sdk/tutorials/how-to-write-csharp-analyzer-code-fix 创建项目 解决方案项目介绍 Resources.resx介绍 填入的内容会在错误列表中显示

Github入门教程,适合新手学习(非常详细)

前言&#xff1a;本篇博客为手把手教学的 Github 代码管理教程&#xff0c;属于新手入门级别的难度。教程简单易操作&#xff0c;能够基本满足读者朋友日常项目寄托于 Github 平台上进行代码管理的需求。Git 与 Github 是一名合格程序员 coder 必定会接触到的工具与平台&#x…

JAVA代码审计之SQL注入代码审计

前言 SQL注入漏洞是对数据库进行的一种攻击方式。其主要形成方式是在数据交互中&#xff0c;前端数据通过后台在对数据库进行操作时&#xff0c;由于没有做好安全防护&#xff0c;导致攻击者将恶意代码拼接到请求参数中&#xff0c;被当做SQL语句的一部分进行执行&#xff0c;…

Stack详解(含动画演示)

目录 Stack详解1、栈数据结构动画演示2、Stack的继承体系3、Stack的push (入栈)方法4、Stack的pop (出栈)方法5、Stack是如何利用Vector实现栈数据结构的&#xff1f;6、自己实现栈(不借助JDK提供的集合)7、自己实现栈(借助JDK提供的集合)利用 ArrayDeque 实现高性能的非线程安…

SyntaxError: Non-UTF-8 code starting with ‘\xbd‘ in file错误解决

在运用python的pandas和numpy的内容环境下&#xff0c;运行代码时发生以下错误&#xff1a; C:\ProgramData\Anaconda3\python.exe D:/zafile/py数据分析与应用/数据分析代码/14.2、紧急电话数据分析.pyFile "D:/zafile/py数据分析与应用/数据分析代码/14.2、紧急电话数据…

7.枚举和模式匹配

一、enum枚举 1.1 定义枚举类型和对应的数据 //定义枚举 #[derive(Debug)] enum IpAddrKind{IPv4,IPv6, }struct Ipaddr{kind: IpAddrKind, //设置kind为IpAddrKind的枚举类型address: String, }fn route(ip_addr: &Ipaddr){println!("ip_type {:#?}", ip_a…

如何基于Nginx配置代理服务器实现邮件告警

当代企业信息化系统建设中&#xff0c;将内网与公网进行隔离是一种非常常见的措施&#xff0c;它可以有效保护企业内部数据不被外泄&#xff0c;有助于企业构建一个更加安全的网络环境&#xff0c;保护企业资产和用户隐私。但另一方面&#xff0c;内网与公网隔离也会带来一些问…

【STM32】输入捕获应用-测量脉宽或者频率(方法2)

链接&#xff1a;https://blog.csdn.net/gy3509/article/details/139629893?spm1001.2014.3001.5502&#xff0c;讲述了只使用一个捕获寄存器测量脉宽和频率的方法&#xff0c;其实测量脉宽和频率还有一个更简单的方法就是使用PWM输入模式&#xff0c;PWM输入模式需要占用两个…

推挽式B类功率放大器的基本原理

单晶体管 B 类放大器&#xff08;图 1&#xff09;使用高 Q 值储能电路作为负载来抑制高次谐波分量。通过采用高 Q 谐振电路&#xff0c;输出电压仅包含基波分量&#xff0c;使放大器能够忠实地再现输入信号。尽管集电极电流是半波整流正弦波&#xff0c;但高 Q 值储能电路将谐…

Python基础教程(十):装饰器

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…