论文分享|Arxiv2024‘麦吉尔大学|LLM2Vec—将LLM转换为文本编码器

LLM本身的表征直接用于Embedding,比如用于检索/聚类/STS等任务,效果其实不太好。因此才需要将Embedding模型和大模型区分开来。本文介绍一篇将LLM转换为Embedding模型的工作,代码全开源,值得好好学习。

论文题目:LLM2Vec: Large Language Models Are Secretly Powerful Text Encoders

来源:Arxiv2024/麦吉尔大学

方向:文本编码/LLM

开源地址:https://github.com/McGill-NLP/llm2vec

转换LLM为Text Encoder的三步骤

参考作者的Tutorial ,以LLaMA2FlashAttention为例介绍主要的三个步骤

第一步:实现双向注意力

img

对于llama这样的transformer模型来说,每一层都含有一个自注意力的子层,而这些自注意力在得到embedding的时候都会mask后面的词的attention,只保留前面的词语,即单向注意力。所以需要先将这个掩蔽关闭,每个词都可以和句子中所有词语进行注意力交互,从而实现双向注意力。如下图所示:

img

下面介绍下代码。首先需要修改llama_attention子层的实现。llama_attention有三种实现,在此我们需要将LlamaFlashAttention2中的is_causal设置为False,从而改为双向注意力。其实只修改了__init__函数

class ModifiedLlamaFlashAttention2(LlamaFlashAttention2):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False  # 原Trnasformer实现中是True

LLAMA_ATTENTION_CLASSES = {
    "eager": LlamaAttention,
    "flash_attention_2": ModifiedLlamaFlashAttention2,  # 原是`LlamaFlashAttention2'
    "sdpa": LlamaSdpaAttention,
}

接下来需要将每个Decoder层中的attention改为双向注意力attention实现。同样只需要修改__init__函数。这一段直接从transformers库中复制即可

class ModifiedLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        nn.Module.__init__(self) # Initially, super().__init__()
        self.hidden_size = config.hidden_size
        
        # 这一行将注意力类绑定为LLAMA_ATTENTION_CLASSES中自定义的类
        self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) 

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

最后需要将LlamaModel中的每一层都改为我们修改的decoder层,还是只需要修改__init__函数中。将每个layer中的改为修改的layer

class LlamaBiModel(LlamaModel):
    def __init__(self, config):
        LlamaPreTrainedModel.__init__(self, config) # Initially, super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [ModifiedLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]  # Initially, `LlamaDecoderLayer(config, layer_idx)`
        ) # 这一行中原是LlamaDecoderLayer(config, layer_idx)
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        self.post_init()

至此,通过简单修改了三行代码,就完成了对llama双向注意力的实现

第二步:使用掩蔽下一词预测(masked next token prediction ,MNTP)任务训练

原LlamaForCausalLM类使用的是LlamaModel,需要先修改为上一步的LlamaBiModel

class BiLlamaForMNTP(LlamaForCausalLM):

    def __init__(self, config, attention_dropout=0.0):
        LlamaPreTrainedModel.__init__(self, config) # Initially, super().__init__(config)
        self.model = LlamaBiModel(config)  # 原是LlamaModel(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.post_init()

至此模型部分就修改完毕了,接下来实现损失函数,直接复用LlamaForCausalLM中forward函数里的下一词预测任务,即利用第i-1个token的表征的logits来预测第i个token位置mask的词,但注意由于此时模型已经是双向注意力,所以本质上这其实类似简化版本的掩蔽语言模型masked language modeling)任务。

# LlamaForCausalLM.forward()中的代码片段
loss = None
if labels is not None:
    # 拷贝表征 从而让小于n的tokens表征预测第n个token 
    # contiguous() 其实是强制拷贝一份数据防止和之前的logits对象共享存储空间
    shift_logits = logits[..., :-1, :].contiguous() # (batch_size,n-1,vocab_size) 即每句话从第1个token开始,第n-1个token结束的token表征 
    shift_labels = labels[..., 1:].contiguous() # (batch_size,n-1) 即每句话从第2个token开始,第n个token结束的token序列
    # 展开token表征
    loss_fct = CrossEntropyLoss() #交叉熵损失
    shift_logits = shift_logits.view(-1, self.config.vocab_size) # (batch_size * (n-1), vocab_size)
    shift_labels = shift_labels.view(-1) (batch_size * (n-1))
    # 允许模型并行
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)

训练时,仿造 examples/pytorch/language-modeling/run_mlm.py 脚本。 本文主要有以下修改:

  1. 将模型类别改成了自己实现的语言模型类而不是原脚本的AutoModelForMaskedLM类
  2. 使用PEFT lora进行高效微调
  3. 使用了下划线_ 作为mask token。下面展示了data collator中的mask操作代码
