模型减肥秘籍:模型压缩技术 知识蒸馏

教程链接:模型减肥秘籍:模型压缩技术-课程详情 | Datawhale

知识蒸馏:让AI模型更轻更快

在人工智能快速发展的今天,我们经常需要在资源受限的设备(如手机、IoT设备)上运行AI模型。但这些设备的计算能力和内存都很有限,无法直接运行庞大的AI模型。这就带来了一个重要问题:如何将大模型的能力迁移到小设备上?知识蒸馏(Knowledge Distillation)就是解决这个问题的重要技术之一。

什么是知识蒸馏?

知识蒸馏可以形象地理解为"教师教学生"的过程。大模型(教师模型)将自己学到的"知识"传授给小模型(学生模型),帮助小模型在保持较小体积的同时,获得接近大模型的性能。

这里的"知识"主要包括:

  • 模型的输出概率分布(软标签)
  • 模型中间层的特征
  • 注意力图等信息

知识蒸馏的核心概念

1. 软标签与硬标签

  • 硬标签:传统的分类标签,比如[0,1,0]表示第二类
  • 软标签:模型输出的概率分布,比如[0.1,0.8,0.1],包含更丰富的信息

2. 温度参数

温度参数用于调节概率分布的"软硬程度":

  • 温度越高,分布越平滑
  • 温度越低,分布越接近硬标签
  • 合适的温度可以帮助学生模型更好地学习

下面是一个例子:当输入一张马的图片时,对于未调整温度(默认为1)的 Softmax 输出,正标签的概率接近 1,而负标签的概率接近 0。这种尖锐的分布对学生模型不够友好,因为它只提供了关于正确答案的信息,而忽略了错误答案的信息。即驴比汽车更像马,识别为驴的概率应该大于识别为汽车的概率。而通过温度调整后, 最后得到一个相对平滑的概率分布, 称为 “软标签” (Soft Label)。

知识蒸馏的不同方式

1. 基于输出的蒸馏

直接匹配教师模型和学生模型的输出概率分布。

2. 基于中间层特征的蒸馏

匹配模型中间层的特征,让学生模型学习教师模型的"思考过程"。

3. 基于中间层注意力图的蒸馏

传递模型的注意力机制,帮助学生模型知道"该关注什么"。

4.基于中间层权重的蒸馏

5.基于中间层稀疏模式的蒸馏

6.基于中间相关信息的蒸馏

创新的蒸馏方法

1. 自蒸馏

模型自己当老师,通过多次迭代提升性能,不需要额外的教师模型。

2. 在线蒸馏

教师模型和学生模型同时训练,相互学习,提高效率。

3.结合在线蒸馏和自蒸馏

实际应用场景

知识蒸馏在多个领域都有成功应用:

1. 目标检测

不仅传递分类知识,还包括物体定位信息。

2. 语义分割

通过像素级、成对和整体三个层面的蒸馏提升性能。

3. 生成对抗网络(GAN)

结合蒸馏、重构和对抗性损失实现模型压缩。

4. 自然语言处理

特别强调注意力机制的传递,提升文本处理能力。

网络增强:另一种思路

除了传统的知识蒸馏,网络增强(NetAug)提供了一个新视角:

  • 不是简化大模型,而是增强小模型
  • 将小模型嵌入到大模型中学习
  • 通过多重监督提升性能

代码实践

主要包含:

KD知识蒸馏        DKD解耦知识蒸馏

其区别主要集中在损失函数的不同。

现有的知识蒸馏方法主要关注于中间层的深度特征蒸馏,而对logit蒸馏的重要性认识不足。[DKD]()重新定义了传统的知识蒸馏损失函数,将其分解为目标类知识蒸馏(TCKD)和非目标类知识蒸馏(NCKD)。

- 目标类知识蒸馏(TCKD):关注于目标类的知识传递。

- 非目标类知识蒸馏(NCKD):关注于非目标类之间的知识传递。

# kd_loss
def loss(logits_student, logits_teacher, temperature):
    log_pred_student = F.log_softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    loss_kd = F.kl_div(log_pred_student, pred_teacher, reduction="none").sum(1).mean()
    loss_kd *= temperature**2
    return loss_kd
import torch
import torch.nn as nn
import torch.nn.functional as F


