【DeepSeek背后的技术】系列二:大模型知识蒸馏(Knowledge Distillation)

目录

  • 1 引言
  • 2 操作步骤和公式说明
    • 2.1 准备教师模型(Teacher Model)和学生模型(Student Model)
    • 2.2 生成软标签(Soft Labels)
    • 2.3 定义蒸馏损失函数
    • 2.4 训练学生模型
    • 2.5 调整超参数
    • 2.6 评估与部署
  • 3 其他知识蒸馏技术
  • 4 实践
  • 参考文献

1 引言

近年来,随着Transformer、MOE架构的提出,使得深度学习模型轻松突破上万亿规模参数,从而导致模型变得越来越大,因此,我们需要一些大模型压缩技术来降低模型部署的成本,并提升模型的推理性能。而大模型压缩主要分为如下几类::剪枝(Pruning)、知识蒸馏(Knowledge Distillation)、量化(Quantization)、低秩分解(Low-Rank Factorization)。

模型压缩方法分类:
模型压缩

技术概述图:
技术介绍

大规模语言模型(LLM)近年来在自然语言处理领域取得了巨大进步,使得人类对话和文本生成成为可能。然而,开源LLM模型由于参数规模较小,性能难以达到商业LLM的水平。知识蒸馏技术可以解决这一问题,它通过利用商业LLM的高性能,将其知识“蒸馏”(Knowledge Distillation,知识蒸馏,简称KD)到更小的开源模型中,从而实现高性能和低成本。

KD

模型蒸馏(Model Distillation)最初由Hinton等人于2015年在论文《Distilling the Knowledge in a Neural Network》提出,其核心思想是通过知识迁移的方式,将一个复杂的大模型(教师模型)的知识传授给一个相对简单的小模型(学生模型),简单概括就是利用教师模型的预测概率分布作为软标签对学生模型进行训练,从而在保持较高预测性能的同时,极大地降低了模型的复杂性和计算资源需求,实现模型的轻量化和高效化。

下面是模型蒸馏的要点:

  • 首先需要训练一个大的模型,这个大模型也称为 teacher 模型。
  • 利用 teacher 模型输出的概率分布训练小模型,小模型称为 student 模型。
  • 训练 student 模型时,包含两种 label,soft label 对应了 teacher 模型输出的概率分布,而 hard label 是原来的 one-hot label。
  • 模型蒸馏训练的小模型会学习到大模型的表现以及泛化能力。

2 操作步骤和公式说明

2.1 准备教师模型(Teacher Model)和学生模型(Student Model)

  • 教师模型:已经训练好的高性能大型模型(如BERT、GPT等)。
  • 学生模型:结构更简单的小型模型(如TinyBERT、DistilBERT等),参数量远小于教师模型。

公式说明
假设教师模型的输出概率分布为 pt,学生模型的输出概率分布为 ps

2.2 生成软标签(Soft Labels)

  • 操作
    对训练数据中的每个样本 x,用教师模型计算其输出概率分布(软标签)。
    通过引入温度参数T 平滑概率分布,使类别间的关系更明显。

公式说明
教师模型的软标签计算:

软标签

其中:

  • zt 是教师模型的原始输出(logits),
  • T是温度参数(通常 T > 1,如 T=3)。

高温T使概率分布更平滑,学生模型能学习到类别间的隐含关系(例如“猫”和“狗”的相似性),硬标签(One-hot编码)仅包含0/1信息,而软标签包含更多知识。

2.3 定义蒸馏损失函数

  • 操作
    设计总损失函数,结合蒸馏损失(模仿教师模型)和任务损失(拟合真实标签)。

公式说明

  1. 蒸馏损失(KL散度):

公式2

其中
公式3

zs是学生模型的logits,乘以 T2 是为了平衡温度缩放对梯度的影响。

  1. 任务损失(交叉熵):

公式4

其中
公式5

是学生模型的原始概率分布。

  1. 总损失
    公式6

其中 α 是蒸馏损失的权重,取值范围通常是 [0.5, 0.9]。

蒸馏损失强制学生模仿教师的概率分布,任务损失确保学生模型不偏离真实标签,温度T和权重 α需调参以平衡两者。

2.4 训练学生模型

  • 操作
    使用教师生成的软标签和真实标签联合训练学生模型,优化总损失 Ltotal
    训练时需注意:
    • 温度参数T:训练阶段使用高温(如T=3),推理阶段恢复T=1。
    • 梯度更新:同时优化学生对教师分布和真实标签的拟合。

公式说明
反向传播时,总损失的梯度计算为:
公式7

其中 θs是学生模型的参数。

注意:训练时高温 ( T ) 增强知识迁移,推理时恢复标准概率分布。

2.5 调整超参数

  • 温度T
    • 较高的T(如3~10)增强软标签的平滑性,适合复杂任务。
    • 较低的T(如1~2)贴近原始分布,适合简单任务。
  • 损失权重 α
    • 若教师模型质量高,可增大 α(如0.7~0.9)。
    • 若真实标签噪声小,可增大任务损失权重。

