LLM - LLaMA-2 获取文本向量并计算 Cos 相似度

 

目录

一.引言

二.获取文本向量

1.hidden_states 与 last_hidden_states

◆ hidden_states

◆ last_hidden_states 

2.LLaMA-2 获取 hidden_states

◆ model config 

◆ get Embedding

三.获取向量 Cos 相似度

1.向量选择

2.Cos 相似度

3.BERT-whitening 特征白化

四.总结


一.引言

前面提到了两种基于统计的机器翻译评估方法: Rouge 与 BLEU,二者通过统计概率计算 N-Gram 的准确率与召回率,在机器翻译这种回答相对固定的场景该方法可以作为一定参考,但在当前大模型更加多样性的场景以及发散的回答的情况下,Rouge 与 BLEU 有时候并不能更好的描述文本之间的相似度,下面我们尝试从 LLM 大模型提取文本的 Embedding 并进行向量相似度计算。

二.获取文本向量

1.hidden_states 与 last_hidden_states

根据 LLM 模型类型的不同,有的 Model 提供 hidden_states 方法,例如 LLaMA-2-13B,有的模型提供 last_hidden_states 方法,例如 GPT-2。查找模型对应方法 API 可以在 Transformer 官网。

 hidden_states

hidden_states 类型为 typing.Optional[typing.Tuple[torch.FloatTensor]],其提供一个 Tuple[Tensor] 分别记录了每层的输出,完整的解释在参数下方: 

模型在每一层输出处的隐藏状态加上可选的初始嵌入输出。这里我们可以通过打印模型 Layer 和索引从而获取 hidden_states 中隐层的输出。

◆ last_hidden_states 

一些传统的模型例如 GPT-2,还有当下一些的新模型例如 ChatGLM2 都有 last_hidden_states 的 API,可以直接获取最后一层的 Embedding 输出,而如果使用 hidden_states 则只需要通过 [-1] 索引即可获得 last_hidden_states,相比来如前者更全面后者更方便。

2.LLaMA-2 获取 hidden_states

model config 

    config_kwargs = {
        "trust_remote_code": True,
        "cache_dir": None,
        "revision": 'main',
        "use_auth_token": None,
        "output_hidden_states": True
    }

    config = AutoConfig.from_pretrained(ori_model_path, **config_kwargs)

    llama_model = AutoModelForCausalLM.from_pretrained(
        ori_model_path,
        config=config,
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        revision='main'
    )

 根据 CausalLMOutputWithPast hidden_states 参数的提示,我们只需要在模型 config 中添加:

"output_hidden_states": True

get Embedding

def get_embeddings(result, llm_tokenizer, model, args):
    fw = open(args.output, 'w', encoding='utf-8')
    for qa in result:
        q = qa[0]
        a = qa[1]
        # 对输出文本进行 tokenize 和编码
        tokens = llm_tokenizer.encode_plus(a, add_special_tokens=True, padding='max_length', truncation=True,
                                           max_length=128, return_tensors='pt')
        input_ids = tokens["input_ids"]
        attention_mask = tokens['attention_mask']

        # 获取文本 Embedding
        with torch.no_grad():
            outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)
            embedding = list(outputs.hidden_states)
            last_hidden_states = embedding[-1].cpu().numpy()
            first_hidden_states = embedding[0].cpu().numpy()
            last_hidden_states = np.squeeze(last_hidden_states)
            first_hidden_states = np.squeeze(first_hidden_states)
            fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

        log = "%s\t%s\t%s\n" % (q, a, toString(fisrt_larst_avg_status))
        fw.write(log)
    fw.close()

predict  预测       ➔  将 model 基于 Question generate 得到的 Answer 存入 result

encode 编码       ➔  对 Answer 进行编码获取对应 Token 与 input_ids、attention_mask

output 模型输出  ➔  直接调用 model 进行输出,有的也可以调用 model.transform 方法进行输出

hidden_states     ➔  outputs.hidden_states 获取各隐层输出

最后获取的向量需要先 cpu 然后再转为 numpy 数组,一般的做法是采用 mean 获得句子的平均表征。

三.获取向量 Cos 相似度

1.向量选择

在 BERT-flow 的论文中,如果不加任何后处理手段,那么基于 BERT 抽取句向量的最好 Pooling 方法是 BERT 的第一层与最后一层的所有 token 向量的平均,即 fisrt-larst-avg,对应 hidden_state 的 0 和 -1 索引,所以后面的相似度计算我们都以 fisrt-larst-avg 为基准来评估 Embedding 相似度。

# 获取文本 Embedding
with torch.no_grad():
    outputs = model(input_ids=input_ids.cuda(), attention_mask=attention_mask)
    embedding = list(outputs.hidden_states)
    last_hidden_states = embedding[-1].cpu().numpy()
    first_hidden_states = embedding[0].cpu().numpy()
    last_hidden_states = np.squeeze(last_hidden_states)
    first_hidden_states = np.squeeze(first_hidden_states)
    fisrt_larst_avg_status = np.mean(first_hidden_states + last_hidden_states, axis=0)

