论文解读:Masked Generative Distillation

文章汇总

话题

知识蒸馏

创新点

带掩盖的生成式蒸馏

方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),来帮助学生获得更好的表现

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:

其中:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

想改进的地方

摘要

知识蒸馏已成功地应用于各种任务中。目前的蒸馏算法通常通过模仿老师的输出来提高学生的表现。本文表明,教师还可以通过引导学生特征恢复来提高学生的代表性。从这个角度来看,我们提出了掩膜生成蒸馏(mask Generative Distillation, MGD),它很简单:我们掩膜学生特征的随机像素,并通过一个简单的块强制其生成教师的完整特征。MGD是一种真正通用的基于特征的蒸馏方法,可用于各种任务,包括图像分类、目标检测、语义分割和实例分割。我们用大量的数据集对不同的模型进行了实验,结果表明所有的学生都取得了很好的进步。值得注意的是,我们将ResNet-18的ImageNet顶级1精度从69.90%提高到71.69%,将ResNet-50主干的RetinaNet从37.4提高到41.0 Boundingbox mAP,基于ResNet-50的SOLO从33.1提高到36.2 Mask mAP,以及基于ResNet-18的DeepLabV3从73.20提高到76.02 mIoU。我们的代码可在https://github.com/yzd-v/MGD上获得。

关键词:知识蒸馏,图像分类,目标检测,语义分割,实例分割

介绍

深度卷积神经网络(cnn)已广泛应用于各种计算机视觉任务中。一般来说,较大的模型具有较好的性能,但推理速度较低,难以在有限的源下部署。为了克服这一问题,知识蒸馏被提出[18]。按蒸馏类型可分为两种。第一种是专门为不同的任务而设计的,例如用于分类的基于logit的蒸馏[18,40]和用于检测的基于head的蒸馏[10,39]。第二种是基于特征的蒸馏[28,17,4]。由于各种网络之间只有头部或投影仪后的特征是不同的,从理论上讲,基于特征的蒸馏方法可以可用于各种任务。然而,为特定任务设计的蒸馏方法通常不适用于其他任务。例如,OFD[17]和KR[4]对探测器的改进有限。FKD[37]和FGD[35]是专门为探测器设计的,由于缺乏颈部,无法用于其他任务。

以往基于特征的提炼方法,由于教师的特征具有更强的表征能力,通常会让学生尽可能地模仿教师的输出。然而,我们认为没有必要直接模仿老师来提高学生特征的表征能力。用于蒸馏的特征通常是通过深度网络获得的高阶语义信息。特征像素已经在一定程度上包含了相邻像素的信息。因此,如果我们可以使用部分像素通过一个简单的块来还原教师的全部特征,那么这些使用的像素的代表性也可以得到提高。从这个角度出发,我们提出了一种简单有效的基于特征的蒸馏方法——掩膜生成蒸馏(mask Generative Distillation, MGD)。

如图2所示,我们首先对学生特征的随机像素进行遮罩,然后通过一个简单的块将遮罩后的特征生成教师的完整特征。由于每次迭代都使用随机像素,因此在整个训练过程中都会使用所有像素,这意味着特征将更加鲁棒,并且其表示能力将得到提高。在我们的方法中,老师只是引导学生还原特征,并不要求学生直接模仿

为了验证我们的假设,即在不直接模仿教师的情况下,掩蔽特征生成可以提高学生的特征表征能力,我们从学生和教师的颈部对特征注意力进行了可视化。如图1所示,学生和教师的特征有很大的不同。

与教师相比,学生特征的背景有更高的反应。教师的mAP也显著高于学生,为41.0比37.4。采用最先进的蒸馏方法FGD蒸馏后[35],迫使学生用心模仿老师的特征,学生的特征与老师的特征更加相似,mAP很大提高到40.7。而经过MGD培训后,学生与教师的特征仍有显著差异,但学生对背景的反应却大大降低。令我们惊讶的是,该学生的成绩超过了FGD,甚至达到了与老师相同的mAP。这也说明用MGD训练可以提高学生特征的表征能力。此外,我们还在图像分类和密集预测任务上做了大量的实验。结果表明,MGD对图像分类、目标检测、语义分割和实例分割等任务都有较大的改善。MGD还可以与其他基于logit或基于head的蒸馏方法相结合,以获得更大的性能收益。综上所述,本文的贡献有:

1. 我们提出了一种新的基于特征的知识提炼方法,使学生利用被掩盖的特征来生成教师的特征,而不是直接模仿教师的特征。

2. 本文提出了一种新的基于特征的蒸馏方法——掩膜生成蒸馏,该方法简单易用,只需要两个超参数。

3. 我们通过在不同数据集上的大量实验验证了我们的方法在各种模型上的有效性。对于图像分类和密集预测任务,学生在MGD的帮助下都取得了显著的进步。

相关工作

面向分类的知识蒸馏

知识蒸馏最早是由Hinton等人[18]提出的,其中学生受到来自教师最后一个线性层的标签和软标签的监督。

然而,除了logit之外,更多的蒸馏方法是基于特征映射的。FitNet[28]从中间层提取语义信息。AT[36]总结了跨渠道维度的价值,并将注意力知识转移给学生。OFD[17]提出了余量ReLU,并设计了一个测量蒸馏距离的新函数。CRD[30]利用对比学习将知识传递给学生。最近,KR[4]建立了一个审查机制,并利用多层次信息进行蒸馏。SRRL[33]将表示学习和分类解耦,利用老师的分类器来训练学生的倒数第二层特征。WSLD[40]从偏方差权衡的角度提出了加权的蒸馏软标签。

面向语义分割的知识蒸馏

Liu等人[23]提出了成对和整体蒸馏,在学生和教师的输出之间执行成对和高阶一致性。他等人[16]将教师网络的输出重新解释为一个重新表示的潜在域,并从教师网络中捕获长期依赖关系。CWD[29]最小化了概率图之间的Kullback-Leibler (KL)散度,该散度是通过对每个通道的激活图进行归一化计算得到的。

方法

对于不同的任务,模型的体系结构差别很大。此外,大多数蒸馏方法都是为特定任务而设计的。然而,基于特征的精馏可以同时应用于分类和密集预测。特征蒸馏的基本方法可表述为:

式中,F^TF^S分别表示教师和学生的特征,f_{align}是将学生的特征F^S与教师的特征F^T

对齐的适应层。C, H, W表示特征映射的形状。

这种方法有助于学生直接模仿老师的特征。然而,我们提出了掩蔽生成蒸馏(MGD),其目的是迫使学生产生教师的特征,而不是模仿它,给学生带来了分类和密集预测方面的显着改善。MGD的架构如图2所示,我们将在本节中专门介绍它。

带掩盖特征的生成

对于基于cnn的模型,更深层的特征具有更大的接受域和更好的原始输入图像表征。换句话说,特性的图像素在一定程度上已经包含了相邻像素的信息。

因此,我们可以使用部分像素来恢复完整的特征映射。我们的方法旨在通过学生的遮罩特征来生成老师的特征(通过遮盖学生部分的特征来生成老师的特征),这样可以帮助学生获得更好的表现。
我们用T^l \in R^{C \times H \times W},S^l \in R^{C \times H \times W}(l=1,...,L)表示分别为教师和学生的第l个特征图。首先我们设置第l个随机掩码来覆盖学生的第l个特征,可以表示为:

其中R_{i,j}^l为(0,1)中的随机数,i,j分别为特征图的横坐标和纵坐标。λ是表示掩码比的超参数。第
l个特征映射被第l个随机掩码覆盖。

然后我们使用相应的掩码覆盖学生的特征图,并尝试用左边的像素生成教师的特征图,可以表示为:

G表示投影层,包括两个卷积层:W_{l1}W_{l2},一个激活层ReLU。在本文中,我们采用1×1卷积层为适配层f_{align}, 3×3为投影层W_{l1}W_{l2}的卷积层。

根据该方法,我们设计了MGD的蒸馏损失L_{dis}:

其中L为蒸馏层数和,C、H、W为特征映射的形状。S和T分别表示学生和教师的特征。

