CvT(ICCV 2021)论文与代码解读

paper:CvT: Introducing Convolutions to Vision Transformers

official implementation:https://github.com/microsoft/CvT

出发点

该论文的出发点是改进Vision Transformer (ViT) 的性能和效率。传统的ViT在处理图像分类任务时虽然表现出色,但在数据量较小的情况下,其表现不如同等规模的卷积神经网络(CNN)。研究人员认为这是因为ViT缺乏CNN固有的一些有利特性,如对局部空间信息的捕捉能力。本文提出通过在ViT结构中引入卷积操作来弥补这一不足,以获得更好的性能和鲁棒性。

创新点

本文解决了如何在保持ViT优点(如动态注意力机制、全局上下文建模和更好的泛化能力)的同时,引入卷积神经网络的优点(如局部感受野、权重共享和空间下采样)。具体来说,论文通过引入卷积的方式来增强ViT的局部信息捕捉能力和计算效率,从而在各种图像分类任务中取得更好的表现。具体如下

  1. 卷积token embedding层:在ViT的结构中引入卷积embedding层,通过卷积操作将图像转换为token,同时保留局部空间信息。这种方法使模型能够在多个阶段逐步减少令token序列长度,同时增加token特征维度,类似于CNN的设计。
  2. 卷积projection:标准Transformer模块中的线性投影替换为卷积投影。通过深度可分离卷积操作,进一步捕捉局部空间上下文,并减少注意力机制中的语义模糊性。此外,卷积投影的步幅可用于对key和value矩阵进行下采样,从而显著提高计算效率。
  3. 无需位置编码:实验表明,CvT模型可以在不使用位置编码的情况下取得良好的性能,这简化了模型设计,尤其适用于处理高分辨率图像任务。

方法介绍

CvT的整体pipeline如图2所示。作者将两种基于卷积的operation引入Vision Transformer中,即Convolutional Token Embedding和Convolutional Projection。如图2(a)所示,借鉴了CNN采用了一个多个stage的层级设计,本文一共包含三个stage。每个stage包括两部分,首先输入图片(或reshape后的二维token map)经过Convolutional Token Embedding层的处理,具体是通过一个重叠的卷积实现。这使得每个stage可以逐渐减少token的数量(即特征分辨率)并增加token的宽度(即特征的维度),从而实现空间降采样并增加特征表示的丰富性。和之前的各种视觉Transformer不同,本文在这里并没有加上一个位置编码。接下来是堆叠的多个本文提出的Convolutional Transformer Block如图2(b)所示, 其中一个深度可分离卷积作为卷积投影分别作用于query、key和value。class token只在最后一个stage添加,最后通过一个MLP head得到最终的输出预测类别。 

Convolutional Token Embedding

给定一张图片或前一个stage输出并reshape成二维的token map \(x_{i-1}\in \mathbb{R}^{H_{i-1}\times W_{i-1}\times C_{i-1}}\) 作为当前stage \(i\) 的输入,我们学习一个卷积 \(f(\cdot)\) 将 \(x_{i-1}\) 映射到新的token \(f(x_{i-1})\),卷积核大小为 \(s\times s\),步长为 \(s-o\),padding为 \(p\)。新的token map \(f(x_{i-1})\in \mathbb{R}^{H_i\times W_i\times C_i}\) 的高和宽分别为

\(f(x_{i-1})\) 然后展平成 \(H_iW_i\times C_i\) 的shape并经过layer normalization处理,然后作为输入到stage \(i\) 的后续transformer block中。

Convolution Token Embedding层使得我们可以通过调整卷积的参数来调整每个stage的token特征维度和数量。通过这种方式,每个stage我们逐渐减少token序列的长度同时增加token特征的维度,使得token能够在越来越大的空间中表示越来越复杂的视觉模式,类似于CNN的特征层。

Convolutional Projection for Attention

本文提出的卷积映射层的目的是实现对局部context的额外建模,并通过对 \(K\) 和 \(V\) 矩阵降采样来提高效率。

图3(a)展示了ViT中使用的position-wise线性投影,图3(b)展示了本文提出的 \(s\times s\) 卷积投影。如图3(b)所示,tokens首先reshape成一个2D token map,然后通过一个深度可分离卷积实现卷积投影。最后再将projected tokens展平成1D作为后续的输入,如下

