YOLOv8改进 | 2023 | FocusedLinearAttention实现有效涨点

论文地址:官方论文地址

代码地址:官方代码地址

一、本文介绍

本文给大家带来的改进机制是Focused Linear Attention(聚焦线性注意力)是一种用于视觉Transformer模型的注意力机制(但是其也可以用在我们的YOLO系列当中从而提高检测精度),旨在提高效率和表现力。其解决了两个在传统线性注意力方法中存在的问题:聚焦能力和特征多样性。这种方法通过一个高效的映射函数和秩恢复模块来提高计算效率和性能,使其在处理视觉任务时更加高效和有效。简言之,Focused Linear Attention是对传统线性注意力方法的一种重要改进,提高了模型的聚焦能力和特征表达的多样性。通过本文你能够了解到:Focused Linear Attention的基本原理和框架,能够在你自己的网络结构中进行添加(需要注意的是一个FLAGFLOPs从8.9涨到了9.1)。

 专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

实验效果对比:放在了第三章,有对比试验供大家参考

目录

一、本文介绍

二、Focused Linear Attention的机制原理

2.1 Softmax和线性注意力机制的对比

2.2 Focused Linear Attention的提出

2.3 效果对比

三、实验效果对比

四、FocusedLinearAttention代码

五、添加Focused Linear Attention到模型中

5.1 Focused Linear Attention的添加教程

5.2 Focused Linear Attention的yaml文件和训练截图

5.2.1 Focused Linear Attention的yaml文件

5.2.2 Focused Linear Attention的训练过程截图 

六、全文总结 


二、Focused Linear Attention的机制原理

2.1 Softmax和线性注意力机制的对比

上面的图片是关于比较Softmax注意力和线性注意力的差异。在这张图中,Q、K、V 分别代表查询、键和值矩阵,它们的维度为 R N×d。这里提到的几个关键点包括:

1. Softmax注意力:它需要计算查询和键之间的成对相似度,导致计算复杂度为 O(N^2 d)。这种方法在计算上是昂贵的,特别是当处理大规模数据时。

2. 线性注意力:通过适当的近似手段,线性注意力可以解耦Softmax操作,并通过先计算K^{T}V来改变计算顺序,从而将复杂度降低到 O(Nd^{^{2}})。由于在现代视觉Transformer设计中通道维度 d 通常小于标记数 N(例如,在DeiT中d=64, N=196,在Swin Transformer中d=32, N=49),线性注意力模块实际上降低了总体计算成本。

此处提出了线性注意力机制的优势(为了后面提出论文提到的注意力机制在线性注意力机制上的优化):线性注意力模块因此能够在节省计算成本的同时,享受更大的接收域和更高的吞吐量的好处。

总结:这张图片可能是在说明线性注意力如何在保持注意力机制核心功能的同时,提高计算效率,尤其是在处理大规模数据集时的优势。这种方法对于改善视觉Transformer的性能和效率具有重要意义(我下面会出将其用在RT-DETR的模型上看看效果)

2.2 Focused Linear Attention的提出

线性注意力的限制和改进: 尽管线性注意力降低了复杂度,但现有的线性注意力方法仍存在性能下降的问题,并可能因映射函数带来额外的计算开销。为了解决这些问题,作者提出了一个新颖的聚焦线性注意力(Focused Linear Attention)模块。该模块通过简单的映射函数调整查询和键的特征方向,使注意力权重更加明显。此外,还通过深度卷积(DWC)应用于原始注意力矩阵的秩恢复模块来增加特征多样性。

Focused Linear Attention(聚焦线性注意力)是一种用于视觉Transformer模型的注意力机制(但是其也可以用在我们的YOLO系列当中从而提高检测精度),旨在提高效率和表现力。它解决了传统线性注意力方法的两个主要问题:

1. 聚焦能力: 以往的线性注意力缺乏足够的聚焦能力,导致模型难以有效地关注重要特征。Focused Linear Attention通过改进的机制增强了这种聚焦能力。

