从预训练的BERT中提取Embedding

文章目录

    • 背景
    • 前置准备
    • 思路
    • 利用Transformer 库实现

背景

假设要执行一项情感分析任务,样本数据如下
在这里插入图片描述
可以看到几个句子及其对应的标签,其中1表示正面情绪,0表示负面情绪。我们可以利用给定的数据集训练一个分类器,对句子所表达的情感进行分类。

前置准备

# 安装modelscope包
pip install modelscope
# 下载 bert-base-uncased 模型
modelscope download --model AI-ModelScope/bert-base-uncased

思路

  1. 分词:以第一句为例,我们使用WordPiece对句子进行分词,并得到标记(单词),如下所示。

    tokens = [I, love, Paris]

  2. 添加标记:在开头添加[CLS]标记,在结尾添加[SEP]标记,如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP] ]

  3. 填充:为了保持所有标记的长度一致,我们将数据集中的所有句子的标记长度设为7。句子I loveParis的标记长度是5,为了使其长度为7,需要添加两个标记来填充,即[PAD]。因此,新标记如下所示。

    tokens = [ [CLS], I, love, Paris, [SEP], [PAD], [PAD] ]

    添加两个[PAD]标记后,标记的长度达到所要求的7。

  4. 注意力掩码:下一步,要让模型理解[PAD]标记只是为了匹配标记的长度,而不是实际标记的一部分。为了做到这一点,我们需要引入一个注意力掩码。我们将所有位置的注意力掩码值设置为1,将[PAD]标记的位置设置为0,如下所示。

    attention_mask = [ 1, 1, 1, 1, 1, 0, 0]

  5. 映射到token id:然后,将所有的标记映射到一个唯一的标记ID。假设映射的标记ID如下所示。

    token_ids = [101, 1045, 2293, 3000, 102, 0, 0]

    ID 101表示标记[CLS],1045表示标记I,2293表示标记love,以此类推。

    现在,我们把token_ids和attention_mask一起输入预训练的BERT模型,并获得每个标记的特征向量(嵌入)。通过代码,我们可以进一步理解以上步骤。下图显示的标记+单词而不是id,但实际传入的是id
    在这里插入图片描述

以上,可以得到每个单词的Embedding,整个句子的Embedding是 R [ C L S ] R_{[CLS]} R[CLS]

利用Transformer 库实现

from transformers import BertModel, BertTokenizer
import torch
# 下载并加载预训练的模型
model = BertModel.from_pretrained('bert-base-uncased')
# 下载并加载用于预训练模型的词元分析器。
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 下面,让我们看看如何对输入进行预处理。
# 0. 对输入进行预处理假设输入句如下所示。
sentence = 'I love Paris'
# 1. 分词
tokens = tokenizer.tokenize(sentence)
print(tokens) # ['i', 'love', 'paris']

# 2. 添加标记
tokens = ['[CLS]'] + tokens + ['[SEP]']
print(tokens) # ['[CLS]', 'i', 'love', 'paris', '[SEP]']

# 3. 填充
tokens = tokens + ['[PAD]'] + ['[PAD]']
print(tokens) #['[CLS]', 'i', 'love', 'paris', '[SEP]', '[PAD]', '[PAD]' ]

# 4. 注意力掩码
attention_mask = [1 if i!= '[PAD]' else 0 for i in tokens]
print(attention_mask) # [1, 1, 1, 1, 1, 0, 0]

# 5. 将所有标记转换为它们的标记ID
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids) # [101, 1045, 2293, 3000, 102, 0, 0]

# 6. 将token_ids和attention_mask转换为张量
token_ids = torch.tensor(token_ids).unsqueeze(0)
attention_mask = torch.tensor(attention_mask).unsqueeze(0)


# 7. 将token_ids和atten-tion_mask送入模型,并得到嵌入向量。
# 需要注意,model返回的输出是一个有两个值的元组。第1个值hidden_rep表示隐藏状态的特征,它包括从顶层编码器(编码器12)获得的所有标记的特征。第2个值cls_head表示[CLS]标记的特征。
hidden_rep, cls_head = model(token_ids, attention_mask = attention_mask)
print(hidden_rep.shape) # torch.Size([1, 7, 768])