class DataCollatorForLanguageModelingWithFullMasking(DataCollatorForLanguageModeling):
    def torch_mask_tokens(
        self,
        inputs: Any,
        special_tokens_mask: Optional[Any] = None,
    ) -> Tuple[Any, Any]:
        """
        为掩蔽语言模型准备inputs/labels,只要被选中为需要mask的,则100%都设置为mask_token
        """
        import torch

        labels = inputs.clone()
        # 使用一个概率 self.mlm_probability 来为每个句子选择少量token用于MLM训练
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(
                    val, already_has_special_tokens=True
                )
                for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # 只计算masked token处的损失

        # 对于需要mask的token,100%使用mask_token替换
        inputs[masked_indices] = self.tokenizer.convert_tokens_to_ids(
            self.tokenizer.mask_token
        )

        return inputs, labels

第三步:使用无监督/有监督对比学习训练

无论是无监督还是有监督,本质上都是对于每一个锚点(可以是查询也可以是文档),构造正例和负例,利用对比学习损失进行训练。每个查询/文档的表征由最后一层的embedding取平均得到。

img

实验

无监督结果

无监督使用了英文Wikipedia,和SimCSE一致使用LLM同样的句子两次Dropout作为正例,批内其他句子作为负例。测试使用METB。

可以发现,直接使用llama2的token表征效果已经超越了之前无监督的表示,加上MNTP训练后效果有一定提升,加上SimCSE无监督训练后效果有大幅提升

img

有监督结果

有监督使用的训练数据是E5数据 中的公开部分,和对比的BGE等模型保持一致。测试使用MTEB。

在MNTP模型的基础上,直接使用有监督训练,相比先使用SimCSE无监督训练再继续有监督训练效果相差不大。且使用有监督数据训练效果远好于无监督训练。

img


大家好,我是NLP研究者BrownSearch,如果你觉得本文对你有帮助的话,不妨点赞收藏支持我的创作,您的正反馈是我持续更新的动力!如果想了解更多LLM/检索的知识,记得关注我!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/799679.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Qt Mqtt客户端 + Emqx

环境 Qt 5.14.2 qtmqtt mqttx 功能 QT Mqtt客户端 qtmqtt 下载 qtmqtt (注意下载与QT版本相符的库)并使用QT 编译 编译完成后需要的文件: emqx 1.虚拟机中安装emqx,并启动 curl -s https://assets.emqx.com/scripts/install-emqx-deb.sh | sudo bash sudo apt-get inst…

【详解】Spring Cloud概述

🎥 个人主页:Dikz12🔥个人专栏:Spring学习之路📕格言:吾愚多不敏,而愿加学欢迎大家👍点赞✍评论⭐收藏 目录 1. 认识微服务 1.1 单体架构 1.2 集群和分布式架构 1.3 集群和分布式…

【全面介绍Pip换源】

🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步! 🦁Pip换源.⛅ 🦁当使用Pip安装Python软件包时,默认情况下会…

BayesPrism(贝叶斯棱镜法)可提取单细胞数据去卷积后将信息映射至bulkRNA数据

贝叶斯棱镜法作为一种工具可以根据scRNA数据(作为先验模型)去推断bulkRNA数据中肿瘤微环境组成(不同免疫细胞组分/不同细胞群)和基因表达情况。 开发者展示的图片就很形象了,左边图展示了把标注了不同细胞类型的单细胞数据作为先验信息(prior info)的基因信息和bul…

力扣144题:二叉树的先序遍历

给你二叉树的根节点 root ,返回它节点值的 前序 遍历。 示例 1: 输入:root [1,null,2,3] 输出:[1,2,3]示例 2: 输入:root [] 输出:[]示例 3: 输入:root [1] 输出&am…

【云岚到家】-day05-6-项目迁移-门户-CMS

【云岚到家】-day05-6-项目迁移-门户-CMS 4 项目迁移-门户4.1 迁移目标4.2 能力基础4.2.1 缓存方案设计与应用能力4.2.2 静态化技术应用能力 4.3 需求分析4.3.1 界面原型 4.4 系统设计4.4.1 表设计4.4.2 接口与方案4.4.2.1 首页信息查询接口4.4.3.1 数据缓存方案4.4.3.2 页面静…

【绝命Coding助力秋招】Python实现<实习僧>海投脚本

hello hello~ ,这里是绝命Coding——老白~💖💖 ,欢迎大家点赞🥳🥳关注💥💥收藏🌹🌹🌹 💥个人主页:绝命Coding-CSDN博客 &a…

Java 实验三:数组操作以及Java中的方法

一、实验目的 1、掌握数组的声明、初始化、查找、排序等的方式; 2、掌握Java中如何定义一个方法,定义好的方法如何进行调用等。 二、实验环境 1、windows11; 2、JDK1.8,集成开发环境Eclipse。 三、实验内容 1、 定义一个函数,获取某个…