总体损失

利用提出的MGD蒸馏损失L_{dis},我们用总损失训练所有模型如下:

其中L_{original}为所有任务中模型的原始损失,α为平衡损失的超参数。

MGD是一种简单有效的蒸馏方法,可方便地应用于各种任务。算法1总结了我们的方法的过程。

方法过程汇总

算法1:带掩盖的生成式蒸馏

输入:老师:T,学生:S,输入:x,标签:y,超参数:\alpha,\lambda

1:使用S得到输入x的特征fea^S和输出\hat{y}

2:使用T得到输入x的特征fea^T

3:计算模型的原始损失:L_{original}(\hat{y},y)

4:计算公式5中的蒸馏损失:L_{dis}(fea^S, fea^T)

5:使用L_{all}=L_{original}+\alpha*L_{dis}更新S

输出:S

主要实验

MGD是一种基于特征的蒸馏,可以很容易地应用于各种任务的不同模型。在本文中,我们对分类、目标检测、语义分割和实例分割等任务进行了实验。我们针对不同的任务使用不同的模型和数据集进行实验,所有模型都通过MGD获得了出色的改进。

分类

数据集

对于分类任务,我们在包含1000个对象类别的ImageNet[11]上评估我们的知识蒸馏方法。我们用120万张图片进行训练,用5万张图片进行测试,完成所有的分类实验。我们用准确性来评价模型。

实现细节

对于分类任务,我们计算来自主干的最后一个特征映射的蒸馏损失。有关消融的研究见5.5节。MGD使用超参数α来平衡方程6中的蒸馏损失。另一个超参数λ用于调整公式2中的屏蔽比。所有分类实验均采用超参数{α = 7 × 10^(−5),λ = 0.5}。我们使用SGD优化器训练所有模型100个epoch,其中动量为0.9,权重衰减为0.0001。我们初始化学习率为0.1,并每30次衰减一次。此设置基于8个gpu。实验采用基于Pytorch[26]的MMClassification[6]和MMRazor[7]进行。

分类结果

我们用两种常用的蒸馏设置进行实验,包括均相蒸馏和非均相蒸馏。

第一个蒸馏设置是从ResNet-34[15]到ResNet-18,另一个设置是从ResNet-50到MobileNet[19]。如表1所示,我们比较了各种知识蒸馏方法[18、36、17、25、30、4、40、33],包括基于特征的方法、基于逻辑的方法和结合的方法。使用我们的方法,学生ResNet-18和MobileNet的Top-1准确率分别提高了1.68和3.14。此外,如上所述,MGD只需要计算特征图上的蒸馏损失,并且可以与其他基于逻辑的图像分类方法相结合。因此,我们尝试在WSLD中加入基于logit的蒸馏损失[40]。这样,两位同学的Top-1准确率分别达到了71.80和72.59,分别提高了0.22和0.24。

表1。不同蒸馏方法在ImageNet数据集上的结果。T和S分别表示老师和学生。

目标检测和实例分割

数据集

我们在COCO2017数据集[22]上进行实验,该数据集包含80个对象类别。我们使用120k的训练图像进行训练,5k的val图像进行测试。用平均精度对模型的性能进行了评价。

表2。不同蒸馏方法在COCO上的目标检测结果。

实现细节

我们从颈部计算所有特征映射的蒸馏损失。我们采用超参数{α = 2 × 10^(−5),λ = 0.65}对所有的单阶段模型,{α = 5 × 10^(−7),λ = 0.45}对所有的两阶段模型。我们使用SGD优化器训练所有模型,其中动量为0.9,权重衰减为0.0001。除非特别说明,否则我们训练模型为24个epoch。我们使用继承策略[20,35],用教师的颈部和头部参数初始化学生,在头部结构相同的情况下训练学生。实验采用MMDetection进行[2]。

目标检测和实例分割结果

对于目标检测,我们在三种不同类型的检测器上进行了实验,包括两级检测器(Faster RCNN[27])、基于锚点的一级检测器(RetinaNet[21])和无锚点的一级检测器(RepPoints[34])。