其中 \(x_i^{q/k/v}\) 是 \(i\) 层 \(Q/K/V\) 矩阵的token输入,\(conv2d\) 是一个深度可分离卷积具体实现为:\(Depthwise\ Con2d\rightarrow BatchNorm2d\rightarrow Pointwise\ Conv2d\),\(s\) 表示卷积核大小。原始的position-wise线性投影可以通过1x1卷积实现,因此这里新的卷积投影可以看作是一种推广。

实验结果

作者设计三种不同size的模型如表2所示,其中CvT-X中的X表示模型总共的transformer block的数量。CvT-224中的W表示Wide。

表3是在ImageNet数据集上和其它SOTA模型的对比。

代码解析

这里的代码是官方实现,convolutional token embedding的代码如下,在每个stage的开始都会首先经过ConvEmbed,以cvt-13为例,一共3个stage,patch_size=[7, 3, 3],patch_stride=[4, 2, 2],patch_padding=[2, 1, 1]。

class ConvEmbed(nn.Module):
    """ Image to Conv Embedding

    """

    def __init__(self,
                 patch_size=7,
                 in_chans=3,
                 embed_dim=64,
                 stride=4,
                 padding=2,
                 norm_layer=None):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        self.patch_size = patch_size

        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=padding
        )
        self.norm = norm_layer(embed_dim) if norm_layer else None

    def forward(self, x):
        x = self.proj(x)

        B, C, H, W = x.shape
        x = rearrange(x, 'b c h w -> b (h w) c')
        if self.norm:
            x = self.norm(x)
        x = rearrange(x, 'b (h w) c -> b c h w', h=H, w=W)

        return x

Attention的代码如下,在forward函数中会首先调用forward_conv得到q、k、v,这里的forward_conv就是本文提出的conv projection,在函数_build_projection中method='dw_bn',因此三个投影都是通过深度可分离卷积实现的。在self.forward_conv后就是普通的计算attention的过程了。

