论文笔记--Distilling the Knowledge in a Neural Network

论文笔记--Distilling the Knowledge in a Neural Network

  • 1. 文章简介
  • 2. 文章概括
  • 3 文章重点技术
    • 3.1 Soft Target
    • 3.2 蒸馏Distillation
  • 4. 文章亮点
  • 5. 原文传送门

1. 文章简介

  • 标题:Distilling the Knowledge in a Neural Network
  • 作者:Hinton, Geoffrey, Oriol Vinyals, Jeff Dean
  • 日期:2015
  • 期刊:arxiv

2. 文章概括

  文章提出了一种将大模型压缩的新的思路:蒸馏distillation。通过蒸馏,可以将很大的模型压缩为轻量级的模型,从而提升推理阶段的速率。

3 文章重点技术

3.1 Soft Target

  随着模型的参数量越来越大,如何从训练好的大模型(教师模型)学习一个轻量级的小模型(学生模型)是一个重要课题。传统的hard-target训练直接学习大模型的预测结果,无法学习到不正确的类别之间的相对关系。比如给定一张宝马的照片,假设教师模型给出的预测结果为宝马,学生模型只从教师模型中学习到“宝马”这一个标签信息。事实上,教师模型还会给出其它类别的信息,比如将宝马预测为垃圾车为0.02,将宝马预测为胡萝卜的概率仅为0.0001,但学生模型没有学习到垃圾车和胡萝卜之间的区别。
  我们需要一种方法来使得学生学习到正确的标签,以及错误标签的相对关系。文章提出“soft-target",即通过学习教师模型的预测概率分布来训练小模型。

3.2 蒸馏Distillation

  对一个分类模型,假设教师模型的输出层给出的logits为 z i z_i zi,然后通过计算Softmax得到预测概率: q i = exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) q_i = \frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)} qi=jexp(zj/T)exp(zi/T),其中相比于传统的SoftMax增加了 T T T表示温度,用于控制输出概率分布的平滑度。 T T T越大,不同类别之间的差异越不明显,即分布越平滑。可以参考 e x p exp exp的函数曲线来理解:给定 x 1 , x 2 x_1, x_2 x1,x2,由当 T T T越大时, x 1 / T , x 2 / T x_1/T, x_2/T x1/T,x2/T对应的导数越小(导数即为 exp ⁡ ( x ) \exp(x) exp(x),也可参考下图),从而差距越小,分布越平滑。当 T = 1 T=1 T=1时,即传统的Softmax。
