知识蒸馏详解及pytorch官网demo案例

知识蒸馏Knowledge Distillation(KD)

1、简介

一种模型压缩方法

知识蒸馏的一般框架(如下图)
三部分:知识、蒸馏算法、师生架构。
知识蒸馏的师生架构

知识

将知识分为三种形式:基于响应的(response-based)、基于特征的(feature-based)、基于关系的(relation-based)。
在这里插入图片描述

①基于响应的知识(response-based)【常用】
学习的知识是教师模型最后一个输出层logits。由于logits实际上是类别概率分布,因此基于响应的知识蒸馏限制在监督学习
在这里插入图片描述

最流行的基于响应的图像分类知识被称为软目标(soft target)

基于响应的知识蒸馏具体架构如下图。后面具体介绍该类知识蒸馏。
在这里插入图片描述
②基于特征的知识(feature-based)
学习的知识是教师模型中间层的基于特征的知识。下图为基于特征的知识蒸馏模型的通常架构。
在这里插入图片描述

③基于关系的知识(relation-based)
基于响应和基于特征的知识都使用了教师模型中特定层的输出,基于关系的知识进一步探索了不同层或数据样本的关系。下图为实例关系的知识蒸馏架构。

在这里插入图片描述

蒸馏机制

根据教师模型是否与学生模型同时更新,知识蒸馏的学习方案可分为离线(offline)蒸馏、在线(online)蒸馏、自蒸馏(self-distillation)

离线蒸馏(常用)
在离线蒸馏中,学生模型仅使用知识进行训练,而不与教师模型同时更新。学生模型独立地使用知识进行训练,目标是使学生模型的输出尽可能接近教师模型的输出。
大多数之前的知识蒸馏方法都是离线的。最初的知识蒸馏中,知识从预训练的教师模型转移到学生模型中,整个训练过程包括两个阶段:1)大型教师模型蒸馏前在训练样本训练;2)教师模型以logits(基于响应,生成软目标(soft target))或中间特征(基于特征)的形式提取知识,将其在蒸馏过程中指导学生模型的训练。

在线蒸馏
在线蒸馏时,教师模型和学生模型同步更新,而整个知识蒸馏框架都是端到端可训练的。
在线蒸馏是一种具有高效并行计算的单阶段端到端训练方案。然而,现有的在线方法(如相互学习)通常无法解决在线环境中的高容量教师,这使进一步探索在线环境中教师和学生模式之间的关系成为一个有趣的话题。

自蒸馏
在自蒸馏中,教师和学生模型使用相同的网络,这可以看作是在线蒸馏的一个特例。
在这里插入图片描述
从人类师生学习的角度可以直观地理解离线、在线和自蒸馏。
离线蒸馏是指知识渊博的教师教授学生知识;
在线蒸馏是指教师和学生一起学习;
自我蒸馏是指学生自己学习知识。

师生架构

教师模型(cumbersome model):已经训练好的,较为笨重的模型。
学生模型:通过蒸馏,将教师模型中已经学习到的知识迁移到的新的轻量级的模型。


2、学生模型的训练(基于响应的离线知识蒸馏)

hard target(硬目标)与 soft target(软目标)

hard target仅包含正样本信息
soft target具有更多信息,不仅包含正样本信息,还有相似负样本信息,比如左图的正样本标签为2,但由于写法与3相像,因此对标签3也给予一定的关注通过增大概率值;而右图的正样本标签2写法与7相像,因此对标签7也给予一定的关注。
具体到代码中就是加入蒸馏温度T。

在这里插入图片描述

蒸馏温度 T T T

原来的softmax 将多分类的输出结果映射为概率值。 q i = e z i ∑ j = 1 n e z j q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}} qi=j=1nezjezi,其中 z i z_i zi是模型的softmax层输出logits。

在进行知识蒸馏时,如果将教师模型的softmax输出,作为学生模型的 s o f t − t a r g e t soft-target softtarget,那么负标签的值接近于0,对学生模型的损失函数贡献非常小,使得模型难以利用教师模型学到的知识。因此,提出蒸馏温度T的概念,使得softmax是输出更加平滑。

加入蒸馏温度 T T T后的softmax
q i = e ( z i / T ) ∑ j = 1 n e ( z j / T ) q_i=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}} qi=j=1ne(zj/T)e(zi/T)

实验:当温度 T T T越高时,负标签的概率值的变化。

在这里插入图片描述正标签为第1个元素,当温度 T T T越高时,负标签的概率值相对被放得越大。在训练时,由于损失函数的惩罚,模型需要对负标签给予一定的关注;从而达到在学习老师模型时,一次训练不仅仅可以学到正样本的特征,也可以学到相似负样本的特征。

import numpy as np

def softmax(x):
    x_exp = np.exp(x)
    return x_exp/x_exp.sum()

def softmax_t(x, T):
    # T是蒸馏温度
    x_exp = np.exp(x/T)
    return x_exp/x_exp.sum()

output = np.array([5, 1.3, 2])

