「AI模型瘦身术」——知识蒸馏技术综述

使用KD原因

遇到问题:从产业发展的角度来看工业化将逐渐过渡到智能化,边缘计算逐渐兴起预示着 AI 将逐渐与小型化智能化的设备深度融合,这也要求模型更加的便捷、高效、轻量以适应这些设备的部署。

解决方案:知识蒸馏技术

知识蒸馏的关键点

如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习最根本的目的在于训练出在某个问题上泛化能力强的模型。

泛化能力强: 在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。

而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。

而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。

一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“soft target”。

KD的训练过程和传统的训练过程的对比

传统training过程(hard targets): 对ground truth求极大似然

KD的training过程(soft targets): 用large model的class probabilities作为soft targets

KD的训练过程为什么更有效?

softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

【举个例子】

在手写体数字识别任务MNIST中,输出类别有10个。

假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。

这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。 下图为知识蒸馏的通用形式。

知识传递形式

原始知识蒸馏(Vanilla Knowledge Distillation)仅仅是从教师模型输出的软目标中学习出轻量级的学生模型。

然而,当教师模型变得更深时,仅仅学习软目标是不够的。

因此,我们不仅需要获取教师模型输出的知识,还需要学习隐含在教师模型中的其它知识,比如有输出特征知识、中间特征知识、关系特征知识和结构特征知识。

标签知识是神经网络对样本数据最终的预测输出中包含的潜在信息,这也是目前蒸馏过程中最简单、应用最多的方式。

标签知识(输出特征知识)通常指的是教师模型的最后一层特征,主要包括逻辑单元和软目标的知识。标签知识(输出特征知识)知识蒸馏的主要思想是促使学生能够学习到教师模型的最终预测,以达到和教师模型一样的预测性能。

原始知识蒸馏是针对分类任务来提出的仅包含类间相似性的软目标知识,然而其它任务(如目标检测)网络最后一层特征输出中还可能包含有目标定位的信息。

换句话说,不同任务教师模型的最后一层输出特征是不一样的。因此,本文根据任务的不同对输 出特征知识分别进行归纳和分析,如表 1 所示。

Hinton 等人最早提出的知识蒸馏方法就属于目标分类的标签知识(输出特征知识)。由于经过“蒸馏温度”调节后的软标签中具有很多不确定信息,通常的研究认为这其中反映了样本间的相似度或干扰性、样本预测的难度,因此标签知识又被称为“暗知识”。

  • 为了有效地解决基于聚类的算法中的伪标签噪声的问题,Ge等人[45]利用“同步平均教学”的蒸馏框架进行伪标签优化,核心思想是利用更为鲁棒的“软”标签对伪标签进行在线优化。

  • MLP[46]提出了基于元学习(Meta - learning)自适应生成目标分布的方法,用于教师和学生模型的伪标签学习过程.利用一个筛选网络从目标检测模型预测的伪标签中区分出正例和负例,将正例用于下一阶段的半监督自训练过程,可以有效提升标签数据的利用率[43]。

  • Xie等人[4]利用有监督训练学生模型自身,在自蒸馏训练中额外地引入无标签噪声数据产生伪标签,将ImageNet的Top-1识别结果提高了约1%.对于标签知识蒸馏方法本身,已经有非常多的变体和应用,主要是从改进蒸馏过程、挖掘标签信息、去除干扰等方面,提升学生模型的性能.

  • Gao等人[47]实现了一种简单的逐阶段的标签蒸馏训练过程,在梯度下降训练过程中,每次只更新学生网络的一个模块,从前至后直到全部更新完成。

  • 根据Mirzadeh等人[48]的研究发现,并不是教师模型性能越高对于学生模型的学习越有利,当教师-学生模型之间的差距过大时,会导致学生难以从教师模型获得提升.为此,他们提出使用辅助教师策略来逐渐缩小教师和学生之间的学习差距,取得更好的蒸馏效果.

  • 同样是为了缩小教师 - 学生之间的学习差距,Yang等人[49]则提出利用教师模型在每个训练周期更新的中间模型产生的标签知识指导学生模型.为了充分挖掘标签信息、去除干扰,Müller等人[50]采用了子类别蒸馏方法,将原标签分组合并参与软标签蒸馏学习;

  • 文献[51]则研究了蒸馏损失函数对犔2范数和归一化的软标签的作用,提出使用球面空间度量蒸馏的方法去除范数的影响;

  • Zhang等人[52]关注了样本权重的影响,通过预测不确定性自适应分配样本权重,改善蒸馏过程;

  • Wu等人[53]提出了同伴协同蒸馏,通过训练多个分支网络并将其他训练较强教师的 logits 知识转移给同伴,有利于模型的稳定和提高蒸馏的质量。