exp
  我们希望学生模型满足:1) 模型可以学习到教师模型的预测概率,即soft targets; 2)学生模型可以预测真实的标签。从而我们可以考虑2个目标函数: L hard \mathcal{L}_{\text{hard}} Lhard L soft \mathcal{L}_{\text{soft}} Lsoft。首先我们记学生模型和教师模型的logits分别为 z i , v i z_i, v_i zi,vi,预测概率分别为 q i , p i q_i, p_i qi,pi,真实标签为labels,则

  • L hard = Cross Entropy ( labels , arg max ⁡ i ( exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) ) ) \mathcal{L}_{\text{hard}}=\text{Cross Entropy}\left(\text{labels}, \argmax_i \left(\frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)}\right)\right) Lhard=Cross Entropy(labels,argmaxi(jexp(zj/T)exp(zi/T)))
  • L soft = Cross Entropy ( p , q ) = Cross Entropy ( ( exp ⁡ ( v i / T ) ∑ j exp ⁡ ( v j / T ) ) , ( exp ⁡ ( z i / T ) ∑ j exp ⁡ ( z j / T ) ) ) \mathcal{L}_{\text{soft}}=\text{Cross Entropy}\left(p, q \right) =\text{Cross Entropy}\left((\frac {\exp (v_i/T)}{\sum_j \exp (v_j/T)}), (\frac {\exp (z_i/T)}{\sum_j \exp (z_j/T)})\right) Lsoft=Cross Entropy(p,q)=Cross Entropy((jexp(vj/T)exp(vi/T)),(jexp(zj/T)exp(zi/T)))
    考虑 L soft \mathcal{L}_{\text{soft}} Lsoft的梯度 ∂ L soft ∂ z k = ∂ ( − ∑ i p i log ⁡ q i ) ∂ z k = − ∑ i p i q i ∂ q i ∂ z k = − p k q k ∂ q k ∂ z k − ∑ i ≠ k p i q i ∂ q i ∂ z k = − 1 T p k q k q k ( 1 − q k ) + ∑ i ≠ k p i q i exp ⁡ ( z i / T ) ( ∑ j exp ⁡ ( z j / T ) ) 2 1 T exp ⁡ ( z k / T ) = 1 T ( − p k ( 1 − q k ) + ∑ i ≠ k p i q i q i q k ) = 1 T ( − p k + ∑ i p i q k ) = 1 T ( q k − p k ) = 1 T ( exp ⁡ ( z k / T ) ∑ j exp ⁡ ( z j / T ) − exp ⁡ ( v k / T ) ∑ j exp ⁡ ( v j / T ) ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} = \frac {\partial (-\sum_i p_i \log q_i)}{\partial z_k} = -\sum_i \frac {p_i}{q_i} \frac{\partial q_i}{\partial z_k} = -\frac {p_k}{q_k} \frac{\partial q_k}{\partial z_k}-\sum_{i\neq k} \frac {p_i}{q_i} \frac{\partial q_i}{\partial z_k} \\=-\frac 1T \frac {p_k}{q_k} q_k (1-q_k) +\sum_{i\neq k} \frac {p_i}{q_i} \frac {\exp (z_i/T)}{(\sum_j \exp (z_j/T))^2} \frac 1T \exp (z_k/T) \\= \frac 1T (-p_k (1-q_k) + \sum_{i\neq k} \frac {p_i}{q_i} q_i q_k )= \frac 1T (-p_k + \sum_i p_i q_k ) \\= \frac 1T (q_k - p_k) = \frac 1T \left(\frac {\exp (z_k/T)}{\sum_j \exp (z_j/T)} - \frac {\exp (v_k/T)}{\sum_j \exp (v_j/T)}\right) zkLsoft=zk(ipilogqi)=iqipizkqi=qkpkzkqki=kqipizkqi=T1qkpkqk(1qk)+i=kqipi(jexp(zj/T))2exp(zi/T)T1exp(zk/T)=T1(pk(1qk)+i=kqipiqiqk)=T1(pk+ipiqk)=T1(qkpk)=T1(jexp(zj/T)exp(zk/T)jexp(vj/T)exp(vk/T)),当 T T T相比于 z i , v i z_i, v_i zi,vi等logits量级比较高时,有 z i / T → 0 , v i / T → 0 z_i/T\to 0, v_i/T \to 0 zi/T0,vi/T0,从而由泰勒公式上式近似为 ∂ L soft ∂ z k ≈ 1 T ( 1 + z k / T N + ∑ j z j / T − 1 + v k / T N + ∑ j v j / T ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac 1T \left(\frac {1+z_k/T}{N + \sum_j z_j/T} - \frac {1+v_k/T}{N + \sum_j v_j/T}\right) zkLsoftT1(N+jzj/T1+zk/TN+jvj/T1+vk/T),假设logits都是零均值的,则有 ∂ L soft ∂ z k ≈ 1 N T 2 ( z k − v k ) \frac {\partial \mathcal{L}_{\text{soft}}}{\partial z_k} \approx \frac 1{NT^2} (z_k - v_k) zkLsoftNT21(zkvk)。从而当温度比较高时,我们的目标近似为最小化 1 2 ( z k − v k ) 2 \frac 12 (z_k - v_k)^2 21(zkvk)2(上式的原函数,不考虑常数项),即最小化logits的MSE函数。温度越低,我们越关注小于均值的logits。
      最终的损失函数为上述hard和soft损失的加权求和。

4. 文章亮点

  文章提出了基于soft-target的蒸馏方法,可以让学生模型学习到教师模型的预测概率分布,从而增强学生模型的泛化能力。实验表明,在MNIST和speech recognition数据上,基于soft target的学生模型可以提取到更多有用的信息,且可以有效防止过拟合的发生。

5. 原文传送门

Distilling the Knowledge in a Neural Network

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/44536.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

iOS 单元测试之常用框架 OCMock 详解

目录 前言: 一、单元测试 1.1 单元测试的必要性 1.2 单元测试的目的 - 约束条件是否通过形式参数来传送。 1.3 单元测试依赖的两个主要框架 二、OCMock 的集成与使用 2.1 OCMock 的集成方式 2.2 OCMock 的使用方法 2.3 mock 使用限制 前言: OC…

IDEA+SpringBoot + Mybatis + Shiro+Bootstrap+Mysql资产设备管理系统

IDEASpringBoot Mybatis ShiroBootstrapMysql资产设备管理系统 一、系统介绍1.环境配置 二、系统展示1. 管理员登录2.用户新增3.用户设置4.岗位管理5. 审批节点6. 人员查询7. 组织设置8. 人员调整9.角色设置10.角色模块映射11.模块设置12.应用模块13.光纤交换机14.服务器15.网…

从实践彻底掌握MySQL的主从复制

目录 一、本次所用结构如图---一主多从级联: 二、IP。 三、配置M1: 四、从库M1S1: 五、从库M2配置: 六、 从库M2S1: 一、本次所用结构如图--- 一主多从级联: 二、IP。这里M1S1和M1S2一样的&#xff0…

hack the box—Lame

扫描 还是老方法nmapfscan得到开放的端口和服务 nmap -sV -sC -sT -v -T4 10.10.10.3 看到开了445,先来波ms17-010,发现失败。 这里还开个21,并且可以知道版本号,直接搜索ftp漏洞 msf正好有对应的模块 设置好参数后进行攻击&…

Hadoop 集群如何升级?

前言 本文隶属于专栏《大数据技术体系》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢! 本专栏目录结构和参考文献请见大数据技术体系 正文 升级 Hadoop 集群需要细致的规划,特…

【设计模式】单例设计模式详解(包含并发、JVM)

文章目录 1、背景2、单例模式3、代码实现1、第一种实现(饿汉式)为什么属性都是static的?2、第二种实现(懒汉式,线程不安全)3、第三种实现(懒汉式,线程安全)4、第四种实现…

树莓派刷机和登入

1.打开映像工具 2.选择映像文件写入 3.拔出卡插入树莓派上电 4.树莓派登入 1.HDMI视频线,连接到显示屏幕 2.串口登录 修改系统配置,启用串口登录树莓派 (1) 打开SD卡根目录的"config.txt文件",停止蓝牙,…

使用lua脚本操作redis

redis中实现事务有两种方法: 1.WATCH监视键的变动,然后MULTI开始事务,EXEC提交事务 WATCH key [key…]:监视一个或多个键,如果在事务执行之前被修改,则事务被打断。 MULTI:标记一个事务的开始。…

Rust学习01:D-day

以前自学过Python,开发了一些小程序,用于工作中提升效率。 Python的确好学易用,但用来做一个真正意义上的产品,哪怕是比较简单的产品,差点意思,特别是在移动端开发领域。 Rust看了两本书,准备动…

Chrome 115 有哪些值得关注的新特性?

今天带大家一起来了解一下 Chrome 115 值得关注的新特性。 滚动动画 用滚动驱动的动画是网站上非常常见的用户体验模式,比如当页面向前或向后滚动时,对应的动画也会向前或向后移动。 比如下面图中这种比较常见的,页面顶部的进度条随着滚动…

C语言-print字符串打印-转义字符妙用

这里有两个有关打印的小知识 打印的字符串内容由两部分组成:可见字符、转义字符;各种字母、数字、以及空格,均属于可见字符,“\”等属于转义字符 举例: 1.直接print里面打印内容,内容直接出现 2.这里想将一…

appscan 应用

HCL appscan是个常见的web app DAST 扫描工具 有企业版和standalone 版本。大家常用的都是单机版本。企业版平台,集成了IAST。 appscan 使用比较简单,基本输入url 账号密码就开扫了。 用了一段时间几点体验 1 还是需要手动explore的,他自…

TSN -促进IT/OT 融合的网络技术

时间敏感网络(tsn)技术是IT/OT 融合的一项关键的基础网络技术,它实现了在一个异构网络中,实现OT的实时数据和IT系统的交互数据的带宽共享。 TSN允许将经典的高确定性现场总线系统和IT应用(如大数据传输)的功…

flutter开发实战-自定义相机camera功能

flutter开发实战-自定义相机camera功能。 Flutter 本质上只是一个 UI 框架,运行在宿主平台之上,Flutter 本身是无法提供一些系统能力,比如使用蓝牙、相机、GPS等,因此要在 Flutter 中调用这些能力就必须和原生平台进行通信。 实现…

vue/cli 自定义配置

vue/cli 自定义配置 1、更改默认的端口号8080 只需要更改vue.config.js文件 1、更改默认的端口号8080 只需要更改vue.config.js文件

openlayers系列:加载arcgis和geoserver在线离线切片

https://www.freesion.com/article/1751396517/ 1.背景 有个项目需要使用openlayer加载各种服务上发布的数据,坐标系也不同,我们都知道openalyer默认可以加载EPAG:3857,要加载4490的坐标系的数据需要重新定义一下,之后再加载。一想起要重新…

脑电信号处理与特征提取——4.脑电信号的预处理及数据分析要点(彭微微)

目录 四、脑电信号的预处理及数据分析要点 4.1 脑电基础知识回顾 4.2 伪迹 4.3 EEG预处理 4.3.1 滤波 4.3.2 重参考 4.3.3 分段和基线校正 4.3.4 坏段剔除 4.3.5 坏导剔除/插值 4.3.6 独立成分分析ICA 4.4 事件相关电位(ERPs) 4.4.1 如何获…

【STM32】 强大的 STM32Cube 生态 STM32CubeIDE 无伤速通

本文介绍的软件,均可以在ST官网st.com免费下载(你需要注册登录),首选官网下载最新版本,如果有问题,可以在我的公众号回复:Cube,获取截止今日的最新版本软件安装包。 目录 一、STM32C…

什么是框架?为什么要学框架?

一、什么是框架 框架是整个或部分应用的可重用设计,是可定制化的应用骨架。它可以帮开发人员简化开发过程,提高开发效率。 项目里有一部分代码:和业务无关,而又不得不写的代码>框架 项目里剩下的部分代码:实现业务…

Maven-----进阶

目录 1 分模块开发1.1 分模块开发的意义1.2 分模块开发实现 2 依赖管理2.1 依赖传递2.2 依赖传递冲突问题2.3 可选依赖和排除依赖 3 继承与聚合3.1 聚合3.2 继承3.2 聚合与继承的区别 4 属性4.1 属性4.2 资源文件引用属性4.3 版本管理 5 多环境配置与使用5.1 多环境开发5.2 跳过…