2. 特征多样性: 传统方法在特征表达上缺乏多样性,影响了模型的表现力。Focused Linear Attention通过特殊的设计来增加特征的多样性和丰富性。

这种方法通过一个高效的映射函数和秩恢复模块来提高计算效率和性能,使其在处理视觉任务时更加高效和有效。

总结:Focused Linear Attention是对传统线性注意力方法的一种重要改进,提高了模型的聚焦能力和特征表达的多样性。

2.3 效果对比

上面的图片显示了多个视觉Transformer模型的性能和计算复杂度的比较。图中分为四个部分:

1. PVT: 对比了不同版本的PVT(Pyramid Vision Transformer),DeiT(Data-efficient Image Transformer),以及T2T(Tokens-to-Token ViT)的Top-1准确率和计算量(FLOPs)。

2. PVT v2: 类似地,展示了PVT v2、ConvNext、DAT(Deformable Attention Transformer)的性能对比。

3. Swin: 对比了Swin Transformer、CvT(Convolutional vision Transformer),以及CoTNet(Contextual Transformer Network)的模型。

4. CSwin: 展示了CSwin Transformer、MViTv2、CoAtNet的性能对比。

在每个图中,还包括了作者提出的FLatten版本的Transformer模型(标记为“Ours”),其在每个分类中都显示了相对较高的准确率或者在相似的FLOPs计算量下具有竞争力的准确率。

