Swin Transformer—— 基于Transformer的图像识别模型

概述

Swin Transformer是微软研究院于2021年在ICCV上发表的一篇论文,因其在多个视觉任务中的出色表现而被评为当时的最佳论文。它引入了移动窗口的概念,提出了一种层级式的Vision Transformer,将Shifted Windows(移动窗口)作为其主要贡献。这个概念使得Swin Transformer可以像卷积神经网络一样进行分块,并进行层级式的特征提取,从而在特征表示中引入多尺度的概念。

在OpenAI发布的Sora中也出现了视频patches的概念,这进一步表明了Vision Transformer和Swin Transformer在引入patch概念方面的重要性。目前,许多多模态模型的backbone都采用了这两种模型,因此理解和应用它们的原理对于掌握和应用这些优秀的多模态模型非常必要。

在 Swin Transformer之前,基于Transformer的图像识别模型是视觉变换器(ViT)。它将图像视为由 16x16 个单词组成的句子,是自然语言处理中使用的变换器在图像识别中的首次应用。

本文指出了文本和图像之间的差异,并提出了 Swin Transformer,使 ViT 更适应图像领域。

文字和图像的两个区别如下。

  • 与文字符号不同,图像中的视觉元素在比例上差异很大
  • 图像中的像素比文件中的文字具有更高的分辨率(更多信息)。

为了消除这些差异

  • 计算不同贴片尺寸下的关注度
  • 用较小的补丁尺寸计算关注度。

下图说明了 ViT 和 Swin Transformer在这些方面的区别。

用较小的斑块尺寸计算注意力可以获得精细的特征,但计算成本较高。

这就是在 Swin 变换器中引入基于移位窗口的自注意的原因。多个补丁被合并到一个窗口中,注意力计算只在该窗口中进行,从而减少了计算量。

在下一节中,我们将了解斯温变换器的整体情况,然后了解一些更微小的细节,包括基于移位窗口的自我关注。

Swin Transformer

大画面

下面是 Swin Transformer的全貌。

首先,对输入图像进行补丁分割。

补丁分割:将 4x4 像素分割为一个补丁;由于 ViT将 16x16像素作为一个补丁,因此斯温变换器可以提取更精细的特征。

然后进行线性嵌入。

线性嵌入:将补丁(4x4x3ch)转换为 C 维标记,其中 C 取决于模型的大小。

对于从每个补丁中获得的标记,Swin Transformer Block 会计算关注度并进行特征提取。

Swin Transformer区块:用基于移位窗口的自保持(W-MSA 和 SW-MSA)取代常规变压器区块中使用的多头自保持(MSA)。以下章节将提供更多信息。下文将对它们进行更详细的介绍。其他配置与普通变压器几乎完全相同。

目前看到的线性嵌入和变换块部分被称为第 1 阶段;共有 1 到 4 个阶段,但每个阶段的补丁大小不同,因此可以在不同尺度上进行特征提取。不同大小的补丁是由补丁合并(Patch Merging)产生的,它将邻域中的补丁聚合在一起。

补丁合并:在每个阶段,相邻的(2 × 2)补丁(标记)合并在一起,形成一个标记。具体来说,合并 2 × 2 标记,并通过线性层将所得的 4C 维向量变为 2C 维。例如,在第 2 阶段,(H/4)×(W/4)×C 维度被简化为 (H/8)×(W/8)×2C 维度。

基于移动窗口的自我关注

从计算复杂度的角度解释了普通变压器和斯温变压器模块注意力计算的区别。

法线变换器计算所有标记之间的距离,其中 h 和 w 是图像中垂直和水平斑块的数量,计算量如下

另一方面,Swin 变换器只计算由多个补丁组成的窗口内的关注度:一个窗口包含 M x M 个补丁,基本固定为 M = 7。计算复杂度如下式所示。