最早使用教师模型中间特征知识的是 FitNets[27],其主要思想是促使学生的隐含层能预测出与教师隐含层相近的输出。

知识传递方式中有同构蒸馏和异构蒸馏,主要就是区分 是否:教师和学生模型的架构相似或属于同一系列的、层与层(Layer -to - Layer)或块与块(Block - to - Block)之间一一对应;不过通过这几年的实验来看,这并没有什么区别

不同知识传递形式的效果

如图所示,不同的知识传递形式,相比是有差异的,使用经典的KD标签知识是还不错的;使用特征间的,有较多都不如开山鼻祖KD;不过近期又有更多优化,比如使用互信息与对比学习的方法;

温度的特点

在回答这个问题之前,先讨论一下温度T的特点

  1. 原始的softmax函数是 𝑇=1 时的特例, 𝑇<1 时,概率分布比原始更“陡峭”, 𝑇1 时,概率分布比原始更“平缓”。

  2. 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i) 𝑇=∞ , 此时softmax的值是平均分布的;(ii) 𝑇→0,此时softmax的值就相当于 𝑎𝑟𝑔𝑚𝑎𝑥 , 即最大的概率处的值趋近于1,而其他值趋近于0)

  3. 不管温度T怎么取值,Soft target都有忽略相对较小的 𝑝𝑖 携带的信息的倾向

温度代表了什么,如何选取合适的温度?

温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:

  1. 从有部分信息量的负标签中学习 --> 温度要高一些

  2. 防止受负标签中噪声的影响 -->温度要低一些

总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)

CRD 对比学习

首先 CRD是2020年提出的新模式的蒸馏方法,使用对比学习,在这年对比了12个KD方法都是最好的,其中,CRD+KD两个方法合在一起更好,相当于两个维度的知识传递的监督,在2023年有基于CRD实现的CRCD,效果好一点,方案是差不多的;

知识提炼(KD)将知识从一个深度学习模型(教师)转移到另一个深度学习模型(学生)。Hinton等人(2015)最初提出的目标是将教师和学生输出之间的KL差异最小化。当输出是一个分布,例如类上的概率质量函数时,该公式具有直观意义。然而,我们通常希望传递有关representation的知识。例如,在“跨模态蒸馏”问题中,我们可能希望将图像处理网络的表示转移到声音(Aytar等人,2016)或深度(Gupta等人,2016)处理网络,这样图像的深度特征和相关的声音或深度特征高度相关。在这种情况下,KL发散是不确定的。

表征知识是结构化的——维度表现出复杂的相互依赖性。最初的KD目标(Hinton等人,2015年)将所有维度视为独立的,以输入为条件。让yT成为老师的输出,yS成为学生的输出。那么原始的KD目标函数ψ具有全因子形式:. 这种带因素的目标不足以传递结构知识,即输出维度i和j之间的依赖关系。这与图像生成中的情况类似,在图像生成中,由于输出维度之间的独立性假设,L2目标会产生模糊的结果。

为了克服这个问题,我们想要一个目标,捕捉相关性和高阶输出依赖性。为了实现这一点,在本文中,我们利用了对比目标家族(Gutmann&Hyvärinen,2010;Oord等人,2018;Arora等人,2019;Hjelm等人,2018)。近年来,这些目标函数已成功地用于密度估计和表征学习,尤其是在自我监督环境中。在这里,我们让他们适应从一个深层网络到另一个深层网络的知识蒸馏任务。我们表明,致力于研究表现空间很重要,类似于最近的工作,如Zagoruyko和Komodakis(2016a);Remero等人(2014年)。然而,请注意,这些工作中使用的损失函数并没有明确尝试捕捉表征空间中的相关性或高阶相关性。