def dkd_loss(logits_student, logits_teacher, target, alpha, beta, temperature):
    # 使用 _get_gt_mask 和 _get_other_mask 函数创建掩码,分别用于标识真实标签和其他类别。这使得损失计算可以选择性地关注特定类别。
    gt_mask = _get_gt_mask(logits_student, target)
    other_mask = _get_other_mask(logits_student, target)
    pred_student = F.softmax(logits_student / temperature, dim=1)
    pred_teacher = F.softmax(logits_teacher / temperature, dim=1)
    # 使用 cat_mask 函数将掩码应用于学生和教师的预测,得到只关注特定类别的输出。
    pred_student = cat_mask(pred_student, gt_mask, other_mask)
    pred_teacher = cat_mask(pred_teacher, gt_mask, other_mask)
    log_pred_student = torch.log(pred_student)
    # 计算针对真实标签的 KL 散度损失(tckd_loss),并进行温度缩放
    tckd_loss = (
        F.kl_div(log_pred_student, pred_teacher, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    # 计算针对其他类别的 KL 散度损失(nckd_loss),通过从 logits 中减去一个大的值(1000.0)来忽略真实标签的影响。
    pred_teacher_part2 = F.softmax(
        logits_teacher / temperature - 1000.0 * gt_mask, dim=1
    )
    log_pred_student_part2 = F.log_softmax(
        logits_student / temperature - 1000.0 * gt_mask, dim=1
    )
    nckd_loss = (
        F.kl_div(log_pred_student_part2, pred_teacher_part2, size_average=False)
        * (temperature**2)
        / target.shape[0]
    )
    # 原论文中这里加入了一个 WarmUP
    return alpha * tckd_loss + beta * nckd_loss


def _get_gt_mask(logits, target):
    # 生成一个与 logits 形状相同的全零张量,并在真实标签对应的位置设置为 1,最终返回一个布尔掩码。这个掩码用于在损失计算中关注真实类别。
    target = target.reshape(-1)
    mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool()
    return mask


def _get_other_mask(logits, target):
    # 生成一个与 logits 形状相同的全一张量,并在真实标签对应的位置设置为 0,最终返回一个布尔掩码。这个掩码用于在损失计算中关注其他类别。
    target = target.reshape(-1)
    mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool()
    return mask


def cat_mask(t, mask1, mask2):
    # 将输入张量 t 与两个掩码结合,计算出只关注特定类别的输出。
    # 由于 mask1 只保留真实类别的概率,因此这个求和操作给出了每个样本的真实类别的总概率。
    t1 = (t * mask1).sum(dim=1, keepdims=True)
    t2 = (t * mask2).sum(1, keepdims=True)
    rt = torch.cat([t1, t2], dim=1)
    return rt

完整代码:

  • KD知识蒸馏
  • DKD解耦知识蒸馏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/922450.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

golang实现TCP服务器与客户端的断线自动重连功能

1.服务端 2.客户端 生成服务端口程序: 生成客户端程序: 测试断线重连: 初始连接成功

React表单联动

Ant Design 1、dependencies Form.Item 可以通过 dependencies 属性,设置关联字段。当关联字段的值发生变化时,会触发校验与更新。 一种常见的场景:注册用户表单的“密码”与“确认密码”字段。“确认密码”校验依赖于“密码”字段&#x…

springboot实战(16)(Validation参数校验冲突问题、分组校验、默认分组)

目录 一、注解NotNull与NotEmpty区别。 二、Validation提供的分组校验。(参数校验冲突问题) (1)基本介绍。 (2)实际案例。 (3)大模型提问提供的方法。 1、定义分组接口。 2、在字段上…

学Linux的第九天--磁盘管理

目录 一、磁盘简介 (一)、认知磁盘 (1)结构 (2)物理设备的命名规则 (二)、磁盘分区方式 MBR分区 MBR分区类型 扩展 GPT格式 lsblk命令 使用fdisk管理分区 使用gdisk管理分…

【ubuntu+win】Win10+Ubuntu22.04双系统给ubuntu系统中的某个分区进行扩容(从400G->800G)数据无损坏

给ubuntu已分区的部分进行扩容 1. 准备扩容的空间2.进入ubuntu系统进行卸载分区3.安装图形界面的安装包4.进行对分区扩容5. 重新挂载 我的情况是这式的(可以不看,直接看后面的): 刚开始买下电脑的时候,只装了一个 1T 的…

流式上传与分片上传的原理与实现

🚀 博主介绍:大家好,我是无休居士!一枚任职于一线Top3互联网大厂的Java开发工程师! 🚀 🌟 在这里,你将找到通往Java技术大门的钥匙。作为一个爱敲代码技术人,我不仅热衷…

Ettus USRP X410

总线连接器: 以太网 RF频率范围: 1 MHz 至 7.2 GHz GPSDO: 是 输出通道数量: 4 RF收发仪瞬时带宽: 400 MHz 输入通道数量: 4 FPGA: Zynq US RFSoC (ZU28DR) 1 MHz to 7.2 GHz,400 MHz带宽,GPS驯服OCXO,USRP软件无线电设备 Ettus USRP X410集…

oracle 19c RAC到单机ogg部署安装

源端(RAC)目标端(FS)IP192.168.40.30/31192.168.40.50数据库版本Oracle 19.3.0Oracle 19.3.0主机名hfdb30/hfdb31hfogg操作系统REHL7.6REHL7.6数据库实例hfdb1/hfdb2hfogg同步用户hfdb1hfdb1同步表testtestOGG版本19.1.0.0.419.1.…

现代密码学

概论 计算机安全的最核心三个关键目标(指标)/为:保密性 Confidentiality、完整性 Integrity、可用性 Availability ,三者称为 CIA三元组 数据保密性:确保隐私或是秘密信息不向非授权者泄漏,也不被非授权者使…

QT QGridLayout控件 全面详解

本系列文章全面的介绍了QT中的57种控件的使用方法以及示例,包括 Button(PushButton、toolButton、radioButton、checkBox、commandLinkButton、buttonBox)、Layouts(verticalLayout、horizontalLayout、gridLayout、formLayout)、Spacers(verticalSpacer、horizonta…

Adobe Illustrator 2024 安装教程与下载分享

介绍一下 下载直接看文章末尾 Adobe Illustrator 是一款由Adobe Systems开发的矢量图形编辑软件。它广泛应用于创建和编辑矢量图形、插图、徽标、图标、排版和广告等领域。以下是Adobe Illustrator的一些主要特点和功能: 矢量绘图:Illustrator使用矢量…

IDEA2023设置控制台日志输出到本地文件

1、Run->Edit Configurations 2、选择要输出日志的日志,右侧,IDEA2023的Logs在 Modify option 里 选中就会展示Logs栏。注意一定要先把这个日志文件创建出来,不然不会自动创建日志文件的 IDEA以前版本的Logs会直接展示出来 3、但是…

[UE5学习] 一、使用源代码安装UE5.4

一、简介 本文介绍了如何使用源代码安装编译UE5.4,并且新建简单的项目,打包成安卓平台下的apk安装包。 二、使用源代码安装UE5.4 注意事项: 请保证可以全程流畅地科学上网。请保证C盘具有充足的空间。请保证接下来安装下载的visual studi…

细说敏捷:敏捷四会之standup meeting

上一篇文章中,我们讨论了 敏捷四会 中 冲刺计划会 的实施要点,本篇我们继续分享敏捷四会中实施最频繁,团队最容易实施但往往也最容易走形的第二个会议:每日站会 关于每日站会的误区 站会是一个比较有标志性的仪式活动&#xff0…

10M和100M网口的编码及EMC影响

10M网口编码技术 10M网口,即10Base-T,采用的是曼彻斯特编码方法 。在这种编码中,“0”由“”跳变到“-”,而“1”由“-”跳变到“” 。这种编码方式的特点是信号的DC平衡,即信号在任何一段时间内的平均电压为零&#…

docker基本使用

参考视频: 参考视频https://www.bilibili.com/video/BV1e64y1F7pJ/?share_sourcecopy_web&vd_source8fc0c76c477d3db71f89fa5ae5b258c7 docker容器操作: 拉取镜像: 拉取官网ubuntu镜像 sudo docker pull ubuntu 运行镜像:…

音频信号采集前端电路分析

音频信号采集前端电路 一、实验要求 要求设计一个声音采集系统 信号幅度:0.1mVpp到1Vpp 信号频率:100Hz到16KHz 搭建一个带通滤波器,滤除高频和低频部分 ADC采用套件中的AD7920,转换率设定为96Ksps ;96*161536 …

构建高效在线教育:SpringBoot课程管理系统

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理在线课程管理系统的相关信息成为必然。开发…

【云计算网络安全】解析 Amazon 安全服务:构建纵深防御设计最佳实践

文章目录 一、前言二、什么是“纵深安全防御”?三、为什么有必要采用纵深安全防御策略?四、以亚马逊云科技为案例了解纵深安全防御策略设计4.1 原始设计缺少安全策略4.2 外界围栏构建安全边界4.3 访问层安全设计4.4 实例层安全设计4.5 数据层安全设计4.6…

基于LiteFlow的风控系统指标版本控制

个人博客:无奈何杨(wnhyang) 个人语雀:wnhyang 共享语雀:在线知识共享 Github:wnhyang - Overview 更新日志 最近关于https://github.com/wnhyang/coolGuard此项目更新了如下内容:https://g…