我们将MGD与最近三种最先进的检测器蒸馏方法进行比较[37,29,35]。以分割为例,我们在SOLO[32]和Mask RCNN[14]两个模型上进行了实验。如表2和表3所示,我们的方法在两种目标检测和实例分割方面都优于其他最先进的方法。学生在MGD的帮助下获得了显著的AP改善,例如基于ResNet-50的retanet和SOLO在COCO数据集上分别获得了3.6个Boundingbox mAP和3.1个Mask mAP的改善。

表3。不同蒸馏方法对实例分割的结果。MS的意思是多尺度训练。这里的AP指掩模AP。

参考资料

论文下载(ECCV 2区 2022)

https://arxiv.org/pdf/2205.01529.pdf

📎Masked Generative Distillation.pdf

代码地址

GitHub - yzd-v/MGD: Masked Generative Distillation (ECCV 2022)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/391916.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

多模态学习综述(MultiModal Learning)

最早开始关注到多模态机器学习是看到Jeff Dean在2019年年底NeurIPS大会上的一个采访报道,讲到了2020年机器学习趋势:多任务和多模态学习将成为突破口。 Jeff Dean 谈2020年机器学习趋势:多任务和多模式学习将成为突破口 站在2022年&#xff…

HMI界面:感官与体验俱佳的智能家居界面分享

Hello,我是大千UI工场,本期分享HMI人机交互界面在智能家居领域的案例,关注大千,学习N多UI干货,有设计需求,可以联络。 设计感官和体验俱佳智能家居的UI界面时,可以考虑以下几个方面:…

一起学量化之Aroon指标

Aroon指标是由Tushar Chande于1995年开发的技术分析工具,旨在识别股票是否处于趋势中及趋势的强度。它通过分析股票价格在一定周期内创下的新高和新低来预测趋势的变化,这基于一种观念:强势趋势通常伴随着频繁的新高或新低。 1. Aroon指标的组成 Aroon指标由两个部分组成:…

【Linux内核】从0开始入门Linux Kernel源码

🌈 博客个人主页:Chris在Coding 🎥 本文所属专栏:[Linux内核] ❤️ 前置学习专栏:[Linux学习]从0到1 ⏰ 我们仍在旅途 ​ 目录 …

【测试】自动化

目 录 一.什么是自动化二.自动化测试分类三.selenium工具(web自动化测试工具)四.环境部署五.什么是驱动?1.常见的元素操作2.窗口3.执行脚本4.等待5.浏览器的操作6.弹窗7.选择器8.文件上传9.浏览器参数 一.什么是自动化 有效的减少人力的消耗&#xff0…

基于PPNSA+扰动算子的车间调度最优化matlab仿真,可以任意调整工件数和机器数,输出甘特图

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于PPNSA扰动算子的车间调度最优化matlab仿真,可以任意调整工件数和机器数,输出甘特图和优化收敛曲线。 2.测试软件版本以及运行结果展示 MATLAB2022a版本运行…

[力扣 Hot100]Day27 合并两个有序链表

