教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale
知识蒸馏:让AI模型更轻更快
在人工智能快速发展的今天,我们经常需要在资源受限的设备(如手机、IoT设备)上运行AI模型。但这些设备的计算能力和内存都很有限,无法直接运行庞大的AI模型。这就带来了一个重要问题:如何将大模型的能力迁移到小设备上?知识蒸馏(Knowledge Distillation)就是解决这个问题的重要技术之一。
什么是知识蒸馏?
知识蒸馏可以形象地理解为"教师教学生"的过程。大模型(教师模型)将自己学到的"知识"传授给小模型(学生模型),帮助小模型在保持较小体积的同时,获得接近大模型的性能。
这里的"知识"主要包括:
- 模型的输出概率分布(软标签)
- 模型中间层的特征
- 注意力图等信息
知识蒸馏的核心概念
1. 软标签与硬标签
- 硬标签:传统的分类标签,比如[0,1,0]表示第二类
- 软标签:模型输出的概率分布,比如[0.1,0.8,0.1],包含更丰富的信息
2. 温度参数
温度参数用于调节概率分布的"软硬程度":
- 温度越高,分布越平滑
- 温度越低,分布越接近硬标签
- 合适的温度可以帮助学生模型更好地学习
下面是一个例子:当输入一张马的图片时,对于未调整温度(默认为1)的 Softmax 输出,正标签的概率接近 1,而负标签的概率接近 0。这种尖锐的分布对学生模型不够友好,因为它只提供了关于正确答案的信息,而忽略了错误答案的信息。即驴比汽车更像马
,识别为驴的概率应该大于识别为汽车的概率。而通过温度调整后, 最后得到一个相对平滑的概率分布, 称为 “软标签” (Soft Label)。
知识蒸馏的不同方式
1. 基于输出的蒸馏
直接匹配教师模型和学生模型的输出概率分布。
2. 基于中间层特征的蒸馏
匹配模型中间层的特征,让学生模型学习教师模型的"思考过程"。
3. 基于中间层注意力图的蒸馏
传递模型的注意力机制,帮助学生模型知道"该关注什么"。
4.基于中间层权重的蒸馏
5.基于中间层稀疏模式的蒸馏
6.基于中间相关信息的蒸馏
创新的蒸馏方法
1. 自蒸馏
模型自己当老师,通过多次迭代提升性能,不需要额外的教师模型。
2. 在线蒸馏
教师模型和学生模型同时训练,相互学习,提高效率。
3.结合在线蒸馏和自蒸馏
实际应用场景
知识蒸馏在多个领域都有成功应用:
1. 目标检测
不仅传递分类知识,还包括物体定位信息。
2. 语义分割
通过像素级、成对和整体三个层面的蒸馏提升性能。
3. 生成对抗网络(GAN)
结合蒸馏、重构和对抗性损失实现模型压缩。
4. 自然语言处理
特别强调注意力机制的传递,提升文本处理能力。
网络增强:另一种思路
除了传统的知识蒸馏,网络增强(NetAug)提供了一个新视角:
- 不是简化大模型,而是增强小模型
- 将小模型嵌入到大模型中学习
- 通过多重监督提升性能
代码实践
主要包含:
KD知识蒸馏 DKD解耦知识蒸馏
其区别主要集中在损失函数的不同。
现有的知识蒸馏方法主要关注于中间层的深度特征蒸馏,而对logit蒸馏的重要性认识不足。[DKD]()重新定义了传统的知识蒸馏损失函数,将其分解为目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。
- 目标类知识蒸馏(TCKD):关注于目标类的知识传递。
- 非目标类知识蒸馏(NCKD):关注于非目标类之间的知识传递。
# kd_loss
def loss(logits_student, logits_teacher, temperature):
log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
loss_kd *= temperature**2
return loss_kd
import torch
import torch.nn as nn
import torch.nn.functional as F
def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
# 使用 _get_gt_mask 和 _get_other_mask 函数创建掩码,分别用于标识真实标签和其他类别。这使得损失计算可以选择性地关注特定类别。
gt_mask = _get_gt_mask(logits_student, target)
other_mask = _get_other_mask(logits_student, target)
pred_student = F.softmax(logits_student / temperature, dim=1)
pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
# 使用 cat_mask 函数将掩码应用于学生和教师的预测,得到只关注特定类别的输出。
pred_student = cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = torch.log(pred_student)
# 计算针对真实标签的 KL 散度损失(tckd_loss),并进行温度缩放
tckd_loss = (
F.kl_div(log_pred_student, pred_teacher, size_average=False)
* (temperature**2)
/ target.shape[0]
)
# 计算针对其他类别的 KL 散度损失(nckd_loss),通过从 logits 中减去一个大的值(1000.0)来忽略真实标签的影响。
pred_teacher_part2 = F.softmax(
logits_teacher / temperature - 1000.0 * gt_mask, dim=1
)
log_pred_student_part2 = F.log_softmax(
logits_student / temperature - 1000.0 * gt_mask, dim=1
)
nckd_loss = (
F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
* (temperature**2)
/ target.shape[0]
)
# 原论文中这里加入了一个 WarmUP
return alpha * tckd_loss + beta * nckd_loss
def _get_gt_mask(logits, target):
# 生成一个与 logits 形状相同的全零张量,并在真实标签对应的位置设置为 1,最终返回一个布尔掩码。这个掩码用于在损失计算中关注真实类别。
target = target.reshape(-1)
mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
return mask
def _get_other_mask(logits, target):
# 生成一个与 logits 形状相同的全一张量,并在真实标签对应的位置设置为 0,最终返回一个布尔掩码。这个掩码用于在损失计算中关注其他类别。
target = target.reshape(-1)
mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
return mask
def cat_mask(t, mask1, mask2):
# 将输入张量 t 与两个掩码结合,计算出只关注特定类别的输出。
# 由于 mask1 只保留真实类别的概率,因此这个求和操作给出了每个样本的真实类别的总概率。
t1 = (t * mask1).sum(dim=1, keepdims=True)
t2 = (t * mask2).sum(1, keepdims=True)
rt = torch.cat([t1, t2], dim=1)
return rt
完整代码:
- KD知识蒸馏
- DKD解耦知识蒸馏