邻里注意Transformer(CVPR2023)

Neighborhood Attention Transformer

  • 摘要
  • 1、介绍
  • 2、相关工作
    • 2.1 新的卷积基线
  • 3、方法
    • 3.1 邻居注意力
    • 3.2 Tiled NA and NATTEN
    • 3.3 邻居注意力Transformer
  • 4、结论

代码

摘要

我们提出邻居注意力(NA),第一个有效和可伸缩的滑动窗口的视觉注意机制。
NA是一种像素级的操作,将自我注意定位到最近的邻近像素上,因此与SA(自注意力)的二次复杂度相比,具有线性的时间和空间复杂度。
与Swin Transformer的窗口自我注意(WSA)不同,滑动窗口模式允许NA的接受域在不需要额外像素位移的情况下增长,并保持平移均方差。
我们开发了NATTEN(邻居注意力扩展),这是一个包含高效c++和CUDA内核的Python包,它允许NA比Swin的WSA运行速度快40%,同时使用的内存少25%。
我们进一步提出了邻居注意力Transformer(NAT),一种新的层次变压器设计基于NA,提高图像分类和下游视觉性能。
NAT实验结果具有一定的竞争力;NAT-Tiny在ImageNet上的准确率达到83.2%,在MSCOCO上达到51.4%的mAP,在ADE20K上达到48.4%的mIoU,这是1.9%的ImageNet准确率,1.0%的COCO mAP,和2.6%的ADE20K mIoU相比相同尺寸的Swin模型。
为了支持更多基于滑动窗口注意力的研究,我们开放了项目的源代码并发布了我们的检查点。

1、介绍

在这里插入图片描述
图 1. 自注意力、(转移的)窗口自注意力和邻域注意力中注意力跨度的图示。自注意力允许每个令牌关注所有事情。 Window Self Attention 将自注意力划分为不重叠的子窗口,然后是 Shifted Window Self Attention,它允许进行感受野扩展所需的窗外交互。邻域注意力将注意力集中在每个标记周围的邻域,引入局部归纳偏差,保持平移等方差,并允许感受野增长而无需额外的操作。

这些高性能的变压器类方法都是基于自我注意(SA),这是原始Transformer[31]的基本构件。SA在嵌入维数方面具有线性复杂度(不包括linear投影),但在tokens数量方面具有二次复杂度。
在视觉范围内,tokens通常与图像分辨率成线性相关。
因此,在严格使用SA(如ViT)的模型中,更高的图像分辨率会导致复杂性和内存使用量的二次增加。
二次复杂性使得这种模型难以适用于后续的视觉任务,如目标检测和分割,在这些任务中,图像分辨率通常比分类分辨率大得多。
另一个问题是,卷积受益于位置和二维空间结构等归纳偏差,而点积的自我注意是一个全局一维操作。
这意味着一些归纳偏差必须通过大量的数据[12]或高级培训技术和增强来学习。
在本研究中,我们回顾了显式滑动窗注意机制,并提出了邻居注意力(NA)。
NA将SA定位到每个像素的最近邻居,它不一定是像素周围的固定窗口。
这种定义上的改变允许所有像素保持相同的注意广度,否则在零填充替代方案(SASA)中,角像素会减少注意广度。
随着邻域大小的增加,NA也接近SA,并且在最大邻域时等于SA。
此外,与阻塞和窗口自注意不同,NA具有保持平移等方差[30]的额外优势。
我们开发了NATTEN,这是一个带有高效c++和CUDA内核的Python包,在使用更少内存的情况下,它允许NA在实践中比Swin的WSA运行得更快。
我们建立了邻居注意力Transformer(NAT),它实现了跨视觉任务的竞争结果。
贡献:
1、==提出邻居注意力(NA):一种简单而灵活的显式滑动窗口注意机制,它将每个像素的注意广度定位到其最近的邻域,随着其广度的增长而接近自我注意,并保持平移方差。==我们将NA的复杂度和内存使用量与自我注意、窗口自我注意和卷积进行了比较。
2、为NA开发高效的c++和CUDA内核,包括tile NA算法,它允许NA运行速度比Swin的WSA快40%,同时使用最多25%的内存。我们将它们发布在一个新的Python包中,用于显式滑动窗口注意机制,NATTEN,以提供易于使用的带有自研支持的模块,这些模块可以插入任何现有的PyTorch管道中。
3、==引入邻居注意力Transformer(NAT),一种新的高效、准确、可伸缩的基于NA的分级变压器。==我们证明了它在分类和下游任务上的有效性。例如,NAT-Tiny在ImageNet上达到了83.2%的top-1准确率,只有4.3 GFLOPs和28M参数,在MS-COCO上是51.4%的box mAP,在ADE20K上是48.4%的mIoU,显著优于Swin Transformer和ConvNeXt[22]。

