文献阅读:LESS: Selecting Influential Data for Targeted Instruction Tuning

  • 文献阅读:LESS: Selecting Influential Data for Targeted Instruction Tuning
    • 1. 文章简介
    • 2. 方法介绍
      • 1. Overview
      • 2. 原理说明
        • 1. SGD上的定义
        • 2. Adam上的定义
      • 3. 具体实现
        • 1. Overview
        • 1. LoRA使用
        • 2. 数据选择
        • 3. LESS-T
    • 3. 实验考察 & 结论
      • 1. 实验设计
      • 2. 主要结果
      • 3. 细节讨论
        • 1. 计算复杂度分析
        • 2. warmup是否必要
        • 3. checkpoint的影响(N的影响)
        • 4. LoRA Dimension的影响
    • 4. 总结 & 思考
  • 文献链接:https://arxiv.org/abs/2402.04333
  • Github链接:https://github.com/princeton-nlp/LESS

1. 文章简介

这篇文章是陈丹琦大佬在今天二月给出的关于LLM Tuning的一篇新作。

这篇文章同样是一篇比较fundamental的基础研究工作,考察的是LLM训练,或者说任意模型训练时如何最优化的选择训练数据,从而在尽可能不损失模型性能的情况下,最优化模型训练的效率,使得模型收敛的又快又好。

相似的工作之前有主动学习相关的一系列工作,这里倒是有些区别,因为主动学习感觉还是对于未标注数据进行最优化的选取,但是这里的LESS方法感觉还是在已有的标注数据当中选取一个子集,使得模型获得足量且优秀的训练结果。

下面,我们就来看看文中给出的具体实现方法和对应的实验考察。

2. 方法介绍

1. Overview

首先,我们来看一下LESS的整体的原理说明和实现。

LESS的全程的话是Low-rank gradiEnt Similarity Search,其整体的思路的话其实还是比较直接的,就是通过数据在模型进行反向传播时产生的梯度大小来判断数据对于模型训练的影响程度,然后选择最有影响的这部分数据来进行模型finetune即可。

但是,这里会涉及到几个问题:

  1. 具体定义上的问题,即如何判断数据对于模型训练的影响程度大小;
  2. 对每一个数据都进行反向传播进行判断的话,基本也就等于跑完一个epoch了,这种效率的话就有点舍本逐末了,因此,需要考察一下如何对效率进行优化。

下面,我们就来看一下文中对于这两个问题的处理。

2. 原理说明

首先,我们来看一下文中是如何来定义一条数据对于模型训练影响的大小的。

1. SGD上的定义

首先,文中在SGD上面进行了一下简单的考察,显然,对于一步训练前后,我们可以将其在测试集上的变化一阶泰勒展开得到:

l ( z ′ ; θ t + 1 ) = l ( z ′ ; θ t ) + ⟨ ∇ l ( z ′ ; θ t ) , θ t + 1 − θ t ⟩ l(z'; \theta^{t+1}) = l(z'; \theta^{t}) + \langle \nabla l(z'; \theta^{t}), \theta^{t+1} - \theta^{t} \rangle l(z;θt+1)=l(z;θt)+l(z;θt),θt+1θt

其中,参数的改变量则有可以通过训练过程中的一轮参数迭代过程来表达,即:

θ t + 1 − θ t = − η t ∇ l ( z ; θ t ) \theta^{t+1} - \theta^{t} = -\eta_t \nabla l(z; \theta^t) θt+1θt=ηtl(z;θt)

此时,我们即可得到测试集上一轮迭代测试集上loss的变化大小可以写为:

l ( z ′ ; θ t + 1 ) − l ( z ′ ; θ t ) = − η t ⋅ ⟨ ∇ l ( z ′ ; θ t ) , ∇ l ( z ; θ t ) ⟩ l(z'; \theta^{t+1}) - l(z'; \theta^{t}) = -\eta_t \cdot \langle \nabla l(z'; \theta^{t}), \nabla l(z; \theta^t) \rangle l(z;θt+1)l(z;θt)=ηtl(z;θt),l(z;θt)⟩

因此,我们就可以定义某一条训练数据对于某一条测试数据在N轮训练当中的影响程度如下:

I n f S G D ( z , z ′ ) = ∑ t = 0 N − 1 l ( z ′ ; θ t ) − l ( z ′ ; θ t + 1 ) = ∑ t = 0 N − 1 η t ⋅ ⟨ ∇ l ( z ′ ; θ t ) , ∇ l ( z ; θ t ) ⟩ \begin{aligned} \mathop{Inf}_{SGD} (z, z') &= \sum\limits_{t=0}^{N-1} l(z'; \theta^{t}) - l(z'; \theta^{t+1}) \\ &= \sum\limits_{t=0}^{N-1} \eta_t \cdot \langle \nabla l(z'; \theta^{t}), \nabla l(z; \theta^t) \rangle \end{aligned} InfSGD(z,z)=t=0N1l(z;θt)l(z;θt+1)=t=0N1ηtl(z;θt),l(z;θt)⟩

2. Adam上的定义

但是,在我们当前的训练过程中,我们更常使用的优化器并不是SGD而是Adam,因此,文中对Adam优化器的情况进行了一下调整。

文中首先回顾了一下Adam优化器的计算:

θ t + 1 − θ t = − η t Γ ( z ; θ t ) Γ ( z ; θ t ) = m t + 1 v t + 1 + ϵ m t + 1 = β 1 m t + ( 1 − β 1 ) ∇ l ( z ; θ t ) 1 − β 1 t v t + 1 = β 2 v t + ( 1 − β 2 ) ∇ l ( z ; θ t ) 2 1 − β 2 t \theta^{t+1} - \theta^{t} = -\eta_t \Gamma (z; \theta^t) \\ \Gamma (z; \theta^t) = \frac{m^{t+1}}{\sqrt{v^{t+1} + \epsilon}} \\ m^{t+1} = \frac{\beta_1 m^t + (1-\beta_1) \nabla l(z; \theta^t)}{1-\beta_{1}^{t}} \\ v^{t+1} = \frac{\beta_2 v^t + (1-\beta_2) \nabla l(z; \theta^t)^2}{1-\beta_{2}^{t}} θt+1θt=ηtΓ(z;θt)Γ(z;θt)=vt+1+ϵ mt+1mt+1=1β1tβ1mt+(1β1)l(z;θt)vt+1=1β2tβ2vt+(1β2)l(z;θt)2

因此,我们可以很直接地将influence的定义迁移至Adam优化器上,得到:

I n f A d a m ( z , z ′ ) = ∑ t = 0 N − 1 l ( z ′ ; θ t ) − l ( z ′ ; θ t + 1 ) = ∑ t = 0 N − 1 η t ⋅ ⟨ ∇ l ( z ′ ; θ t ) , Γ ( z ; θ t ) ⟩ \mathop{Inf}_{Adam} (z, z') = \sum\limits_{t=0}^{N-1} l(z'; \theta^{t}) - l(z'; \theta^{t+1}) = \sum\limits_{t=0}^{N-1} \eta_t \cdot \langle \nabla l(z'; \theta^{t}), \Gamma (z; \theta^t) \rangle InfAdam(z,z)=t=0N1l(z;θt)l(z;θt+1)=t=0N1ηtl(z;θt),Γ(z;θt)⟩

不过实际发现模型的参数梯度与文本长度强相关:

在这里插入图片描述

这就导致直接迁移上述定义公式会使得数据选择明显趋于短文本,因此文中对其进行了一下修正,将其加入了一下归一化因子,最终得到定义式如下:

I n f A d a m ( z , z ′ ) = ∑ t = 0 N − 1 η t ⋅ ⟨ ∇ l ( z ′ ; θ t ) , Γ ( z ; θ t ) ⟩ ∥ ∇ l ( z ′ ; θ t ) ∥ ⋅ ∥ Γ ( z ; θ t ) ∥ \mathop{Inf}_{Adam} (z, z') = \sum\limits_{t=0}^{N-1} \eta_t \cdot \frac{\langle \nabla l(z'; \theta^{t}), \Gamma (z; \theta^t) \rangle}{\lVert \nabla l(z'; \theta^{t}) \rVert \cdot \lVert \Gamma (z; \theta^t) \rVert} InfAdam(z,z)=t=0N1ηtl(z;θt)∥Γ(z;θt)∥l(z;θt),Γ(z;θt)⟩

3. 具体实现

1. Overview

有了上述影响程度的定义之后,文中就可以根据上述influence的大小进行数据选择策略了,具体来说的话,就是:

  • 在训练集上进行少量的tuning作为warmup,然后在验证集上计算所有训练数据当中的influence,最后挑选出影响因子最大的数据进行模型训练。

但是,如果直接使用LLM进行warmup然后进行上述定义下的influence计算时,可以想见其计算量必然极其巨大,和我们最终优化训练效率的目的显然是南辕北辙的,因此,我们必须要优化一下这里的计算效率,具体来说的话,文中就是通过引入LoRA的方法减少总的参数量,然后进行数据的选择。

因此,总的pipeline示意图如下:

在这里插入图片描述

下面,我们就来看看LoRA训练和数据选择的具体细节。

1. LoRA使用

首先的话,文中使用了LoRA来进行模型的finetune,这是因为模型本身的参数量太大了,常规的像是Llama这些都至少有着6B左右的参数量,更别说那些更大的模型了,使用全部参数finetune然后反向推导influence显然成本太大了,典型的舍本逐末,因此,这里使用LoRA进行模型的finetune,可以大幅减少模型的计算量。

2. 数据选择

然后,关于数据选择的部分,文中就是使用上述原理说明部分的内容进行数据选择,具体来说的话就是先使用少量训练数据进行一下warmup,然后使用少部分测试集来计算每一条数据对于模型的影响大小,然后选择出影响最大的几条数据即可。

对应的公式如下:

I n f A d a m ( z , D v a l ) = ∑ t = 0 N − 1 η t ⋅ ⟨ ∇ l ( D v a l ; θ t ) , Γ ( z ; θ t ) ⟩ ∥ ∇ l ( D v a l ; θ t ) ∥ ⋅ ∥ Γ ( z ; θ t ) ∥ \mathop{Inf}_{Adam} (z, D_{val}) = \sum\limits_{t=0}^{N-1} \eta_t \cdot \frac{\langle \nabla l(D_{val}; \theta^{t}), \Gamma (z; \theta^t) \rangle}{\lVert \nabla l(D_{val}; \theta^{t}) \rVert \cdot \lVert \Gamma (z; \theta^t) \rVert} InfAdam(z,Dval)=t=0N1ηtl(Dval;θt)∥Γ(z;θt)∥l(Dval;θt),Γ(z;θt)⟩

3. LESS-T

在上述基础上,文中进一步提出,这里的模型可以具有泛化性,也就是说,使用基于模型A选择出来的数据集 D D D同样有利于另一个模型 B B B的训练。

因此,文中给出了一个LESS-T的数据选择方法,固定使用Llama2 7B模型来进行数据选择,然后在其他模型上进行finetune。

3. 实验考察 & 结论

下面,我们来看一下文中给出的具体实验考察以及对应的结果如下。

1. 实验设计

首先,关于文中的实验设计的话,文中主要是使用MMLU, TYDIQA, BBH三个数据集的测试集,其具体信息如下:

在这里插入图片描述

而实验模型的话主要是Llama2 7B, 13B以及Mistral 7B三个模型,数据选择比例的话则是以5%作为标准。

然后,作为对照组的话,主要是以下几种方法:

  • 随机选择
  • BM25
  • DSIR
  • RDS

其中,BM25和DSIR都是基于词频的选择方法,RDS全称为Representationbased
Data Selection,这部分我倒是完全不知道,有兴趣的读者可以去追一下这个文献看看这具体是个啥。

2. 主要结果

下面,我们来看一下文中给出的具体实验结果。

最直接的一个实验结果显然就是在几个数据集下LESS选择的5%的数据和全量数据训练以及随机选择5%数据的效果差异:

在这里插入图片描述

然后,在Llama2 7B模型上,文中进一步考察了不同的数据选择策略下模型finetune效果的差异:

在这里插入图片描述

基于上述两张表格,文中总结了以下几个主要的实验结论:

  1. 从表2可以看到,LESS在不同模型上都有效,方法具有稳定性
  2. 同样从表2可以看到,使用LESS方法有时选择5%的优质数据的训练效果甚至可以超过全量数据finetune的效果;
  3. 同样从表2可以看到,LESS-T方法在Llama2 13B和Mistral 7B模型上同样有效,说明了LESS数据选择对于模型的泛化性
  4. 从表3可以看到,相较于其他对照组中的方法,LESS 是唯一一个在各个任务下均有效的数据选择策略。

3. 细节讨论

此外,文中还更进一步地做了一些关于LESS的细节讨论。

1. 计算复杂度分析

首先,文中分析了一下LESS方法的整体复杂度,得到结果如下:

在这里插入图片描述

可以看到,LESS的计算量其实还是非常大的。

2. warmup是否必要

然后,文中考察了一下对LoRA的warmup是否必要,得到结果如下:

在这里插入图片描述

可以看到,warmup还是非常必要的。

3. checkpoint的影响(N的影响)

此外,关于文中使用多个checkpoint进行梯度的avg这一点,文中同样说明了一下这个操作的必要性:

在这里插入图片描述

4. LoRA Dimension的影响

最后,文中还考察了一下LoRA模型当中维度对于数据选择的影响:

在这里插入图片描述

可以看到,确实维度越大数据选择效果越好,但是小维度的下已经足以选择出很好的数据带来明显的效果提升了。

4. 总结 & 思考

综上就是陈丹琦大佬提出的LESS方法了,可以看到,在数据选择方面LESS确实给出了非常强大的效果,可以在5%左右的数据上就获得非常优秀的效果,而且数据的选择对模型还有任务都有着足够的泛化性。

但是比较困惑我的一点在于LESS的数据选择计算开销实在是非常大,而且是对于已有的训练数据进行二次提纯选择,而现实中我们的问题其实个人感觉还是更多的像是主动学习那样缺少训练数据因此要对未标注数据进行选择标注,这个问题LESS似乎是无法处理的,当然LESS也不是研究的这个问题就是了。

因此个人感觉LESS的定位就多少有些尴尬了,已有足量训练数据的情况下如此大开销地精炼数据是否真的有足够的价值,多少还是有些怀疑。

不过考虑到之前像是Meta在LIMA这篇工作中提到的那样:只需要少量的优质数据,模型就足以finetune获得非常优秀的效果了。

因此,数据质量的价值可能远高于单纯的数据量的价值,数据精炼的意义可能真的会比想象的更大吧,谁知道呢。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/544035.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Jmeter三个常用组件

Jmeter三个常用组件 一、线程组二、 HTTP请求三、查看结果树 线程组:jmeter是基于线程来运行的,线程组主要用来管理线程的数量,线程的执行策略。 HTTP请求:HTTP请求是jmeter接口测试的核心部分,主要使用HTTP取样器来发…

PyQt5

Qt是基于C实现的GUI,而PyQt就是用python调用Qt. PyQt中有很多的功能模块,开发最常用的模块功能主要有3个 1) QtCore:包含核心的非GHI的功能,主要和时间,文件与文件夹,各种数据,流,URLs,进程与线程一起使用 2) QtGUi:包含窗口系统,事件处理,2D图像,基本绘画,字体和文字类 3)…

