【EMNLP2024】基于多轮课程学习的大语言模型蒸馏算法 TAPIR

近日,阿里云人工智能平台PAI与复旦大学王鹏教授团队合作,在自然语言处理顶级会议EMNLP 2024 上发表论文《Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning》。文章提出了一个名为 TAPIR 的知识蒸馏框架,TAPIR 通过多任务课程规划来蒸馏黑盒大语言模型的指令回答能力,在蒸馏和多轮迭代过程中,使用教师 LLM 做为裁判找出对于学生 LLM 来说难以回答的指令,进行难度重采样。同时,TAPIR 调整多任务配比,进行训练集中的任务多样性分布的重采样,并根据相应多任务特点自动优化教师模型的回答风格。

背景

大语言模型在回答开放领域通用任务的指令上取得了很大地进步。指令微调是微调预训练模型,使其从文本补全模型成为强大的对话模型的关键。尽管已有研究探索了使用强大的黑盒教师模型(如GPT-4, Qwen-max)来自动蒸馏和标注指令的方法,但这些研究往往忽视了微调训练集中任务的多样性分布,以及训练集中指令难度的差异,这可能导致学生LLMs知识能力的不平衡和解决复杂任务的能力的不足。为了解决这些挑战,文章提出了一个名为TAPIR的新框架,它通过多任务课程规划来蒸馏黑盒大语言模型的指令回答能力,从而提高学生小模型的指令遵循能力。

算法流程

文章中提出的TAPIR(Task-Aware Curriculum Planning for Instruction Refinement)框架的算法流程是一个多轮次的蒸馏方法,旨在提升学生大型语言模型(LLMs)遵循指令的能力。整个流程从初始化一个预训练的学生模型开始,然后通过以下步骤进行:

  1. 数据集难度过滤:使用一个开源的指令数据集(如Alpaca数据集)作为基础,通过计算模型拟合难度(MFD)分数来筛选出对学生模型来说较难的指令对,过滤得到种子数据集。

  2. 多任务规划指令蒸馏:根据设定的任务类型配比,利用一个教师模型(如ChatGPT)扩展种子数据集,生成更多具有相似难度水平的指令-响应对,并提升推理类任务的采样概率,以更好的缓解能力冲突问题。

  3. 多任务回答风格增强:对于某些任务,使用特定的提示重写响应,以便从教师模型获得更精细、更详细的回答,或者是特定任务格式的回答(如思维链,代码注释),这有助于学生模型更好地理解和学习复杂任务。

  4. 模型多轮优化迭代:通过多轮训练,利用裁判模型得到学生模型的回答质量反馈奖励分数,采样得到新的蒸馏种子数据集。逐步增加新一轮蒸馏种子数据集中挑战性指令的比例,实现从易到难的泛化。

TAPIR框架通过这种逐步提升任务难度和均衡任务类型的策略,使学生模型能够在较少的训练数据下超越更大的模型,显示出更好的性能,并在多个基准测试中取得了显著的性能提升。

难度重采样

难度重采样旨在解决训练集中任务难度分布不均的问题。难度重采样的目标是确保学生大型语言模型在蒸馏微调过程中能够接触到难度逐渐增大的任务,从而在困难的任务中泛化。我们通过计算模型拟合难度(Model Fitting Difficulty, MFD)分数来评估每个指令对学生模型的难度。MFD分数是通过比较学生模型生成的响应与教师模型生成的响应之间的质量差异来确定的。我们使用教师模型来作为裁判打分。

MFD(x_i) = f_J(x_i, \tilde{y}_i) - f_J(x_i, y_i)

根据MFD分数,筛选出对学生模型来说较难的指令对,即分差大于阈值 \delta 的指令。这些指令对将被纳入种子数据集。

D_S = {(x_i, y_i) \in D \mid MFD(x_i) > \delta }

任务重采样

在TAPIR框架中,任务重采样旨在解决训练集中任务分布不均的问题。其目的是提升训练集的多样性。在均衡的任务配比下为微调学生模型,以缓解微调过程中的能力冲突和灾难性遗忘问题。