右侧的表格详细列出了不同模型的分辨率(Reso)、参数数量(#Params)、计算量(Flops)和Top-1准确率。表中突出了FLatten版本的Transformer模型在Top-1准确率上相对于原始模型的提升(括号中的百分点)。

个人总结:这张图片展示了通过改进的线性注意力模块,即FLatten模型,在保持或稍微增加计算量的前提下,提高了Transformer架构的图像识别准确率。

三、实验效果对比

实验效果图如下所示-> 

因为资源有限我发的文章都要做对比实验所以本次实验我只用了一百张图片检测的是安全帽训练了一百个epoch,该结果只能展示出该机制有效,但是并不能产生决定性结果,因为具体的效果还要看你的数据集和实验环境所影响。

 

四、FocusedLinearAttention代码

在场的FocusedLinearAttention代码是用于Transformer的想要将其用于YOLO上是需要进行很大改动的,所以我这里进行了挺多的改动的,创作不易而且免费给大家看,所以如果能够帮助到大家希望大家能给点个赞和关注支持一下。

import torch.nn as nn
import torch
from einops import rearrange


class FocusedLinearAttention(nn.Module):
    def __init__(self, dim, num_patches=64, num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, sr_ratio=1,
                 focusing_factor=3.0, kernel_size=5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

        self.focusing_factor = focusing_factor
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
        # self.positional_encoding = nn.Parameter(torch.zeros(size=(1, num_patches // (sr_ratio * sr_ratio), dim)))


    def forward(self, x):
        B, C, H, W = x.shape  # 输入为四维:[批次大小, 通道数, 高度, 宽度]
        dtype, device = x.dtype, x.device
        # 调整输入以匹配原始模块的预期格式
        x = rearrange(x, 'b c h w -> b (h w) c')
        q = self.q(x)
        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        else:
            kv = self.kv(x).reshape(B, -1, 2, C).permute(2, 0, 1, 3)
        k, v = kv[0], kv[1]
        N = H * W  # 序列长度
        # 重新生成位置编码
        positional_encoding = nn.Parameter(torch.zeros(size=(1, N, self.dim), device=device))
        k = k + positional_encoding
        focusing_factor = self.focusing_factor
        kernel_function = nn.ReLU()
        scale = nn.Softplus()(self.scale)
        q = kernel_function(q) + 1e-6
        k = kernel_function(k) + 1e-6
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        q = q ** focusing_factor
        k = k ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
        bool = False
        if dtype == torch.float16:
            q = q.float()
            k = k.float()
            v = v.float()
            bool = True
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
        if self.sr_ratio > 1:
            v = nn.functional.interpolate(v.permute(0, 2, 1), size=x.shape[1], mode='linear').permute(0, 2, 1)
        if bool:
            v = v.to(torch.float16)
            x = x.to(torch.float16)

        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)

        x = self.proj(x)
        x = self.proj_drop(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)
        return x

五、添加Focused Linear Attention到模型中

5.1 Focused Linear Attention的添加教程

添加教程这里不再重复介绍、因为专栏内容有许多,添加过程又需要截特别图片会导致文章大家读者也不通顺如果你已经会添加注意力机制了,可以跳过本章节,如果你还不会,大家可以看我下面的文章,里面详细的介绍了拿到一个任意机制(C2f、Conv、Bottleneck、Loss、DetectHead)如何添加到你的网络结构中去。

注意:本文的注意力机制是有参数的!!!

这个注意力机制也可以放在C2f和Bottleneck中进行使用可以即插即用,个人觉得放在Bottleneck中效果比较好。

添加教程->YOLOv8改进 | 如何在网络结构中添加注意力机制、C2f、卷积、Neck、检测头

需要注意的是本文的task.py配置的代码如下(你现在不知道其是干什么用的可以看添加教程)-> 

from .modules.FocusLinearAttention import FocusedLinearAttention as FLAttention
        elif m is FLAttention:
            args = [ch[f], *args]

5.2 Focused Linear Attention的yaml文件和训练截图

5.2.1 Focused Linear Attention的yaml文件

下面的是放在Neck部分的截图,参数我以及设定好了,无需进行传入会根据模型输入自动计算,帮助大家省了一些事。

下面的是放在C2f中的yaml配置。 

 

5.2.2 Focused Linear Attention的训练过程截图 

下面是我添加了Focused Linear Attention的训练截图。

下面的是将FLAttention机制我添加到了C2f和Bottleneck。

下面的是我将FLAttention放在Neck中的截图。 

六、全文总结 

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,目前本专栏免费阅读(暂时,大家尽早关注不迷路~),如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏回顾:YOLOv8改进系列专栏——本专栏持续复习各种顶会内容——科研必备

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/187405.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

小程序项目:springboot+vue基本微信小程序的学生健康管理系统

项目介绍 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生,各行各业相继进入信息管理时…

基于协作搜索算法优化概率神经网络PNN的分类预测 - 附代码

基于协作搜索算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于协作搜索算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于协作搜索优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

“升级图片质量:批量提高或缩小像素,赋予图片全新生命力!“

如果你想让你的图片更加清晰、更加美观,或者符合特定的像素要求,那么现在有一个好消息要告诉你!我们推出了一款全新的图片处理工具,可以帮助你批量提高或缩小图片像素,让你的图片焕发出新的生机! 第一步&a…

基于人工蜂鸟算法优化概率神经网络PNN的分类预测 - 附代码

基于人工蜂鸟算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于人工蜂鸟算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于人工蜂鸟优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要:针对PNN神…

我的崩溃。。想鼠??!

身为程序员哪一个瞬间让你最奔溃? 某天一个下午崩溃产生。。。 一个让我最奔溃的瞬间是关于一个看似无害的拼写错误。我当时正在为一个电子商务网站添加支付功能,使用了一个第三方支付库。所有的配置看起来都正确,代码也没有报错,…

zookeeper 单机伪集群搭建简单记录

1、官方下载加压后,根目录下新建data和log目录,然后分别拷贝两份,分别放到D盘,E盘,F盘 2、data目录下面新建myid文件,文件内容分别为1,2,3.注意文件没有后缀,不能是txt文…

数据结构—小堆的实现

前言:前面我们已经学习了二叉树,今天我们来学习堆,堆也是一个二叉树,堆有大堆有小堆,大堆父节点大于子节点,小堆父节点总小于子节点,我们在学习C语言的时候也有一个堆的概念,那个堆是…

栈和队列OJ题目——C语言

目录 LeetCode 20、有效的括号 题目描述: 思路解析: 解题代码: 通过代码: LeetCode 225、用队列实现栈 题目描述: 思路解析: 解题代码: 通过代码: LeetCode 232、用栈…

C/C++ 运用Npcap发送UDP数据包

Npcap 是一个功能强大的开源网络抓包库,它是 WinPcap 的一个分支,并提供了一些增强和改进。特别适用于在 Windows 环境下进行网络流量捕获和分析。除了支持通常的网络抓包功能外,Npcap 还提供了对数据包的拼合与构造,使其成为实现…

HarmonyOS简述及开发环境搭建

一、HarmonyOS简介 1、介绍 HarmonyOS是一款面向万物互联时代的、全新的分布式操作系统。有三大系统特性,分别是:硬件互助,资源共享;一次开发,多端部署;统一OS,弹性部署。 HarmonyOS通过硬件互…

洛谷P1049装箱问题 ————递归+剪枝+回溯

没没没没没没没没没错,又是一道简单的递归,只不过加了剪枝,我已经不想再多说,这道题写了一开始写了普通深搜,然后tle了一个点,后面改成剪枝,就ac了,虽然数据很水,但是不妨…

第96步 深度学习图像目标检测:FCOS建模

基于WIN10的64位系统演示 一、写在前面 本期开始,我们继续学习深度学习图像目标检测系列,FCOS(Fully Convolutional One-Stage Object Detection)模型。 二、FCOS简介 FCOS(Fully Convolutional One-Stage Object D…

iOS强引用引起的内存泄漏

项目中遇到一个问题: 1.在A页面的ViewDidLoad 方法里写了一个接收通知的方法,如下图: 然后在B页面发送通知 (注:下图的NOTI 是 [NSNotificationCenter defaultCenter] 的宏, 考虑一下可能有小白看这篇文章…

物联网中基于信任的安全性调查研究:挑战与问题

A survey study on trust-based security in Internet of Things: Challenges and issues 文章目录 a b s t r a c t1. Introduction2. Related work3. IoT security from the one-stop dimension3.1. Output data related security3.1.1. Confidentiality3.1.2. Authenticity …

【vue_2】创建一个弹出权限不足的提示框

定义了一个名为 getUserRole 的 JavaScript 函数,该函数接受一个参数 authorityId,根据这个参数的不同值返回相应的用户角色字符串。这段代码的目的是根据传入的 authorityId 值判断用户的角色,然后返回相应的角色名称。 如果 authorityId 的…

Visual Studio 中文注释乱码解决方案

在公司多人开发项目中经常遇到拉到最新代码,发现中文注释都是乱码,很是emjoy..... 这是由于编码格式不匹配造成的,如果你的注释是 UTF-8 编码,而文件编码是 GBK 或者其他编码,那么就会出现乱码现象。一般的解决办法是…

STM32-使用固件库新建工程

参考链接: 【入门篇】11-新建工程—固件库版本(初学者必须认认真真看)_哔哩哔哩_bilibili 使用的MCU是STM32F103ZET6 。 这篇参考的是野火的资料,可以在“野火大学堂”或者它的论坛上下载。(我通常是野火和正点原子的资料混着看的…

【AI认证笔记】NO.2人工智能的发展

目录 一、人工智能的发展里程碑 二、当前人工智能的发展特点 1.人工智能进入高速发展阶段 2.人工智能元年 三、人工智能高速发展的三大引擎 1.算法突破 2.算力飞跃 3.数据井喷 四、AI的机遇 五、AI人才的缺口 六、行业AI 人工智能算法,万物互联&#xff…

基于mediapipe的人手21点姿态检测模型—CPU上检测速度惊人

前期的文章,我们介绍了MediaPipe对象检测与对象分类任务,也分享了MediaPipe的人手手势识别。在进行人手手势识别前,MediaPipe首先需要进行人手的检测与人手坐标点的检测,经过以上的检测后,才能把人手的坐标点与手势结合起来,进行相关的手势识别。 MediaPipe人手坐标点检测…

案例022:基于微信小程序的行政复议在线预约系统

文末获取源码 开发语言:Java 框架:SSM JDK版本:JDK1.8 数据库:mysql 5.7 开发软件:eclipse/myeclipse/idea Maven包:Maven3.5.4 小程序框架:uniapp 小程序开发软件:HBuilder X 小程序…