2、相关工作

2.1 新的卷积基线

Liu et al.[22]提出了一种受Swin等模型影响的新的CNN架构,称为ConvNeXt。
这些模型不是基于注意力的,并且在不同的视觉任务中都优于Swin。
这项工作已经成为一个新的CNN基线,用于对卷积模型和基于注意力的模型进行公平的比较。
我们建议邻居注意力,它通过设计将接受域定位到每个查询周围的窗口,因此不需要额外的技术,例如Swin使用的循环移位。
此外,邻居注意力保持了平移的方差,这是交换效率的方法,如HaloNet和Swin。
我们用NATTEN python包演示了邻居注意力可以比Swin等方法运行得更快,同时使用更少的内存。
我们引入了一个具有这种注意力机制的分层变换式模型,称为邻居注意力Transformer,并展示了它与Swin相比在图像分类、目标检测和语义分割方面的性能。

3、方法

在这一节中,我们介绍了邻居注意力,一种考虑到视觉数据结构的自我注意定位(参见Eq.(1))。与自注意相比,这不仅降低了计算成本,而且还引入了类似于卷积的局部归纳偏差。我们发现,在限制自我注意方面,NA优于之前提出的SASA[25],但在理论成本上是等价的。然后我们介绍我们的Python包NATTEN,它为CPU和GPU加速提供了NA的有效实现。我们讨论了扩展中的新奇之处,以及它如何在使用更少内存的情况下,设法超过Swin的WSA和SWSA的速度。最后我们介绍了我们的模型,邻居注意力Transformer(NAT),它使用这种新的机制代替自我注意。此外,NAT利用了一种多级的分层设计,类似于Swin[21],这意味着特征特征图在不同的层之间向下采样,而不是一次全部采样。与Swin不同,NAT使用重叠卷积对特征特征图进行降采样,而非非重叠(打补丁)的特征特征图,后者已被证明通过引入有用的归纳偏差来提高模型性能[15,34]。
在这里插入图片描述

3.1 邻居注意力

在这里插入图片描述
图 2. 单个像素的邻域注意力 (NA) 与自注意力 (SA) 的查询键值结构图示。 SA 允许每个像素关注所有其他像素,而 NA 则将每个像素的注意力集中到其周围的邻域。因此,每个像素的注意力广度通常与下一个像素不同。

Swin的WSA可以被认为是现有限制自我注意速度最快的方法之一,以降低二次注意成本。它只是简单地划分输入,并将自我注意单独应用到每个分区。WSA需要与移位的变体SWSA配对,后者将这些分隔线移位以允许窗口外的交互。
这对扩大其接受范围至关重要。然而,限制局部自注意的最直接方法是允许每个像素对其邻近像素进行关注。这导致在大多数像素周围有一个动态移动的窗口,这扩大了接受域,因此将不需要手动移动的变体。
此外,与Swin不同,与卷积相似的是,这种动态形式的受限自我注意可以保持平移等方差30。
受此启发,我们引入了邻居注意力(NA)。给定一个输入X∈Rn×d,它是一个矩阵,其行是d维token向量,以及X的线性投影Q, K, V,和相对位置偏差B(i, j),我们定义了第i个输入的注意权值,其邻域大小为K, Ak i,作为第i个输入的查询投影的点积,它的k个最邻近的关键投影:
在这里插入图片描述
其中ρj(i)表示第j个最近的邻居。然后我们定义邻近值vki,作为一个矩阵,它的行是第i个输入的k个最近的邻近值投影:
在这里插入图片描述
邻域大小为k的第i个token的邻居注意力定义为:
在这里插入图片描述
其中√d为缩放参数。对feature map中的每个像素重复此操作。图2和图八展示了这一操作的插图。
从这个定义可以很容易看出,随着k的增长,Ak i趋向于自我注意权重,Vk i趋向于Vi本身,因此邻居注意力趋向于自我注意。这是NA和SASA[25]之间的关键区别,每个像素在输入周围使用填充来处理边缘情况。正是由于这种差异,当窗口大小增加时,NA会接近于自我注意,这在SASA中并不适用,因为输入周围的填充为零。

