ConvMixer 论文与代码解析

paper:Patches Are All You Need?

official implementation:https://github.com/locuslab/convmixer

精度上去了,推理速度只有卷积和ViTs的四分之一!

出发点

文章讨论了卷积神经网络(CNN)在视觉任务中的主导地位,以及近期基于Transformer模型的架构(特别是Vision Transformer,ViT)在某些情况下可能超越了CNN的性能。ViT由于自注意力层的二次运行时间复杂度,需要使用patch embeddings来处理更大的图像尺寸。

作者探讨了ViT的高性能是否源于Transformer架构本身的强大能力,还是部分归因于使用patches作为输入表示。

创新点

文章提出了一个新的模型——ConvMixer,这是一个非常简单的模型,它直接在patches上操作输入,分离空间和通道维度的混合,并在整个网络中保持相同的尺寸和分辨率。ConvMixer使用标准的卷积来实现混合步骤,而不是Transformer架构。

  • 模型设计:ConvMixer的设计灵感来自于ViT和MLP-Mixer,但它只使用标准的卷积操作来处理输入patches。
  • 简化架构:与ViT和MLP-Mixer相比,ConvMixer通过简化架构,减少了模型的复杂性。
  • 性能与效率:尽管ConvMixer的实现非常简单,但作者展示了它在相似参数计数和数据集大小下,性能超过了ViT、MLP-Mixer以及传统的视觉模型如ResNet。

通过ConvMixer的性能表现,作者认为patch embeddings(图像分块嵌入)可能是导致新型架构(如Vision Transformers)性能提升的一个关键因素。通过在网络的初始阶段一次性完成所有的下采样,即减小内部分辨率并增加有效感受野大小,有助于混合远距离的空间信息。

此外,ConvMixer提供了一个强大的“等距”(isotropic)架构模板,该架构通过简单的patch embeddings stem实现,这为深度学习提供了一个有效的框架

方法介绍

如图2所示,ConvMixer的结构非常简单,包括一个patch embedding层,然后重复堆叠一个简单的全卷积block。在patch embedding后保持空间分辨率一直到网络结束,对于patch size为 \(p\) embdding维度为 \(h\) 的patch embedding层可以通过一个输入通道数为 \(c_{in}\),输出通道数为 \(h\),kernel size为 \(p\),stride为 \(p\) 的卷积实现

ConvMixer block由一个depthwise convolution和一个pointwise convolution组成。MLP和self-attention可以mix distance spatial locations,即具有很大或者全局的感受野从而可以捕获长距离依赖关系,受此启发,ConvMixer中的深度卷积采用了非常大的卷积核,比如7或 9。在每个卷积后都有一个激活函数和一个post-activation BatchNorm

最后通过一个global average pooling和一个softmax classifier得到最终的分类预测结果。

实验结果

ConvMixer不同大小的模型通过ConvMixer-h/d来命名,其中h表示hidden dimension即patch embedding的维度,d表示网络深度即图2中ConvMixer Layer的数量。和三种不同架构的代表性网络在ImageNet上的性能对比如表1所示,注意这里ConvMixer没有经过专门的调参,训练配置都是直接采用ResNet和DeiT中的一些常规设置,并且训练Epoch也更短。可以看到ConvMixer取得了几句竞争力的结果,同时参数量也很少,但存在一个非常大的缺点,吞吐很小或者说推理速度很慢。

但这里有一个问题就是ConvMixer的patch size非常小只有DeiT的一半,比ResMLP-B24/8还小1,这种情况下比较是不公平的。如果增大patch size,就达不到卷积网络或ViTs的精度,减小的patch size精度上去了延迟只有卷积和ViTs的四分之一,表明ConvMixer和MLP类的网络还存在局限。

代码解析

这里以timm中的实现为例,输入大小为(1, 3, 224, 224),模型选择"convmixer_768_32",具体配置如下

model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)

其中核心部分block的实现非常简单,如下,每层就是7x7 depthwise conv + ReLU + BN + 1x1 conv + ReLU + BN,此外还是用了residual connection。

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