示例调参策略

  • 两阶段训练:
    1. 第一阶段:高T和大α,专注于学习教师知识。
    2. 第二阶段:逐渐降低T和α,贴近真实任务。

2.6 评估与部署

  • 评估指标
    • 学生模型在测试集上的准确率、F1值等任务指标。
    • 计算学生模型与教师模型的输出相似性(如KL散度)。
  • 部署
    • 学生模型以T=1运行,直接输出原始概率分布 psraw

核心思想:
通过教师模型的软标签(富含类别间关系)和学生模型的任务损失(保留真实标签信息),蒸馏实现了知识的迁移。

  • 温度T 是核心超参数,控制知识迁移的“清晰度”。
  • 两阶段训练(先学教师,再微调)是常见优化策略。

3 其他知识蒸馏技术

知识蒸馏技术分类

KD分类

前向KL和逆向KL对比

4 实践

以下是一个简单的模型蒸馏代码示例,使用一个预训练的ResNet-18模型作为教师模型,并使用一个简单的CNN模型作为学生模型。同时,将使用交叉熵损失函数和L2正则化项来优化学生模型的性能表现。

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms

# 定义教师模型和学生模型
teacher_model = models.resnet18(pretrained=True)
student_model = nn.Sequential(
    nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(128 * 7 * 7, 10)
)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=0.01, momentum=0.9)
optimizer_student = optim.Adam(student_model.parameters(), lr=0.001)