class Attention(nn.Module):
    def __init__(self,
                 dim_in,
                 dim_out,
                 num_heads,
                 qkv_bias=False,
                 attn_drop=0.,
                 proj_drop=0.,
                 method='dw_bn',
                 kernel_size=3,
                 stride_kv=1,
                 stride_q=1,
                 padding_kv=1,
                 padding_q=1,
                 with_cls_token=True,
                 **kwargs
                 ):
        super().__init__()
        self.stride_kv = stride_kv
        self.stride_q = stride_q
        self.dim = dim_out
        self.num_heads = num_heads
        # head_dim = self.qkv_dim // num_heads
        self.scale = dim_out ** -0.5
        self.with_cls_token = with_cls_token

        self.conv_proj_q = self._build_projection(
            dim_in, dim_out, kernel_size, padding_q,
            stride_q, 'linear' if method == 'avg' else method
        )
        self.conv_proj_k = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )
        self.conv_proj_v = self._build_projection(
            dim_in, dim_out, kernel_size, padding_kv,
            stride_kv, method
        )

        self.proj_q = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_k = nn.Linear(dim_in, dim_out, bias=qkv_bias)
        self.proj_v = nn.Linear(dim_in, dim_out, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim_out, dim_out)
        self.proj_drop = nn.Dropout(proj_drop)

    def _build_projection(self,
                          dim_in,
                          dim_out,
                          kernel_size,
                          padding,
                          stride,
                          method):
        if method == 'dw_bn':
            proj = nn.Sequential(OrderedDict([
                ('conv', nn.Conv2d(
                    dim_in,
                    dim_in,
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    bias=False,
                    groups=dim_in
                )),
                ('bn', nn.BatchNorm2d(dim_in)),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'avg':
            proj = nn.Sequential(OrderedDict([
                ('avg', nn.AvgPool2d(
                    kernel_size=kernel_size,
                    padding=padding,
                    stride=stride,
                    ceil_mode=True
                )),
                ('rearrage', Rearrange('b c h w -> b (h w) c')),
            ]))
        elif method == 'linear':
            proj = None
        else:
            raise ValueError('Unknown method ({})'.format(method))

        return proj

    def forward_conv(self, x, h, w):
        if self.with_cls_token:
            cls_token, x = torch.split(x, [1, h*w], 1)

        x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

        if self.conv_proj_q is not None:
            q = self.conv_proj_q(x)
        else:
            q = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_k is not None:
            k = self.conv_proj_k(x)
        else:
            k = rearrange(x, 'b c h w -> b (h w) c')

        if self.conv_proj_v is not None:
            v = self.conv_proj_v(x)
        else:
            v = rearrange(x, 'b c h w -> b (h w) c')

        if self.with_cls_token:
            q = torch.cat((cls_token, q), dim=1)
            k = torch.cat((cls_token, k), dim=1)
            v = torch.cat((cls_token, v), dim=1)

        return q, k, v

    def forward(self, x, h, w):
        if (
            self.conv_proj_q is not None
            or self.conv_proj_k is not None
            or self.conv_proj_v is not None
        ):
            q, k, v = self.forward_conv(x, h, w)

        q = rearrange(self.proj_q(q), 'b t (h d) -> b h t d', h=self.num_heads)
        k = rearrange(self.proj_k(k), 'b t (h d) -> b h t d', h=self.num_heads)
        v = rearrange(self.proj_v(v), 'b t (h d) -> b h t d', h=self.num_heads)

        attn_score = torch.einsum('bhlk,bhtk->bhlt', [q, k]) * self.scale
        attn = F.softmax(attn_score, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum('bhlt,bhtv->bhlv', [attn, v])
        x = rearrange(x, 'b h t d -> b t (h d)')

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/697057.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

风能远程管理ARMxy嵌入式系统深度解析

智能技术正以前所未有的速度融入传统能源管理体系,而ARMxy工业计算机作为这一变革中的关键技术载体,正以其独特的性能优势,为能源管理的智能化升级铺设道路。本文将聚焦于智能电表、太阳能电站监控、风力发电站远程管理三大应用场景&#xff…

央视频官方出品,AI高考智友助你成就高考梦想

大家好,我是小麦。今天分享一款由央视频官方出品的AI工具套件,不仅支持直接使用,同时还具备了开发能力,是一款非常不错的AI产品工具,该软件的名称叫做扣子。 扣子是新一代 AI 应用开发平台。无论你是否有编程基础&…

【Java探索之旅】继承结构 继承和组合 protected final

文章目录 📑前言一、继承1.1 继承关系的代码块1.2 protected关键字1.3 继承方式1.4 final关键字1.5 继承与组合 🌤️全篇总结 📑前言 在面向对象编程中,继承是一种重要的概念,它允许我们创建一个新类,从现有…

全局异常处理器

后端: 全局异常处理器的作用: 当我们在项目中碰到很多不同的异常情况时,我们需要去处理异常 不过我们不可能每个异常都用try/catch,那样很不优雅 所以我们可以用这个全局异常处理器,来优雅的处理异常 这个全局异常…

AI大模型日报#0610:港大等1bit大模型“解决AI能源需求”、谷歌开源TimesFM时序预测模型

导读:AI大模型日报,爬虫LLM自动生成,一文览尽每日AI大模型要点资讯!目前采用“文心一言”(ERNIE 4.0)、“零一万物”(Yi-Large)生成了今日要点以及每条资讯的摘要。欢迎阅读&#xf…

43【PS 作图】颜色速途

1 通过PS让画面细节模糊,避免被过多的颜色干扰 2 分析画面的颜色 3 作图 参考网站: 色感不好要怎么提升呢?分享一下我是怎么练习色感的!_哔哩哔哩_bilibili https://www.bilibili.com/video/BV1h1421Z76p/?spm_id_from333.1007.…

OpenGL绘制简单图形

绘制了一个紫色矩形和一个三角形&#xff0c;代码如下&#xff1a; #include <Windows.h> #include <gl/glut.h> void display(void) {glClearColor(0.0f, 0.0f, 0.0f, 1.0f); //设置清屏颜色glClear(GL_COLOR_BUFFER_BIT); //刷新颜色缓冲区&#xff1b;glColor3f…

QSlider样式示例

参考代码&#xff1a; /********************QSlider横向滑动条样式**********************/ QSlider {background-color: rgba(170, 255, 255, 100); /* 设置滑动条主体*/ }QSlider::groove:horizontal {border: 1px solid #999999;height: 8px; /* 默认…

力扣 42. 接雨水

题目来源&#xff1a;https://leetcode.cn/problems/trapping-rain-water/description/ C题解1&#xff1a;双指针 按列算&#xff0c;一列一列的求雨水面积。使用双指针是记录当前列左右侧的最大元素。 class Solution { public:int trap(vector<int>& height) {in…

运维一个宝塔面板的php项目的艰辛历程【解决了http3,ssl,quic】

在这个项目的环境 使用了宝塔面板 有4个php:php5.6,php7.3,php7.4,php8.0 nignx为1.20版本 升级计划&#xff1a; 升级nginx1.26.0版本&#xff0c;添加上http3协议&#xff0c;添加ssl证书 遇到的问题&#xff1a; 升级nginx1.26版本后 无法打开php5.6的后台 原因&#xff…

力扣hot100:295. 数据流的中位数(两个优先队列维护中位数)

LeetCode&#xff1a;295. 数据流的中位数 这个题目最快的解法应该是维护中位数&#xff0c;每插入一个数都能快速得到一个中位数。 根据数据范围&#xff0c;我们应当实现一个 O ( n l o g n ) O(nlogn) O(nlogn)的算法。 1、超时—插入排序 使用数组存储&#xff0c;维持数…

MySQL数据库(二)和java复习

一.MySQL数据库学习(二) (一).DQL查询数据 DQL&#xff08;Data Query Language&#xff09;是用于从数据库中检索数据的语言。常见的 DQL 语句包括 SELECT、FROM、WHERE、GROUP BY、HAVING 和 ORDER BY 等关键字&#xff0c;用于指定要检索的数据、数据源、过滤条件、分组方…

ROS云课三分钟外传之CoppeliaSim_Edu_V4_1_0_Ubuntu16_04

三分钟热度试一试吧&#xff0c;走过路过不要错过。 参考之前&#xff1a; 从云课五分钟到一分钟之v-rep_pro_edu_v3_6_2-CSDN博客 git clone https://gitcode.net/ZhangRelay/v-rep_pro_edu_v3_6_2_ubuntu16_04.gittar -xf v-rep_pro_edu_v3_6_2_ubuntu16_04/V-REP_PRO_EDU…

字符串常量池字符串常量的几种创建方式及其位置

从JDK7开始&#xff0c;字符串常量池被移到了堆区中&#xff0c;因此Java程序中的字符串常量对象要么在堆区的字符串常量池之中&#xff0c;要么在堆区的字符串常量池之外。为了做区分&#xff0c;下文将堆区的字符串常量池区域称为字符串常量池&#xff0c;将堆区字符串常量池…

Zabbix配置中文显示及乱码问题

页面配置为中文显示 在zabbix 5.0版本开始用户菜单更改为左侧栏显示&#xff0c;找到并点击 User Settings&#xff0c;Language 修改语言为 Chinese (zh_CN) 即可。 PS&#xff1a;一般在部署后初始配置时&#xff0c;未找到 Chinese (zh_CN) 这一项&#xff0c;修改如下&…

分享一个 .NET Core Console 项目中应用 NLog 写日志的详细例子

前言 日志在软件开发中扮演着非常重要的角色&#xff0c;通常我们用它来记录应用程序运行时发生的事件、错误信息、警告以及其他相关信息&#xff0c;帮助在调试和排查问题时更快速地定位和解决 Bug。 通过日志&#xff0c;我们可以做到&#xff1a; 故障排除和调试&#xff…

探索智慧景区的总体架构与应用

背景&#xff1a; 在旅游业快速发展的今天&#xff0c;智慧景区已成为提升景区管理水平、提高游客体验的重要手段之一。智慧景区系统的总体架构设计与应用&#xff0c;将现代信息技术与景区管理相结合&#xff0c;为景区的运营管理和游客服务提供了新的思路和解决方案。本文将…

按键精灵在Win11中弹窗出现乱码并且自带的部分系统插件不能使用的解决方法

按键精灵中出现以下问题&#xff1a; 提示信息的弹窗出现乱码&#xff1a; 系统自带的部分像 plugin. 开头的插件不能使用&#xff0c;如下&#xff1a;s Plugin.Sys.GetDateTime() screenX Plugin.GetSysInfo.GetScreenResolutionX screenY Plugin.GetSysInfo.GetScreenRe…

在Linux or Windows中如何优雅的写出对拍

在Linux or Windows中如何优雅的写出对拍 一、前言二、结论1、对拍 三、对拍详解1、什么是对拍呢&#xff1f;&#x1f9d0;2、对拍的组成部分3、输入数据生成4、对拍程序5、操作流程 四、最后 一、前言 网上的对拍程序层出不穷&#xff0c;大多Linux和Windows中的对拍程序都是…

MySQL 函数与约束

MySQL 函数与约束 文章目录 MySQL 函数与约束1 函数1.1 字符串函数1.2 数值函数1.3 日期函数1.4 流程函数 2 约束2.1 概述2.2 约束演示2.3 外键约束2.4 删除/更新行为 1 函数 函数是指一段可以直接被另一程序调用的程序或代码。 1.1 字符串函数 MySQL中内置了很多字符串函数&…