图1:我们考虑的三种提取设置:(a)压缩模型,(b)将知识从一种模式(例如RGB)转移到另一种模式(例如深度),(c)将网络集合提取到单个网络中。对比目标鼓励教师和学生将相同的输入映射到接近的表示(在某些度量空间中),并将不同的输入映射到遥远的表示,如阴影圈所示。

我们的目标是最大化教师和学生之间的互信息的下限。我们发现,这会在多个知识转移任务中产生更好的表现。我们推测,这是因为对比目标能更好地传递教师表征中的所有信息,而不仅仅是传递关于条件独立输出类概率的知识。有些令人惊讶的是,对比目标甚至改善了最初提出的提取类概率知识的任务的结果,例如,将大型CIFAR100网络压缩为性能几乎相同的较小网络。我们认为这是因为不同类别概率之间的相关性包含有用的信息,可以规范学习问题。我们的论文在两个主要独立发展的文献之间建立了联系:知识蒸馏和表征学习。这种联系使我们能够利用表征学习的强大方法,显著提高知识蒸馏的SOTA。

我们的贡献是:

1.基于对比的目标,用于在深度网络之间传递知识。

2.模型压缩、跨模态传输和整体蒸馏的应用。

3.对标12种最新蒸馏方法;CRD优于所有其他方法,例如,与原始KD相比,平均相对改善57%(Hinton等人,2015),令人惊讶的是,后者的表现次之。

这是近几年的得分,有使用crd结合其他损失的,可以在一些任务中得到较好表现,不同任务表现不一致,

多教师蒸馏

多教师蒸馏(Multi-Teacher Distillation)是一种知识蒸馏的方法,它通过同时蒸馏多个教师网络的知识来提升学生网络的性能。相比于传统的单一教师蒸馏,多教师蒸馏可以利用不同教师网络的多样性和丰富性,从而获得更全面的知识传递。

在多教师蒸馏中,通常包括一个学生网络(Student Network)和多个教师网络(Teacher Networks)。每个教师网络都是一个独立的模型,具有不同的架构或参数初始化。学生网络通过同时学习多个教师网络的知识来提高自己的性能。

多教师蒸馏的核心思想是将不同教师网络的预测结果作为辅助目标来训练学生网络。具体而言,多教师蒸馏包括以下步骤:

1、教师网络的训练:针对不同的教师网络,使用标准的监督学习方法进行训练,以获得具有丰富知识的教师模型。

2、教师网络的预测:使用已训练好的教师网络对输入样本进行预测,得到多个教师网络的预测结果。

3、学生网络的训练:将教师网络的预测结果作为辅助目标,与真实标签一起用于训练学生网络。通过最小化学生网络的预测与教师网络预测之间的差异,将教师网络的知识传递给学生网络。

4、蒸馏损失函数的定义:通常使用交叉熵损失函数来衡量学生网络的分类性能。同时,为了传递教师网络的知识,可以定义额外的辅助目标损失,如平均软标签损失(Mean Soft Labels Loss)或特定的蒸馏损失函数。

通过多教师蒸馏,学生网络能够从多个教师网络中获得更丰富的知识,并综合各个教师网络的预测结果来提高自己的性能。多教师蒸馏可以增强模型的泛化能力,减少过拟合问题,并在复杂任务中取得更好的性能表现。

好,接下来我们从源码分析;

蒸馏算法源码分析

KD

链接:https://arxiv.org/pdf/1503.02531.pd3f

发表:NIPS14

class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""
    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T #教师模型指导学生模型的程度(蒸馏温度),值越大,指导程度越高
 
    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s/self.T, dim=1)
        p_t = F.softmax(y_t/self.T, dim=1)
        
        #下面就是对两个模型的预测值,做KL散度的分布分析,如果偏差越大,则kl散度算出来的值越大。
        #p_t表示教师模型的目标值
        #p_s表示学生模型的预测值
        loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
        return loss

核心就是一个kl_div函数,用于计算学生网络和教师网络的分布差异。输入为学生和教师模型的分类输出,经过温度可控的软化之后进行KL散度计算,简单直接粗暴有效;

FitNet

全称:Fitnets: hints for thin deep nets

链接:https://arxiv.org/pdf/1412.6550.pdf

发表:ICLR 15 Poster