'''
数组[1, 7, 768]表示[batch_size, se-quence_length, hidden_size],也就是说,批量大小设为1,序列长度等于标记长度,即7。因为有7个标记,所以序列长度为7。隐藏层的大小等于特征向量(嵌入向量)的大小,在BERT-base模型中,其为768。
* hidden_rep[0][0]给出了第1个标记[CLS]的特征。   
* hidden_rep[0][1]给出了第2个标记I的特征。   
* hidden_rep[0][2]给出了第3个标记love的特征
'''

print(cls_head.shape) # torch.Size([1, 768])

'''
大小[1, 768]表示[batch_size, hid-den_size]。我们知道cls_head持有句子的总特征,所以,可以用cls_head作为句子I love Paris的整句特征。
'''


以上获得的是从顶层编码器(编码器12)获得的特征,如果要获取所有编码器的特征,需要修改以下两个地方。

# 下载并加载预训练的模型时,设置output_hidden_states = True
model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states = True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 调用模型时,产生的是三元组
last_hidden_state, pooler_output, hidden_states = model(token_ids, attention_mask = attention_mask)

'''
* last_hidden_state,它仅有从最后的编码器(编码器12)中获得的所有标记的特征
* pooler_output表示来自最后的编码器的[CLS]标记的特征,它被一个线性激活函数和tanh激活函数进一步处理。
* hidden_states包含从所有编码器层获得的所有标记的特征。它是一个包含13个值的元组,含有所有编码器层(隐藏层)的特征,即从输入嵌入层h到最后的编码器层h。
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/952249.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Kafka 会丢消息吗?

目录 01 生产者(Producer) 02 消息代理(Broker) 03 消费者(Consumer) 来源:Kafka 会丢消息吗? Kafka 会丢失信息吗? 许多开发人员普遍认为,Kafka 的设计本身就能保证不会丢失消息。然而,Kafka 架构和配置的细微差别会导致消息的丢失。我们需要了解它如何以及何时…

RPM包安装Nginx部署Zr.Admin Vue2前端

0、确认node环境 安装node 参考Linux Red Hat安装包安装nodejs 设置 registry 执行 npm set registry https://registry.npmmirror.com/ 1、安装依赖 进入代码根目录 cd /lsp/code/zradmin/ZR.Vue 安装依赖 执行 npm install 最后生成node_modules文件夹 2、测试运行 …

【SpringBoot】入门

【SpringBoot】入门 一、前言1 什么是 Spring Boot ?2 特点与优势3 发展历程 二、项目构建【基于IDEA】1 创建 Spring Boot 项目2 项目结构3 运行 Spring Boot 项目 三、补充1 Maven操作2 项目打 jar 包在本地运行 一、前言 1 什么是 Spring Boot ? Spring Boot最开始基于S…

如何在本地部署大模型并实现接口访问( Llama3、Qwen、DeepSeek等)

如何在本地部署大模型并实现接口访问( Llama3、Qwen、DeepSeek等) 如何在本地部署大模型并实现接口访问( Llama3、Qwen、DeepSeek等)模型地址模型下载模型部署指定显卡运行app.py 运行环境requirements 调用接口代码调用 结语 如何…

vmware-ubuntu22.04配置虚拟机win10,重新上网成功

打开问题显示 Hardware配置 Options配置 最后的Advanced,第一次用了BIOS,然后启动中有更新,然后关闭,再用UEFI启动

GDPU Android移动应用 重点习题集

目录 程序填空 ppt摘选 题目摘选 “就这两页ppt,你还背不了吗” “。。。” 打开ppt后 “Sorry咯,还真背不了😜” 更新日志 考后的更新日志 没想到重点勾了一堆,还愣是没考到其中的内容,翻了一下,原…

运行.Net 7 Zr.Admin项目(后端)

1.下载Zr.Admin代码压缩包 https://codeload.github.com/izhaorui/Zr.Admin.NET/zip/refs/heads/main 2.打开项目 我这里装的是VS2022社区版 进入根目录,双击ZRAdmin.sln打开项目 3.安装.net7运行时 我当时下载的代码版本是.net7的 点击安装 点击安装&#xff0…

代码随想录算法训练营第3天(链表1)| 203.移除链表元素 707.设计链表 206.反转链表

一、203.移除链表元素 题目:203. 移除链表元素 - 力扣(LeetCode) 视频:手把手带你学会操作链表 | LeetCode:203.移除链表元素_哔哩哔哩_bilibili 讲解:代码随想录 注意: 针对头结点和非头结点的…

2024年总结及2025年目标之关键字【稳进】

1. 感受 时光荏苒,都731天(2年时间)下来了,从第一年的【坚持】,到第二年的【提速】,定目标,现在回头看,还是那句话【事非经过不知难】,那又怎么样呢,再难不是…

qt QLabel QPushButton 控件重写paintEvent后 控件消失

qt 继承自PushButton控件的类 重写paintEvent后 控件消失 解决办法,在paintevent结尾加上这条语句:QPushButton::paintEvent(event); void MyButton::paintEvent(QPaintEvent *event) {QPushButton::paintEvent(event); } 这里QPushButton不能写成Q…

苹果手机(IOS系统)出现安全延迟进行中如何关闭?

苹果手机(IOS系统)出现安全延迟进行中如何关闭? 一、设置二、隐私与安全性三、失窃设备保护关闭 一、设置 二、隐私与安全性 三、失窃设备保护关闭

ELK的搭建

ELK elk:elasticsearch logstatsh kibana统一日志收集系统 elasticsearch:分布式的全文索引引擎点非关系型数据库,存储所有的日志信息,主和从,最少需要2台 logstatsh:动态的从各种指定的数据源,获取数据…

从CentOS到龙蜥:企业级Linux迁移实践记录(龙蜥开局)

引言: 在我们之前的文章中,我们详细探讨了从CentOS迁移到龙蜥操作系统的基本过程和考虑因素。今天,我们将继续这个系列,重点关注龙蜥系统的实际应用——特别是常用软件的安装和配置。 龙蜥操作系统(OpenAnolis&#…

【AI自动化渗透】大模型支持的自动化渗透测试,看蚂蚁和浙大的

参考文章: https://mp.weixin.qq.com/s/WTaO54zRxtNMHaiI1tfdGw 最近,美国西北大学,浙江大学,蚂蚁集团的一些专家学者联手发表了一篇论文,介绍了一个PentestAgent的方案,实现了渗透测试自动化。 01 技术方案 图的字…

探秘block原理

01 概述 在iOS开发中,block大家用的都很熟悉了,是iOS开发中闭包的一种实现方式,可以对一段代码逻辑进行封装,使其可以像数据一样被传递、存储、调用,并且可以保存相关的上下文状态。 很多block原理性的文章都比较老&am…

【Java】-- 利用 jar 命令将配置文件添加到 jar 中

目录 1、准备 2、目标 3、步骤 3.1、安装 jdk 3.2、添加配置文件 3.3、校验 1、准备 java 环境hadoop-core-1.2.1.jar 和 core-site.xml 2、目标 将 core-site.xml 添加到 hadoop-core-1.2.1.jar 中。 3、步骤 3.1、安装 jdk 3.2、添加配置文件 jar -cvf hadoop-core-…

概率论与数理统计总复习

复习课本:中科大使用的教辅《概率论和数理统计》缪柏其、张伟平版本 目录 0.部分积分公式 1.容斥原理 2.条件概率 3.全概率公式 4.贝叶斯公式 5.独立性 6.伯努利分布(两点分布) 7.二项分布 8.帕斯卡分布(负二项分布&am…

阿里mod_asr3.0集成webrtc静音算法

alibabacloud-nls-cpp-sdk-master 先到阿里官网下载nls库的源代码,编译生成对应的库文件和头文件。 我编译的放到了以下目录。 /home/jp/2025/alibabacloud-nls-cpp-sdk-master/build/install/NlsSdk3.X_LINUX/include/ /home/jp/2025/alibabacloud-nls-cpp-sdk-…

机器人碳钢去毛刺,用大扭去毛刺主轴可轻松去除

在碳钢精密加工的最后阶段,去除毛刺是确保产品质量的关键步骤。面对碳钢这种硬度较高的材料,采用大扭矩的SycoTec去毛刺主轴,成为了行业内的高效解决方案。SycoTec作为精密加工领域的领军品牌,其生产的高速电主轴以其卓越的性能&a…

iOS 逆向学习 - Inter-Process Communication:进程间通信

iOS 逆向学习 - Inter-Process Communication:进程间通信 一、进程间通信概要二、iOS 进程间通信机制详解1. URL Schemes2. Pasteboard3. App Groups 和 Shared Containers4. XPC Services 三、不同进程间通信机制的差异四、总结 一、进程间通信概要 进程间通信&am…