YOLOv5、v8改进:CrissCrossAttention注意力机制

目录

1.简介

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

2.2yolo.py中注册 CrissCrossAttention模块

2.3修改yaml文件。


1.简介

这是ICCV2019的用于语义分割的论文,可以说和CVPR2019的DANet遥相呼应。

和DANet一样,CCNet也是想建模像素之间的long range dependencies,来做更加丰富的contextual information,来补充特征图,以此来提升语义分割的性能。但是和DANet不一样,CCNet仅考虑空间分辨上的建模,不考虑建模通道之间的联系。作者提出的模块,criss-cross attention module,针对空间维度上的建模,对于空间位置的一个点u,仅考虑建模和u在同一行或者同一列的其他位置的像素之间的联系。相比DANet,能减少很多计算量,但是不足的是,对一个点的特征向量,尽管有同一行或者同一列的其他像素信息作为补充,对于语义分割任务,contextual information仍然是稀疏的(sparse),因为语义分割更在意一个像素和它周围的一些像素的关系。针对这个问题,作者提出了recurrent criss-cross attention module,来建模一个像素和全局所有像素的关系。方式是通过重复criss-cross attention module来实现的。这些module也是参数shared的。

同样是建模空间维度的pixel-wise contextual information,CCNet的计算量相较于self attention,可小太多了。一个CC module,要处理的是一个像素点和同一行、同一列一共(H+W-1)这么多的像素,那么应用在所有像素上,计算量就是O(HW(H+W-1))。回顾DANet的空间注意力分支(position attention module),每一个像素就要和(HW)个像素建模之间的联系,应用在所有相素,计算量就是O(HW*(H*W))。
通过递归的方式用CC module,可以对一个像素捕捉到全局的contextual information,提到了语义分割任务的效果。
个人看法,简单且有效的,就是极其优秀的方法,CCNet就属于这一类方法。
 

在这里插入图片描述

1.首先一个原图送进backbone,这个backbone是修改过的,把最后两个stage的stride改为1,同时应用空洞卷积来增大感受野。得到的特征图是原图的1/8.

2.然后经过1*1的卷积降维。得到H

3.H经过一个criss-cross attention module 得到H ′ 这个时候,H’中的每个位置都捕捉到了和u在同一行或者同一列的context information

4.H’经过一个相同结构、相同参数的cc module,得到了H’’。在H‘’中的每个位置,捕捉的是全局性的contextual information
5..最后经过一个分割层输出最后的预测结果。
在这里插入图片描述

 

之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入CrissCrossAttention注意力机制,该注意力机制为应用在语义分割中的模块,用于可以让网络更加关注待检测目标,提高检测效果

基本原理:

       语义分割的Criss-Cross网络(CCNet)的细节。我们首先介绍了CCNet的总体框架。然后,将介绍在水平和垂直方向捕获上下文信息的2D交叉注意力模块。为了获取密集的全局上下文信息,我们建议对交叉注意力模块采用循环操作。为了进一步改进RCCA,我们引入了判别损失函数来驱动RCCA学习类别一致性特征。最后,我们提出了同时利用时间和空间上下文信息的三维交叉注意模块。

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax


def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))


    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))

        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

2.2yolo.py中注册 CrissCrossAttention模块

elif m is CrissCrossAttention:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, *args[1:]]

2.3修改yaml文件。

# YOLOAir 🚀, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOAir v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOAir v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, CrissCrossAttention, [1024]], #修改

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

YOLOv8和v5的改法是一致的

有什么问题可以评论区私聊

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/92936.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

maven下载不了仓库地址为https的依赖jar,配置参数忽略ssl安全检查

问题原因 私服使用的https地址,然后安全证书过期的或没有,使用maven命令时,可以添加以下参数,忽略安全检查 mvn -Dmaven.wagon.http.ssl.insecuretrue -Dmaven.wagon.http.ssl.allowalltrue -Dmaven.wagon.http.ssl.ignore.vali…

《机器学习核心技术》分类算法 - 决策树

「作者主页」:士别三日wyx 「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:小白零基础《Python入门到精通》 决策树 1、决策树API2、决策时实际应用2.1、获取数据集2.2、划分数据集2.3、决策…

【算法专题突破】双指针 - 快乐数(3)

目录 1. 题目解析 2. 算法原理 3. 代码编写 写在最后: 1. 题目解析 题目链接:202. 快乐数 - 力扣(Leetcode) 这道题的题目也很容易理解, 看一下题目给的示例就能很容易明白, 但是要注意一个点&#…

redux中间件理解,常见的中间件,实现原理。

文章目录 一、Redux中间件介绍1、什么是Redux中间件2、使用redux中间件 一、Redux中间件介绍 1、什么是Redux中间件 redux 提供了类似后端 Express 的中间件概念,本质的目的是提供第三方插件的模式,自定义拦截 action -> reducer 的过程。变为 actio…

Ceph入门到精通-大流量10GB/s LVS+OSPF 高性能架构

LVS 和 LVSkeepalived 这两种架构在平时听得多了,最近才接触到另外一个架构LVSOSPF。这个架构实际上是LVSKeepalived 的升级版本,我们所知道LVSKeepalived 架构是这样子的: 随着业务的扩展,我们可以对web服务器做水平扩展&#xf…

4G WiFi LoRa无线外夹式超声波管道流量计MQTT/http协议 json数据说明

