参考连接:
Knowledge Distillation Tutorial — PyTorch Tutorials 2.2.1+cu121 documentation
方法一 :
知识蒸馏的损失函数只接受两个相同维度的输入,所以我们需要采取措施使他们在进入损失函数之前是相同维度的。我们将使用平均池化层对在教师模型卷积运算后的logits进行池化,使得logits维度和学生保持一样。
方法二:
原来的教师模型:
只有forward函数内有调整
class DeepNN(nn.Module):
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
加入了池化后的教师模型:
class ModifiedDeepNNCosine(nn.Module):
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)
return x, flattened_conv_output_after_pooling
原来的学生模型
class LightNN(nn.Module):
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
更新后的学生模型:
class ModifiedLightNNCosine(nn.Module):
def forward(self, x):
x = self.features(x)
flattened_conv_output = torch.flatten(x, 1)
x = self.classifier(flattened_conv_output)
return x, flattened_conv_output
方法三:
Teacher accuracy: 75.60% Student accuracy without teacher: 70.41% Student accuracy with CE + KD: 70.19% Student accuracy with CE + CosineLoss: 70.87% Student accuracy with CE + RegressorMSE: 71.40%