bert提取词向量比较两文本相似度

使用 bert-base-chinese 预训练模型做词嵌入(文本转向量)

模型下载:bert预训练模型下载-CSDN博客

参考文章:使用bert提取词向量

下面这段代码是一个传入句子转为词向量的函数

from transformers import BertTokenizer, BertModel
import torch

# 加载中文 BERT 模型和分词器
model_name = "../bert-base-chinese"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)


def get_word_embedding(sentence):
    # 分词
    tokens = tokenizer.tokenize(sentence)
    # 添加特殊标记 [CLS] 和 [SEP]
    tokens = ['[CLS]'] + tokens + ['[SEP]']
    # 将分词转换为对应的编号
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    # 转换为 PyTorch tensor 格式
    input_ids = torch.tensor([input_ids])

    # 获取词向量
    outputs = model(input_ids)

    # outputs[0]是词嵌入表示
    embedding = outputs[0]
    # 去除头尾标记的向量值
    word_embedding = embedding[:, 1:-1, :]

    return word_embedding

embedding[:, 1:-1, :] 这一行的意是以下,数据类型张量

[batch_size, sequence_length, hidden_size],其中:

  • batch_size 是输入文本的批次大小,即一次输入的文本样本数量。
  • sequence_length 是输入文本序列的长度,即编码器输入的词的数量。
  • hidden_size 是隐藏状态的维度大小,是 BERT 模型的超参数,通常为 768 或 1024。

比较两文本相似度

def compare_sentence(sentence1, sentence2):
    # 分词
    tokens1 = tokenizer.tokenize(sentence1)
    tokens2 = tokenizer.tokenize(sentence2)
    # 添加特殊标记 [CLS] 和 [SEP]
    tokens1 = ['[CLS]'] + tokens1 + ['[SEP]']
    tokens2 = ['[CLS]'] + tokens2 + ['[SEP]']
    # 将分词转换为对应的词表中的索引
    input_ids1 = tokenizer.convert_tokens_to_ids(tokens1)
    input_ids2 = tokenizer.convert_tokens_to_ids(tokens2)
    # 转换为 PyTorch tensor 格式
    input_ids1 = torch.tensor([input_ids1])
    input_ids2 = torch.tensor([input_ids2])

    # 获取词向量
    outputs1 = model(input_ids1)
    outputs2 = model(input_ids2)

    # outputs[0]是词嵌入表示
    embedding1 = outputs1[0]
    embedding2 = outputs2[0]
    # 提取 [CLS] 标记对应的词向量作为整个句子的表示
    sentence_embedding1 = embedding1[:, 0, :]
    sentence_embedding2 = embedding2[:, 0, :]

    # 计算词的欧氏距离
    # 计算p范数距离的函数,其中p设置为2,这意味着它将计算的是欧几里得距离(L2范数)
    euclidean_distance = torch.nn.PairwiseDistance(p=2)
    distance = euclidean_distance(sentence_embedding1, sentence_embedding2)
    # 计算余弦相似度
    # dim=1 表示将在第一个维度(通常对应每个样本的特征维度)上计算余弦相似度;eps=1e-6 是为了数值稳定性而添加的一个很小的正数,以防止分母为零
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    similarity = cos(sentence_embedding1, sentence_embedding2)

    print("句1: ", sentence1)
    print("句2: ", sentence2)
    print("相似度: ", similarity.item())
    print("欧式距离: ", distance.item())


compare_sentence("黄河南大街70号8门", "皇姑区黄河南大街70号8门")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/356617.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年了,是谁还在学C++11?(没错,是我)

今天要聊的这本书, 是真正畅行全球20年的C入门必读经典,各版本全球总销量超1300万册! 它惠及了数百万高校师生,启蒙了5代国产程序员, 令全球数千万C开发者全部为之疯狂的大!师!名&#xff01…

实现Crm系统的灵活配置,满足不同行业客户需求

目录 一:数据模型配置 二:流程配置 三:扩展性配置 实现CRM系统的可配置性需要关注以下几个方面: 一:数据模型配置 为了满足企业的个性化需求,CRM系统需要提供灵活的数据模型配置。用户可以根据自己的业…

秋招面试—计算机网络安全

2021 计算机网络安全 1.Get 和 Post 的区别 get 用于获取数据,post用于提交数据; get 的缓存保存在浏览器和web服务器日志中; get 使用明文传输,post请求保存在请求体中; get 长度限制在2048以内 2.常见的HTTP请…

CVE-2024-0352 likeshop v2.5.7文件上传漏洞分析

本次的漏洞研究基于thinkPHP开发开的一款项目..... 漏洞描述 Likeshop是Likeshop开源的一个社交商务策略的完整解决方案,开源免费版基于thinkPHP开发。Likeshop 2.5.7.20210311及之前版本存在代码问题漏洞,该漏洞源于文件server/application/api/contr…

pytest教程-8-用例参数化方法

领取资料,咨询答疑,请➕wei: June__Go 上一小节中我们学习了pytest用例前后置方法的使用,本小节我们讲解一下pytest用例的参数化方法。 参数化简介: 参数化测试是指在测试用例中通过传入不同的参数来运行多次测试,…