很容易理解,方法使用特征间信息,对中间层进行蒸馏的开山之作,通过将学生网络的feature map扩展到与教师网络的feature map相同尺寸以后,使用均方误差MSE Loss来衡量两者差异

(1)大模型训练,小模型随机初始化

(2)将大模型特征提取器的第H层作为hint,从第一层到第H层的参数对应图(a)中Whint,,选择小模型特征提取器的第G层作为guided,从第一层到第G层对应图(a)中Wguided

(3)两者feature map大小可能不匹配,引入卷积层调整器(Wr)对guided层进行调整,对应图(b)

(4)优化均方损失函数

(5)对预训练好的小模型进行进一步知识蒸馏,对应图

 
class HintLoss(nn.Module):
    """Fitnets: hints for thin deep nets, ICLR 2015"""
    
    def __init__(self):
        super(HintLoss, self).__init__()
        self.crit = nn.MSELoss()  # 在这个类中,初始化函数中使用了nn.MSELoss(),即均方误差损失函数,
用于度量学生网络和教师网络之间的均方误差
 
'''
在前向传播函数中,接收学生网络的中间层表示f_s和教师网络的中间层表示f_t作为输入。
然后使用均方误差损失函数计算它们之间的差异,得到"hint"损失。
'''


    def forward(self, f_s, f_t):
        loss = self.crit(f_s, f_t)
        return loss
class ConvReg(nn.Module):
    """Convolutional regression for FitNet 用来对齐T-S某层feature map的特征尺寸 可学"""
    def __init__(self, s_shape, t_shape, use_relu=True):
        super(ConvReg, self).__init__()
        self.use_relu = use_relu
        s_N, s_C, s_H, s_W = s_shape
        t_N, t_C, t_H, t_W = t_shape
        if s_H == 2 * t_H:
            self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)
        elif s_H * 2 == t_H:
            self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)
        elif s_H >= t_H:
            self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))
        else:
            raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))
        self.bn = nn.BatchNorm2d(t_C)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.use_relu:
            return self.relu(self.bn(x))
        else:
            return self.bn(x)

损失计算时,就先使用guided 网络处理完,送进fitloss算一次mse即可;

Fitloss 使用的特征维度做监督,效果没有kd好,可能是由于mse或者特征的提取选择不好,可以考虑多使用几个维度的特征监督;

PKT:Probabilistic Knowledge Transfer

全称:Probabilistic Knowledge Transfer for deep representation learning

链接:https://arxiv.org/abs/1803.10837

发表:CoRR18

提出一种概率知识转移方法,引入了互信息来进行建模。该方法具有可跨模态知识转移、无需考虑任务类型、可将手工特征融入网络等的优点。

 

class PKT(nn.Module):
    """Probabilistic Knowledge Transfer for deep representation learning
    Code from author: https://github.com/passalis/probabilistic_kt"""
    def __init__(self):
        super(PKT, self).__init__()
 
    def forward(self, f_s, f_t):
        return self.cosine_similarity_loss(f_s, f_t)
 
    @staticmethod
    def cosine_similarity_loss(output_net, target_net, eps=0.0000001):
        # Normalize each vector by its norm
        output_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))
        output_net = output_net / (output_net_norm + eps)
        output_net[output_net != output_net] = 0
 
        target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))
        target_net = target_net / (target_net_norm + eps)
        target_net[target_net != target_net] = 0
 
        # Calculate the cosine similarity
        model_similarity = torch.mm(output_net, output_net.transpose(0, 1))
        target_similarity = torch.mm(target_net, target_net.transpose(0, 1))
 
        # Scale cosine similarity to 0..1
        model_similarity = (model_similarity + 1.0) / 2.0
        target_similarity = (target_similarity + 1.0) / 2.0
 
        # Transform them into probabilities
        model_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)
        target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)
 
        # Calculate the KL-divergence
        loss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))
 
        return loss

这和PKT方法效果比KD好一些,主要是使用了概率传递学习先将教师和学生的网络输出进行标准化,再将输出的特征信息使用矩阵乘法、概率化方法映射到另一个空间,最后进行KL散度计算,就是在KD的基础上,将网络输出进行非线性映射成一个更简单的空间,监督这个空间下的S-T KL散度

CRD: Contrastive Representation Distillation

全称:Contrastive Representation Distillation

链接:https://arxiv.org/abs/1910.10699v2