《Kubernetes部署篇:基于Kylin V10+ARM架构CPU使用containerd部署K8S 1.26.15集群(一主多从)》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:企业级K8s集群运维实战 1、在当前实验环境中安装K8S1.25.14版本,出现了一个问题,就是在pod中访问百度网站,大…

【opencv】示例-stiching_detailed.cpp 使用OpenCV进行图像拼接的整体流程

#include <iostream> // 引入输入输出流库 #include <fstream> // 引入文件流库&#xff0c;用于文件输入输出 #include <string> // 引入字符串库 #include "opencv2/opencv_modules.hpp" // 引入OpenCV模块 #include <opencv2/core/utility.h…

【微信小程序——开发DAY4(黑马程序员课程)】

学习目标 自定义小程序组件自定义组件&#xff08;1.&#xff09;创建自定义组件文件夹&#xff08;2.&#xff09;引用自定义组件&#xff08;3.&#xff09;组件和页面的区别&#xff08;4.&#xff09;自定义组件的隔离性——自定义组件不影响小程序的样式——自定义组件也只…

用通俗易懂的方式讲解:大模型高级 RAG 检索策略之递归检索

节前&#xff0c;我们组织了一场算法岗技术&面试讨论会&#xff0c;邀请了一些互联网大厂朋友、参加社招和校招面试的同学&#xff0c;针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 基于大模…