print('temperature is 5: ', softmax_t(output, 5))
print('temperature is 10: ', softmax_t(output, 10))
print('temperature is 100: ', softmax_t(output, 100))

在这里插入图片描述

知识蒸馏训练的具体步骤

①训练好Teacher模型
②利用高温 T h i g h T_{high} Thigh产生 s o f t − t a r g e t soft-target softtarget
③使用{ s o f t − t a r g e t , T h i g h soft-target, T_{high} softtarget,Thigh}和{ h a r d − t a r g e t , T = 1 hard-target, T=1 hardtarget,T=1},同时训练 Student 模型
④设置蒸馏温度 T = 1 T=1 T=1,Student模型线上做推理

高温蒸馏过程的损失函数

学生损失函数student loss即, L h a r d = − ∑ j = 1 n l j l o g ( q j ) , q i = e z i ∑ j = 1 n e z j L_{hard}=-\sum_{j=1}^nl_jlog(q_j),q_i=\frac{e^{z_i}}{\sum_{j=1}^n{e^{z_j}}} Lhard=j=1nljlog(qj)qi=j=1nezjezi
蒸馏损失函数distillation loss即, L s o f t = − ∑ j = 1 n p j T l o g ( q j T ) , p i T = e ( v i / T ) ∑ j = 1 n e ( v j / T ) , q i T = e ( z i / T ) ∑ j = 1 n e ( z j / T ) L_{soft}=-\sum_{j=1}^np_j^Tlog(q_j^T),p_i^T=\frac{e^{(v_i/T)}}{\sum_{j=1}^n{e^{(v_j/T)}}},q_i^T=\frac{e^{(z_i/T)}}{\sum_{j=1}^n{e^{(z_j/T)}}} Lsoft=j=1npjTlog(qjT)piT=j=1ne(vj/T)e(vi/T)qiT=j=1ne(zj/T)e(zi/T)

高温蒸馏过程的损失函数定义为: L = α L s o f t + β L h a r d L=\alpha L_{soft}+\beta L_{hard} L=αLsoft+βLhard
其中, l i l_i li为第i个ground truth值, z i z_i zi为学生模型的第i个输出logits值, v i v_i vi为老师模型的第i个输出logits值, α \alpha α β \beta β为超参数。

在这里插入图片描述

pytorch官网 知识蒸馏demo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/504151.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

pytest--python的一种测试框架--pytest常用断言类型

一、pytest常用断言类型 等于: 不等于&#xff1a;&#xff01; 大于&#xff1a;> 小于&#xff1a;< 属于&#xff1a;in 不属于&#xff1a;not in 大于等于&#xff1a;> 小于等于&#xff1a;< 是&#xff1a;is 不是&#xff1a;is not def test_two():ass…

酷得单片机方案 2.4G儿童遥控漂移车

电子方案开发定制&#xff0c;我们是专业的 东莞酷得智能单片机方案之2.4G遥控玩具童车具有以下比较有特色的特点&#xff1a; 1、内置充电电池&#xff1a;这款小车配备了可充电的电池&#xff0c;无需频繁更换电池&#xff0c;既环保又方便。充电方式可能为USB充电或者专用…

LATTICE进阶篇DDR2--(0)获取ddr2 IP核

前言 想要仿真lattice的DDR2由来已久&#xff0c;但苦于对其了解甚少&#xff0c;在查阅过很多资料后&#xff0c;终于对这个IP核的仿真有了一些了解。 现做一些总结&#xff0c;以备不时之需&#xff0c;也让有需要的朋友&#xff0c;少走一些弯路。 环境&#xff1a;win10…

算法学习——LeetCode力扣动态规划篇5

算法学习——LeetCode力扣动态规划篇5 198. 打家劫舍 198. 打家劫舍 - 力扣&#xff08;LeetCode&#xff09; 描述 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统…

通知中心架构:打造高效沟通平台,提升信息传递效率

随着信息技术的快速发展&#xff0c;通知中心架构作为一种关键的沟通工具&#xff0c;正逐渐成为各类应用和系统中必不可少的组成部分。本文将深入探讨通知中心架构的意义、设计原则以及在实际场景中的应用。 ### 什么是通知中心架构&#xff1f; 通知中心架构是指通过集中管…

信息学奥赛一本通T1268-完全背包问题

solution1 二维形式 #include<iostream> #include<algorithm> using namespace std; const int maxn 35, maxv 210; int w[maxn], c[maxn], dp[maxn][maxv]; int main(){int n, m;scanf("%d%d", &m, &n);for(int i 1; i < n; i){scanf(&…

电脑win10系统更新后很卡怎么办,win10电脑更新完系统特别卡

更新或者升级win10系统后发现电脑变卡了,这是什么原因呢?如果电脑硬件不是特别差,那么可以按照下面的方法来缓解卡顿,因为可能是内存不足所引起的,试试清理更新缓存和禁用开机启动项。但如果是硬件较低或者太老旧,并且本身的内存就很小的话,那么建议你还是升级硬件吧。下…

