YOLOv5改进 | 添加注意力篇 | 利用YOLOv10提出的PSA注意力机制助力YOLOv5有效涨点(附代码 + 详细修改教程)

 一、本文介绍

本文给大家带来的改进机制是YOLOv10提出的PSA注意力机制,自注意力在各种视觉任务中得到了广泛应用,因为它具有显著的全局建模能力。然而,自注意力机制表现出较高的计算复杂度和内存占用。为了解决这个问题,鉴于注意力头冗余的普遍存在,我们提出了一种高效的部分自注意力(PSA)模块设计,其能够在不显著增加计算成本的情况下提升YOLO模型的性能!本文附其网络结构图辅助大家理解该结构,同时本文包含YOLOv8添加该注意力机制的方法!

欢迎大家订阅我的专栏一起学习YOLO!

专栏目录:YOLOv5改进有效涨点目录 | 包含卷积、主干、检测头、注意力机制、Neck上百种创新机制


目录

 一、本文介绍

二、PSA介绍

三、PSA核心代码 

四、手把手教你添加PSA注意力机制

4.1 PSA添加步骤

4.1.1 修改一

4.1.2 修改二

4.1.3 修改三 

4.1.4 修改四

4.2 PSA的yaml文件

4.3 训练截图 

五、本文总结


二、PSA介绍

论文地址:官方论文地址点击此处即可跳转

代码地址:官方代码地址点击此处即可跳转


自注意力在各种视觉任务中得到了广泛应用,因为它具有显著的全局建模能力。然而,自注意力机制表现出较高的计算复杂度和内存占用。为了解决这个问题,鉴于注意力头冗余的普遍存在,我们提出了一种高效的部分自注意力(PSA)模块设计,如图3(c)所示。

具体来说,我们通过1×1卷积将特征均匀地划分为两部分。然后,我们仅将其中一部分输入到由多头自注意力模块(MHSA)和前馈网络(FFN)组成的NPSA块中。两部分特征随后被连接并通过1×1卷积融合。此外,我们遵循将MHSA中查询和键的维度分配为值的一半,并将LayerNorm替换为BatchNorm以加快推理速度。

PSA仅在分辨率最低的Stage 4之后放置,以避免自注意力二次复杂性带来的过多开销。通过这种方式,可以在低计算成本下将全局表示学习能力引入YOLO模型,从而增强模型能力并提高性能。

通过这些精度驱动的设计,我们能够在不显著增加计算成本的情况下提升YOLO模型的性能。


三、PSA核心代码 

核心代码的使用方式看章节四!

import torch
import torch.nn as nn

__all__ = ['PSA']

def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8,
                 attn_ratio=0.5):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.key_dim = int(self.head_dim * attn_ratio)
        self.scale = self.key_dim ** -0.5
        nh_kd = nh_kd = self.key_dim * num_heads
        h = dim + nh_kd * 2
        self.qkv = Conv(dim, h, 1, act=False)
        self.proj = Conv(dim, dim, 1, act=False)
        self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)

    def forward(self, x):
        B, _, H, W = x.shape
        N = H * W
        qkv = self.qkv(x)
        q, k, v = qkv.view(B, self.num_heads, -1, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)

        attn = (
                (q.transpose(-2, -1) @ k) * self.scale
        )
        attn = attn.softmax(dim=-1)
        x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + self.pe(v.reshape(B, -1, H, W))
        x = self.proj(x)
        return x


class PSA(nn.Module):

    def __init__(self, c1, c2, e=0.5):
        super().__init__()
        assert (c1 == c2)
        self.c = int(c1 * e)
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv(2 * self.c, c1, 1)

        self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
        self.ffn = nn.Sequential(
            Conv(self.c, self.c * 2, 1),
            Conv(self.c * 2, self.c, 1, act=False)
        )

    def forward(self, x):
        a, b = self.cv1(x).split((self.c, self.c), dim=1)
        b = b + self.attn(b)
        b = b + self.ffn(b)
        return self.cv2(torch.cat((a, b), 1))


四、手把手教你添加PSA注意力机制

4.1 PSA添加步骤

4.1.1 修改一

首先我们找到如下的目录'yolov5-master/models',然后在这个目录下在创建一个新的目录然后这个就是存储改进的仓库,大家可以在这里新建所有的改进的py文件,对应改进的文件名字可以根据你自己的习惯起(不影响任何但是下面导入的时候记住改成你对应的即可),然后将PSA的核心代码复制进去。


4.1.2 修改二

