『论文精读』Data-efficient image Transformers(DeiT)论文解读

『论文精读』Data-efficient image Transformers(DeiT)论文解读

文章目录

  • 一. DeiT简介
  • 二. 知识蒸馏(knowledge distillation)
  • 三. better hyperparameter
  • 四. data augmentation
  • 五. label smoothing
  • 参考文献

  • 论文下载链接:https://arxiv.org/pdf/2012.12877.pdf
  • 论文代码链接:https://github.com/facebookresearch/deit
  • 关于VIT论文的解读可以关注我之前的文章:『论文精读』Vision Transformer(VIT)论文解读

一. DeiT简介

  • 现有的基于Transformer的分类模型ViT需要在海量数据上(JFT-300M,3亿张图片)进行预训练,再在ImageNet数据集上进行fune-tuning,才能达到与CNN方法相当的性能,这需要非常大量的计算资源,这限制了ViT方法的进一步应用。

在这里插入图片描述

  • DeiT的模型和VIT的模型几乎是相同的,可以理解为本质上是在训一个VIT。
  • better hyperparameter:指的是模型初始化、learning-rate等设置。
  • data augmentation:在只有120万张图片的Imagenet,使用数据增广模拟更多数据。
  • Distillation:知识蒸馏。
  • 三部分的作用分别为:保证模型更好的收敛、可以使用小的数据训练、进一步提升性能。还有一些其他的方式,如:warmup、label smoothing、droppath等。

在这里插入图片描述

  • Data-efficient image transformers (DeiT) 无需海量预训练数据,只依靠ImageNet数据,便可以达到SOTA的结果,同时依赖的训练资源更少(4 GPUs in three days)。

在这里插入图片描述

  • 文章贡献如下:
  • 仅使用Transformer,不引入Conv的情况下也能达到SOTA效果。
  • 提出了基于token蒸馏的策略,这种针对transformer的蒸馏方法可以超越原始的蒸馏方法。
  • Deit发现使用Convnet作为教师网络能够比使用Transformer架构取得更好的效果。

在这里插入图片描述

二. 知识蒸馏(knowledge distillation)

  • Knowledge Distillation(KD)最初被Hinton提出,与Label smoothing动机类似,但是KD生成soft label的方式是通过教师网络得到的。
  • KD可以视为将教师网络学到的信息压缩到学生网络中。还有一些工作“Circumventing outlier of autoaugment with knowledge distillation”则将KD视为数据增强方法的一种。
  • KD能够以soft的方式将归纳偏置传递给学生模型,Deit中使用Conv-Based架构作为教师网络,将局部性的假设通过蒸馏方式引入Transformer中,取得了不错的效果。
  • 简单来说就是用teacher模型去训练student模型,通常teacher模型更大而且已经训练好了,student模型是我们当前需要训练的模型。在这个过程中,teacher模型是不训练的。
  • 当teacher模型和student模型拿到相同的图片时,都进行各自的前向,这时teacher模型就拿到了具有分类信息的feature,在进行softmax之前先除以一个参数 τ \tau τ,叫做temperature(蒸馏温度),然后softmax得到soft labels(区别于one-hot形式的hard-label)。
  • student模型也是除以同一个 τ \tau τ,然后softmax得到一个soft-prediction,我们希望student模型的soft-prediction和teacher模型的soft labels尽量接近,使用KLDivLoss进行两者之间的差距度量,计算一个对应的损失teacher loss
  • 在训练的时候,我们是可以拿的到训练图片的真实的ground truth(hard label)的,可以看到上面图中student模型下面一路,就是预测结果和真是标签之间计算交叉熵crossentropy。
  • 链接:损失函数|交叉熵损失函数
  • 然后两路计算的损失:KLDivLoss和CELoss,按照一个加权关系计算得到一个总损失total loss,反向修改参数的时候这个teacher模型是不做训练的,只依据total loss训练student模型。
  • 还可以使用硬蒸馏,对比上面的结构图,哪种更好没有定论。