# 训练数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
trainset = datasets.MNIST('../data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# 蒸馏过程
for epoch in range(10):
    running_loss_teacher = 0.0
    running_loss_student = 0.0
    
    for inputs, labels in trainloader:
        # 教师模型的前向传播
        outputs_teacher = teacher_model(inputs)
        loss_teacher = criterion(outputs_teacher, labels)
        running_loss_teacher += loss_teacher.item()
        
        # 学生模型的前向传播
        outputs_student = student_model(inputs)
        loss_student = criterion(outputs_student, labels) + 0.1 * torch.sum((outputs_teacher - outputs_student) ** 2)
        running_loss_student += loss_student.item()
        
        # 反向传播和参数更新
        optimizer_teacher.zero_grad()
        optimizer_student.zero_grad()
        loss_teacher.backward()
        optimizer_teacher.step()
        loss_student.backward()
        optimizer_student.step()
    
    print(f'Epoch {epoch+1}/10 \t Loss Teacher: {running_loss_teacher / len(trainloader)} \t Loss Student: {running_loss_student / len(trainloader)}')

在这个示例中:
(1)首先定义了教师模型和学生模型,并初始化了相应的损失函数和优化器;
(2)然后,加载了MNIST手写数字数据集,并对其进行了预处理;
(3)接下来,进入蒸馏过程:对于每个批次的数据,首先使用教师模型进行前向传播并计算损失函数值;然后使用学生模型进行前向传播并计算损失函数值(同时加入了L2正则化项以鼓励学生模型学习教师模型的输出);
(4)最后,对损失函数值进行反向传播和参数更新:打印了每个批次的损失函数值以及每个epoch的平均损失函数值。
通过多次迭代训练后,我们可以得到一个性能较好且轻量化的学生模型。

参考文献

  1. A Survey on Knowledge Distillation of Large Language Models
  2. A Survey on Model Compression for Large Language Models
  3. BERT模型蒸馏指南(知乎)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964303.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【BUUCTF杂项题】后门查杀、webshell后门

前言:Webshell 本质上是一段可在 Web 服务器上执行的脚本代码,通常以文件形式存在于 Web 服务器的网站目录中。黑客通过利用 Web 应用程序的漏洞,如 SQL 注入、文件上传漏洞、命令执行漏洞等,将 Webshell 脚本上传到服务器&#x…

SPI(Serial Peripheral Interface)串行外围设备接口

SPI概述: SPI协议最初由Motorola公司(现为NXP Semiconductors的一部分)在20世纪80年代中期开发。最初是为了在其68000系列微控制器中实现高速、高效的串行通信。该协议旨在简化微控制器与外围设备之间的数据传输。 1980年代:SPI协…

深度学习 Pytorch 基础网络手动搭建与快速实现

为了方便后续练习的展开,我们尝试自己创建一个数据生成器,用于自主生成一些符合某些条件、具备某些特性的数据集。 导入相关的包 # 随机模块 import random# 绘图模块 import matplotlib as mpl import matplotlib.pyplot as plt# 导入numpy import nu…

10分钟快速上手DeepSeek!

DeepSeek 是一款基于命令行和配置文件的数据处理工具,支持多种数据格式(如 CSV、JSON、SQL 等)和多种数据源(如本地文件、数据库、API 等)。 它的核心功能包括: 数据导入与导出:支持从多种数据…

【现代深度学习技术】深度学习计算 | 延后初始化自定义层

【作者主页】Francek Chen 【专栏介绍】 ⌈ ⌈ ⌈PyTorch深度学习 ⌋ ⌋ ⌋ 深度学习 (DL, Deep Learning) 特指基于深层神经网络模型和方法的机器学习。它是在统计机器学习、人工神经网络等算法模型基础上,结合当代大数据和大算力的发展而发展出来的。深度学习最重…

Redis --- 秒杀优化方案(阻塞队列+基于Stream流的消息队列)

下面是我们的秒杀流程: 对于正常的秒杀处理,我们需要多次查询数据库,会给数据库造成相当大的压力,这个时候我们需要加入缓存,进而缓解数据库压力。 在上面的图示中,我们可以将一条流水线的任务拆成两条流水…

Rust HashMap :当储物袋遇上物品清单

开场白:哈希映射的魔法本质 在Rust的奇幻世界里,HashMap就像魔法师的储物袋: 键值对存储 → 每个物品都有专属咒语(键)和实体(值)快速查找 → 念咒瞬间召唤物品动态扩容 → 自动伸展的魔法空间…

LabVIEW的智能电源远程监控系统开发

在工业自动化与测试领域,电源设备的精准控制与远程管理是保障系统稳定运行的核心需求。传统电源管理依赖本地手动操作,存在响应滞后、参数调节效率低、无法实时监控等问题。通过集成工业物联网(IIoT)技术,实现电源设备…

C# Winform制作一个登录系统

using System; using System.Collections; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.Linq; using System.Text; using System.Threading.Tasks; using System.Windows.Forms;namespace 登录 {p…

尝试把clang-tidy集成到AWTK项目

前言 项目经过一段时间的耕耘终于进入了团队开发阶段,期间出现了很多问题,其中一个就是开会讨论团队的代码风格规范,目前项目代码风格比较混乱,有的模块是驼峰,有的模块是匈牙利,后面经过讨论,…

Docker技术相关学习三

一、Docker镜像仓库管理 1.docker仓库:用于存储和分发docker镜像的集中式存储库,开发者可以将自己创建的镜像推送到仓库中也可以从仓库中拉取所需要的镜像。 2.docker仓库: 公有仓库(docker hub):任何人都可…

挑战项目 --- 微服务编程测评系统(在线OJ系统)

一、前言 1.为什么要做项目 面试官要问项目,考察你到底是理论派还是实战派? 1.希望从你的项目中看到你的真实能力和对知识的灵活运用。 2.展示你在面对问题和需求时的思考方式及解决问题的能力。 3.面试官会就你项目提出一些问题,或扩展需求…

Python 与 PostgreSQL 集成:深入 psycopg2 的应用与实践

title: Python 与 PostgreSQL 集成:深入 psycopg2 的应用与实践 date: 2025/2/4 updated: 2025/2/4 author: cmdragon excerpt: PostgreSQL 作为开源关系型数据库的佼佼者,因其强大的功能与性能被广泛应用于各种项目中。而 Python 则因其简洁易用的语法、丰富的库和强大的…

计算机从何而来?计算技术将向何处发展?

计算机的前生:机械计算工具的演进 算盘是计算机的起点,它其实是一台“机械式半自动化运算器”。打算盘的“口诀”其实就是它的编程语言,算盘珠就是它的存储器。 第二阶段是可以做四则运算的加法器、乘法器。1642年,法国数学家帕斯…

【Blazor学习笔记】.NET Blazor学习笔记

我是大标题 我学习Blazor的顺序是基于Blazor University,然后实际内容不完全基于它,因为它的例子还是基于.NET Core 3.1做的,距离现在很遥远了。 截至本文撰写的时间,2025年,最新的.NET是.NET9了都,可能1…

MapReduce分区

目录 1. MapReduce分区1.1 哈希分区1.2 自定义分区 2. 成绩分组2.1 Map2.2 Partition2.3 Reduce 3. 代码和结果3.1 pom.xml中依赖配置3.2 工具类util3.3 GroupScores3.4 结果 参考 本文引用的Apache Hadoop源代码基于Apache许可证 2.0,详情请参阅 Apache许可证2.0。…

重生之我在异世界学编程之C语言:深入指针篇(上)

大家好,这里是小编的博客频道 小编的博客:就爱学编程 很高兴在CSDN这个大家庭与大家相识,希望能在这里与大家共同进步,共同收获更好的自己!!! 本文目录 引言正文(1)内置数…

deep generative model stanford lecture note3 --- latent variable

1 Introduction 自回归模型随着gpt的出现取得很大的成功,还是有很多工程上的问题并不是很适合使用自回归模型: 1)自回归需要的算力太大,满足不了实时性要求:例如在自动驾驶的轨迹预测任务中,如果要用纯自回…

STM32_SD卡的SDIO通信_DMA读写

本篇,将使用CubeMXKeil,创建一个SD卡的DMA读写工程。 目录 一、简述 二、CubeMX 配置 SDIO DMA 三、Keil 编辑代码 四、实验效果 实现效果,如下图: 一、简述 上篇已简单介绍了SD、SDIO,本篇不再啰嗦,…

互联网行业常用12个数据分析指标和八大模型

本文目录 前言 一、互联网线上业务数据分析的12个指标 1. 用户数据(4个) (1) 存量(DAU/MAU) (2) 新增用户 (3) 健康程度(留存率) (4) 渠道来源 2. 用户行为数据(4个) (1) 次数/频率…