一、CWD特征蒸馏介绍
大部分的KD方法都是通过algin学生网络和教师网络的归一化的feature map, 最小化feature map上的激活值的差异。
逐通道知识蒸馏(channel-wise knowledge dissillation, CWD
)将每个通道的特征图归一化来得到软概率图。通过简单地最小化两个网络的通道概率图之间的KL
散度,蒸馏过程更关注每个通道最显著的区域,这对于密集预测任务很有价值。
参考论文:https://arxiv.org/pdf/2011.13256
二、CWD特征蒸馏实现流程
(1)使用softmax生成软概率图
class ChannelNorm(nn.Module):
def __init__(self):
super(ChannelNorm, self).__init__()
def forward(self, featmap):
n, c, h, w = featmap.shape
featmap = featmap.reshape((n, c, -1))
featmap = featmap.softmax(dim=-1)
return featmap
self.normalize = ChannelNorm()
# 1.使用softmax生成软概率图
norm_s = self.normalize(s_pred / self.temperature)
norm_t = self.normalize(t_pred.detach() / self.temperature)
(2)对学生概率图取log
# 2.对学生概率图取log
norm_s = norm_s.log()
(3)计算学生和教师概率图之间的KL散度
# 3.计算学生和教师概率图之间的KL散度
loss = nn.KLDivLoss(reduction='sum')(norm_s, norm_t)
loss /= n * c
return loss * (self.temperature ** 2)
三、完整CWD特征蒸馏代码实现
class ChannelNorm(nn.Module):
def __init__(self):
super(ChannelNorm, self).__init__()
def forward(self, featmap):
n, c, h, w = featmap.shape
featmap = featmap.reshape((n, c, -1))
featmap = featmap.softmax(dim=-1)
return featmap
class CriterionCWD(nn.Module):
def __init__(self, temperature=1.0):
super(CriterionCWD, self).__init__()
# define normalize function
self.normalize = ChannelNorm()
self.temperature = temperature
def forward(self, s_pred, t_pred):
n, c, h, w = s_pred.shape
# 1.使用softmax生成软概率图
norm_s = self.normalize(s_pred / self.temperature)
norm_t = self.normalize(t_pred.detach() / self.temperature)
# 2.对学生概率图取log
norm_s = norm_s.log()
# 3.计算学生和教师概率图之间的KL散度
loss = nn.KLDivLoss(reduction='sum')(norm_s, norm_t)
loss /= n * c
return loss * (self.temperature ** 2)