发表:ICLR20

将对比学习引入知识蒸馏中,其目标修正为:学习一个表征,让正样本对的教师网络与学生网络尽可能接近,负样本对教师网络与学生网络尽可能远离。

构建的对比学习问题表示如下:

整体的蒸馏Loss表示如下:

实现如下:https://github.com/HobbitLong/RepDistiller

class ContrastLoss(nn.Module):
    """
    contrastive loss, corresponding to Eq (18)
    """
    def __init__(self, n_data):
        super(ContrastLoss, self).__init__()
        self.n_data = n_data
 
    def forward(self, x):
        bsz = x.shape[0]
        m = x.size(1) - 1
 
        # noise distribution
        Pn = 1 / float(self.n_data)
 
        # loss for positive pair
        P_pos = x.select(1, 0)
        log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()
 
        # loss for K negative pair
        P_neg = x.narrow(1, 1, m)
        log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()
 
        loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz
 
        return loss
        
class CRDLoss(nn.Module):
    """CRD Loss function
    includes two symmetric parts:
    (a) using teacher as anchor, choose positive and negatives over the student side
    (b) using student as anchor, choose positive and negatives over the teacher side
    Args:
        opt.s_dim: the dimension of student's feature
        opt.t_dim: the dimension of teacher's feature
        opt.feat_dim: the dimension of the projection space
        opt.nce_k: number of negatives paired with each positive
        opt.nce_t: the temperature
        opt.nce_m: the momentum for updating the memory buffer
        opt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim
    """
    def __init__(self, opt):
        super(CRDLoss, self).__init__()
        self.embed_s = Embed(opt.s_dim, opt.feat_dim)
        self.embed_t = Embed(opt.t_dim, opt.feat_dim)
        self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)
        self.criterion_t = ContrastLoss(opt.n_data)
        self.criterion_s = ContrastLoss(opt.n_data)
 
    def forward(self, f_s, f_t, idx, contrast_idx=None):
        """
        Args:
            f_s: the feature of student network, size [batch_size, s_dim]
            f_t: the feature of teacher network, size [batch_size, t_dim]
            idx: the indices of these positive samples in the dataset, size [batch_size]
            contrast_idx: the indices of negative samples, size [batch_size, nce_k]
        Returns:
            The contrastive loss
        """
        f_s = self.embed_s(f_s)
        f_t = self.embed_t(f_t)
        out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)
        s_loss = self.criterion_s(out_s)
        t_loss = self.criterion_t(out_t)
        loss = s_loss + t_loss
        return loss
 

他会在训练过程中,使用contrast-memory 来记忆网络的负样本,在网络训练中互信息监督;效果不错;

超分等生成任务与蒸馏

众所周知,图像/视频超分 (SR) 是工业界非常具有应用场景的应用,但能够生产具有良好视觉效果的重建图像的SR模型的参数量和运算量都非常巨大,比如业界公认的优秀baseline模型EDSR,EDVR等的算力需求高达几百,几千GFLOPs。而业界真正需求的轻量化模型,尤其是可以部署于移动端设备的实时模型,其算力限制可能严苛到小于10GFlops。

在high-level CV tasks上得到广泛应用和验证的模型剪枝、c馏方法应用到超分任务上,即将一个训练好的大模型进行裁剪,或者用性能较强的教师大模型蒸馏原本较弱的学生小模型,使裁剪/蒸馏后的小模型能够取得相比普通训练方式更好,甚至接近原先大模型的性能。这里的challenge在于,直接的迁移应用这些算法,在超分任务上无法得到有效的性能提升,甚至可能导致非常严重的performance degradation.

  • SRKD:它将最基本的知识蒸馏直接应用到图像超分中,整体思想分类网络中的蒸馏方式基本一致,整体来看属于应用形式;

  • FAKD:它在常规知识蒸馏的基础上引入了特征关联机制,进一步提升被蒸馏所得学生网络的性能,相比直接应用有了一定程度的提升;

  • PISR:它则是利用了广义蒸馏的思想进行超分网络的蒸馏,通过充分利用训练过程中HR信息的可获取性进一步提升学生网络的性能。