LinkSage:基于 GNN 的 Pinterest理解

目录 一、背景二、动机和介绍三、技术设计3.1 数据3.2 图3.3 特征3.4 型 四、主要创新4.1 多维表示4.2 XSage的兼容性4.3 增量服务 五、离线结果5.1 召回5.2 分数分布5.3 峰度 六、在线结果6.1 面向用户的表面6.2 Ads 七、总结 LinkSage&#xff1a;基于图神经网络的Pinterest…

微服务之LoadBalancer负载均衡服务调用

一、概述 1.1什么是负载均衡 LB&#xff0c;既负载均衡&#xff08;Load Balancer&#xff09;,是高并发、高可用系统必不可少的关键组件&#xff0c;其目标是尽力将网络流量平均分发到多个服务器上&#xff0c;以提高系统整体的响应速度和可用性。 负载均衡的主要作用 高并发…

IDEA阅读Java源码 SimpleDateFormat

IDEA阅读Java源码 SimpleDateFormat 文章目录 IDEA阅读Java源码 SimpleDateFormat一、阅读的代码二、IDEA操作2.1 标记断点2.2 启用Debug2.3 按键区分2.4 强制进入方法2.5 进入指定方法2.6 多方法进入指定方法2.7 进入正确的方法2.8 真正的方法体实现 三、SimpleDateFormat源码…

网络篇08 | 运输层 tcp

