DeiT:训练ImageNet仅用4卡不到3天的平民ViT | ICML 2021

论文基于改进训练配置以及一种新颖的蒸馏方式,提出了仅用ImageNet就能训练出来的Transformer网络DeiT。在蒸馏学习时,DeiT以卷积网络作为teacher,能够结合当前主流的数据增强和训练策略来进一步提高性能。从实验结果来看,效果很不错

来源:晓飞的算法工程笔记 公众号

论文: Training data-efficient image transformers & distillation through attention

  • 论文地址:https://arxiv.org/abs/2012.12877
  • 论文代码:https://github.com/facebookresearch/deit

Introduction


  Vision Transformer一般要先在大型计算设施上预训练数以亿计的图片才能有较好的性能,这极大地提高其应用门槛。为此,论文基于ViT提出了可在ImageNet上训练的Vision Transformer模型DeiT,仅需要一台电脑(4卡)训练不到三天(53小时的预训练和可选的20小时微调)的时间。在没有外部数据预训练的情况下,在ImageNet上达到了83.1% 的最高精度。

  此外,论文还提出了一种针对Transformer的蒸馏策略,通过一个蒸馏token确保student网络通过注意力从teacher网络那里进行学习。当使用卷积网络作为teacher网络时,ImageNet上可达到85.2%的准确性。

  总体而言,论文主要有以下贡献:

  • 通过实验表明,在没有外部数据的情况下,Vision Transformer也可以在ImageNet上达到SOTA的结果,而且仅需要4卡设备训练三天。
  • 论文提出了一种基于蒸馏token的新蒸馏方法,这种用于Transformer的蒸馏方法大幅优于一般蒸馏方法。蒸馏token与class token的作用相同,都参与注意力计算中,只是蒸馏token的训练目的在于复现teacher网络的标签预测。
  • 有趣的是,论文发现在使用新蒸馏方法时,用卷积网络作为teacher要比用另一个相同准确率的transformer的作为teacher的效果要好。
  • 在Imagenet上预训练的模型可以转移到不同的下游任务(如细粒度分类),得到很不错的性能。

Distillation through attention


Soft distillation

  一般的蒸馏方法都是Soft distillation,其核心目标是最小化teacher网络和student网络的softmax输出之间的Kullback-Leibler散度。

  定义 Z t Z_t Zt为teacher网络的logits输出(输入softmax的向量), Z s Z_s Zs为student网络的logits输出。用 τ \tau τ表示蒸馏温度, λ \lambda λ表示平衡Kullback-Leibler散度损失(KL)和交叉熵损失(LCE)的权值, ψ \psi ψ表示softmax函数。定义soft distillation的目标函数为:

Hard-label distillation

  论文提出了一种蒸馏的变体,将teacher网络的预测标签作为蒸馏的GT标签。假设 y t = a r g m a x c Z t ( c ) y_t = argmax_c Z_t(c) yt=argmaxcZt(c)是teacher网络的预测标签,与之相关的hard-label distillation目标为:

  对于同一张图片,teacher网络预测的标签可能随着特定的数据增强而有所变化。从实验结果来看,将预测标签作为蒸馏目标的做法比传统的做法更好,不仅无额外参数,概念上还更简单:teacher网络预测的 y t y_t yt与真实标签 y y y是相同的作用。
  此外,hard label也可以通过label smoothing转换为软标签,其中GT标签具有 1 − ε 1 - \varepsilon 1ε的概率,其余类共享 ε \varepsilon ε概率。在相关的实验中,参数固定为 ε = 0.1 \varepsilon = 0.1 ε=0.1

Distillation token

  论文提出的蒸馏方案如如图2所示,在输入的token序列中添加一个蒸馏token。蒸馏token与class token类似,通过self-attention与其它token交互并将最后一层中的对应输出作为网络输出,其训练目标为损失函数中的蒸馏损失部分。蒸馏token使得模型可以像常规蒸馏一样从teacher网络的输出中学习,同时与class token保持互补的关系。

  论文发现,训练后的输入层class token和蒸馏token收敛到了完全不同的向量,平均余弦相似度仅为0.06。但随着在网络的计算,class和蒸馏token在越深层中的对应输出逐渐变得更加相似,最后一层达到了较高的相似度(cos=0.93),但没有完全相同。这是符合预期的,因为两个token的目标就是产生相似但不相同的目标。

  论文也尝试替代实验,用另一个class token代替teacher网络的蒸馏token进行伪蒸馏学习。但无论如何随机且独立地初始化两个class token,训练后都会收敛到相同的向量(cos=0.999),其对应的输出也是准相同的。这表明这个代替的class token不会对分类性能带来任何影响,相比之下蒸馏token则能带来显著的提升。

Fine-tuning with distillation

  在分辨率增加的fine-tuning阶段,同样使用真实标签和teacher网络预测标签进行训练。此时需要一个具有相同目标分辨率的teacher网络,可通过FixRes的做法从之前的低分辨率teacher网络中转换。论文也尝试了只用真实标签进行fine-tuning,但这导致了性能的降低。