3.2 Tiled NA and NATTEN

在过去,以像素方式限制自我注意还没有得到很好的研究,主要是因为它被认为是一种昂贵的操作[21,25,30],需要更低层次的重新实现。这是因为自我关注本身被分解为矩阵乘法,这是一种很容易在加速器上并行化的操作,并且在计算软件中为不同的用例定义了大量的高效算法(举几个例子:LAPACK、cuBLAS、cutass)。此外,大多数深度学习平台,如PyTorch,都是在此类软件和附加包(如cuDNN)之上编写的。这对研究人员非常有帮助,因为它允许他们使用操作的抽象,如矩阵乘法或卷积,而后端根据他们的硬件、软件和用例决定运行哪个算法。它们通常还能处理自动梯度计算,这使得设计和训练深度神经网络非常简单。由于NA的像素级结构(以及其他像素级注意机制,如SASA[25]),以及NA中邻域定义的新奇性,在这些平台上实现NA的唯一方法是堆叠大量高效的操作来提取邻域,并将其存储为一个中间张量,然后计算注意力。
这会导致操作非常慢,内存使用量呈指数增长。为了应对这些挑战,我们开发了一套高效的CPU和CUDA内核,并将它们打包为Python包(邻居注意力扩展(NATTEN))。NATTEN包括半精度支持,对1D和2D数据的支持,以及与PyTorch的自动兼容集成。这意味着用户可以简单地将NA导入为PyTorch模块,并将其集成到现有的管道中。我们还补充说,SASA也可以很容易地用这个包实现,而不需要对底层内核进行更改(只需将输入填充为零),因为这是NA的一种特殊情况。反之则不成立。它还包括我们的平铺NA算法,它通过将不重叠的查询平铺加载到共享内存来计算邻居的注意力权重,以最小化全局内存读取。与简单的实现相比,平放的NA可以减少一个数量级的延迟(技术细节见附录A),并且它允许基于NA的模型比类似的Swin模型运行速度快40%(见图4)。NATTEN的开源网址是:https://github.com/SHI-Labs/NATTEN。

3.3 邻居注意力Transformer

在这里插入图片描述
图 5.我们的模型 NAT 及其分层设计的概述。该模型从卷积下采样器开始,然后进入 4 个连续级别,每个级别由多个 NAT 块组成,这些块是类似变压器的编码器层。每层由多头邻域注意力(NA)、多层感知器(MLP)、每个模块之前的层规范(LN)和跳过连接组成。在各个级别之间,特征图被下采样至其空间大小的一半,而其深度则加倍。这样可以更轻松地通过特征金字塔转移到下游任务。
NAT 使用 2 个连续的 3 × 3 卷积(步幅为 2 × 2)嵌入输入,导致空间大小为输入大小的 1/4。这类似于使用补丁和具有 4 × 4 补丁的嵌入层,但它利用重叠卷积而不是非重叠卷积来引入有用的归纳偏差 [15,34]。另一方面,使用重叠卷积会增加成本,并且两个卷积会产生更多参数。然而,我们通过重新配置模型来处理这个问题,从而实现更好的权衡。 NAT 由 4 个级别组成,每个级别后面都有一个下采样器(最后一个除外)。下采样器将空间大小减少一半,同时通道数量增加一倍。我们使用步长为 2 × 2 的 3 × 3 卷积,而不是 Swin 使用的 2 × 2 非重叠卷积(补丁合并)。由于分词器下采样为 4 倍,因此我们的模型生成大小为 h 4 × w 4 、 h 8 × w 8 、 h 16 × w 16 和 h 32 × w 32 的特征图。这一变化是由之前成功的 CNN 结构推动的,随后是其他基于分层注意力的方法,例如 PVT [32]、ViL [38] 和 Swin Transformer [21]。此外,我们使用 LayerScale [29] 来保证较大变体的稳定性。图 5 展示了整体网络架构。我们在表 1 中总结了不同 NAT 变体。
在这里插入图片描述

