Swin Transformer

Swin Transformer

简介

image-20230321183426196

  • 下采样的层级设计,能够逐渐增大感受野。
  • 采用window进行注意力计算,极大降低了内存消耗,避免了整张图像尺寸大小的qkv矩阵
  • 滑窗操作包括不重叠的 local window,和重叠的 cross-window。不重叠的local windows将注意力计算限制在一个窗口(window size固定),而cross-windows则让不同窗口之间信息可以进行关联,实现了信息的交互。

整体架构

930f1a33661f56ef6e4bb0bab3062769_3_Figure_3

  1. Patch Partition结构:将图像切分重排,并进行embedding
  2. Patch Merging结构:下采样方法,实现层次化结构
  3. Swin Transformer Block:一个W-MSA ,一个SW-MSA,也即是一个window-多头注意力机制和一个shift-windows多头注意力机制,实现将自注意力机制限制在一个windows中进行计算,同时,通过shift-window解决限制在一个windows中后,不同windows之间无信息共享的问题。

Patch Embedding

在图像切分重排中,采用的是使用patch size大小的conv2d进行实现

class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding图像切分重排

    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x

Patch Merging

class PatchMerging(nn.Module):
    r""" Patch Merging Layer.

    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x
    

img

SW-MSA设计

如下所示,w-msa mask避免窗口5和窗口3进行相似度计算,通过mask只在窗口内部进行计算。

通过对特征图移位,并给Attention设置mask来间接实现的。能在保持原有的window个数下,最后的计算结果等价

2023-11-18_10-20-26

2023-11-18_10-23-41

Window Attention

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q K T d + B ) V Attention(Q,K,V)=Softmax(\frac{QK^T}{\sqrt{d}}+B)V Attention(Q,K,V)=Softmax(d QKT+B)V

相对位置编码

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

img

对于相对位置编码,在2维坐标系中,当偏移从0开始时,(2,1)和(1,2)相对(0,0)的位置编码是不同的,而转为1维坐标后,却是相同数值,为了解决这个问题,采用对x坐标2 * self.window_size[1] - 1操作,从而进行区分。而该相对位置编码需要2 * self.window_size[1] - 1编码数值。

A Survey of Transformers

图解Swin Transformer - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/164719.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

volatile 无法保证原子性 案例展示

volatile 无法保证原子性 在 Java 中,原子性是指一个操作是不可中断的,要么都执行要么都不执行。 但是 volatile 修饰的变量,只是保证了从主内存加载到工作内存的值是最新的,并不能保证对变量的操作是原子性的 变量的写操作和读…

关于缓存和数据库一致性问题的深入研究

如何保证缓存和数据库一致性,这是一个老生常谈的话题了。 但很多人对这个问题,依旧有很多疑惑: 到底是更新缓存还是删缓存?到底选择先更新数据库,再删除缓存,还是先删除缓存,再更新数据库&…

Spring Boot中实现支付宝、微信和银联支付的功能

Spring Boot中实现支付宝、微信和银联支付的功能 在Spring Boot中实现支付宝、微信和银联支付的功能,通常需要使用它们各自的SDK(Software Development Kit)。以下是一个简单的示例代码,演示了如何在Spring Boot项目中集成支付宝…

操作系统:操作系统教程第六版(骆斌、葛季栋、费翔林)习题一计算机操作系统概述

目录 前言1. 思考题2. 应用题 前言 本系列文章是针对操作系统教程第六版(骆斌、葛季栋、费翔林)的习题解答,其中简答题部分为博主自己搜索整理的,错漏之处在所难免。应用题部分有答案为依据。 1. 思考题 (1&#xf…

vscode设置代码模板

一键生成vue3模板代码 效果演示 输入vue3 显示快捷键 按回车键 一键生成自定义模板 实现方法 进入用户代码片段设置 选择片段语言 vue.json输入自定义的代码片段 prefix是触发的内容,按自己的喜好来就行; body是模板代码,写入自己需要的…

java springboot 在测试类中声明临时Bean对象

上文 java springboot在当前测试类中添加临时属性 不影响application和其他范围 中 我们讲了怎么在测试类中设置临时属性 但是 如果我们想设置临时的Bean呢? 其实做过几个项目的人都会理解 我们很多功能 需要一些第三方bean才能完成 那么 我们可能存在需要用第三方b…

MySQL 的执行原理(四)

5.5. MySQL 的查询重写规则 对于一些执行起来十分耗费性能的语句,MySQL 还是依据一些规则,竭尽全力的把这个很糟糕的语句转换成某种可以比较高效执行的形式,这个过程也可以 被称作查询重写。 5.5.1. 条件化简 我们编写的查询语句的搜索条件…

SpringMVC总结

SpringMVC简介 简介 SpringMVC是一款基于Servlet API构建的原始Web框架,从一开始就包含在Spring Framework中。正式名称“Spring Web MVC”来自其源模块的名称( spring-webmvc ),但它通常被称为“Spring MVC”。 调用流程 接收数…

清华 ChatGLM-6B 中文对话模型部署简易教程

本文目录 一、前言二、下载三、部署3.1 配置环境3.2 启动 demo 程序3.2.1 启动 cli_demo.py3.2.2 启动 web_demo.py 四、【最新】ChatGLM-6B-int4 版本教程4.1 下载4.2 配置环境4.3 启动 demo 程序 五、华为 CodeLab 免费 GPU 平台部署 ChatGLM-6B-int4(CPU版&#…

RabbitMQ消息的可靠性

RabbitMQ消息的可靠性 一 生产者的可靠性 生产者重试 有时候由于网络问题,会出现连接MQ失败的情况,可以配置重连机制 注意:SpringAMQP的重试机制是阻塞式的,重试等待的时候,当前线程会等待。 spring:rabbitmq:conne…

Blowfish在线加密解密调试校验工具

具体请前往:在线Blowfish加密解密工具

java游戏制作-拼图游戏

一.制作主界面 首先创建一个Java项目命名为puzzlegame。 再在src中创建一个包,用来制作主界面 代码: 结果: 二.设置界面 代码: 三.初始化界面 代码: 优化代码: 结果: 四.添加图片 先在Java项…

二叉搜索树的最近公共祖先

235. 二叉搜索树的最近公共祖先 给定一个二叉搜索树, 找到该树中两个指定节点的最近公共祖先。 百度百科中最近公共祖先的定义为:“对于有根树 T 的两个结点 p、q,最近公共祖先表示为一个结点 x,满足 x 是 p、q 的祖先且 x 的深度尽可能大&…

网络协议入门 笔记一

一、服务器和客户端及java的概念 JVM (Java Virtual Machine) : Java虚拟机,Java的跨平台:一次编译,到处运行,编译生成跟平台无关的字节码文件 (class文件),由对应平台的JVM解析字节码为机器指令 (010101)。 如下图所示&#xff0…

lv11 嵌入式开发 ARM指令集中(汇编指令集) 6

目录 1.指令 1.1 数据处理指令:数学运算、逻辑运算 1.1.1数据搬移指令MOV 、MVN 1.1.2立即数 1.1.3 加法指令 1.1.4 减法指令 1.1.5 逆向减法指令 1.1.6 乘法指令 1.1.7 与、或、非、异或、左移、右移指令 1.1.8 位清零指令 1.1.9 格式扩展 1.1.10 数据运算指令对条件位…

Ubuntu20.0中安装Gradle

下载Gradle到temp文件夹 wget https://services.gradle.org/distributions/gradle-8.3-bin.zip -P /tmp 然后解压文件到/opt/gradle目录 sudo unzip -d /opt/gradle /tmp/gradle-8.3.zip 配置Gradle环境变量 接下来我们会创建一个gradle.sh文件来保存Gradle的环境变量 sudo…

文心一言 VS 讯飞星火 VS chatgpt (138)-- 算法导论11.4 2题

二、用go语言,试写出HASH-DELETE 的伪代码;修改 HASH-INSERT,使之能处理特殊值DELETED。 文心一言: 首先,让我们定义一个基本的哈希表数据结构。这个结构将包括一个存储键值对的哈希表和一个存储已删除键值对的队列。我们可以用…

RobotFramework进阶之自定义的python模块(十四)

引言 RobotFramework自动化框架(以下简称RF)之前文章我们讲了通过import第三方的library(RequestsLibrary等),在实际项目中第三方的包并不能满足我们的需要,此时我们可自己编写python模块(.py文…

穷举法、回溯法、分支界限法解决旅行商(TSP)问题

文章目录 一、问题描述二、穷举法解决2.1 介绍2.2 代码 三、回溯法解决四、分支界限法4.1 介绍4.2 代码 一、问题描述 有一个旅行商由某城市出发,经过所有给定的 n n n 个城市后,再回到出发的城市。除了出发的城市外,其它城市只经过一回。这…

2018年计网408

第33题 下列 TCP/P应用层协议中, 可以使用传输层无连接服务的是()A. FTPB. DNSC. SMTPD. HTTP 本题考察TCP/IP体系结构中,应用层常用协议所使用的运输层服务。 如图所示。这是TCP/IP体系结构中常见应用层协议各自所使用的运输层端口,。在这些应用层协议中&#x…