基于模型蒸馏的模型加速方案总结

1.简介

1.1目的

在过去的一段时间里,对基于模型蒸馏技术的模型加速方案的方法在多个数据集上进行了一系列的实验。所谓的模型蒸馏技术,简单的来说就是利用一个设计简单的小网络去学习一个设计比较复杂的大网络。特别的有,本次实验针对每一个复杂的(teacher model)大模型,都设计了多个简单的(student model)小模型去学习,并且针对不同的超参数组合,本文给出了多组对比实验。详细的实验结果以及相应实验现象的分析和总结将在下文给出。

1.2范围

本文档描述的代码修改以及实验方法都是基于caffe框架进行的,添加的新层有SoftmaxWithLossWithSoftTargetLabel和SoftmaxWithLossWithLabelSmooth。主要的加速策略是利用参数少的(student model)小模型去学习参数多的(teacher model)大模型,所有的实验都在mnist数据集、cifar10数据集、以及年龄和性别属性相关数据集进行的,其中需要调节的超参数包括温度超参数T,loss比例超参数LAMDA。

1.3定义、首字母缩写词和缩略语

序号

术语或缩略语

说明性定义

1

2

3

4

5

6

7

1.4 参考资料

《基于模型蒸馏技术的模型加速方案实验设计v2.pdf》

Distill the Knowledge in a Neural Network

2.实验的方法——模型蒸馏

        本文档中的实验是基于caffe框架进行的,修改了其中的源码,并添加了新的层,使得这个框架可以按照制定的模型加速方案进行运行和测试。

2.1 为什么需要模型蒸馏

一个很大的DNN往往训练出来的效果会比较好,并且多个DNN一起ensemble的话效果会更加的好,但是当用在实际的应用中的话,过于庞大的DNN ensemble在一起会增大计算量,从而影响应用。于是一个问题就被提出了:有没有一个方法,能使降低网络的规模,但是保持(一定程度上)精确度呢?

Hinton举了一个仿生学的例子,就是昆虫在幼生期的时候往往都是一样的,适于它们从环境中摄取能量和营养;然而当它们成长到成熟期,会基于不同的环境或者身份,变成另外一种形态以适应这种环境。那么对于DNN是不是存在类似的方法?在一开始training的过程中比较的庞杂但是后来当需要拿去deploy的时候,可以转换成一个更小的模型。他把这种方法叫做Knowledge Distillation(KD)。

2.2模型蒸馏的基本原理

这里的distillation方法其实主要用的就是通过一个performance非常好的大网络(有可能是ensemble的)来教一个小网络进行学习。这里我们可以把大网络叫为:teacher network,小网络叫为:student network。至于为什么是希望通过大网络来教小网络而不是直接利用ground truth label来学习,hinton也给了一个例子:比如说在MNIST数据集中,有两个数字“2”,但是写法是不一样的:一个可能写的比较像3(后面多出了一点头),一个写的比较像7(出的头特别的短)。在这样的情况下,ground truth label都是“2”,然而一个学习的很好的大网络会给label “3” 和 “7” 都有一定的概率值,如图1所示。通常叫这种信息为 “soft targets”;相对的,ground truth label 是一种 “hard target” 因为它是one-hot label。总的来说就是,通过大网络的“soft targets”,能得到更加多的信息来更好的训练小网络。

图1 hard target vs. soft target

论文中所提出的上述soft target实际上就是已经训练好的复杂模型的softmax层的输出概率,而其中所提出的“蒸馏”方法在softmax层中引入了一个”温度”参数T,如公式(1)所示:

qi=expzi/Tjexpzj/T

(1)

      其中zi 表示的是logit,即softmax层的输入;qi 表示经过softmax层计算后的每个类别的概率;T 表示的就是上述的温度参数,通常设置为1。不过通过上述温度参数的调整,softmax层的映射曲线更加平缓,因而实例的概率映射将更为集中,便使得目标更加地"soft"。并且有论文中还指出,当transfer set中的标签可得时,将soft target和实际标签的两个目标共同使用作为目标函数将使得其性能更加提高。在训练过程中,作者将迁移样本集中样例输入原复杂模型并通过上述蒸馏softmax得到soft target,并将其作为目标,并在迭代过程中更新温度,训练出细粒度的模型。