4、结论

在本文中,我们提出了邻域注意力(NA),这是第一个高效且可扩展的视觉滑动窗口注意力机制。 NA 是一种逐像素操作,它将每个像素的自注意力定位到其最近邻域,因此具有线性复杂度。与阻塞(HaloNet)和窗口自注意力(Swin)不同,它还引入了局部归纳偏差并保持平移等方差。与 SASA 不同,NA 随着窗口大小的增加而接近自注意力,并且在极端情况下不限制注意力跨度。我们通过开发 NATTEN 来挑战显式滑动窗口注意力模式效率不高或可并行化的普遍观念 [21]。通过使用 NATTEN,基于 NA 的模型可以比现有替代方案运行得更快,尽管后者主要运行在构建于较低级别计算包之上的高度优化的深度学习库上。我们还提出了 Neighborhood Attention Transformer (NAT) 并展示了此类模型的强大功能:NAT 在图像分类方面优于 Swin Transformer 和 ConvNeXt,并且在下游视觉任务中优于或与两者竞争。我们开源整个项目,以鼓励在这个方向进行更多研究。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/141432.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

链表题(3)

链表题 正文开始前给大家推荐个网站,前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 本篇内容继续给大家带来链表的一些练习题 链表分割 知识点: 编程基础 链表…

北京智达鑫业信息咨询有限公司专业的信息技术服务领域资质认证解决方案供应商

北京智达鑫业信息咨询有限公司成立于2014年1月8日,注册资本为500万元人民币.公司主要致力于信息化项目的资质咨询、指导、和培训服务,以及为互联网技术领域服务的企业。主要业务有:(CS)信息系统建设和服务能力评估、&a…

vue 数字软键盘 插件 封装 可拖动

1、效果图 2、使用方式 <Keyboard v-if"show" close"show false" :inputDom"$refs.input" /> 封装的数字键盘 Keyboard.vue 组件代码 <template><divclass"keyboard"ref"keyboard":style"{ left: …

《QT从基础到进阶·二十四》按钮组QButtonGroup,单选框QRadioButton和多选框QCheckBox

1、按钮组QButtonGroup 如果有多个单选按钮&#xff0c;可以统一放进一个按钮组。 图中有三个单选按钮放进了一个QGroupBox,并且设置了水平布局&#xff0c;现在要将这三个单选按钮放进一个按钮组&#xff0c;之前的想法是先把三个按钮加入按钮组&#xff0c;再把按钮组放进QG…

图的表示与基础--Java

1.图的基础知识 该图片来自于&#xff1a; https://b23.tv/KHCF2m6 2.稀疏图与稠密图 G(V,E)&#xff1a;V顶点个数&#xff0c;E边的个数 稀疏图&#xff1a;E<<V 一般用邻接表表示(数组链表) 稠密图&#xff1a;E接近V 一般用邻接矩阵表示&#xf…

S32K3基础学习 linker链接器脚本ld文件的学习(一)

一、简介 最近学习NXP新推出的S32K3系列芯片&#xff0c;我在学习容易转牛角尖&#xff0c;非得要搞明白这个芯片的启动流程&#xff0c;所以花费了一些时间&#xff0c;进行查阅资料进行学习&#xff0c;这里做下详细的记录&#xff0c;希望有用&#xff0c;如果有错误欢迎指正…

海上船舶交通事故VR模拟体验低成本高效率-深圳华锐视点

在海上运输行业&#xff0c;安全事故的防范和应对能力是企业安全教育的重中之重。针对这一问题&#xff0c;海上运输事故VR模拟逃生演练成为了一种创新且高效的教育手段。通过这种演练&#xff0c;企业能够在提升员工安全意识和技能方面获得多方面的帮助。 在VR船舶搜救演练中&…

第十五章,输入输出流例题