上图给出了SRKD的蒸馏示意图,它采用了最基本的知识蒸馏思想对老师网络与学生网络的不同阶段特征进行蒸馏。考虑到老师网络与学生网络的通道数可能是不相同的,SRKD则是对中间特征的统计信息进行监督。该文考虑了如下四种统计信息:

owards Compact Single Image Super-Resolution via Contrastive Self-distillation

链接:

code:GitHub - Booooooooooo/CSD: Towards Compact Single Image Super-Resolution via Contrastive Self-distillation, IJCAI21

发表:IJCAI21

团队:Yonsei University

1.背景

卷积神经网络在超分任务上取得了很好的成果,但是依然存在着参数繁重、显存占用大、计算量大的问题,为了解决这些问题,作者提出利用对比自蒸馏实现超分模型的压缩和加速。

我们的目标是同时压缩和加速SR模型。我们提出了一个简单的自蒸馏框架,其中学生网络通过在每层使用教师的部分通道从教师(目标)网络中分离出来。我们将这种学生网络称为信道分割超分辨率网络(CSSRNet)。教师网络和学生网络共同训练,形成两个计算方式不同的SR模型。根据设备中计算资源的不同,我们可以动态分配这两种模型,即在资源有限的设备中,如果超过所需的计算开销,则选择CSSR-Net,否则选择教师模型.

主要贡献

作者提出的对比自蒸馏(CSD)框架可以作为一种通用的方法来同时压缩和加速超分网络,在落地应用中的运行时间也十分友好。

自蒸馏被引用进超分领域来实现模型的加速和压缩,同时作者提出利用对比学习进行有效的知识迁移,从而 进一步的提高学生网络的模型性能。

在Urban100数据集上,加速后的EDSR+可以实现4倍的压缩比例和1.77倍的速度提高,带来的性能损失仅为0.13 dB PSNR。

2.方法

我们的CSD包括两个部分:CSSR-Net和对比损失(CL)。首先,我们描述了CSSR-Net。然后,我们给出了构造CSSR-Net的上界和下界的正则表达式。

最后,给出了CSD方案的总体损失函数,并用一种新的优化策略对其进行了求解。

总结

回顾

近年来,知识蒸馏(Knowledge Distillation)方法在深度学习领域中备受关注,它是一种模型压缩技术,旨在将一个复杂的模型(通常被称为教师模型)的知识转移到一个简化的模型(通常被称为学生模型)中,从而使学生模型能够在保持性能的同时具有更小的模型尺寸和计算成本。

一些近年来的知识蒸馏方法和拓展包括:

  1. Teacher-Student Architecture: 最常见的知识蒸馏方法之一是使用教师模型和学生模型之间的监督信号。教师模型通常是一个大型、复杂的模型,而学生模型则是一个较小、简化的模型。通过让学生模型学习教师模型的输出,学生模型可以在学习到教师模型的知识的同时获得更好的泛化性能。

  2. Soft Target Training: 传统的监督学习使用的是硬标签(one-hot编码),即只有正确类别的概率为1,其余为0。而软目标训练则使用教师模型的输出概率分布作为目标。这种方法能够提供更丰富的信息,使得学生模型可以学习到更多的知识。

  3. Attention Mechanisms: 在知识蒸馏中引入注意力机制可以帮助学生模型更好地关注教师模型的重要信息,从而提高模型性能。

  4. Self-Distillation: 自蒸馏是一种方法,其中学生模型在训练过程中不仅要学习来自教师模型的知识,还要学习自身的输出。这种方法可以进一步提高学生模型的性能,同时减少对教师模型的依赖。

  5. Multi-Teacher Distillation: 多教师蒸馏是一种将多个教师模型的知识融合到学生模型中的方法。每个教师模型可能具有不同的视角或专长,通过结合它们的知识,学生模型可以获得更全面和鲁棒的学习。

未来

随着深度学习模型的不断发展和复杂化,未来的知识蒸馏方法可能会涉及更复杂的模型结构。这可能包括对于更深、更宽的神经网络架构的探索,以及对于更复杂的模型组合和蒸馏技术的研究。例如,结合Transformer模型的自注意力机制与知识蒸馏技术可能会带来更加高效的模型压缩和知识传递方式。