首先,我们训练了一个指令任务分类模型(Deberta v3)识别和分类训练集中的任务类型,给每条指令打上显示的任务标签。然后通过任务标签重采样,使数据集中的任务分布更均衡,并且增强逻辑推理和编程任务的占比。基于我们的采样概率,教师模型扩展种子数据生成了新指令问答对,这些新数据与原有数据在难度上相近。

设指令对(x_i, y_i)的任务采样概率为 \Pr(\mathcal{T} | (x_i, y_i),则学生模型微调的自回归损失可以写作:

\mathcal{L}(\Phi) = - \sum_{(x_i, y_i) \in D_P} \Pr(\mathcal{T} | (x_i, y_i)) \cdot \log \Pr(\hat{y}_i|x_i; \Phi)

我们针对任务特点增强了教师模型标注的回答格式。如下所示:

多轮迭代优化

在多轮迭代的过程中,我们可以不断更新计算学生模型在新的微调数据集上的模型拟合难度来动态调整新一轮的蒸馏种子数据集难度配比。如下面的公式所示,当 \alpha_r 设置为 1 时,整个训练语料库仅由这些“困难”样本作为种子蒸馏。通过逐渐增加 \alpha_r , 系统地提高学习任务的复杂性。同时,为了保证指令的多样性,在每一轮中通过教师模型扩展难度重采样后的数据集,并将扩展后的数据集表示为D_P^{(r)}。第 r 轮的损失函数定义如下:

\mathcal{L}(\Phi, r) = - \alpha_r \sum_{(x_i, y_i) \in D_P^{(r)}} \mathbb{1}{(x_i, y_i)} \cdot \log \Pr(\hat{y}i | x_i; \Phi) - (1 - \alpha_r) \sum{(x_i, y_i) \in D \setminus D_S} \mathbb{1}{(x_i, y_i)} \cdot \log \Pr(\hat{y}_i | x_i; \Phi)

在每轮之后,更新规则为:

\alpha_{r+1} = \alpha_r + \Delta_\alpha

其中 \Delta_\alpha 是一个预定义的常数,用来逐渐增加学习任务的难度。

实验结果

实验结果表明,使用TAPIR框架训练的学生语言模型在较少的训练数据下,其性能超过了更大的指令调整模型和其他蒸馏基线方法。具体地说,TAPIR训练的模型在AlpacaEval 2.0基准测试中取得了7.80的胜率,优于Vicuna 13B和LLaMA2-Chat 13B,即便其培训数据和参数量仅为后者的一半。此外,在MT-Bench基准测试中,该模型在角色扮演、推理、数学、编程和人文学科等子任务中,表现优于LLaMA2 7B Chat 基线模型。为验证TAPIR框架在不同规模模型上的一致性,我们在Qwen1.5-Chat系列模型上做了实验,结果显示TAPIR能有效提升模型的指令遵循能力。

TAPIR-7B模型例子如下所示。在角色扮演任务中,语言模型扮演体育解说评论员。TAPIR-7B 生动地描述了比赛的最后胜利时刻并表现出色,而 Lion-7B 只是提供了如何评论的分析,没有完全执行任务,LLaMA2-Chat则误解了指令。

参考文献

  • Li, M., Chen, L., Chen, J., He, S., Huang, H., Gu, J., & Zhou, T. Reflection-Tuning: Data Recycling Improves LLM Instruction-Tuning. ArXiv, abs/2310.11716.

  • Song, C., Zhou, Z., Yan, J., Fei, Y., Lan, Z., & Zhang, Y. Dynamics of Instruction Tuning: Each Ability of Large Language Models Has Its Own Growth Pace. ArXiv, abs/2310.19651.

  • Jiang, Y., Chan, C., Chen, M., & Wang, W. Lion: Adversarial Distillation of Proprietary Large Language Models. EMNLP 2023.

论文信息

论文名字:Distilling Instruction-following Abilities of Large Language Models with Task-aware Curriculum Planning

论文作者:岳元浩、汪诚愚、黄俊、王鹏

论文pdf链接:https://arxiv.org/pdf/2405.13448

阿里云人工智能平台PAI长期招聘研究实习生。团队专注于深度学习算法研究与应用,重点聚焦大语言模型和多模态AIGC大模型的应用算法研究和应用。简历投递和咨询:chengyu.wcy@alibaba-inc.com

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/912049.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Web服务nginx基本实验

安装软件: 启动服务: 查看Nginx服务器的网络连接信息,监听的端口: 查看默认目录: 用Windows访问服务端192.168.234.111的nginx服务:(防火墙没有放行nginx服务,访问不了) …

github使用基础

要通过终端绑定GitHub账号并进行文件传输,你需要使用Git和SSH密钥来实现安全连接和操作。以下是一个基本流程: 设置GitHub和SSH 检查Git安装 通过终端输入以下命令查看是否安装Git: bash 复制代码 git --version配置Git用户名和邮箱 bash …

excel常用技能

1.基础技能 1.1 下拉框设置 a. 选中需要设置的列或单元格,数据 ---》 数据验证 b.验证条件 ---> 序列(多个值逗号隔开) 2.函数 2.1 统计函数-count a.count(区域,区域,......) 统计数量,只针…

Flipper Zero BadUSB反弹shell

Flipper Zero BadUSB反弹shell 前置知识点: Flipper Zero BadUSB 以及其他几个 BadUSB 设备使用用 DuckyScript 编写的有效负载。一种简单的脚本语言,用于执行导致键盘注入攻击的击键。 步骤 创建rev_shell_win.txt文件,并将其拖到badusb文件夹中. 相…

【GPTs】Email Responder Pro:高效生成专业回复邮件

博客主页: [小ᶻZ࿆] 本文专栏: AIGC | GPTs应用实例 文章目录 💯GPTs指令💯前言💯Email Responder Pro主要功能适用场景优点缺点 💯小结 💯GPTs指令 Email Craft is a specialized assistant for cra…

检测敏感词功能

今天策划给我一个任务 —— 检测昵称中是否含有敏感词功能,然后丢给我两个压缩包,我解压一看: 有的txt文件是一行一个词: 有的txt文件是按逗号分隔开: 不管是什么格式的总之量非常多,把我这辈子脏话都囊括…

【SpringBoot】19 文件/图片下载(MySQL + Thymeleaf)

Git仓库 https://gitee.com/Lin_DH/system 介绍 从 MySQL 中,下载保存的 blob 格式的文件。 代码实现 第一步:配置文件 application.yml spring:jackson:date-format: yyyy-MM-dd HH:mm:sstime-zone: GMT8datasource:driver-class-name: com.mysql.…

Coppelia Sim (v-REP)仿真 机器人3D相机手眼标定与实时视觉追踪 (三)

使用标定好的结果进行跟踪标定板的位置 坐标转换的步骤为: 1.图像坐标点转到相机坐标系下的点 2.相机坐标系下的点转为夹爪坐标系下的点 3.夹爪坐标系下的点转为机械手极坐标系下的点 跟踪的方式 1.采用标定板的第一个坐标点作为跟踪点 3.机器人每次移动到该点位&a…

easyui +vue v-slot 注意事项

https://www.jeasyui.com/demo-vue/main/index.php?pluginDataGrid&themematerial-teal&dirltr&pitemCheckBox%20Selection&sortasc 接口说明 <template><div><h2>Checkbox Selection</h2><DataGrid :data"data" style&…

运动【跑步 03】安踏冠军3的10KM和15KM*2体验(对比必迈PURE LIGHT)

这里写目录标题 1. 前言2. 两双鞋2.1 必迈 PURE LIGHT2.2 安踏 冠军 3 3. 主观对比4. 问题4.1 必迈 PURE LIGHT4.2 冠军 3 5. 总结 1. 前言 我是程序员&#xff0c;并不是专业的运动员&#xff0c;对跑步鞋的研究也不深&#xff0c;至今也就买过两双相对比较专业的跑鞋&#x…

O-RAN Fronthual CU/Sync/Mgmt 平面和协议栈

O-RAN Fronthual CU/Sync/Mgmt 平面和协议栈 O-RAN Fronthual CU/Sync/Mgmt 平面和协议栈O-RAN前端O-RAN 前传平面C-Plane&#xff08;控制平面&#xff09;&#xff1a;控制平面消息定义数据传输、波束形成等所需的调度、协调。U-Plane&#xff08;用户平面&#xff09;&#…

【JavaEE进阶】导读

本节⽬标 了解什么是JavaEE 在JavaEE中, 我们学习什么, 如何学, 难点是什么 一、Java EE 发展历程 Java EE(Java Platform Enterprise Edition), Java 平台企业版. 是JavaSE的扩展, ⽤于解决企业级的开发需求, 所以也可以称之为是⼀组⽤于企业开发的Java技术标准. 所以, 学习…

Javascript事件循环流程分析

基础概念 事件循环&#xff08;Event Loop&#xff09;&#xff1a;事件循环是JavaScript运行时环境中的一个循环机制&#xff0c;它不断地检查调栈用和任务队列。当调用栈为空时&#xff0c;事件循环会首先检查微任务队列&#xff0c;并执行其中的所有任务。只有当微任务队列…

单元/集成测试解决方案

在项目开发的前期针对软件单元/模块功能开展单元/集成测试&#xff0c;可以尽早地发现软件Bug&#xff0c;避免将Bug带入系统测试阶段&#xff0c;有效地降低HIL测试的测试周期&#xff0c;也能有效降低开发成本。单元/集成测试旨在证明被测软件实现其单元/架构设计规范、证明被…

【C++】C++的单例模式、跟踪内存分配的简单方法

二十四、C的单例模式、跟踪内存分配的简单方法 1、C的单例模式 本小标题不是讨论C的语言特性&#xff0c;而是一种设计模式&#xff0c;用于确保一个类在任何情况下都只有一个实例&#xff0c;并提供一个全局访问点来获取这个实例。即C的单例模式。这种模式常用于资源管理&…

LangGPT结构化提示词编写实践

基础任务 如果直接询问大模型strawberry有几个r&#xff0c;大模型会给出错误的答案&#xff1a; 这里我们引入思维连Chain of Thought&#xff0c;我们让大模型遍历一遍单词&#xff0c;每次累加得到最终结果 之前怎么都做不对的题&#xff0c;让大模型一步一步思考&#xf…

【Python系列】使用 Poetry 进行 Python 项目管理

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

Linux内核USB2.0驱动框架分析--USB设备枚举过程

一 USB特点 1.1 USB协议版本介绍&#xff1a; USB1.0/1.1&#xff08;low/fullspeed&#xff09;&#xff1a;传输速率最大为12Mbps&#xff0c;是较早的USB协议版本。 USB2.0&#xff08;highspeed&#xff09;&#xff1a;传输速率最大为480Mbps&#xff0c;相比USB1.0/1.1…

解决ultralytics的YOLO模型训练中验证集Loss为NaN(或mAP为0)的问题

前言 在使用ultralytics库的YOLO模型时&#xff0c;比如YOLOv8进行目标检测模型训练&#xff0c;遇到一个非常奇怪的问题&#xff1a;训练过程中的验证损失&#xff08;loss&#xff09;出现了NaN&#xff0c;而验证的评价指标如mAP50却能正常计算&#xff08;有时mAP都也为0&…

微信支付现金红包,实现转账到零钱包功能

大家好&#xff0c;我是小悟。 上次说到微信商家转账到零钱要出新玩法&#xff0c;可能会对某些特定的业务产生影响&#xff0c;详细请阅读【微信商家转账到零钱新玩法&#xff0c;却是个不好接受的消息】。 微信支付还有个现金红包的产品&#xff0c;也可以实现转账到用户零…