package 例题;import java.io.File;public class 例题1 {public static void main(String[] args) {//创建文件对象File file new File("D:\\Java15-1.docx");//判断&#xff0c;如果该文件存在。exists存在的意思if (file.exists()) {//删除//file.delete();//Syst…

轻量封装WebGPU渲染系统示例<28>- MRT纹理(源码)

当前示例源码github地址: https://github.com/vilyLei/voxwebgpu/blob/feature/rendering/src/voxgpu/sample/MRT.ts 当前示例运行效果: 此示例基于此渲染系统实现&#xff0c;当前示例TypeScript源码如下: export class MRT {private mRscene new RendererScene();initial…

拼多多商品详情API接口接入流程如下:

拼多多商品详情API接口可以用于获取拼多多商品的具体信息&#xff0c;包括商品ID、商品名称、价格、销量、评价等。以下是使用拼多多商品详情API接口的步骤&#xff1a; 进入拼多多开放平台&#xff0c;注册并登录账号。在开放平台页面中&#xff0c;找到“商品详情”或“商品…

DDD领域驱动设计模式结构图面向接口编程

DDD领域驱动设计模式结构图面向接口编程 9.资源库 在刚接触资源库(Repository)时&#xff0c;第一反应便是这就是个 DAO 层&#xff0c;访问数据库&#xff0c;然后吧啦吧啦&#xff0c;但是&#xff0c;当接触的越久&#xff0c;越发认识到第一反应是错的&#xff0c;资源库更…

No194.精选前端面试题,享受每天的挑战和学习

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云课上架的前后端实战课程《Vue.js 和 Egg.js 开发企业级健康管理项目》、《带你从入…

azkaban的安装

一、下载上传文件 二、创建目录 mkdir /opt/soft/azkaban 三、解压 tar -zxvf /opt/install/azkaban-db-3.84.4.tar.gz -C /opt/soft/azkaban tar -zxvf /opt/install/azkaban-exec-server-3.84.4.tar.gz -C /opt/soft/azkaban tar -zxvf /opt/install/azkaban-web-server-…

Ultipa 支持OpenCypher,助力企业级应用发展

OpenCypher 是欧美图数据库厂家 Neo4j 基于其图查询语言Cypher 开发的一套开源图查询语言&#xff0c;该语言也是开发者们较为熟悉的图查询语言之一。 Ulitpa Graph&#xff08;嬴图&#xff09;于2022年6月实现的对OpenCypher 的支持&#xff0c;旨在让用户能够通过自己熟悉的…

APP攻防-资产收集篇反证书检验XP框架反代理VPN数据转发反模拟器

文章目录 常见问题防护手段 常见问题 没有限制过滤的抓包问题&#xff1a; 1、抓不到-工具证书没配置好 2、抓不到-app走的不是http/s 有限制过滤的抓包问题&#xff1a; 3、抓不到-反模拟器调试 4、抓不到-反代理VPN 5、抓不到-反证书检验 做移动安全测试时&#xff0c;设置…

Java中所有的运算符,以及运算符优先级(总结)

运算法是一种特殊的符号&#xff0c;用于表示数据的运算、复制、比较等。 1、算数运算符 // % 取余运算&#xff1a;结果的符号和被模数的符号一致 12 % 5 2 -12 % 5 -2 12 % -5 2 -12 % -5 -2int a1 10; int b1 a1; // a111, b111 int a2 10; int b2 a2; // a211, …

频域分析实践介绍

频域分析实践介绍 此示例说明如何执行和解释基本频域信号分析。该示例讨论使用信号的频域表示相对于时域表示的优势&#xff0c;并使用仿真数据和真实数据说明基本概念。该示例回答一些基本问题&#xff0c;例如&#xff1a;FFT 的幅值和相位的含义是什么&#xff1f;我的信号是…

谈谈Vue双向数据绑定的原理

目录 一、什么是Vue.js 二、什么是双向数据绑定 三、双向数据绑定的原理 一、什么是Vue.js Vue.js是一款流行的JavaScript前端框架&#xff0c;用于构建用户界面。它是一个轻量级、灵活而高效的框架&#xff0c;被广泛应用于单页应用程序和可交互的前端界面开发。Vue.js的设…

Jmeter+ant+Jenkins持续集成

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

同济 MBA 携手和鲸课程共建,以数智人才培养持续赋能企业数字化转型

数智化的浪潮席卷全球&#xff0c;我国产业界应如何做出应变&#xff1f;各企业又该如何深化数字化转型&#xff1f;在任重道远的持续探索中&#xff0c;数智人才培养作为企业实现成功转型的关键要素&#xff0c;已然成为大势所趋。 同济大学综合 MBA 项目高度重视工商管理人才…