其次,未来的知识蒸馏方法可能会更加注重模型的智能化和个性化。这意味着,蒸馏过程将更加关注于学生模型的个性化需求和特征提取,以及对于不同学习任务和场景的适应性。这可能会涉及到更加精细的目标函数设计、更加智能化的蒸馏策略以及更加灵活的模型结构。

目前有的蒸馏方法效果提升不大,知识蒸馏还有很大提升空间,因为网络中有大量的参数,而实际使用到的很少,所以可以在蒸馏方法上优化,将特征提取和知识传递做得更通用,或者更准确,甚至像大模型的预训练与微调一样,或者是自监督蒸馏,或者是自动地结合上剪枝量化,感知量化等等方法。

reference

1、crd https://arxiv.org/abs/1910.1069

2、crd code https://github.com/HobbitLong/RepDistiller

3、cls kd https://blog.csdn.net/akaweige/article/details/131520764

4、sr kd https://zhuanlan.zhihu.com/p/346422123

5、cls kd https://zhuanlan.zhihu.com/p/102038521

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/632185.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

记一次:mysql统计的CAST函数与json字段中的某个字段

前言&#xff1a;因为需求的问题&#xff0c;会遇到将某个json存入到一个字段中&#xff0c;但在统计的时候&#xff0c;又需要将这个json中的某个字段作为条件来统计&#xff0c;所以整理了一下cast函数和json中某个字段的条件判断 一、浅谈mysql的json 1.1 上例子 SELECTli…

管仲故乡是颍川,何分颍上或颍下

第一仲父管仲&#xff0c;故乡在哪里&#xff1f;依然像许多名人故里一样存在争议&#xff0c;但是这个争议却很不一般&#xff0c;引出了一个大话题。 管子是安徽颍上县人&#xff0c;《史记》记载: “管仲&#xff0c;颍上人也。”颍上县有管鲍祠&#xff0c;是安徽省重点文物…

【小项目】简单实现博客系统(一)(前后端结合)

一、实现逻辑 1&#xff09;实现博客列表页 让页面从服务器拿到博客数据&#xff08;数据库&#xff09; 2&#xff09;实现博客详情页 点击博客的时候&#xff0c;可以从服务器拿到博客的完整数据 3&#xff09;实现登录功能&#xff08;跟之前写的登录页面逻辑一致&…

羊大师解析,春季羊奶助力健康成长

羊大师解析&#xff0c;春季羊奶助力健康成长 随着春天的到来&#xff0c;万物复苏&#xff0c;大自然呈现出一派生机勃勃的景象。在这个充满希望的季节里&#xff0c;我们不仅要关注外界环境的变化&#xff0c;更要关注身体的健康和成长。羊大师发现羊奶作为一种营养丰富的食…

探索未来:苹果如何在 Apple Vision Pro 上进行创新

视觉体验的演进 在当今快节奏的数字化时代&#xff0c;技术创新不断塑造着我们与周围世界互动的方式。在这些进步中&#xff0c;苹果视觉专业技术凭借其创新精神脱颖而出&#xff0c;彻底改变了我们感知和参与视觉内容的方式。 无与伦比的显示技术 苹果视觉专业技术的核心是…

error Component name “Child4“ should always be multi-word

error Component name "Child4" should always be multi-word 这个错误是来自于ESLint的规则&#xff0c;它强制要求组件的名称必须是多单词的。这是因为单单一个单词可能与HTML的内建标签或者其他组件的名称产生冲突&#xff0c;从而导致意外的行为。 解决方法&am…

C#知识|(实例)大乐透双色球随机选号器项目实现(二)