在普通变换器中,计算复杂度的增加与补丁数 (hw) 的平方成正比。然而,由于 M = 7,影响很小,即使是补丁数 (hw) 的增加也保持在幂级数以内。这使得 Swin变换器可以计算小尺寸的贴片。

接下来介绍将图像划分为窗口的方法:窗口的排列方式是将图像平均划分为 M x M 个补丁。以这种方式排列的每个窗口都会计算注意力,因此即使是相邻的补丁,如果它们是不同的窗口,也不会计算注意力。为了解决窗口边界问题,在计算第一个注意力(W-MSA:基于窗口的多头自注意力)后,窗口会被移动,注意力会被再次计算(SW-.MSA:基于移动窗口的多头自注意)。

如下图所示,在原窗口分割的基础上移动([M/2], [M/2])个像素。


代码模型:

class PatchEmbed(nn.Module):
    
    def __init__(self, patch_size=4, in_c=3, embed_dim=96, norm_layer=None):
        super().__init__()
        patch_size = (patch_size, patch_size)
        self.patch_size = patch_size
        self.in_chans = in_c
        self.embed_dim = embed_dim
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        _, _, H, W = x.shape

        # 如果输入图片的 H,W 不是patch_size的整数倍,需要进行padding
        pad_input = (H % self.patch_size[0] != 0) or (W % self.patch_size[1] != 0)
        if pad_input:
            # to pad the last 3 dimensions, (W_left, W_right, H_top,H_bottom, C_front, C_back)
            x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1],
                          0, self.patch_size[0] - H % self.patch_size[0],
                          0, 0))

        # 下采样patch_size倍
        x = self.proj(x)
        _, _, H, W = x.shape
        # flatten: [B, C, H, W] -> [B, C, HW];  transpose: [B, C, HW] -> [B, HW, C]
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x, H, W

移位配置的高效批量计算

SW-MSA 的窗口大小不同,窗口数量也会增加。因此,如果直接进行处理,就会出现计算量比 W-MSA 增加的问题。因此,在 SW-MSA 中,使用一种称为循环移动的方法进行伪操作,而不是实际改变窗口的排列。

如下图所示,整个图像向左上方移动,溢出区域插入空白区域 (循环移动 )。通过这种方法,它的计算方法与 W-MSA 窗口中的 Attention 计算方法相同。此外,由于窗口中可能包含不相邻的斑块,因此要对这些部分进行掩膜处理。在最终输出中,将执行循环移位的反向操作(反向循环移位),将补丁恢复到原始位置。


代码实现:

class PatchMerging(nn.Module):

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x, H, W):
        """
        x: B, H*W, C
        """
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        x = x.view(B, H, W, C)

        # 如果输入feature map的H,W不是2的整数倍,需要进行padding
        pad_input = (H % 2 == 1) or (W % 2 == 1)
        if pad_input:
            # to pad the last 3 dimensions, starting from the last dimension and moving forward.
            # (C_front, C_back, W_left, W_right, H_top, H_bottom)
            # 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
            x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))

        x0 = x[:, 0::2, 0::2, :]  # [B, H/2, W/2, C]
        x1 = x[:, 1::2, 0::2, :]  # [B, H/2, W/2, C]
        x2 = x[:, 0::2, 1::2, :]  # [B, H/2, W/2, C]
        x3 = x[:, 1::2, 1::2, :]  # [B, H/2, W/2, C]
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, H/2, W/2, 4*C]
        x = x.view(B, -1, 4 * C)  # [B, H/2*W/2, 4*C]

        x = self.norm(x)
        x = self.reduction(x)  # [B, H/2*W/2, 2*C]

        return x

结构变体

Swin 变压器有 T、S、B 和 L 四种尺寸,每级的尺寸(dim)、头(head)和块数各不相同,如下表所示。

试验

在 ImageNet-1K 的图像识别任务、COCO 的物体检测任务和 ADE20K 的语义分割任务中与其他模型进行了比较,结果都达到了最高准确率。(实验结果详见本文第四章表 1~表 3)。

