一.是什么?
把一个大的模型(定义为教师模型)萃取,蒸馏,把它浓缩到小的模型(定义为学生模型)。
即:大的神经网络把他的知识教给了小的神经网络。
二.为什么要用知识蒸馏把大模型学习到的东西迁移到小模型呢呢?
因为大的模型很臃肿,而真正落地的终端算力有限,比如手表,安防终端。
所以要把大模型变成小模型,把小模型部署到终端上。
2.1 轻量化网络的方向
分为下面四个方向,知识蒸馏是第一个方向。
三.用蒸馏温度处理学生网络的标签
学生网络的输入是教师网络的输出
3.1 soft target
soft target使我们常用的概率版的标签值。比如输入下面的图片预测。
hard targets和soft targets的预测概率如下:
hard targets的预测结果不科学,因为马和驴比马和汽车相似的多。所以驴和汽车都是0,没有表现出这个信息,所以要用soft targets.
3.2 用教师网络预测出的soft target作为学生网络的标签。
教师网络预测出的soft target具有很多信息。
3.3 蒸馏温度
softmax有放大差异的功能。
如果值高那么一点点,经过softmax的放大就会变得很高。
如果想让soft target更加平缓,高的降低,低的升高。
这时就要对soft target使用蒸馏温度。 让soft target更soft。
实现方法是在softmax的分母处加个T。
效果如下:T=1时相当于没有蒸馏温度。T=3时确实低的更低高的更高了。
T和分布的关系如下图,T从1增加到10,值之间的差异越来越小,softmax的放大效果被冲淡。
当T=100的时候,结果直接变成一个横线,众生平等。
四.知识蒸馏训练过程
4.1 图示知识蒸馏训练过程
上面是已经训练好的教师网络。
把数据输入到教师网络,在输出时使用蒸馏温度为T的softmax.
再把数据输入到学生网络,学生网络可能是还没有训练的网络,也可能是训练一半的半成品网络。
4.2 损失函数
学生网络既要在蒸馏温度等于T时与教师网络的结果相接近。
也要保证不使用蒸馏温度时的结果与真实结果相接近。
蒸馏损失:
把教师网络使用蒸馏温度为t的输出结果 与 学生网络蒸馏温度为t的输出结果做损失。
让这个损失越小越好。
学生损失:
学生网络蒸馏温度为1(即不使用蒸馏网络)时的预测结果和真实的标签做loss.
最后对这两项加权求和。
4.3 图解损失函数计算过程
红色线条指向的是学生损失。
紫色线条指向的是蒸馏损失。
五.推理过程
此时学生网络已经训练好,把X输入到学生网络得到结果。
六.最终效果:
学生网络可以接近教师网络的识别效果,并且附加如下两个特点:
1.零样本识别
论文里面说:以手写体数字数据集为例,假如在训练学生网络时把标签为3的类别全部去掉,
但是教师网络学过3。当使用知识蒸馏将教师网络学到的东西迁移到学生网络时,学生网络虽然没有见过3,但是却能识别3,即达到了零样本的效果。
2.使用soft target训练而不是hard target,减少了过拟合
第二行和第三行是使用百分之3的训练样本并分别用hard target和soft target,结果显示
使用3%的训练样本 + hard target :
训练集的准确率为 67.3%, 测试集的准确率为44.5%。
使用3%的训练样本 + soft target :
训练集的准确率为 65.4%, 测试集的准确率为57.5%。
七.迁移学习和知识蒸馏的区别
迁移学习是把一个模型学习的领域泛化到另一个领域,比如把猫狗这些动物域迁移到医疗域。
知识蒸馏是把一个模型的知识迁移到另一个模型上。
八.参考视频
B站UP主,同济子豪兄的视频:
【精读AI论文】知识蒸馏
https://www.bilibili.com/video/BV1gS4y1k7vj/?spm_id_from=333.788&vd_source=ebc47f36e62b223817b8e0edff181613