2.1. KLDivloss

  • 这里可以参考下我之前的文章:〖ML笔记〗信息量、信息熵、交叉熵、KL散度(相对熵)、JS散度以及逻辑损失+面试知识点!
  • KL divergence(KL散度又叫相对熵): 它表示用分布 q ( x ) q(x) q(x) 模拟真实分布 p ( x ) p(x) p(x) 所需要的额外信息。同时也叫KL距离,就是是两个随机分布间距离的度量。
  • 取值范围: [ 0 , + ∞ ] [0, +\infty ] [0,+]当两个分布接近相同的时候KL散度取值为0,当两个分布差异越来越大的时候KL散度值就会越来越大。
    D K L ( p ∣ q ) = H ( p , q ) ⏟ 交叉熵 − H ( p ) ⏟ 信息熵 = − ∑ i = 1 n p ( x i ) log ⁡ q ( x i ) + ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) = ∑ i = 1 n p ( x i ) log ⁡ p ( x i ) q ( x i ) (1) \begin{aligned} {D}_{K L}({p} | {q})&=\underbrace{H(p, q)}_{\text {交叉熵}}-\underbrace{H(p)}_{\text {信息熵}}\\&=-\sum_{i=1}^{n}{p}(x_i) \log {q}(x_i)+\sum_{i=1}^{n} {p}(x_i) \log {p}(x_i) \\ &=\sum_{i=1}^{n} {p}(x_i) \log \frac{{p}(x_i)}{{q}(x_i)}\tag{1} \end{aligned} DKL(pq)=交叉熵 H(p,q)信息熵 H(p)=i=1np(xi)logq(xi)+i=1np(xi)logp(xi)=i=1np(xi)logq(xi)p(xi)(1) 注意: 直观来说,由于 p ( x ) p(x) p(x) 是已知的分布(真实分布), H ( p ) H(p) H(p) 是个常数,交叉熵和KL散度之间相差一个这样的常数(信息熵)
  • 当两个分布完全一致时候,KL散度就等于0。KLDivloss定义和使用方式为:

2.2. 蒸馏温度 τ \tau τ

  • 蒸馏温度 τ \tau τ 的作用,回想之前VIT中在self-attention里面计算 q , k \mathbf {q,k} q,k间的加权因子的时候,计算完了要scale(除以 k k k 的维度),然后再做softmax,然后用它们对 v \mathbf v v 加权相加得到对应的表示向量。
  • 如果是[1.0,20.0,400.0]直接做softamx,那结果是[0.0,0.0,1.0],可见结果完全借鉴第三个引子。而先进行处理(比如除以1000)后变为[0.001,0.02,0.4]时,在做softamx结果为[0.28,0.29,0.42]结果总综合考虑了三部分,这显然是更合理的结果。实际中,看我是更希望结果偏向于更大的值,还是偏向于综合考虑来决定是否使用softmax前输入的预处理。

2.3. distillation in transformer

这一节主要弄清楚,如何在transformer中进行蒸馏操作。

在这里插入图片描述

  • 先说一下,在这DeiT篇论文出来的时候,teacher model使用的是Regnet(一个CNN)
  • 在VIT中时使用class tokens去做分类的,相当于是一个额外的patch,这个patch去学习和别的patch之间的关系,然后连classifier,计算CELoss。在DeiT中为了做蒸馏,又额外加一个distill token,这个distill token也是去学和其他tokens之间的关系,然后连接teacher model计算KLDivLoss,那CELoss和KLDivLoss共同加权组合成一个新的loss取指导student model训练(知识蒸馏中teacher model不训练)。
  • 在预测阶段,class token和distill token分别产生一个结果,然后将其加权(分别0.5),再加在一起,得到最终的结果做预测。

L global  = ( 1 − λ ) L C E ( ψ ( Z s ) , y ) + λ τ 2 K L ( ψ ( Z s / τ ) , ψ ( Z t / τ ) ) (2) \mathcal{L}_{\text {global }}=(1-\lambda) \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_{\mathrm{s}}\right), y\right)+\lambda \tau^2 \mathrm{KL}\left(\psi\left(Z_{\mathrm{s}} / \tau\right), \psi\left(Z_{\mathrm{t}} / \tau\right)\right)\tag{2} Lglobal =(1λ)LCE(ψ(Zs),y)+λτ2KL(ψ(Zs/τ),ψ(Zt/τ))(2)

L global  hardill  = 1 2 L C E ( ψ ( Z s ) , y ) + 1 2 L C E ( ψ ( Z s ) , y t ) (3) \mathcal{L}_{\text {global }}^{\text {hardill }}=\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y\right)+\frac{1}{2} \mathcal{L}_{\mathrm{CE}}\left(\psi\left(Z_s\right), y_{\mathrm{t}}\right)\tag{3} Lglobal hardill =21LCE(ψ(Zs),y)+21LCE(ψ(Zs),yt)(3)

三. better hyperparameter

  • DeiT中第二个优化点在于better hyperparameter,也就是更好的参数配置,看看其都包含哪些部分。

在这里插入图片描述

  • 参数初始化方式:truncated normal distribution(截断标准分布)
  • learning-rate:CNN中的结论:当batch size越大的时候,learning rate设置的越大。
  • learning rate decay:cosine,在warm-up阶段lr先线性升上去,然后通过余弦方式lr降下来