网络篇08 | 运输层 tcp 01 简介1&#xff09;运输层的作用2&#xff09;与应用层的关系3&#xff09;两个协议的应用场景4&#xff09;传输的数据单位 02 功能特性1&#xff09;面向连接2&#xff09;停止等待协议3&#xff09;流水线传输协议4&#xff09;滑动窗口机制5&#…

011、Python+fastapi,第一个后台管理项目走向第11步:建立python+fastapi项目,简单测试一下

一、说明 本文章就是记录自己的学习过程&#xff0c;如果有用您可以参考&#xff0c;没用你就略过&#xff0c;没有好与不好之分&#xff0c;今天主要是参考了gitee上的一些项目&#xff0c;一步一步的往后i建立 对于学习来说&#xff0c;如果您有java c等经验&#xff0c;py…

Redis的哨兵机制

引入&#xff1a; 主从复制最大的问题还是在主节点上&#xff0c;主节点挂了&#xff0c;从节点就迷茫了&#xff0c;虽然能够提供读操作&#xff0c;但是从节点不能自动生成主节点&#xff0c;不能替换原有主节点对应的角色&#xff1b;此时&#xff0c;就需要程序员/运维手工…

绿联HDMI延长器40265使用AG7120芯片放大器方案

HDMI延长器和放大器 延长器&#xff1a;主要用于HDMI线的延长&#xff0c;有HDMI对接头方式延长&#xff0c;或HDMI公头加HDMI母头的HDMI线进行延长&#xff0c;或通过网线方式延长&#xff0c;早期为双网线&#xff0c;目前已发展为单网线&#xff0c;需要注意的是&#xff0…

L45 【哈工大_操作系统】操作系统接口 系统调用实现

L4 操作系统接口 本节比较简单&#xff0c;故与第五节课程笔记一起发布。本节主要是研究 上层应用 是怎么穿过边界进入 操作系统。 接口&#xff1a;操作系统提供的重要函数/指令( system call )&#xff0c;用来连接硬件&#xff0c;所以OS接口就是系统调用POSIX&#xff08;…

Res2Net网络

Res2Net网络 摘要Abstract1. Res2Net网络1.1 文献摘要1.2 背景1.3 创新点1.4 网络结构1.5 实验1.5.1 在ImageNet数据集上进行实验1.5.2 在CIFAR数据集上进行实验 2. Res2Net代码实现3. 总结 摘要 Res2Net是一种神经网络架构&#xff0c;旨在改善类似ResNet的网络在计算机视觉任…

vscode开发 vue3+ts 的 uni-app 微信小程序项目

创建uni-app项目&#xff1a; # 创建用ts开发的uni-app npx degit dcloudio/uni-preset-vue#vite-ts 项目名称 # 创建用js开发的uni-app npx degit dcloudio/uni-preset-vue#vite 项目名称VS Code 配置 为什么选择 VS Code &#xff1f; HbuilderX 对 TS 类型支持暂不完善VS…

unity记一下如何播放动画

我使用的版本是2022.3.14fc 展开你的模型树&#xff0c;是会出现这个三角形的东西的 然后在资源面板创建一个animation controller 进去之后&#xff0c;把三角形拖进去&#xff0c;就会出现一个动画&#xff0c;然后点击他 在左侧给他创建这么个状态名字&#xff0c;类型…

AskManyAI:一个GPT、Claude、Gemini、Kimi等顶级AI的决斗场

一直以来很多人问我能不能有个稳定&#xff0c;不折腾的全球AI大模型测试网站&#xff0c;既能够保证真实靠谱&#xff0c;又能够保证稳定、快速&#xff0c;不要老动不动就挂了、出错或者漫长的响应。 直到笔者遇到了AskManyAI&#xff0c;直接就惊艳住了&#xff01; 话不多…

第17天:信息打点-语言框架开发组件FastJsonShiroLog4jSpringBoot等

第十七天 本课意义 1.CMS识别到后期漏洞利用和代码审计 2.开发框架识别到后期漏洞利用和代码审计 3.开发组件识别到后期漏洞利用和代码审计 一、CMS指纹识别-不出网程序识别 1.概念 CMS指纹识别一般能识别到的都是以PHP语言开发的网页为主&#xff0c;其他语言开发的网页识…

springboot-tomcat冲突

maven tomcat版本冲突异常 记录一个错误 在项目启动时报错&#xff1a; java.lang.AbstractMethodError: org.apache.tomcat.websocket.server.WsSessionListener.sessionCreated(Ljavax/servlet/http/HttpSessionEvent;)V 引入Session报错。 原因&#xff1a;Springboot默…