哈喽,你好,我是雷工! 前面记录了UI设计,接下来记录类的设计,及相关代码。 04 类的设计 4.1、创建文件夹 为了使分类更加清晰,添加Models文件夹; 4.2、添加大乐透类 添加SuperLotto.cs类 该类的代码如下: namespace LeiGongNotes {/// <summary>/// 大乐透类…

el-upload 上传多个视频

<el-form-item label="视频" prop="video_url"><el-uploadclass="upload-demo"ref="uploadRef":multiple="true":on-change="handleChange":before-remove="beforeRemove":before-upload=&quo…

Typescript 哲学 - d.ts文件

The .d.ts syntax intentionally looks like ES Modules syntax. ES Modules was ratified by TC39 in 2015 as part of ES2015 (ES6), while it has been available via transpilers for a long time default export (esModuleInterop:true) / export 讲一个 d.ts export 的…

PCB供电夹子DIY

在刷小红书的时候&#xff0c;看到了清华卓晴教授【https://zhuoqing.blog.csdn.net/】DIY的供电夹子&#xff0c;感觉对于自己DIY PCB的时候供电会比较方便&#xff0c;物料也比较简单&#xff0c;打算复刻一下。 使用物料 1、小夹子&#xff0c;文具店都有卖&#xff0c;选…

Android手动下载Gradle的使用方法

导入新项目通常会自动下载gradle版本&#xff0c;这种方式很慢而且经常下载失败&#xff0c;按照提示手动下载的gradle应该放在那里&#xff0c;如何使用&#xff0c;本篇文章为你提供一种亲测有效的方法&#xff1a; 在Android Studio打开Setting搜索Gradle找到Gradle的存放目…

亚马逊测评真人号与自养号:如何选择?区别与作用全面解析

亚马逊卖家都希望能打造出热销产品的产品列表&#xff0c;因为评论对于列表的曝光和流量有着巨大的影响。然而&#xff0c;获取有效的产品评论并不容易&#xff0c;许多卖家为了提高自己产品在同类别中的竞争力&#xff0c;选择进行测评。测评可以快速提高产品的排名、权重和销…

Python自学之路--004:Python使用注意点(原始字符串‘r’\字符转换\‘wb’与‘w区别’\‘\‘与‘\\’区别)

目录 1、原始字符串‘r’ 2、字符转换问题 3、open与write函数’wb’与’w’区分 4、Python里面\与\\的区别 1、原始字符串‘r’ 以前的脚本通过Python2.7写的&#xff0c;通过Python3.12去编译发现不通用了&#xff0c;其实也是从一个初学者的角度去看待这些问题。 其中的\…

ROS2 - 创建项目( Ubuntu 22.04 )

本文简述&#xff1a;在 Ubuntu22.04 系统中使用 VS Code 来搭建一个ROS2开发项目。 ROS2 安装&#xff1a; 可以运行下面的命令&#xff0c;一键安装&#xff1a; wget http://fishros.com/install -O fishros && . fishros 1. 创建工作空间 本文假设配置完成 VS …

探索未来:Google I/O 2024 AI重磅发布一览

亲爱的读者们&#xff0c;大家期待已久的Google I/O开发者大会终于到来了&#xff01;今年的大会尤为特别&#xff0c;Google在发布会上大力强调了人工智能&#xff08;AI&#xff09;的重要性&#xff0c;可以说AI成为了绝对的主角。为了让大家快速了解今年的重点内容&#xf…

短视频创作者的9个免费实用的视频素材网站

在视频剪辑的过程中&#xff0c;找到高质量、无水印且可商用的视频素材是每个创作者的梦想。下面为大家推荐9个无水印素材网站&#xff0c;助你轻松获取所需的视频素材。 1. 蛙学府 - 提供丰富的高清视频素材&#xff0c;涵盖风景、人物、科技等类别。所有素材高清且可商用&…

2025秋招Java还是c++?

一、我的编程经 说说我的编程经历&#xff0c;在C和Java之间我经历了几个阶段&#xff1a; 大学期间&#xff0c;我浅尝辄止地学习了一段时间的Java&#xff0c;但后来放弃了&#xff0c;开始学习C/C。本科毕业后&#xff0c;我选择攻读硕士学位&#xff0c;并一直专注于C的学…

【错题集-编程题】空调遥控(二分 / 滑动窗口)

牛客对应题目链接&#xff1a;空调遥控 (nowcoder.com) 一、分析题目 1、滑动窗口 先排序&#xff0c;然后维护窗口内最大值与最小值的差在 2 * p 之间&#xff08;max - min&#xff09;。 2、二分查找 先排序&#xff0c;然后枚举所有的温度&#xff0c;⼆分出符合要求的…

C语言详解:数组指针

数组指针是指针 int* p[10] 这是指针数组的写法 &#xff0c;因为【】的优先级比*高&#xff0c; 所以为了解决优先级问题&#xff0c;加&#xff08;&#xff09; int(* p)[10]&arr;//数组的地址要存起来 说明p是指针&#xff08;首先与*结合&#xff09;&#xff0c…