蒸馏”最简单的形式就是:以从复杂模型得到的“软目标”为目标(这时T比较大),用“转化”训练集训练小模型。训练小模型时T不变仍然较大,训练完之后T改为1。 

当“转化”训练集中部分或者所有数据都有标签时,这种方式可以通过一起训练模型使得模型得到正确的标签来大大提升效果。一种实现方法是用正确标签来修正“软目标”,但是论文中发现一种更好的方法是:对两个目标函数设置权重系数。第一个目标函数是“软目标”的交叉熵,这个交叉熵用开始的那个比较大的T来计算。第二个目标函数是正确标签的交叉熵,这个交叉熵用小模型softmax层的logits来计算且T等于1。论文中指出当第二个目标函数权重较低时可以得到最好的结果。整体的结构如图2所示:

图2 模型蒸馏的整体结构

2.3为什么使用soft target会有用

图3 soft target的用处

信息量:

hard target 包含的信息量(信息熵)很低,soft target包含的信息量大,拥有不同类之间关系的信息(比如同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target 那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。)

软化:

问题是像图3左侧的红色0.001这部分,在cross entropy的loss function中对于权重的更新贡献微乎其微,这样就起不到作用。把soft target软化(整体除以一个数值后再softmax),就可以达到右侧绿色的0.1这个数值,这样在后来权重的更新中就有一定的贡献了。

3. 模型蒸馏实验设计

3.1 蒸馏模型训练过程

实验步骤:

1.根据提出的目标问题设计一个或者多个复杂的网络结构(N1,N2,…,Nt)。

2.收集足够多的训练数据,按照常规CNN模型训练流程,训练好1中的一个或者多个复杂网络得到(M1,M2,…,Mt),记为原始网络。

3.收集简单模型训练数据,此处的训练数据可以是训练原始网络的有标签数据,也可以是额外的无标签数据。

4.修改原始模型(M1,M2,…,Mt)的softmax层中温度参数T为一个较大值如T=20,将3中收集到的样本输入到原始复杂模型中。每一个样本在每个原始模型可以得到其最终的分类概率向量,选取其中概率至最大即为该模型对于当前样本的判定结果。对于t个原始模型就可以得到t个概率向量。那么对这t个概率向量求取均值作为当前样本最后的概率输出向量,记为soft_target label,最后保存到文件中。

5.根据(N1,N2,…,Nt)重新创建一个精简的小网络N0,该网络最后有两个loss,一个是hard loss,即传统的softmaxloss,使用one shot label;另外一个是soft loss,即T>1的softmaxloss,使用我们第4步保存下来的soft target label。

6.设置精简的小网络N0的softmax层温度参数与原始复杂模型产生soft target label时所采用的温度一致,如T=20,按照常规模型训练精简的小网络得到模型M0。

7.训练完成之后,在实际应用中将精简的小模型中的softmax温度参数重置为1,即采用最原始的softmax,来走前向作为最后输出的小模型。

上述的训练过程可以用图4简单表示:

图4 模型蒸馏的训练过程

3.2 论文中的经验

论文中作者认为,由于soft target具有更高的熵,它能比hard target提供更加多的信息,因此可以使用较少的数据以及较大的学习率。将hard和soft的target通过加权平均来作为学生网络的目标函数,soft target所占的权重更大一些。 论文中作者同时还指出,T值取一个中间值时,效果更好,而soft target所分配的权重应该为T^2,hard target的权重为1。 这样训练得到的小模型也就具有与复杂模型近似的性能效果,但是复杂度和计算量却要小很多。 
对于distilling而言,复杂模型的作用事实上是为了提高label包含的信息量。通过这种方法,可以把模型压缩到一个非常小的规模。模型压缩对模型的准确率没有造成太大影响,而且还可以应付部分信息缺失的情况。

4.实验结果及其分析

针对上述方法修改完毕之后的caffe框架,分别在mnist数据集,cifar10数据集以及年龄和性别属性识别数据集上分别对不同的温度超参数T以及loss比例超参数lamda做了多组对比实验。其中T的取值为0.5,1,2,3,5,10,20,lamda的取值为0.01,0.1,0.3。具体的实验结果以及实验分析如下所示。

4.1 mnist