在 SW-MSA 中进行的消融研究证实,在这两项任务中,引入 SW-MSA 的准确率都高于单独引入 W-MSA。

摘要

与在所有斑块之间计算注意力的 ViT 不同,注意力计算和斑块聚合可以在相邻斑块的窗口中重复进行,从而可以在不同尺度上提取特征。另一个优点是不在所有斑块之间计算注意力,从而降低了计算复杂度,并能从较小的斑块尺寸中提取特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/577713.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++图论基础(1)

目录 无向图 无向图度 无向图性质 有向图 有向图度 有向图性质 图的分类: 稀疏图: 稠密图: 零图: 有向完全图: 无向完全图: 度序列: 图是由顶点集合(简称点集)和顶点间的边(简称边…

mac上安装Tomcat

1. 简介 Tomcat 是一个开源的 Java 服务器,它实现了 Java Servlet、JavaServer Pages(JSP)和Java WebSocket 技术。Tomcat 是 Apache 软件基金会的一个项目,是一个轻量级、高性能的 Web 容器。作为一个 Web 服务器,To…

【Java EE】CAS原理和实现以及JUC中常见的类的使用

˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN 如…

11.JAVAEE之网络原理1

1.应用层(和程序员接触最密切) 应用程序 在应用层这里,很多时候, 都是程序员"自定义"应用层协议的,(当然,也是有一些现成的应用层协议)(这里的自定义协议,其实是非常简单的~~协议 >约定,程序员在代码中规定好,数据如何进行传输) 1.根据需求, 明确要传…

了解HTTP代理服务器:优势、分类及应用实践

在我们日常的网络使用中,我们经常听到HTTP代理服务器这个术语。那么,HTTP代理服务器到底是什么?它有什么优势和分类?又如何应用于实践中呢?让我们一起来了解一下。 HTTP代理服务器是一种位于客户端和服务器之间的中间…

中电金信:向“新”而行——探索融合架构的项目管理在保险行业的应用

近年来,险企在政策推动、市场牵引、自身发展、新技术应用日趋成熟等内外部因素的驱动下,积极投身到数字化转型的浪潮中。在拜访各类保险客户和合作项目的过程中,我们发现不少险企在数字化转型中或多或少都面临着战略如何落地、技术如何承接和…

美国洛杉矶站群服务器如何提高网站排名?

美国洛杉矶站群服务器怎么样?美国洛杉矶站群服务器如何提高网站排名?Rak部落小编为您整理发布美国洛杉矶站群服务器如何提高网站排名? 美国洛杉矶站群服务器可以通过以下几种方式帮助提高网站排名: - **提升网站性能**:美国站群服务器通常配备高速CPU…

python-pytorch官方示例Generating Names with a Character-Level RNN的部分理解0.5.03

pytorch官方示例Generating Names with a Character-Level RNN的部分理解 模型结构功能关键技术模型输入模型输出预测实现 模型结构 功能 输入一个类别名和一个英文字符,就可以自动生成这个类别,且以英文字符开始的姓名 关键技术 将字符进行one-hot编…

抖音小店怎么做?新店铺起店就做这3步,核心玩法来了

大家好,我是电商笨笨熊 做抖音小店迟迟不起店,店铺一直没有销量怎么办? 新店铺玩家前期一定都遇到过这种烦恼,毫无头绪不知道该从哪入手; 实际上,想要店铺快速起店,只需要做对三步就够了。 作…

基于Rust的多线程 Web 服务器

构建多线程 Web 服务器 在 socket 上监听 TCP 连接解析少量的 HTTP 请求创建一个合适的 HTTP 响应使用线程池改进服务器的吞吐量优雅的停机和清理注意:并不是最佳实践 创建项目 ~/rust ➜ cargo new helloCreated binary (application) hello package~/rust ➜ma…

一 SSM 整合理解

SSM整合理解 一 SSM整合什么 ​ 以spring框架为基础,整合springmvc,mybatis框架,以更好的开发。 ​ spring管理一切组件,为开发更好的解耦,以及提供框架的组件,如aop,tx。springmvc是表述层框…

Bytebase 2.16.0 - 支持 Oracle 和 SQL Server DML 变更的事前备份

🚀 新功能 支持 Oracle 和 SQL Server DML 变更的事前备份。 支持在 SQL 编辑器中显示存储过程和函数。 支持兼容 TDSQL 的 MySQL 和 PostgreSQL 版本。 支持把数据库密码存储在 AWS Secrets Manager 和 GCP Secret Manager。 支持通过 IAM 连接到 Google Clou…

积极应对半导体测试挑战 加速科技助力行业“芯”升级

在全球半导体产业高速发展的今天,中国“芯”正迎来前所未有的发展机遇。AI、5G、物联网、自动驾驶、元宇宙、智慧城市等终端应用方兴未艾,为测试行业带来新的市场规模突破点,成为测试设备未来重要的增量市场。新兴领域芯片产品性能不断提升、…

如何有效的将丢失的mfc140u.dll修复,几种mfc140u.dll丢失的解决方法

当你在运行某个程序或应用程序时,突然遭遇到mfc140u.dll丢失的错误提示,这可能会对你的电脑运行产生一些不利影响。但是,不要担心,以下是一套详细的mfc140u.dll丢失的解决方法。 mfc140u.dll缺失问题的详细解决步骤 步骤1&#x…

通过一篇文章让你了解STL是什么

STL 前言一、什么是STL二、STL的版本原始版本P. J. 版本RW版本SGI版本 三、STL的六大组件四、STL的重要性试题面经 五、如何学习STL六、STL的缺陷 前言 STL(Standard Template Library)是C编程语言的一个标准库,包含了一系列模板类和函数&am…

Jmeter之Beanshell详解

一、 Beanshell概念 Beanshell: BeanShell是一种完全符合Java语法规范的脚本语言,并且又拥有自己的一些语法和方法;BeanShell是一种松散类型的脚本语言(这点和JS类似);BeanShell是用Java写成的,一个小型的、免费的、可以下载的、嵌入式的Java源代码解释器,具有对象脚本语言特性…

RGB灯珠的控制-单片机通用模板

RGB灯珠的控制-单片机通用模板 一、RGB控制的原理二、RGB.c的实现三、RGB.h的实现四、color色彩空间变换以及控制渐变一、RGB控制的原理 ①通过IO发送脉冲识别0/1编码,组合24Bit的RGB数据,从而控制RGB;②每个RGB灯珠通过DIN、DOU进行级联起来;③通过HSV色彩转换成RGB从而控…

BUUCTF--web(2)

1、[HCTF 2018]admin1 打开题目后发现有注册和登录两个页面,因为题目提示admin,尝试用admin进行爆破 爆破得到密码为123 登录得到flag 2、[护网杯 2018]easy_tornado1 打开题目后有三个文件,分别打开查看 在url地址栏中发现包含两个参数&a…

[图解]领域驱动设计伪创新-为什么互联网是重灾区-01

0 00:00:00,840 --> 00:00:03,270 今天我们来说一下 1 00:00:03,650 --> 00:00:06,255 领域驱动设计伪创新 2 00:00:06,255 --> 00:00:08,860 为什么互联网是重灾区 3 00:00:09,500 --> 00:00:12,610 这个我们分几个视频来讲 4 00:00:15,620 --> 00:00:17,5…

FIB和RIB基础

1.思考以下的topo从数据层面和控制层面分别是如何通信的 (1)数据层面;数据包从PC1经过AR1 AR2最后到达PC2,这就是数据层面的通信。 (2)控制层面:PC2所在的网段192.168.2.0/24是经过AR2传递给AR…