四. data augmentation

在这里插入图片描述

  • mixup之后的图片的label不再是单一的label,而是soft-label,比如[cat,dog]=[0.5,0.5]
  • cutmix之后的图片label是按所占据的比例给的,比如[cat,dog]=[0.3,0.7]

在这里插入图片描述

  • randomaug其实是由autoaug来的,autoaug是选取了25中增强策略,每种策略中有两个操作,这两种操作都要被执行。每次为一张图随机从25中策略中选取一种,将这两种操作对该图执行。至于这25中策略是怎么组成的,每种里面的操作的概率是如何确立的,这些是由搜索算法的实现的,总之认为这么搭配有效就行了。对于randomaug,相当于对于autoaug的简化,它是13种增强策略,然后从中一次选取6种策略依次对图片进行操作,完成增强操作。
  • model EMA(Exponential Moving Average)指数滑动平均,使得模型权重更新与一段时间内的历史取值有关。 m t m_{t} mt 是当前的模型权重, m t − 1 m_{t-1} mt1 是上一轮模型权重, θ t \theta_{t} θt为模型当前权重的值,举一个例子:

在这里插入图片描述
在这里插入图片描述

  • 三种更新参数方式的更新参数结果曲线:

在这里插入图片描述

  • 实际使用的时候,设置上面例子中的 β \beta β 值例如为0.99996,保证模型的参数值不会乱动。

五. label smoothing

  • label smoothing:原本hard-label变成soft-label,设置参数,给其余非标签平均一些label概率。
     Label  one hot  = [ 1 , 0 , 0 , 0 , 0 , 0 ]  Label  smoothing  = [ 0.9 , 0.02 , 0.02 , 0.02 , 0.02 , 0.02 ] , α = 0.1 \begin{aligned} & \text { Label }_{\text {one hot }}=[1,0,0,0,0,0] \\ & \text { Label }_{\text {smoothing }}=[0.9,0.02,0.02,0.02,0.02,0.02], \alpha=0.1 \end{aligned}  Label one hot =[1,0,0,0,0,0] Label smoothing =[0.9,0.02,0.02,0.02,0.02,0.02],α=0.1

在这里插入图片描述
在这里插入图片描述

  • 上图来自论文:When Does Label Smoothing Help

参考文献

  • 以上内容主要参考自大神Transformer学习(四)—DeiT
  • DeiT | Training data-efficient image transformers & distillation through attention
  • DeiT:使用Attention蒸馏Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/77588.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言快速回顾(三)

前言 在Android音视频开发中,网上知识点过于零碎,自学起来难度非常大,不过音视频大牛Jhuster提出了《Android 音视频从入门到提高 - 任务列表》,结合我自己的工作学习经历,我准备写一个音视频系列blog。C/C是音视频必…

一分钟上手Vue VueI18n Internationalization(i18n)多国语言系统开发、国际化、中英文语言切换!

这里以Vue2为例子 第一步:安装vue-i18n npm install vue-i18n8.26.5 第二步:在src下创建js文件夹,继续创建language文件夹 在language文件夹里面创建zh.js、en.js、index.js这仨文件 这仨文件代码分别如下: zh.js export de…

踩坑串口通信 serialPort.RtsEnable = true

背景: 最近在调试一个激光模块,使用的是422通信,然后买了一个485转422的转换器。 通过串口监控软件观察,明明和串口助手发的东西一模一样,但是就是不返回! 解决方案: 我加了,这句&…

Java课题笔记~ 日期处理

2.8 日期处理 2.8.1 日期注入 日期类型不能自动注入到方法的参数中。需要单独做转换处理。 使用DateTimeFormat注解,需要在springmvc.xml文件中添加mvc:annotation-driven/标签。 (1)在方法的参数上使用DateTimeFormat注解 RequestMappi…

【mysql】—— 表的增删改查

目录 序言 (一)Create操作 1、单行数据 全列插入 2、多行数据 指定列插入 3、插入否则更新 4、直接替换 (二)Retrieve操作 1、SELECT 列 1️⃣全列查询 2️⃣指定列查询 3️⃣查询字段为表达式 4️⃣为查询结果指定…

机器学习基础笔记

文章目录 1.机器学习简介1.1 机器学习的一般功能1.2 机器学习的应用1.3 机器学习的方法1.4 机器学习的种类1.5 机器学习的常用框架 2. Spark机器学习2.1 MLlib介绍2.2 MLlib的数据格式2.2.1 本地向量2.2.2 标签数据 2.3 MLlib与ml2.4 MLlib的应用场景 3.Spark环境搭建4.向量与矩…

【C语言】回调函数,qsort排序函数的使用和自己实现,超详解