2.Cos 相似度

# 计算 Cos 相似度
def compute_cosine(a_vec, b_vec):
    norms1 = np.linalg.norm(a_vec, axis=1)
    norms2 = np.linalg.norm(b_vec, axis=1)
    dot_products = np.sum(a_vec * b_vec, axis=1)
    cos_similarities = dot_products / (norms1 * norms2)
    return cos_similarities

a_vec 为预测文本转化得到的 Embedding,b_vec 为人工标注正样本文本转化得到的 Embedding,通过计算二者相似度,评估预测文本与人工文本的相似程度。

3.BERT-whitening 特征白化

苏神在 BERT-whitening 一文中提出了一种基于 PCA 降维的无监督 Embedding 评估方式,Bert-whitening 又叫特征白化,其思路与 PCA 降维类似,意在对 SVD 分解后的主成分矩阵取前 λ 个特征向量构造特征值矩阵,提取向量中的关键信息,使输出向量矩阵每个维度均值为零,协方差矩阵为单位阵,λ 个特征值也对应前 λ 个主成分。其算法逻辑如下:

 下面我们调用 Sklearn 的 PCA 库简单实现下:

from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize

    # 取出句子的平均表示 -> 使用 PCA 降维 -> 白化处理
    concatenate = np.concatenate((answer_vector, predict_vector))
    pca = PCA(n_components=2048)
    pca.fit(concatenate)
    ans_white_vec = pca.transform(answer_vector)
    ans_norm_vec = normalize(ans_white_vec)
    pre_white_vec = pca.transform(predict_vector)
    pre_norm_vec = normalize(pre_white_vec)

    pca_cos_similarities = compute_cosine(ans_norm_vec, pre_norm_vec)

answec_vector 和 predict_vector 均通过 first_and_last 方法从 hidden_states 中获取,n_components 即 top_k 的选择,以 LLaMA-2 为例,原始得到的向量维度为 5120,原文中也有使用 n_components = 256 实验。

四.总结

博主采用 1500+ 样本分别使用 cos、pca 和 self_pca [自己实现 SVD 与特征矩阵] 三种方法对向量相似度进行评估,n_components 设为 1024:

可以看到 SVD 处理后得到的 W 和 mu 的 shape,通过下述操作可完成向量的降维:

vecs = (vecs + bias).dot(kernel)

最终得到的结果 Cosine 与 PCA 降维的相似度差距较大,由于自然语言生成的样本没有严格意义的正样本,上面计算采用的参考文本也是人工标注,有一定的不确定性,所以基于不同的度量,我们也可以统计分析,定一个 threshold,认为大于该 threshold 的输入样本为可用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/98988.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

.ssh文件夹下缺失known_hosts文件

.ssh文件夹下缺失known_hosts文件 先确认工蜂或github 添加了git生成的密钥 然后 桌面打开git bash 1、执行ssh -T gitgitlab.com 2、输入yes

b站手机缓存文件转MP4

b站缓存的文件 音频、视频、弹幕是分开的 这里我只用到了音频和视频所以只介绍这一部分 b站的缓存视频文件和路径结构如下 默认缓存路径 内部存储\Android\data\tv.danmaku.bilil\download\89720189 文件夹结构 文件夹 c_738583 这是单个视频的缓存文件夹 进入c_738583文件夹…

电脑可以上网,微信都可以用,但浏览器打不开网页

可以试试设置DNS(其他windows版本步骤): 1.打开控制面板 2.网络和Internet 3.查看网络计算机和设备 4.按照下图步骤: 5.按下图进行

Unity 之 方括号[ ] 的用法以及作用

文章目录 在Unity中,方括号 [ ] 通常用于表示属性、特性(Attributes)或者元数据(Metadata)。这些标记提供了附加信息,可以用于修改类、方法、字段等的行为或者在编辑器中进行设置。 以下是一些常见的用法&…

基于stm32的烟雾浓度检测报警proteus仿真设计(仿真+程序+讲解)

基于STM32的烟雾浓度检测报警仿真设计(仿真程序讲解) 1.主要功能2.仿真3. 程序4. 资料清单&下载链接 基于STM32的烟雾浓度检测报警仿真设计(仿真程序讲解) 仿真图proteus 8.9 程序编译器:keil 5 编程语言:C语言 设计编号&a…

Huggingface托管机器学习模型及API提供

推荐:用 NSDT编辑器 快速搭建可编程3D场景 我想在我的网络和移动应用程序中使用机器学习模型,但要做到这一点,我必须在某个地方托管我的机器学习应用程序。 托管预先训练的 ML 模型称为推理。 我只想添加一些 Python ML 代码并快速获得 REST…