self.blocks = nn.Sequential(
            *[nn.Sequential(
                    Residual(nn.Sequential(
                        nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                        act_layer(),
                        nn.BatchNorm2d(dim)
                    )),
                    nn.Conv2d(dim, dim, kernel_size=1),
                    act_layer(),
                    nn.BatchNorm2d(dim)
            ) for i in range(depth)]
        )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/748914.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

#### 广告投放 ####

以巨量引擎为例: 计费模式 eCPM(expected Cost Per Mile,估计千次展示收入) 概括: ecpm为千次展示的预估收益,是广告平台用来给广告排序的指标。 注意是展示而不是千次点击收益,展示了可能不…

从0到1:亮数据浏览器,为数据采集工作注入全新动力

亮数据浏览器提升数据采集效率 一、 导言1.1 引入亮数据浏览器的重要性1.2 简要介绍本文将涉及的主题和内容 二、 亮数据浏览器简介2.1. 什么是亮数据浏览器2.2. 亮数据浏览器的特点和优势 三、优化数据采集的核心功能3.1 自动化数据采集3.1.1 通过亮数据浏览器实现自动化数据采…

LangChain入门之 GPT 和小范大人不太熟?

前言 嗨,大家好!我是海鸽。 《庆余年2》刚刚完结,热度不减,我忍不住好奇:我们的AI伙伴GPT,是否也对剧中那位机智过人的小范大人有所耳闻? 不仅如此,最近我们还尝试了LangChain的调…

Xcode安装Simulator失败问题解决方法

Xcode安装Simulator_Runtime失败,安装包离线安装保姆级教程 Xcode更新之后有时候会提示要安装模拟器运行时环境,但是用Xcode更新会因为网络原因,我觉得基本上就是因为苹果服务器的连接不稳定导致的,更可气的是不支持断点续…

介绍几种 MySQL 官方高可用方案

前言: MySQL 官方提供了多种高可用部署方案,从最基础的主从复制到组复制再到 InnoDB Cluster 等等。本篇文章以 MySQL 8.0 版本为准,介绍下不同高可用方案架构原理及使用场景。 1.MySQL Replication MySQL Replication 是官方提供的主从同…

记录dinky0.6.7+flink1.14.5集成问题

先说一句mmp,这个jar包冲突搞吐我。如果有遇到math3问题需要注意少个包 看相关issue 以下为flink的lib目录 一、yarn-application和perjob模式 yarn session模式不依赖dlink-app-1.14-0.6.7-jar-with-dependencies.jar这个包,。但是yarn-application…

新能源行业知识体系-------蒙西电网需求侧响应

新能源行业知识体系-------主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/139946830 目录 一、背景介绍二、需求响应电能量收益介绍三、超额回收需求响应减免收益介绍四、参与需求侧响应五、蒙西电力现货特点六、交易中…

1012:Joseph

网址如下&#xff1a; OpenJudge - 1012:Joseph 其中一个解法 只想到了一个快速找到下一位处决的人的方法&#xff0c;本质上还是遍历&#xff0c;暂时没想到更优的方法了 代码如下&#xff1a; #include<cstdio> int k;bool judge(int tt, int m, int r){if(tt k) …

GPU技术全景:推动未来计算的新动力-4

7.中国厂家 在中国市场&#xff0c;也有几家本土企业在GPU领域崭露头角&#xff0c;虽然市场份额相对较小&#xff0c;但在国产替代和自主可控的浪潮下发展迅速&#xff0c;包括但不限于&#xff1a; •沐曦集成电路、壁仞科技、燧原科技、登临科技、摩尔线程等&#xff0c…

信号处理——时频分析

经典傅里叶变换的限制&#xff1a; 1、只能反映信号的整体特性&#xff1b;&#xff08;完全是时域或频域&#xff09; 2、要求信号满足平稳条件&#xff1b; 3、必须获得时域中的全部信息。 所以引入时频分析&#xff0c;同时使用时间和频率的联合函数来表示信号。 1 时频…

单段时间最优S型速度规划算法