文章目录 前言一、回调函数是什么二、回调函数的使用1.使用标准库中的qsort函数2.利用qsort函数对结构体数组进行排序 三、实现qsort函数总结 先记录一下访问量突破2000啦,谢谢大家支持!!! 这里是上期指针进阶链接,方便…

Python入门【TCP建立连接的三次握手、 TCP断开连接的四次挥手、套接字编程实战、 TCP编程的实现、TCP双向持续通信】(二十七)

👏作者简介:大家好,我是爱敲代码的小王,CSDN博客博主,Python小白 📕系列专栏:python入门到实战、Python爬虫开发、Python办公自动化、Python数据分析、Python前后端开发 📧如果文章知识点有错误…

KMPBC:KMP算法及其改进(kmp with bad character)

前言 最近在看字符串匹配算法,突然灵光一闪有了想法,可以把kmp算法时间效率提高,同时保持最坏时间复杂度O(nm)不变。其中n为主串长度,m为模式串长度,经测试可以块3-10倍,以为发现了新大陆,但是…

内网ip与外网ip

一、关于IP地址 我们平时直接接触最多的是内网IP。而且还可以自己手动修改ip地址。而外网ip,我们很少直接接触,都是间接接触、因为外网ip一般都是运营商管理,而且是全球唯一的,一般我们自己是无法修改的。 内网IP和外网IP是指在…

【2024】MySQL中常用函数和窗口函数的基本使用方式

MySQL中常用函数和窗口函数的基本使用方式 一、基础函数1、聚合函数:2、字符串函数:3、日期和时间函数4、数值函数5、条件函数 二、窗口函数(*OVER*) 一、基础函数 1、聚合函数: SELECT COUNT(*) FROM table_name;:计算表中的行…

Effective C++学习笔记(8)

目录 条款49:了解new-handler的行为条款50:了解new和delete的合理替换时机条款51:编写new和delete时需固守常规条款52:写了placement new也要写placement delete条款53:不要轻忽编译器的警告条款54:让自己熟…

Spring Boot 中的 AOP,到底是 JDK 动态代理还是 Cglib 动态代理

大家都知道,AOP 底层是动态代理,而 Java 中的动态代理有两种实现方式: 基于 JDK 的动态代理 基于 Cglib 的动态代理 这两者最大的区别在于基于 JDK 的动态代理需要被代理的对象有接口,而基于 Cglib 的动态代理并不需要被代理对…

PyTorch训练简单的生成对抗网络GAN

文章目录 原理代码结果参考 原理 同时训练两个网络:辨别器Discriminator 和 生成器Generator Generator是 造假者,用来生成假数据。 Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。 目的:使…

C++中List的实现

前言 数据结构中,我们了解到了链表,但是我们使用时需要自己去实现链表才能用,但是C出现了list将这一切皆变为现。list可以看作是一个带头双向循环的链表结构,并且可以在任意的正确范围内进行增删查改数据的容器。list容器一样也是…

【CSS】CSS 布局——常规流布局

<h1>基础文档流</h1><p>我是一个基本的块级元素。我的相邻块级元素在我的下方另起一行。</p><p>默认情况下&#xff0c;我们会占据父元素 100%的宽度&#xff0c;并且我们的高度与我们的子元素内容一样高。我们的总宽度和高度是我们的内容 内边距…

如何发布自己的小程序

小程序的基础内容组件 text&#xff1a; 文本支持长按选中的效果 <text selectable>151535313511</text> rich-text: 把HTML字符串渲染为对应的UI <rich-text nodes"<h1 stylecolor:red;>123</h1>"></rich-text> 小程序的…

2023牛客暑期多校训练营8-C Clamped Sequence II

2023牛客暑期多校训练营8-C Clamped Sequence II https://ac.nowcoder.com/acm/contest/57362/C 文章目录 2023牛客暑期多校训练营8-C Clamped Sequence II题意解题思路代码 题意 解题思路 先考虑不加紧密度的情况&#xff0c;要支持单点修改&#xff0c;整体查询&#xff0…

AUTOSAR NvM Block的三种类型

Native NVRAM block Native block是最基础的NvM Block&#xff0c;可以用来存储一个数据&#xff0c;可以配置长度、CRC等。 Redundant NVRAM block Redundant block就是在Native block的基础上再加一个冗余块&#xff0c;当Native block失效&#xff08;读取失败或CRC校验失…

时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)

时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价) 目录 时序预测 | MATLAB实现基于BiLSTM双向长短期记忆神经网络的时间序列预测-递归预测未来(多指标评价)预测结果基本介绍程序设计参考资料 预测结果 基本介绍 Matlab实现BiLST…