图5 mnist上模型蒸馏的实验结果

    从图5中可以看出,teacher model一共有四层参数层,即两个卷积层以及两个全连接层,尺寸为20_50_500_10,其中的数字表示的为caffe中prototxt中每一层的num_output的大小。其精度可以达到很高的精度0.9914;student model这里设计了三组对比实验,相应的尺寸分别为10_25_250_10、4_10_100_10、2_5_50_10。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的实验结果都要优于train_smallnet_from_scratch的实验结果,前者表示的是模型蒸馏的结果,后者表示从头训练小模型的结果,即图中的红色部分和small net的baseline进行对比。
  • 特别的有模型大小为10_25_250_10即尺寸为原始大模型的一半的时候,当lamda设置为0.3,T设置为3的时候,小模型经过对大模型的学习是可以达到大模型的精度的。
  • 第三对于当前的实验对于mnist数据集可以看出,最优值的基本上是在T超参数设置为3左右的时候出现的。

    综上所述,当大模型的精度很高的时候,模型蒸馏(知识提取)的效果可以达到很好,并且小模型经过学习是能够达到大模型的那种效果的。

4.2 cifar10

图6 cifar10数据集上模型蒸馏的实验结果

从图6中可以看出,teacher model一共有四层参数层,即两个卷积层以及两个全连接层,尺寸为32_32_64_10,其中的数字表示的为caffe中prototxt中每一层的num_output的大小。其精度不高只有0.7937;student model这里设计了三组对比实验,相应的尺寸分别为16_16_32_10、8_8_16_10、4_4_8_10。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的实验结果部分优于train_smallnet_from_scratch的实验结果。这个和mnist的实验结果有点差距。
  • 所有的train_smallnet_from_kd的实验结果都达不到最初的大模型的效果。
  • 对于当前的实验对于mnist数据集可以看出,最优值的基本上是在T超参数设置为1、2左右。

综上所述,当大模型的精度不高的时候,在特定的小模型尺寸、温度参数T以及lamda参数设置后的模型蒸馏(知识提取)也可以达到一定效果,但是最终达不到原始大模型的精度。

   

4.3 年龄和性别属性

图7 年龄和性别属性数据集上模型蒸馏的实验结果

从图7中可以看出,teacher model是一个具有12个卷积层CaffeNetConv网络,年龄和性别属性的精度都挺高分别为0.912161和0.98991。而student model分成了三种。第一种为 具有6个卷积层的CaffeNetConv网络,仅对年龄的属性进行模型蒸馏(知识提取);第二种还是具有12个卷积层的CaffeNetConv,但是其每层的num_output的大小减半,仅对年龄的属性进行模型蒸馏(知识提取);第三种的网络结构和第二种的网络结构一致,并且对年龄和性别两个属性同时进行模型蒸馏(知识提取)。其中12layers_half表示的是只对age进行模型蒸馏,12layers_half_both表示的是对age和gender同时进行模型蒸馏。

从图中可以看出:

  • 在设置特定的T和lamda的超参时,train_smallnet_from_kd的结果都要优于train_smallnet_from_scratch的结果。这个和mnist的实验结论一样。
  • 6层小网络模型在经过多组对比实验中都不能达到原始12层大网络的效果,然而在保持原始深度12层,将参数量减半的小网络在经过模型蒸馏(知识提取)后,却可以达到甚至超过原始12层大网络的精度。比如,当lamda设置为0.3,T设置为3的时候的第二种小网络以及lamda设置为0.1,T设置为3的时候的第三种小网络。
  • 对于当前的实验对于年龄以及性别属性数据集可以看出,最优值的基本上是在T超参数设置为3左右。
  • 从第二种小网络的实验结果可以看出,随着相应T和lamda超参数的设定,会使得年龄属性的精度上升,但是相反会导致未进行模型蒸馏的另一个分类任务精度的下降。
  • 从第三种小网络的实验结果可以看出,将年龄和性别的分类任务都进行模型蒸馏,可以解决上一个问题,性别精度的都有所提高,但是有可能会使得年龄的精度有些许下降(很少)。

原本的年龄属性识别网络中就已经引入了label smooth的思想,这个和模型蒸馏(知识提取)的思想很类似,所以本身年龄属性识别模型蒸馏的效果可能会被弱化。