Classification with our approach:joint classifiers

  在测试时,网络输出的class token和蒸馏token都用于标签分类。论文的建议做法是将这两个token独立预测后再融合,即将两个分类器的softmax输出相加再进行预测。

Transformer models


  DeiT的架构设计与ViT相同,唯一的区别是训练策略和蒸馏token,训练策略的区别如表9所示。此外,在预训练时不使用MLP,仅使用线性分类器。

  为避免混淆,用ViT来指代先前工作中的结果,用DeiT来指代论文的结果。如果未指定,DeiT指的是DeiT-B,与ViT-B具有相同的架构。当以更大的分辨率fine-tune DeiT时,论文会在名字的最后附加分辨率,例如DeiT-B↑384。最后,当使用论文提出的蒸馏方法时,论文会用一个蒸馏符号将其标识为DeiT⚗.。

  如表1所示,DeiT-B的结构与ViT-B完全一样,参数固定为 D = 768 D = 768 D=768 h = 12 h = 12 h=12 d = D / h = 64 d = D/h = 64 d=D/h=64。另外,论文设计了两个较小的模型:DeiT-S和DeiT-Ti,减少了head的数量, d d d保持不变。

Experiment


  不同类型的teacher网络的蒸馏效果。

  不同蒸馏策略的对比实验。

  不同网络以及蒸馏策略之间的结果差异,值越小差异越小。

  蒸馏策略与训练周期的关系。

  整体性能的对比。

  ImageNet上预训练模型的在其它训练集上的迁移效果。

  不同优化器、数据增强、正则化的对比,寻找最佳的训练策略和配置。

  224分辨率预训练的DeiT在不同数据集上用不同分辨率fine-tune的效果。

Conclusion


  论文基于改进训练配置以及一种新颖的蒸馏方式,提出了仅用ImageNet就能训练出来的Transformer网络DeiT。在蒸馏学习时,DeiT以卷积网络作为teacher,能够结合当前主流的数据增强和训练策略来进一步提高性能。从实验结果来看,效果很不错。



如果本文对你有帮助,麻烦点个赞或在看呗~
更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/548438.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于FMC接口的Kintex-7 XC7K325T PCIeX4 3U PXIe接口卡

基于FMC接口的Kintex-7 XC7K325T PCIeX4 3U PXIe接口卡 一、板卡概述 本板卡基于Xilinx公司的FPGAXC7K325T-2FFG900 芯片,pin_to_pin兼容FPGAXC7K410T-2FFG900 ,支持PCIeX8、64bit DDR3容量2GByte,HPC的FMC连接器,板卡支持PXI…

html基础——CSS