.NET 开发支持技术路线 .Net 7 将停止支持

.NET 开发技术路线图 微软方面强调&#xff0c;使用 .NET 7 的应用程序将在支持结束后继续运行&#xff0c;但用户可能无法获得 .NET 7 应用程序的技术支持。他们不会继续为 .NET 7 发布新的安全更新&#xff0c;用户可能会面临安全漏洞问题。 开发人员必须使用 .NET 8 SDK 构建…

Windows提权!!!

之前讲过一下提权&#xff0c;但是感觉有点不成体系&#xff0c;所以我们就成体系的来讲一下这个操作系统的提权 目录 Windows的提权 1.Widnows的内核溢出提权 1.MSF自带的提权模块&#xff08;Win11都能提上来&#xff0c;有点牛逼&#xff09; 2.CS的插件提权 3.补丁对比…

毕设论文目录设置

添加目录 选择一种格式的自动目录 更新目录 发现该目录中只有1、2章&#xff0c;3、4章 然后再点击更新目录 对应的&#xff0c;小标题添加二级目录

基于JavaSpringMVC+Mybatis+Jquery高校毕业设计管理系统设计和实现

基于JavaSpringMVCMybatisJquery高校毕业设计管理系统设计和实现 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐…

【C语言】结构体详解(一)

目录 1、什么是结构体? 2、结构体成分 3、结构体变量的定义与初始化 3.1、结构体变量的三种定义方式 3.2、结构体变量的初始化 4、结构体成员的访问&#xff08;两种方式&#xff09; 4.1、直接访问 4.2、间接访问 5、结构的特殊声明 5.1、不完全声明&#xff08;匿…

医院陪诊管理系统(源码+文档)

TOC) 文件包含内容 1、搭建视频 2、流程图 3、开题报告 4、数据库 5、参考文献 6、服务器接口文件 7、接口文档 8、任务书 9、功能图 10、环境搭建软件 11、十六周指导记录 12、答辩ppt模板 13、技术详解 14、前端后台管理&#xff08;管理端程序&#xff09; 15、项目截图 1…

06-JavaScript DOM对象

1. 从ECMA到W3C 我们知道&#xff0c;ECMA定义的是js的变量语法等基础的标准规范&#xff0c;而W3C是针对浏览器API提出的规范&#xff0c; 所以我们要工作不可能只了解语法&#xff0c;我们的代码要在浏览器上跑起来就需要我们去了解W3C的标准。 那么W3C规定了哪一系列的的A…

深入PostgreSQL中的pg_global表空间

pg_global表空间的位置 在PG当中&#xff0c;一个实例(cluster)初始化完以后&#xff0c;你会看到有下边两个与表空间相关的目录生成&#xff1a; $PGDATA/base $PGDATA/global 我们再用元命令\db以及相关视图看看相应的表空间信息&#xff1a; postgres# \db …

28. UE5 RPG同步面板属性(四)

在前面几篇中&#xff0c;我们实现了以下步骤&#xff1a; 首先我们需要通过c去实现创建GameplayTag&#xff0c;这样可以在c和UE里同时获取到Tag创建一个DataAsset类&#xff0c;用于设置tag对应的属性和显示内容创建AttributeMenuWidgetController实现对应逻辑 上面几步在前…

MySQL数据库下,使页面传入的数据与保存的数据编码一致

一、查询当前MySQL数据库的编码 &#xff08;1&#xff09;登录MySQL数据库&#xff08;Windows系统&#xff09;&#xff1a;winR打开命令终端&#xff0c;cd到MySQL的bin目录&#xff0c;输入mysql -u root -p&#xff0c;回车后输入登录密码 &#xff08;2&#xff09;查看…

【C++】C++入门第一课(c++关键字 | 命名空间 | c++输入输出 | 缺省参数)

目录 前言 C关键字 命名空间 1.命名空间的定义 A.标准命名空间定义 B.命名空间允许嵌套定义 C.同名命名空间的合并 2.命名空间的使用 加命名空间名称及作用限定符 使用using将命名空间中某个成员引入 使用using namespace命名空间名称引入 C的输入和输出 缺省参数…

C++类基础5——拷贝构造函数,拷贝赋值运算符(复制构造函数,复制赋值运算符)

拷贝控制操作 当定义一个类时&#xff0c;我们显式地或隐式地指定在此类望的对象拷贝&#xff0c;移动、赋值和销毁时做什么。 一个类通定义五种特殊的成员函数来控制这些操作&#xff0c;包括&#xff1a;拷贝构造函数(copy consinuctor)、拷贝赋值运算符(copy-assignment op…

如何修复开机但不显示任何内容的计算机?这里提供详细步骤

前言​ 计算机“无法开机”的最常见方式是PC实际开机但在显示器上不显示任何内容。你看到电脑机箱上的灯,可能看到里面的风扇在转,甚至可能听到声音,但屏幕上什么也没有显示,请按照我们提供的顺序尝试以下常见修复方法。 测试显示器 在对计算机的其余部分进行更复杂和耗时…