5.总结

   在经过三个数据集上对模型蒸馏(知识提取)的方法进行实验,都表明模型蒸馏方法的有效性。当原始模型精度很高的时候,模型蒸馏的效果往往都会很好,并且在特定的模型T和lamda超参数的组合下,小的student model可以很好的学习到大的teacher model,甚至会超过原始大的网络的精度。相反,当原始teacher model的精度就不高,如cifar10中的实验效果一样,模型蒸馏的效果要差些,可能达不到原始teacher model的精度,甚至差距还挺大。特别的在属性数据集上的实验中可以看出,同样参数量的两种student网络,保持深度缩小宽度的小网络要比缩小深度保持宽度的小网络模型蒸馏的效果会更好。特别的有当原始的teacher model是一个多分类任务的时候,我们的实验表明如果仅对一个任务进行模型蒸馏,会使得其他分类任务的精度下降,而对多个分类任务都进行模型蒸馏的话,虽然没有单独模型蒸馏的效果那么好,但是所有模型的精度都会上升。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/751273.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

串口通信例子SeriaPort

本篇例子使用的虚拟串口转自这位博主:http://t.csdnimg.cn/LSGIs 串口COM: 是一种用于联接计算机和外设设备的接口,也叫串行接口,简称com,常见的串口有一般电脑应用的RS-232(使用25帧或者9帧的连接器) 通俗来讲串口就是usb接口、鼠标窗口。键…

第二十四节:带你梳理Vue2 : Vue具名插槽/作用域插槽/v-slot指令

1. 具名插槽 1.1 没有使用具名插槽的问题 有的时候我们在使用子组件时,在子组件模板上不同的位置插入不同的内容, 只有一个插槽显然没法满足我们的需求,看示例: 需求如下: 子组件是一篇文章的结构父组件在调用子组件是给文章插入标题,正文,时间信息 示例代码如下: <di…

6.26.4.1 基于交叉视角变换的未配准医学图像多视角分析

1. 介绍 许多医学成像任务使用来自多个视图或模式的数据&#xff0c;但很难有效地将这些数据结合起来。虽然多模态图像通常可以在神经网络中作为多个输入通道进行配准和处理&#xff0c;但来自不同视图的图像可能难以正确配准(例如&#xff0c;[2])。因此&#xff0c;大多数多视…

创新实训博客(十三)——admin前端工作效果

管理/教师端前端工作汇总education-admin&#xff1a; 首先是登录注册页面的展示 管理员 首页 管理员登录后的首页如下图所示 管理员拥有所有的权限 课程管理 1、可以查看、修改、增添、删除课程列表内容 2、可以对课程资源进行操作 3、可以对课程的类别信息进行管理&…

一个最简单的MySQL事务模拟测试

这里只是简单写了一个转账的小事务&#xff0c;模拟一下事务的过程 代码&#xff1a; 初始数据&#xff1a; 当你关闭自动提交 并且开启一个事务执行了下面的更新语句 但是没有提交时&#xff1a; 此时虽然你运行查询语句会发现他的值发生了变化 &#xff0c;但是当你运行回滚…

51单片机看门狗定时器配置

测试环境 单片机型号&#xff1a;STC8G1K08-38I-TSSOP20&#xff0c;其他型号请自行测试&#xff1b; IDE&#xff1a;KEIL C51&#xff1b; 寄存器配置及主要代码 手册中关于看门狗的寄存器描述如下&#xff1a; 启动看门狗&#xff0c;需将B5位EN_WDT置1即可&#xff0c;…

大数据------额外软件、插件及技术------Linux(完整知识点汇总)

Linxu 不同领域的主流操作系统 桌面操作系统 WindowsMAac OSLinux 服务器端操作系统 UNIX&#xff08;付费&#xff09;LinuxWindows Server&#xff08;付费&#xff09; 移动设备操作系统 Android&#xff08;基于Linux开源&#xff09;IOS&#xff08;不开源&#xff09; 嵌…

时间序列分析入门:概念、模型与应用【ARMA、ARIMA模型】

在这篇博客中&#xff0c;我们将全面探讨时间序列分析的基本概念和分类&#xff0c;深入理解平稳性及其检验方法&#xff0c;并介绍自回归模型&#xff08;AR&#xff09;、滑动平均模型&#xff08;MA&#xff09;、自回归滑动平均模型&#xff08;ARMA&#xff09;以及自回归…

动态流体工厂大屏