题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 出处 思路 简单题,两个指针就能解决。 代码 class Solution { public:ListNode* mergeTwoLists(ListNode* list1, ListNode* list2) {if(!list1)…

蓝桥杯嵌入式学习记录——按键的使用

目录 一、按键原理简介 二、cubeMX的配置 三、按键的短按代码 四、按键的长按代码 一、按键原理简介 在STM32中,按键连接通常使用GPIO(通用输入/输出)端口来实现。当按键未被按下时,GPIO端口处于高电平状态(即1&am…

CSS 评分器星星效果

<template><view class="rating"><!-- 5颗星 --><input value="5" name="rating" id="star5" type="radio"><label for="star5"></label><!-- 4颗星 --><input val…

每日一练:LeeCode-530、二叉搜索树的最小绝对差【二叉搜索树+pre辅助节点+DFS】

本文是力扣LeeCode-530、二叉搜索树的最小绝对差【二叉搜索树pre辅助节点DFS】 学习与理解过程&#xff0c;本文仅做学习之用&#xff0c;对本题感兴趣的小伙伴可以出门左拐LeeCode。 给你一个二叉搜索树的根节点 root &#xff0c;返回 树中任意两不同节点值之间的最小差值 。…

Docker容器运行

1、通过--name参数显示地为容器命名&#xff0c;例如:docker run --name “my_http_server” -d httpd 2、容器重命名可以使用docker rename。 3、两种进入容器的方法&#xff1a; 3.1、Docker attach 例如&#xff1a; 每间隔一秒打印”Hello World”。 Sudo docker run…

杂题——字符串——试题 算法训练 二元函数

分析&#xff1a; 关键在于&#xff0c;如果处理输入的字符串成表达式字符串分三种情况&#xff1a; 如果 S 中只包含一个整数&#xff0c;则该整数就是表达式 S 的值&#xff1b;如果 S 中包含 f 函数&#xff0c;则递归计算 f 函数的参数&#xff0c;并将计算结果代入 f 函数…

第三百五十一回

文章目录 1. 概念介绍2. 获取方法3. 示例代码4. 对比与总结4.1 横向对比4.2 内容总结 我们在上一章回中介绍了"如何获取当前系统语言"相关的内容&#xff0c;本章回中将介绍获取当前时区.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们使用的北京…

线程池 ThreadPool

文章目录 线程池一、线程池概述1、什么是线程池&#xff1f;2、为什么需要线程池&#xff1f;3、线程池的优势4、基本原理 二、线程池相关接口与方法1、Executor2、ExecutorService3、ScheduledExecutorService4、Runnable & Callable5、Future & FutureTask6、execute…

永久内核映射

内存就像墙上的气球&#xff0c;32体系就好比是小孩&#xff0c;64位体系好比是大人。对于位置比较低的气球&#xff0c;抬抬脚就可以够到&#xff0c;这些气球相当于DMA内存(可用于设备直接内存访问)&#xff1b;位置再高一点的气球&#xff0c;小孩伸手可以够到&#xff0c;这…

AutoKeras(Python自动化机器学习)多模态数据和多任务

要点拓扑 AutoKeras 拓扑 要点 常规机器学习&#xff1a;scikit-learn示例探索性数据分析和数据预处理&#xff0c;线性回归&#xff0c;决策树图像分类ResNet模型示例&#xff0c;合成数据集DenseNet模型示例绘图线性回归和决策树模型使用Python工具seaborn、matplotlib、pan…

import tensorflow_hub报错

问题&#xff1a; 导入tensorflow_hub报ModuleNotFoundError: No module named ‘tensorflow.python.checkpoint’ 解决&#xff1a; tensorflow-estimator版本不对 和tensorflow&#xff08;2.6.0&#xff09;版本一致 。 pip install -U tensorflow-estimator2.6.0 验证&a…

一个收集了大量的C#/.NET/.NET Core项目宝库组织

项目宝库介绍 为.NET开发者提供一个寻找优秀C#/.NET/.NET Core项目和框架的入口&#xff0c;通过了解和对比更多的项目和框架来选择最适合我们自己学习、工作开发的一套项目或者框架。优秀的项目不应该被埋没&#xff0c;欢迎大家一起加入这个组织共同完善、发展.NET社区&…

线程和进程【并发和并行、线程上下文切换、线程的状态】

线程和进程【并发和并行、线程上下文切换、线程的状态】 什么是并发与并行&#xff1f;什么是线程上下文切换&#xff1f;线程状态&#xff1a;一个线程的一生 转自 极客时间 进程&#xff1a;是指内存中运行的一个应用程序&#xff0c;每个进程都有自己独立的内存空间&#x…

RapidMiner数据挖掘2 —— 初识RapidMiner

本节由一系列练习与问题组成&#xff0c;这些练习与问题有助于理解多个基本概念。它侧重于各种特定步骤&#xff0c;以进行直接的探索性数据分析。因此&#xff0c;其主要目标是测试一些检查初步数据特征的方法。大多数练习都是关于图表技术&#xff0c;通常用于数据挖掘。 为此…