在HTML中,CSS的作用是用于控制网页的样式,包括字体、颜色、背景、布局等方面的设计。通过一个样例来说明CSS的作用: 如下是一个名为global.css的CSS文件: .C1{font-size: 10px;color: blue;border:1px solid red;height: 200px;…

【Redis 神秘大陆】009 案例实践进阶

九、案例实践&进阶方案 9.1 本地缓存组件选型 使用缓存组件时需要重点关注集群方式、集群、缓存命中率。 需要关注集群组建方式、缓存统计;还需要考虑缓存开发语言对缓存的影响,如对于JAVA开发的缓存需要考虑GC的影响;最后还要特别关注…

vue快速入门(二十六)生命周期钩子函数

注释很详细&#xff0c;直接上代码 上一篇 新增内容 生命周期钩子函数的解析生命周期函数效果演示 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevic…

【原创教程】海为PLC与RS-WS-ETH-6传感器的MUDBUS_TCP通讯

一、关于RS-WS-ETH-6传感器的准备工作 要完成MODBUS_TCP通讯,我们必须要知道设备的IP地址如何分配,只有PLC和设备的IP在同一网段上,才能建立通讯。然后还要选择TCP的工作模式,来建立设备端和PC端的端口号。接下来了解设备的报文格式,方便之后发送报文完成数据交互。 1、…

【Altium Designer 20 笔记】PCB层

Top Overlay & Bottom Overlay (顶部丝印层和底部丝印层)&#xff1a; 用于标记元件、连接和其他重要信息。丝印层是 PCB 表面的一层&#xff0c;上面印上文字、图标或标记。 Top Solder & Bottom Solder (顶部阻焊层和底部阻焊层)&#xff1a; 阻焊层、开窗层、绿油层…

【电控笔记2.3】速度回路+系统延迟

2.3.1速度回路pi控制器设计 pi伯德图近似设计(不考虑延时理想情况下) Tl:负载转矩 PI控制器的转折频率:Ki/Kp

金融数字化能力成熟度指引

1 范围 本文件提出了金融数字化能力成熟度模型、成熟度计算方法&#xff0c;明确了不同维度金融数字化转型能力 相应的分档要求。 本文件适用于金融机构衡量金融科技应用和数字化转型发展水平&#xff0c;检视自身数字化发展优势与短板&#xff0c; 加快数字化转型&#xff0c…

Fatal error in launcher: Unable to create process using【解决方案】

拷贝python 项目到其他电脑以后&#xff0c;执行pip list 命令时报如下错误&#xff1a; Fatal error in launcher: Unable to create process using ‘“d:\python37\python.exe” “C:\Python\Scripts\pip.exe” list’: ??? 解决方法&#xff1a; 先试这条&#xff1a; …

什么是One-Class SVM

1. 简介 单类支持向量机&#xff0c;简称One-Class SVM(One-Class Support Vector Machine)&#xff0c;用于异常检测和离群点检测(无监督学习&#xff0c;其他svm属于有监督的)&#xff0c;可以在没有大量异常样本的情况下有效地检测异常。其目标是通过仅使用正常数据来建模&a…

Gin框架小结

Gin 简介 Gin是一个轻量级的Web框架&#xff0c;用于构建高性能的Go语言Web应用程序。提供了路由管理、中间件支持、参数绑定和验证、错误处理、静态文件服务等功能。 Gin框架解决了什么问题和痛点 1.golang http 标准库本身提供了比较简单的路由注册能力&#xff0c;只支持…

企业内部知识库:帮助你提高工作效率的好帮手

在现代企业中&#xff0c;知识和信息是一种无形资产&#xff0c;对企业的成长至关重要。员工之间有效地共享知识&#xff0c;可以大幅提高工作效率和团队的整体执行力。为了实现这一点&#xff0c;越来越多的企业开始构建自己的内部知识库&#xff0c;为员工提供一个集中的信息…

华为服务Fellow、首席项目管理专家,华为H5M项目管理标准制定主导者孙虎受邀为PMO大会演讲嘉宾

全国PMO专业人士年度盛会 华为服务Fellow、首席项目管理专家&#xff0c;华为H5M项目管理标准制定主导者孙虎先生受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾&#xff0c;演讲议题为“落地项目管理标准&#xff0c;打赢班长的战争”。大会将于5月25-26日在北京举办&am…

液晶触摸屏中应用的电容式触摸芯片

随着多媒体信息查询的与日俱增&#xff0c;人们越来越多地谈到触摸屏&#xff0c;因为触摸屏不仅适用于中国多媒体信息查询的国情&#xff0c;而且触摸屏具有坚固耐用、反应速度快、节省空间、易于交流等许多优点。利用这种技术&#xff0c;用户只要用手指轻轻地碰计算机显示屏…

PCL 高斯滤波(C++详细过程版)

目录 一、概述二、代码实现三、结果展示1、滤波前2、滤波后3、对比PCL 高斯滤波(C++详细过程版)由CSDN点云侠原创,爬虫自重。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫。 一、概述 高斯滤波在PCL里有现成的调用函数,具体算法原理和实现代码见:

(八)C++自制植物大战僵尸游戏植物基类讲解

植物大战僵尸游戏开发教程专栏地址http://t.csdnimg.cn/m0EtD 在植物大战僵尸游戏中&#xff0c;最重要的两个类别就是植物与僵尸。植物可以对僵尸进行攻击&#xff0c;不同的植物攻击方式千差万别&#xff0c;但是不同植物又有许多相同的属性。在基类&#xff08;父类&#xf…

temux安装debian自用记录

http://ip:9001/ user/123 http://ip:5705/index admin/drpy 一、安装Ubuntu1804 1&#xff0e;首先安装termux.app 2&#xff0e;启动该app&#xff0c;输入命令 curl -Lo l l.tmoe.me; sh l 3&#xff0e;运行过程中连续选“Y”&…

面向对象设计模式之概念

一、面向对象设计模式 按目的分为创建型&#xff08;creational&#xff09;、结构型&#xff08;structural&#xff09;和行为型&#xff08;behavioural&#xff09;&#xff1b;按范围分为类模式和对象模式。 二、设计原则 两大基础设计原则 程序设计的原则&#xff1a;模…

electron项目打包慢、打包报错

项目使用了electron框架&#xff0c;在第一次打包或者网络条件不好的环境下进行打包时熟速度慢的出奇&#xff0c;甚至经常出现打包失败的情况&#xff08;如上面图片的报错&#xff09;。 这是因为&#xff0c;在electron打包的过程中&#xff0c;需要去官方源https://github.…

安卓手机怎么连接电脑?这三个方法即刻解决!

随着智能手机的普及&#xff0c;我们越来越依赖手机进行工作和娱乐。然而&#xff0c;有时候我们需要将手机上的数据或文件传输到电脑上&#xff0c;或者需要在电脑上进行某些操作。那么&#xff0c;安卓手机怎么连接电脑呢&#xff1f;这篇文章将为你介绍三种简单的方法&#…