一&#xff0c;背景 在做机械臂轨迹规划的单段路径的速度规划时&#xff0c;除了参考《Trajectory Planning for Automatic Machines and Robots》等文献之外&#xff0c;还在知乎找到了这位大佬 韩冰 写的在线规划方法&#xff1a; https://zhuanlan.zhihu.com/p/585253101/e…

Java基础知识-线程

Java基础知识-线程 1、在 Java 中要想实现多线程代码有几种手段&#xff1f; 1. 一种是继承 Thread 类 2. 另一种就是实现 Runnable 接口 3. 最后一种就是实现 Callable 接口 4. 第四种也是实现 callable 接口&#xff0c;只不过有返回值而已 2、Thread 类中的 start() 和 …

AI大模型会有意识的出千吗?

1. 引言 1.1 研究背景&#xff0c;AI系统中的规范游戏问题 在人工智能(AI)系统的发展过程中&#xff0c;规范游戏(specification gaming)一直是一个令研究者们头疼的问题。规范游戏指的是AI系统学习到一些意想不到的行为&#xff0c;这些行为虽然能够获得高奖励&#xff0c;但…

万字长文,解读大模型技术原理(非常详细)零基础入门到精通,收藏这一篇就够了

大模型是指具有大规模参数和复杂计算结构的机器学习模型。 本文从大模型的发展历程出发&#xff0c;对大模型领域的各个技术细节进行详细解读&#xff0c;供大家在了解大模型基本知识的过程中起到一定参考作用。 一、大模型的定义 大语言模型作为一个被验证可行的方向&#x…

客户案例|某 SaaS 企业租户敏感数据保护实践

近年来&#xff0c;随着云计算技术的快速发展&#xff0c;软件即服务&#xff08;SaaS&#xff09;在各行业的应用逐渐增多&#xff0c;SaaS 应用给企业数字化发展带来了便捷性、成本效益与可访问性&#xff0c;同时也带来了一系列数据安全风险。作为 SaaS 产品运营服务商&…

注意!!2024下《系统架构设计师》易混淆知识点来了,赶紧收藏

宝子们&#xff0c;在复习软考系统架构设计师中&#xff0c;是不是觉得有很多知识点含义比较相近&#xff0c;很多友友刚看的时候&#xff0c;估计会像我一样把它们弄混&#xff0c;作为一个软考老鸟&#xff0c;在这里给大家整理了系构学习过程中易混淆的知识点&#xff0c;大…

Part 8.3.2 树的直径

树的直径被定义为树上最远的两点间的距离。 关于求树的直径的两种方式 HXY造公园 题目描述 现在有一个现成的公园&#xff0c;有 n n n 个休息点和 m m m 条双向边连接两个休息点。众所周知&#xff0c;HXY 是一个 SXBK 的强迫症患者&#xff0c;所以她打算施展魔法来改造…

彩虹PLM系统:引领汽车行业的数字化转型

彩虹PLM系统&#xff1a;引领汽车行业的数字化转型 彩虹PLM系统作为汽车行业数字化转型的引领者&#xff0c;凭借其卓越的技术实力和丰富的行业经验&#xff0c;为汽车行业带来了全面的解决方案。以下是彩虹PLM系统如何引领汽车行业数字化转型的详细分析&#xff1a; 一、整合全…

虚拟机使用的是此版本 VMware Workstation 不支持的硬件版本

复制了同事的VMware镜像&#xff0c;但是他的软件版本和我的不同&#xff0c;于是乎出现了这个报错&#xff1a;虚拟机使用的是此版本 VMwareWorkstation 不支持的硬件版本。 模块“Upgrade”启动失败。 解决办法&#xff0c;直接改.vmx文件的版本信息&#xff1a; 以文本格式打…

ROS学习(17):定位和地图绘制(1)

目录 0.前言 1.定位和建图 1.里程计&#xff08;Odometry&#xff09; 2.扫描匹配&#xff08;Scan Matching&#xff09; 3.结尾 0.前言 好久不见各位&#xff0c;前段时间忙着考试&#xff08;6级和一些专业课&#xff09;和摆烂断更了近30天&#xff0c;现在哥们回来更…