然后在新建的目录里面我们在新建一个__init__.py文件(此文件大家只需要建立一个即可),然后我们在里面添加导入我们模块的代码。注意标记一个'.'其作用是标记当前目录。

​​


4.1.3 修改三 

然后我们找到如下文件''models/yolo.py''在开头的地方导入我们的模块按照如下修改->

(如果你看了我多个改进机制此处只需要添加一个即可,无需重复添加)

注意的添加位置要放在common的导入上面!!!!!

​​​​​


4.1.4 修改四

然后我们找到parse_model方法,按照如下修改->

到此就修改完成了,复制下面的ymal文件即可运行。


4.2 PSA的yaml文件

保持和YOLOv10相同的使用方法!

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
    [-1, 3, C3, [128]],
    [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
    [-1, 6, C3, [256]],
    [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
    [-1, 9, C3, [512]],
    [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
    [-1, 3, C3, [1024]],
    [-1, 1, SPPF, [1024, 5]], # 9
    [-1, 1, PSA, [1024]],  # 10
  ]

# YOLOv5 v6.0 head
head: [
    [-1, 1, Conv, [512, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 6], 1, Concat, [1]], # cat backbone P4
    [-1, 3, C3, [512, False]], # 14

    [-1, 1, Conv, [256, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 4], 1, Concat, [1]], # cat backbone P3
    [-1, 3, C3, [256, False]], # 18 (P3/8-small)

    [-1, 1, Conv, [256, 3, 2]],
    [[-1, 15], 1, Concat, [1]], # cat head P4
    [-1, 3, C3, [512, False]], # 21 (P4/16-medium)

    [-1, 1, Conv, [512, 3, 2]],
    [[-1, 11], 1, Concat, [1]], # cat head P5
    [-1, 3, C3, [1024, False]], # 24 (P5/32-large)

    [[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  ]

4.3 训练截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv5改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

专栏目录:YOLOv5改进有效涨点目录 | 包含卷积、主干、检测头、注意力机制、Neck上百种创新机制

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/670255.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

本地公网IP是什么?

本地公网IP(Public IP Address)是指分配给计算机或设备的可在互联网上直接访问的唯一标识符。每个连接到互联网的设备都会被分配一个公网IP,用于与其他设备进行通信。本地公网IP在网络通信中起到了非常重要的作用,为用户提供了方便…

Docker 简介和安装

目录 Docker 是什么 跟普通虚拟机的对比 打包、分发、部署 Docker 部署的优势 Docker 通常用来做什么 重要概念:镜像、容器 安装 镜像加速源 Docker 是什么 Docker 是一个应用打包、分发、部署的工具 你也可以把它理解为一个轻量的虚拟机,它只虚…

【创作活动】如何让大模型更聪明:提升智能的技术策略

😄 19年之后由于某些原因断更了三年,23年重新扬帆起航,推出更多优质博文,希望大家多多支持~ 🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Mi…

前端框架安全防范

前端框架安全防范 在现代Web开发中,前端框架如Angular和React已经成为构建复杂单页面应用(SPA)的主流工具。然而,随着应用复杂度的增加,安全问题也变得越来越重要。本文将介绍如何在使用Angular和React框架时&#xf…

内网渗透-在HTTP协议层面绕过WAF

进入正题,随着安全意思增强,各企业对自己的网站也更加注重安全性。但很多web应用因为老旧,或贪图方便想以最小代价保证应用安全,就只仅仅给服务器安装waf。 本次从协议层面绕过waf实验用sql注入演示,但不限于实际应用…

mysql(数据库)可视化工具——Navicat Premium

Navicat Premium是一款功能强大的数据库管理工具,它支持多种数据库管理系统,包括MySQL、MariaDB、SQL Server、SQLite、Oracle和PostgreSQL等。Navicat Premium提供了直观的用户界面,使用户能够轻松地管理数据库结构、执行复杂的SQL查询、导入…

基于Django的博客系统之用HayStack连接elasticsearch增加搜索功能(五)

上一篇:搭建基于Django的博客系统数据库迁移从Sqlite3到MySQL(四) 下一篇:基于Django的博客系统之增加类别导航栏(六) 功能概述 添加搜索框用于搜索博客。 需求详细描述 1. 添加搜索框用于搜索博客 描…

1940java swing零售库存管理系统myeclipse开发Mysql数据库CS结构java编程

一、源码特点 java swing 零售库存管理系统 是一套完善的窗体设计系统,对理解SWING java 编程开发语言有帮助,系统具有完整的源代码和数据库,,系统主要采用C/S模式开发。 应用技术:javamysql 开发工具:…

如何快速找到 RCE

背景介绍 本文将分享国外白帽子在‘侦察’阶段如何快速发现 RCE 漏洞的经历。以Apache ActiveMQ 的 CVE-2023–46604 为特例,重点介绍如何发现类似此类的漏洞,让我们开始吧。 快速发现过程 在‘侦察’阶段,白帽小哥会保持每周更新一次目标…

域内用户枚举和密码喷洒

一. 域内用户枚举原理和流量 1. 原理 在AS-REQ阶段客户端向AS发送用户名,cname字典存放用户名,AS对用户名进行验证,用户存在和不存在返回的数据包不一样。 不同之处主要是在返回数据包中的状态码不同,根据不同的状态码来区分账…

论文阅读笔记(十二)——Augmenting large language models with chemistry tools

论文阅读笔记(十二)——Augmenting large language models with chemistry tools TOC Abstract 大型语言模型(LLMs)在跨领域任务中表现出色,但在化学相关问题上却表现不佳。这些模型也缺乏外部知识源的访问权限,限制了它们在科…

Go-知识并发控制Context

Go-知识并发控制Context 1. 介绍2. 实现原理2.1 接口定义2.2 Deadline()2.3 Done()2.4 Err()2.5 Value() 3. 空 context4. cancelCtx4.1 Done()4.2 Err()4.3 cancel()4.4 WithCancel4.5 例子4.6 总结 5. timerCtx5.1 Deadline5.2 cancel5.3 WithDeadline5.4 WithTimeout5.5 例子…

linux tomcat版本漏洞升级

Tomcat Session 反序列化代码执行漏洞(CVE-2020-9484) Tomcat 安全限制绕过漏洞(CVE-2018-8034) Tomcat远程代码执行漏洞(CVE-2017-12615) 以上均可以升级版本处理,小版本升级方法 tomcat安装请查看https://blog.csdn.net/qq_42250832/article/details/139015573 1、…

2024医美如何做抖音医美抖音号,本地团购、短视频直播双ip爆品引流,实操落地课

课程下载:https://download.csdn.net/download/m0_66047725/89307619 更多资源下载:关注我。 课程内容: 01-0-序.mp4 02-01-账号定位.mp4 03-02-误区.mp4 04-03-五件套.mp4 05-04-文案怎么来.mp4 06-05-对标怎么弄.mp4 07-06-人设怎…

计算机网络期末复习(1)计算机网络在信息时代对的作用 计算机网络的定义和分类 三种交换方法

计算机网络在信息时代扮演着至关重要的角色,它极大地改变了我们生活、工作和学习的方式。 计算机网络在信息时代的作用 信息共享与传播:计算机网络使全球范围内的信息快速共享成为可能,无论是新闻、学术研究还是娱乐内容,都可以…

MyBatis源码分析--02:SqlSession建立过程

我们再来看看MyBatis使用流程: InputStream inputStream Resources.getResourceAsStream("myBatis_config.xml"); SqlSessionFactory sqlSessionFactory new SqlSessionFactoryBuilder().build(inputStream); SqlSession session sqlSessionFactory.op…

AI编程新手快速体验SpringCloud Alibaba 集成AI功能

上周六写了一篇文章 震撼发布!Spring AI 框架重磅上线,Java 集成 AI 轻松搞定! 部分同学可能没有科学上网的条件,本地ollama 集成又比较笨重。趁着周六,写一篇基于SpringCloud Alibaba 集成AI的文章。 先简单介绍…

浅析R16移动性增强那些事儿(DAPS/CHO/MRO)

R16移动性增强相关技术总结 Dual Active Protocol Handover Dual Active Protocol Handover意为双激活协议栈切换,下文简称DAPS切换,DAPS切换的核心思想是切换过程中,在UE成功连接到目标基站前继续保持和源基站的连接和数据传输,…

MinIO 使用

MinIO自建对象存储 1、dock-compose 使用dock-compose拉取 minio:image: "minio/minio"container_name: minioports:- "9000:9000"- "9001:9001"volumes:- "./minio/data1:/data1"- "./minio/data2:/data2"restart: on-fai…

网络原理-TCP/IP --传输层(UDP)

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 今天你敲代码了吗 目录 端口号UDP协议 端口号 我们在学习套接字的时候,涉及到两个概念:IP地址和端口号 IP地址是用来确定主机,这是网络层提供的概念 而端口号就是用来确定主机上的应用程序,就是传输层的概念的…