C语言:字符函数和字符串函数(一篇拿捏字符串函数!)

目录 求字符串长度: 1. strlen(字符串长度) 长度不受限制函数: 2. strcpy(字符串拷贝) 3. strcat(字符串追加) 4. strcmp(字符串比较) 长度受限制函数: 5. strncpy(字符串拷贝) 6. strncat(字符串追加) 7. strncmp(字符串比较) 字…

MR混合现实汽车维修情景实训教学演示

MR混合现实技术应用于汽车维修课堂中,能够赋予学生更加真实,逼真地学习环境,让学生在情景体验中不断提高自己的专业能力。 MR混合现实汽车维修情景实训教学演示具体体现在: 1. 虚拟维修指导:利用MR技术,可…

upgrade pip报错:def read(rel_path: str) -> str: syntaxerror

命令行执行以下命令就可以大功告成! wget https://bootstrap.pypa.io/pip/2.7/get-pip.py python get-pip.py pip install --upgrade setuptools最后大功告成:

Vue3+ts封装一个简单版的Message组件

Vue3ts封装一个Message组件 项目中需要使用信息提示框的功能,ui组件库使用的是字节的arco-design-vue。看了一下,现有的Message不满足要是需求,直接使用message组件的话,改样式太麻烦。Notification组件样式倒是符合了&#xff0c…

使用maven创建springboot项目

创建maven快速启动项目 命令行或者idea、eclipse快捷创建也可以 pom.xml下project项目下导入springboot 父工程 <!--导入springboot 父工程--> <parent><artifactId>spring-boot-starter-parent</artifactId><groupId>org.springframework.bo…

Flink+Paimon多流拼接性能优化实战

目录 &#xff08;零&#xff09;本文简介 &#xff08;一&#xff09;背景 &#xff08;二&#xff09;探索梳理过程 &#xff08;三&#xff09;源码改造 &#xff08;四&#xff09;修改效果 1、JOB状态 2、Level5的dataFile总大小 3、数据延迟 &#xff08;五&…

研华I/O板卡 Win10+Qt+Cmake 开发环境搭建

文章目录 一.研华I/O板卡 Win10QtCmake 开发环境搭建 一.研华I/O板卡 Win10QtCmake 开发环境搭建 参考这个链接安装研华I/O板卡驱动程序系统环境变量添加研华板卡dll Qt新建一个c项目 cmakeList.txt中添加研华库文件 cmake_minimum_required(VERSION 3.5)project(advantechDA…

科技资讯|苹果发布新专利:可在车内定位苹果的智能设备

根据美国商标和专利局近期公示的清单&#xff0c;苹果公司获得了一项名为《车内定位移动设备的系统和方式》专利&#xff0c;概述了在车内狭窄空间内如何定位 iPhone 等移动设备。 Find My 服务现阶段没有使用 UWB 来追踪 iPhone 或者 iPad&#xff0c;而是依赖 GPS 等相关辅…

【Java并发】聊聊对象内存布局和syn锁升级过程

对象存储解析&#xff1a;一个空Object对象到底占据多少内存&#xff1f; 对象内存布局 Mark Word占用8字节&#xff0c;类型指针占用8个字节&#xff0c;对象头占用16个字节。 好了&#xff0c;我们来看一下一个Object对占用多少空间&#xff0c; 因为java默认是开启压缩…

Spring框架知识点汇总

01.Spring框架的基本理解 关键字&#xff1a;核心思想IOC/AOP&#xff0c;作用&#xff08;解耦&#xff0c;简化&#xff09;&#xff0c;简单描述框架组成&#xff1b; Spring框架是一款轻量级的开发框架&#xff0c;核心思想是IOC&#xff08;反转控制&#xff09;和AOP&a…

【C++】输入输出及格式控制

在各类算法竞赛和机试中&#xff0c;对测试数据和输出格式往往会有明确的规定&#xff0c;笔者结合个人刷题经历&#xff0c;得到了以下C语言输入输出控制的方法。 cin&#xff1a;从缓冲区中读取数据 cin>>从缓冲区中读取数据时&#xff0c;若缓冲区中第一个字符是空格…

DBeaver 23.1.5 发布

导读DBeaver 是一个免费开源的通用数据库工具&#xff0c;适用于开发人员和数据库管理员。DBeaver 23.1.5 现已发布&#xff0c;更新内容如下. Data editor 重新设计了词典查看器面板 UI 空间数据类型&#xff1a;曲线几何线性化已修复 数据保存时结果选项卡关闭的问题已解决…

2第一个Java程序

目录 1第一个Java代码 2类class 3运行Java文件 1第一个Java代码 public class Hello {public static void main(String[] args) {System.out.println("Hello, world!");} } 2类class public class Hello {public static void main(String[] args) {System.ou…