图像复原的天花板在哪里?SUPIR:开创性结合文本引导先验和模型规模扩大

SUPIR(Scaling-UP Image Restoration),这是一种开创性的图像复原方法,利用生成先验和模型扩大规模的力量。通过利用多模态技术和先进的生成先验,SUPIR在智能和逼真的图像复原方面取得了重大进展。作为SUPIR中的关键催化…

纵向拼接,一键高效,让图片处理更简单!

你是否曾经因为需要批量处理图片而感到烦恼?现在,有了我们的图片处理工具,你可以轻松地纵向拼接图片,一键批量处理,让图片处理工作更加高效!这款工具采用先进的技术,能够快速准确地完成图片纵向…

Android SystemUI 介绍

目录 一、什么是SystemUI 二、SystemUI应用源码 三、学习 SystemUI 的核心组件 四、修改状态与导航栏测试 本篇文章,主要科普的是Android SystemUI , 下一篇文章我们将介绍如何把Android SystemUI 应用转成Android Studio 工程项目。 一、什么是Syst…

大数据 - Spark系列《一》- 分区 partition数目设置详解

目录 🐶3.2.1 分区过程 🐶3.2.2 SplitSize计算和分区个数计算 🐶3.2.3 Partition的数目设置 1. 🥙对于数据读入阶段,输入文件被划分为多少个InputSplit就会需要多少初始task. 2. 🥙对于转换算子产生的…

在centos 7 中安装配置Jdk、Tomcat、及Tomcat自启动

目录 一、安装配置Jdk 1.创建目录并上传文件 2.解压JDK压缩包 3.配置JDK环境变量 4.设置环境变量生效 二、安装配置Tomcat 1.上传Tomcat并解压 2.启停Tomcat 3.修改tomcat-user.xml配置 4.配置远程访问Tomcat 5.远程项目发布 三.Tomcat自启动配置 1.配置Tomcat自启…

imx6ull学习记录(一)

这一块主要是了解linux系统驱动部分,编译镜像相关的知识,这里记录一下。 使用板子如下: 教程用的这一个版本: 1、基本环境搭建 这个比较简单,只是注意一下就是正点原子的教程用了一个NFS文件系统,简单来…

MongoDB介绍及安装

文章目录 MongoDB介绍什么是MongoDBMongoDB技术优势MongoDB应用场景 MongoDB快速开始linux安装MongoDB启动MongoDB Server关闭MongoDB服务 Mongo shell使用mongo shell常用命令数据库操作集合操作 安全认证创建管理员账号常用权限创建应用数据库用户 Docker安装MongoDB工具官方…

物流平台如何与电商平台进行自动化流程管理

为什么要实现物流与电商平台进行自动化管理 实现物流平台与电商平台的自动化流程管理对企业和消费者都有着重要的意义,比如以下几点: 提高效率:自动化流程管理可以减少人为操作的错误和延误,提高订单处理和物流配送的效率。通过定…

What is Rust? Why Rust?

why Rust? 目前,Rust 变得越来越流行。然而,仍然有很多人(和公司!)误解了 Rust 的主张价值是什么,甚至误解了它是什么。在本文中,我们将讨论 Rust 是什么以及为什么它是一种可以增强…

Pytest单元测试框架

第一章、pytest概述 Pytest is a framework that makes building simple and scalable tests easy. Tests are expressive and readable—no boilerplate code required. Get started in minutes with a small unit test or complex functional test for your application or l…

Linux提权:Docker组挂载 Rsync未授权 Sudo-CVE Polkit-CVE

目录 Rsync未授权访问 docker组挂载 Sudo-CVE漏洞 Polkit-CVE漏洞 这里的提权手法是需要有一个普通用户的权限,一般情况下取得的webshell权限可能不够 Rsync未授权访问 Rsync是linux下一款数据备份工具,默认开启873端口 https://vulhub.org/#/envir…

第九节HarmonyOS 常用基础组件17-ScrollBar

1、描述 滚动条组件ScrollBar,用于配合可滚动组件使用,如List、Grid、Scroll。 2、接口 可包含子组件 ScrollBar(value:{scroller:Scroller, direction?: ScrollBarDirection, state?: BarState}) 3、参数 参数名 参数类型 必填 描述 scrolle…

148基于matlab的带有gui的轮轨接触几何计算程序

基于matlab的带有gui的轮轨接触几何计算程序,根据不同的踏面和轨头,计算不同横移量下面的接触点位置。程序已调通,可直接运行。 148 matlab 轮轨接触 横移量 (xiaohongshu.com)

Android App开发基础(2)—— App的工程结构

本专栏文章 上一篇 Android开发修炼之路——(一)Android App开发基础-1 2 App的工程结构 本节介绍App工程的基本结构及其常用配置,首先描述项目和模块的区别,以及工程内部各目录与配置文件的用途说明;其次阐述两种级别…

【qt】switchBtn

方法1 在qtdesigner中设置按钮图标的三个属性,normal off 、normal on和checkabletrue。 from PyQt5.QtWidgets import * from PyQt5.QtGui import * from PyQt5.QtCore import * from PyQt5 import uic from switchBtn import Ui_Dialogclass Test(QDialog, Ui_…