Linux系统搭建轻量级个人博客VanBlog并一键发布公网远程访问

文章目录 前言1. Linux本地部署2. VanBlog简单使用3. 安装内网穿透4. 创建公网地址5. 创建固定公网地址 前言 今天和大家分享如何在Linux Ubuntu系统搭建一款轻量级个人博客VanBlog,并结合cpolar内网穿透软件生成公网地址,轻松实现随时随地远程访问本地…

网络配置命令

文章目录 一、查看网络接口信息 ifconfig1.1 网络接口名称1.2 使用 ifconfig 查看网络接口信息1.2.1 输出示例1.2.2 输出解释 1.3 查看特定网络接口信息1.3.1 输出示例 1.4 查看所有网络接口信息1.5 特殊网络接口 二、修改网络配置文件2.1 配置文件示例2.2 使配置生效2.3 关闭 …

JavaScript日期对象倒计时案例

思路&#xff1a;1.先求出当前时间的总毫秒数 2.再求出所需要求的时间的总毫秒数 3.用所求时间的减去当前时间的可得到倒计时剩余时间 4.最后将所求的倒计时剩余时间转换为天&#xff0c;小时&#xff0c;分钟&#xff0c;秒即可 <!DOCTYPE html> <html lang"en…

1.31、基于长短记忆网络(LSTM)的发动机剩余寿命预测(matlab)

1、基于长短记忆网络(LSTM)的发动机剩余寿命预测的原理及流程 基于长短期记忆网络(LSTM)的发动机剩余寿命预测是一种常见的机器学习应用&#xff0c;用于分析和预测发动机或其他设备的剩余可用寿命。下面是LSTM用于发动机剩余寿命预测的原理和流程&#xff1a; 数据收集&#…

可观察性优势:掌握当代编程技术

反馈循环是我们开发人员工作的关键。它们为我们提供信息&#xff0c;并让我们从用户过去和现在的行为中学习。这意味着我们可以根据过去的反应进行主动开发。 TestComplete 是一款自动化UI测试工具&#xff0c;这款工具目前在全球范围内被广泛应用于进行桌面、移动和Web应用的…

C++ 类和对象 赋值运算符重载

前言&#xff1a; 在上文我们知道数据类型分为自定义类型和内置类型&#xff0c;当我想用内置类型比较大小是非常容易的但是在C中成员变量都是在类(自定义类型)里面的&#xff0c;那我想给类比较大小那该怎么办呢&#xff1f;这时候运算符重载就出现了 一 运算符重载概念&…

npm发布的包如何快速在cnpm上使用

npm发布的包如何快速在cnpm上使用 解决方案 前往淘宝npm镜像官网 搜索插件库并点击同步 等待一分钟即可查看最新版本

Nuxt.js 错误侦探:useError 组合函数

title: Nuxt.js 错误侦探&#xff1a;useError 组合函数 date: 2024/7/14 updated: 2024/7/14 author: cmdragon excerpt: 摘要&#xff1a;文章介绍Nuxt.js中的useError组合函数&#xff0c;用于统一处理客户端和服务器端的错误&#xff0c;提供statusCode、statusMessage和…

PostgreSQL修改最大连接数

在使用PostgreSQL 的时候&#xff0c;经常会遇到这样的错误提示&#xff0c; sorry, too many clients already&#xff0c;这是因为默认PostgreSQL最大连接数是 100, 一般情况下&#xff0c;个人使用时足够的&#xff0c;但是在生产环境&#xff0c;这个连接数是远远不够的&am…

内存函数(C语言)

内存函数 以下函数的头文件&#xff1a;string.h 针对内存块进行处理的函数 memcpy 函数原型&#xff1a; void* memcpy(void* destination, const void* source, size_t num);目标空间地址 源空间地址num&#xff0c;被拷贝的字节个数 返回目标空间的起始地…

火星全球彩色影像图介绍(中分辨率相机)

一、数据基本信息 该数据是利用天问一号轨道器中分辨率相机获取的影像经光度校正、几何校正、全球制图等制作而成的全火星地图数据DOM&#xff0c;每个数据包含一个tif数据文件。该影像图分辨率为76米。 任务型号&#xff1a;天问一号 搭载平台&#xff1a;环绕器 数据获…

2.The DispatcherServlet

The DispatcherServlet Spring的Web MVC框架与许多其他Web MVC框架一样&#xff0c;是请求驱动的&#xff0c;围绕一个中央Servlet&#xff08;即DispatcherServlet&#xff09;设计&#xff0c;该Servlet将请求分派给控制器&#xff0c;并提供其他功能以促进Web应用程序的开发…