目录 一 设计原型 二 后台源码 一 设计原型 二 后台源码 namespace 动态流体工厂大屏 {public partial class Form1 : Form{public Form1(){InitializeComponent();}private void Form1_Load(object sender, EventArgs e){Task.Run(() >{while (true){this.Invoke(() >…

openEuler搭建hadoop Standalone 模式

Standalone 升级软件安装常用软件关闭防火墙修改主机名和IP地址修改hosts配置文件下载jdk和hadoop并配置环境变量配置ssh免密钥登录修改配置文件初始化集群windows修改hosts文件测试 1、升级软件 yum -y update2、安装常用软件 yum -y install gcc gcc-c autoconf automake…

模块化沙箱的优势与应用

在数字化时代&#xff0c;数据安全已成为企业乃至国家层面不可忽视的重要议题。随着云计算、大数据等技术的广泛应用&#xff0c;数据泄露、恶意攻击等安全威胁日益严峻。在这样的背景下&#xff0c;模块化沙箱技术应运而生&#xff0c;为企业提供了高效、灵活的数据安全解决方…

NAND闪存巨头铠侠(Kioxia)计划最迟于10月下旬通过首次公开募股IPO

据路透社于6月26日引用消息来源的报道&#xff0c;在半导体市场条件反弹及财务业绩迅速改善的背景下&#xff0c;NAND闪存巨头铠侠&#xff08;Kioxia&#xff09;正准备尽快提交初步申请&#xff0c;并计划最迟于10月下旬通过首次公开募股&#xff08;IPO&#xff09;在东京证…

【Hive中常见的优化手段----数据采集!Join 优化!Hive索引!数据倾斜!mapreduce本地模式!map和reduce数量调整!】

前言&#xff1a; &#x1f49e;&#x1f49e;大家好&#xff0c;我是书生♡&#xff0c;今天主要和大家分享一下Hive中常见的优化手段----数据采集&#xff01;常见的Join 优化有哪几种&#xff01;什么是Hive索引&#xff01;数据怎么发生倾斜&#xff01;什么是mapreduce的本…

Pycharm 文件标头设置

一、设置模板步骤&#xff1a; “文件File--设置Settings--编辑器Editor--File and Code Templates- Python Script” 里面设置模板 官方预设变量表 变量名 含义 ${DATE} 当前系统日期 ${DAY} 当前月的第几日 ${DAY_NAME_SHORT} 当前星期几的单词缩写&#xff08…

Vue2配置前端代理

在8080向5000请求数据 clivue2 一、cli内配置前端代理 1、使用 发送请求时写8080 在配置文件中配置 vue.config.js 2、缺点 无法配置多个代理无法控制某个请求知否要代理 二、方式二 module.exports {devServer: {proxy: {/api1:{ //匹配所有以/api1开头的请求路径…

向量化算法 doc2vec

第1关&#xff1a;认识 Doc2vec Doc2vec 算法简介 Doc2vec 又叫做 Paragraph2vec&#xff0c; Sentence embeddings&#xff0c;是一种非监督式算法&#xff0c;可以获得句子、段落、文档的向量表达&#xff0c;是 Word2vec 的拓展。学出来的向量可以通过计算距离来找句子、段…

华为笔记本电脑d盘数据丢失:原因、恢复方案与防范建议

华为笔记本电脑以其高性能和稳定的品质赢得了众多用户的青睐&#xff0c;但即使是如此优质的设备&#xff0c;也难免遭遇数据丢失的困境。本文将围绕华为笔记本电脑D盘数据丢失这一问题&#xff0c;探讨其常见原因、恢复方案&#xff0c;并提出未来防范的建议&#xff0c;以帮助…

Go 延迟调用 defer

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

Transformer教程之循环神经网络(RNN)和长短期记忆网络(LSTM)

在当今人工智能和深度学习的世界中&#xff0c;Transformer模型已经成为了主流。然而&#xff0c;回顾过去&#xff0c;循环神经网络&#xff08;RNN&#xff09;和长短期记忆网络&#xff08;LSTM&#xff09;在序列数据处理上也曾风靡一时。本文将详细讲解RNN和LSTM的原理、应…

FPC板设计

在板框属性里面选择FPC软板&#xff1a; FPC补强为什么要比焊盘单边大1mm&#xff1a;补强区域需比焊盘大1.0mm以上&#xff0c;才能有效保护焊盘与线路交接处不断裂 补强板放在功能面的背面&#xff1a; 、金手指厚度计算工具&#xff1a;https://tools.jlc.com/jlcTools/#/ca…