ip:114.128.112.131 port:1883 uname:scwl_flowmeter pwd:b123 topic:iot/data/scwlflowmeter { “deviceId”:“设备序列号”, “flow”:“瞬时流量(浮点数)”, “heatFlow”:“瞬时热流量&am…

Vue脚手架中安装ElementUi

目录 ElementUi简介: ElementUi下载: npm 安装: 引入ElementUi: 测试是否引入成功: Element-ui官网:组件 | Element ElementUi简介: ElementUi,是由国内的饿了么团队开发并开源的一套为开…

DML语句的用法(MySQL)

文章目录 前言一、DML介绍二、DML语句操作1、给指定字段添加数据2、给全部字段添加数据3、批量添加数据4、修改数据5、删除数据 总结 前言 本文主要介绍SQL语句中DML语句的用法。 在实验开始之前我们先创建一下所要使用表,如下图所示: 一、DML介绍 DM…

PowerDesigner学习笔记

备注:文章主要对概念数据模型进行深入分析 1.对各种模型图初步认识 1.1.概念数据模型 (CDM) (Conceptual Data Model) 对数据和信息进行建模,利用实体-关系图(E-R图)的形式组织数据,检验数据设计的有效性和合理性。 …

Linux驱动之设备树下的platform驱动

目录 一、设备树下的 platform 驱动简介 二、修改设备树文件 2.1 添加 LED 设备节点 2.2 添加 pinctrl 节点 2.3 检查 PIN 是否被其他外设使用 三、platform 驱动程序编写 四、测试 APP 编写 五、运行测试 5.1 编译 5.2 运行测试 前面一篇我们讲解了传统的、未采用设备…

用香港服务器域名需要备案吗?

​  在选择服务器的时候,很多人会考虑使用香港服务器。香港服务器的一个优势就是不需要备案。不管是虚拟主机还是云主机,无论是个人网站还是商业网站,都不需要进行备案手续。 域名实名认证 虽然不需要备案,但使用香港服务器搭建…

webrtc的Sdp中的Plan-b和UnifiedPlan

在一些类似于视频会议场景下,媒体会话参与者需要接收或者发送多个流,例如一个源端,同时发送多个左右音轨的音频,或者多个摄像头的视频流;在2013年,提出了2个不同的SDP IETF草案Plan B和Unified Plan&#x…

数据库表结构导出为word、html、markdown【转载,已解决,已验证,开源】

注&#xff1a;本文为gitcode代码验证&#xff0c;转载gitcode gitcode&#xff1a;https://gitcode.net/mirrors/pingfangushi/screw?utm_sourcecsdn_github_accelerator 整理数据库文档&#xff1a;https://mp.weixin.qq.com/s/Bo_U5_cl82hfQ6GmRs2vtA <!--数据库文档核…

【Day-21慢就是快】代码随想录-栈与队列-逆波兰表达式求值

逆波兰表达式&#xff1a;是一种后缀表达式&#xff0c;所谓后缀就是指运算符写在后面。 平常使用的算式则是一种中缀表达式&#xff0c;如 ( 1 2 ) * ( 3 4 ) 。 该算式的逆波兰表达式写法为 ( ( 1 2 ) ( 3 4 ) * ) 。 逆波兰表达式主要有以下两个优点&#xff1a; 去掉…

Spring AOP基于注解方式实现和细节

目录 一、Spring AOP底层技术 二、初步实现AOP编程 三、获取切点详细信息 四、 切点表达式语法 五、重用&#xff08;提取&#xff09;切点表达式 一、Spring AOP底层技术 SpringAop的核心在于动态代理&#xff0c;那么在SpringAop的底层的技术是依靠了什么技术呢&#x…

数据结构--队列与循环队列

队列 队列是什么&#xff0c;先联想一下队&#xff0c;排队先来的人排前面先出&#xff0c;后来的人排后面后出&#xff1b;队列的性质也一样&#xff0c;先进队列的数据先出&#xff0c;后进队列的后出&#xff1b;就像图一的样子&#xff1a; 图1 如图1&#xff0c;1号元素是…

【Rust】Rust学习 第十八章模式用来匹配值的结构

模式是 Rust 中特殊的语法&#xff0c;它用来匹配类型中的结构&#xff0c;无论类型是简单还是复杂。结合使用模式和 match 表达式以及其他结构可以提供更多对程序控制流的支配权。模式由如下一些内容组合而成&#xff1a; 字面值解构的数组、枚举、结构体或者元组变量通配符占…

Java【手撕双指针】LeetCode 18. “四数之和“, 图文详解思路分析 + 代码

文章目录 前言一、四数之和1, 题目2, 思路分析3, 代码 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: &#x1f4d5; JavaSE基础: 基础语法, 类和对象, 封装继承多态, 接口, 综合小练习图书管理系统等 &#x1f4d7; Java数据结构: 顺序表, 链表, 堆…

React笔记(二)JSX

一、JSX JSX是javascript XML的简写&#xff0c;实际上是javascript的扩展&#xff0c;既有javascript的语法结构&#xff0c;又有XML的结构 1、JSX的规则要求 jsx必须要有一个根节点 如果不想产生无用的根标签&#xff0c;但是还要遵守JSX的语法的要求&#xff0c;可以使用…

【javaweb】学习日记Day6 - Mysql 数据库 DDL DML

之前学习过的SQL语句笔记总结戳这里→【数据库原理与应用 - 第六章】T-SQL 在SQL Server的使用_Roye_ack的博客-CSDN博客 目录 一、概述 1、如何安装及配置路径Mysql&#xff1f; 2、SQL分类 二、DDL 数据定义 1、数据库操作 2、IDEA内置数据库使用 &#xff08;1&…