NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

NLP文本匹配任务Text Matching [无监督训练]:SimCSE、ESimCSE、DiffCSE 项目实践

文本匹配多用于计算两个文本之间的相似度,该示例会基于 ESimCSE 实现一个无监督的文本匹配模型的训练流程。文本匹配多用于计算两段「自然文本」之间的「相似度」。

例如,在搜索引擎中,我们通常需要判断用户的搜索内容是否相似:

A:蛋黄吃多了有什么坏处    B:吃鸡蛋白过多有什么坏处  ->  不相似
A:蛋黄吃多了有什么坏处    B:蛋黄可以多吃吗         ->  相似
...

那最直觉的思路就是让人工去标注文本对,再喂给模型去学习,这种方法称为基于「监督学习」训练出的模型:

但是,如果我们今天没有这么多的标注数据,只有一大堆的「未标注」数据,我们还能训练一个匹配模型吗?这种不依赖于「人工标注数据」的方式,就叫做「无监督」(或自监督)学习方式。我们今天要讲的 SimCSE, 就是一种「无监督」训练模型。

SimCSE: Simple Contrastive Learning of Sentence Embeddings

1.SimCSE 是如何做到无监督的?

SimCSE 将对比学习(Contrastive Learning)的思想引入到文本匹配中。对比学习的核心思想就是:将相似的样本拉近,将不相似的样本推远

但现在问题是:我们没有标注数据,怎么知道哪些文本是相似的,哪些是不相似的呢?SimCSE 相出了一种很妙的办法,由于预训练模型在训练的时候通常都会使用 dropout 机制。这就意味着:即使是同一个样本过两次模型也会得到两个不同的 embedding。而因为同样的样本,那一定是相似的,模型输出的这两个 embedding 距离就应当尽可能的相近;反之,那些不同的输入样本过模型后得到的 embedding 就应当尽可能的被推远。

具体来讲,一个 batch 内每个句子会过 2 次模型,得到 2 * batch 个向量,将这些句子中通过同样句子得到的向量设置为正例,其他设置为负例。

假设 a1 和 a2 是由句子 a 过两次模型得到的结果,那么一个 batch 内的正负例构建如下所示:

a1a2b1b2c1c2
a1-10010000
a21-1000000
b100-100100
b2001-10000
c10000-1001
c200001-100

其中,对角线上的 - 100 表示自身和自身不做相似度比较。

2. SimCSE 的缺点?

从 SimCSE 的正例构建中我们可以看出来,所有的正例都是由「同一个句子」过了两次模型得到的。这就会造成一个问题:模型会更倾向于认为,长度相同的句子就代表一样的意思。由于数据样本是随机选取的,那么很有可能在一个 batch 内采样到的句子长度是不相同的。

为了解决这个问题,我们最终采取的实现方式为 ESimCSE

3. ESimCSE 解决模型对文本长度的敏感问题

ESimCSE 通过随机重复单词(Word Repetition)的方式来构建正例,巧妙的解决了句子长度敏感性的问题:

ESimCSE: Enhanced Sample Building Method for Contrastive Learning of Unsupervised Sentence Embedding

要想消除模型对句子长度的敏感,我们就需要在构建正例的时候让输入句子的长度发生改变,如下所示:

那么,改变句子长度通常有 3 种方法:随机删除、随机添加、同义词替换,但它们均存在句意变化的风险:

方法原句子变换后的句子句意是否改变
随机删除我 [不] 喜欢你我喜欢你
随机添加今天的饭好吃今天的饭 [不] 好吃
同义词替换小明长得像一只 [狼]小明长得像一只 [狗]

用语义变换后的句子去构建正例,模型效果自然会受到影响。

那如果我们随机重复一些单词呢?

方法原句子变换后的句子句意是否改变
随机重复单词今天天气很好今今天天气很好好
随机重复单词我喜欢你我我喜欢欢你

可以看到,通过随机重复单词,既能够改变句子长度,又不会轻易改变语义。

实现上,假设我们有一个 batch 的句子,我们先依次将每一个句子都进行随机单词重复(产生正例),如下:

origin ->     ['人和畜生的区别', '今天天气很好', '三星手机屏幕是不是最好的?']
repetition -> ['人人和畜生的的区别', '今今天天气很好好', '三星星手机屏屏幕是不是最最好好的?']

随后,我们将 origin 的 embedding(batch,768) 和 repetition 的 embedding(batch,768)做矩阵乘法,可以得到一个矩阵(batch,batch),矩阵对角线上就是正例,其余的均是负例:

句子 a句子 b句子 c
句子 a0.92480.23420.4242
句子 b0.31420.91230.1422
句子 c0.29030.18570.9983

矩阵中第(i,j)个元素代表 origin 列表中的第 i 个元素和 repetition 列表中第 j 个元素的相似度。

接下来就好构建训练标签了,因为 label 都在对角线上,所以第 n 行的 label 就是 n 。

labels = [i for i in range(len(origin))]     # labels = [0, 1, 2]

之后就用 CrossEntropyLoss 去计算并梯度回传就能开始训练啦。

def forward(
        self,
        query_input_ids: torch.tensor,
        query_token_type_ids: torch.tensor,
        doc_input_ids: torch.tensor,
        doc_token_type_ids: torch.tensor,
        device='cpu'
        ) -> torch.tensor:
        """
        传入query/doc对,构建正/负例并计算contrastive loss。

        Args:
            query_input_ids (torch.LongTensor): (batch, seq_len)
            query_token_type_ids (torch.LongTensor): (batch, seq_len)
            doc_input_ids (torch.LongTensor): (batch, seq_len)
            doc_token_type_ids (torch.LongTensor): (batch, seq_len)
            device (str): 使用设备

        Returns:
            torch.tensor: (1)
        """
        query_embedding = self.get_pooled_embedding(
            input_ids=query_input_ids,
            token_type_ids=query_token_type_ids
        )                                                           # (batch, self.output_embedding_dim)

        doc_embedding = self.get_pooled_embedding(
            input_ids=doc_input_ids,
            token_type_ids=doc_token_type_ids
        )                                                           # (batch, self.output_embedding_dim)
        
        cos_sim = torch.matmul(query_embedding, doc_embedding.T)    # (batch, batch)
        margin_diag = torch.diag(torch.full(                        # (batch, batch), 只有对角线等于margin值的对角矩阵
            size=[query_embedding.size()[0]], 
            fill_value=self.margin
        )).to(device)
        cos_sim = cos_sim - margin_diag                             # 主对角线(正例)的余弦相似度都减掉 margin
        cos_sim *= self.scale                                       # 缩放相似度,便于收敛

        labels = torch.arange(                                      # 只有对角上为正例,其余全是负例,所以这个batch样本标签为 -> [0, 1, 2, ...]
            0, 
            query_embedding.size()[0], 
            dtype=torch.int64
        ).to(device)
        loss = self.criterion(cos_sim, labels)

        return loss

4.DiffCSE

结合句子间差异的无监督句子嵌入对比学习方法——DiffCSE主要还是在SimCSE上进行优化(可见SimCSE的重要性),通过ELECTRA模型的生成伪造样本和RTD(Replaced Token Detection)任务,来学习原始句子与伪造句子之间的差异,以提高句向量表征模型的效果。

其思想同样来自于CV领域(采用不变对比学习和可变对比学习相结合的方法可以提高图像表征的效果)。作者提出使用基于dropout masks机制的增强作为不敏感转换学习对比学习损失和基于MLM语言模型进行词语替换的方法作为敏感转换学习「原始句子与编辑句子」之间的差异,共同优化句向量表征。

在SimCSE模型中,采用pooler层(一个带有tanh激活函数的全连接层)作为句子向量输出。该论文发现,采用带有BN的两层pooler效果更为突出,BN在SimCSE模型上依然有效。

①对于掩码概率,经实验发现,在掩码概率为30%时,模型效果最优。
②针对两个损失之间的权重值,经实验发现,对比学习损失为RTD损失200倍时,模型效果最优。

参考链接:https://blog.csdn.net/PX2012007/article/details/127696477

5. 数据集准备

项目中提供了一部分示例数据,我们使用未标注的用户搜索记录数据来训练一个文本匹配模型,数据在 data/LCQMC

若想使用自定义数据训练,只需要仿照示例数据构建数据集即可:

  • 训练集:
喜欢打篮球的男生喜欢什么样的女生
我手机丢了,我想换个手机
大家觉得她好看吗
晚上睡觉带着耳机听音乐有什么害处吗?
学日语软件手机上的
...
  • 测试集:
开初婚未育证明怎么弄?	初婚未育情况证明怎么开?	1
谁知道她是网络美女吗?	爱情这杯酒谁喝都会醉是什么歌	0
人和畜生的区别是什么?	人与畜生的区别是什么!	1
男孩喝女孩的尿的故事	怎样才知道是生男孩还是女孩	0
...

由于是无监督训练,因此训练集(train.txt)中不需要记录标签,只需要大量的文本即可。

测试集(dev.tsv)用于测试无监督模型的效果,因此需要包含真实标签。

每一行用 \t 分隔符分开,第一部分部分为句子A,中间部分为句子B,最后一部分为两个句子是否相似(label)

6.模型训练

修改训练脚本 train.sh 里的对应参数, 开启模型训练:

python train.py \
    --model "nghuyong/ernie-3.0-base-zh" \
    --train_path "data/LCQMC/train.txt" \
    --dev_path "data/LCQMC/dev.tsv" \
    --save_dir "checkpoints/LCQMC" \
    --img_log_dir "logs/LCQMC" \
    --img_log_name "ERNIE-ESimCSE" \
    --learning_rate 1e-5 \
    --dropout 0.3 \
    --batch_size 64 \
    --max_seq_len 64 \
    --valid_steps 400 \
    --logging_steps 50 \
    --num_train_epochs 8 \
    --device "cuda:0"

正确开启训练后,终端会打印以下信息:

...
0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 226.41it/s]
DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 477532
    })
    dev: Dataset({
        features: ['text'],
        num_rows: 8802
    })
})
global step 50, epoch: 1, loss: 0.34367, speed: 2.01 step/s
global step 100, epoch: 1, loss: 0.19121, speed: 2.02 step/s
global step 150, epoch: 1, loss: 0.13498, speed: 2.00 step/s
global step 200, epoch: 1, loss: 0.10696, speed: 1.99 step/s
global step 250, epoch: 1, loss: 0.08858, speed: 2.02 step/s
global step 300, epoch: 1, loss: 0.07613, speed: 2.02 step/s
global step 350, epoch: 1, loss: 0.06673, speed: 2.01 step/s
global step 400, epoch: 1, loss: 0.05954, speed: 1.99 step/s
Evaluation precision: 0.58459, recall: 0.87210, F1: 0.69997, spearman_corr: 
0.36698
best F1 performence has been updated: 0.00000 --> 0.69997
global step 450, epoch: 1, loss: 0.25825, speed: 2.01 step/s
global step 500, epoch: 1, loss: 0.27889, speed: 1.99 step/s
global step 550, epoch: 1, loss: 0.28029, speed: 1.98 step/s
global step 600, epoch: 1, loss: 0.27571, speed: 1.98 step/s
global step 650, epoch: 1, loss: 0.26931, speed: 2.00 step/s
...

logs/LCQMC 文件下将会保存训练曲线图:

7.模型推理

完成模型训练后,运行 inference.py 以加载训练好的模型并应用:

...
    if __name__ == '__main__':
    ...
    sentence_pair = [
        ('男孩喝女孩的故事', '怎样才知道是生男孩还是女孩'),
        ('这种图片是用什么软件制作的?', '这种图片制作是用什么软件呢?')
    ]
    ...
    res = inference(query_list, doc_list, model, tokenizer, device)
    print(res)

运行推理程序:

python inference.py

得到以下推理结果:

[0.1527191698551178, 0.9263839721679688]   # 第一对文本相似分数较低,第二对文本相似分数较高

参考链接:https://github.com/HarderThenHarder/transformers_tasks/blob/main/text_matching/supervised

github无法连接的可以在:https://download.csdn.net/download/sinat_39620217/88214437 下载

更多优质内容请关注公号:汀丶人工智能;会提供一些相关的资源和优质文章,免费获取阅读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/79315.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据结构】 链表简介与单链表的实现

文章目录 ArrayList的缺陷链表链表的概念及结构链表的分类单向或者双向带头或者不带头循环或者非循环 单链表的实现创建单链表遍历链表得到单链表的长度查找是否包含关键字头插法尾插法任意位置插入删除第一次出现关键字为key的节点删除所有值为key的节点回收链表 总结 ArrayLi…

p5.js 渐变填充的实现方式

theme: smartblue 本文简介 p5.js 作为一款艺术类的 canvas 库&#xff0c;对颜色方面的支持是挺下功夫的&#xff0c;比如本文要介绍的渐变方法。 lerpColor() 要实现渐变效果&#xff0c;可以使用 lerpColor() 方法。 lerpColor 的作用是混合两个颜色以找到一个介于它们之间的…

Llama 2免费托管及API提供

Llama 2 是 Meta 最新的文本生成模型&#xff0c;目前其性能优于所有开源替代方案。 推荐&#xff1a;用 NSDT编辑器 快速搭建可编程3D场景 1、强大的Llama 2 它击败了 Falcon-40B&#xff08;之前最好的开源基础模型&#xff09;&#xff0c;与 GPT-3.5 相当&#xff0c;仅低…

Springboot 在 redis 中使用 Guava 布隆过滤器机制

一、导入SpringBoot依赖 在pom.xml文件中&#xff0c;引入Spring Boot和Redis相关依赖 <!-- Google Guava 使用google的guava布隆过滤器实现--><dependency><groupId>com.google.guava</groupId><artifactId>guava</artifactId><vers…

UniApp 制作高德地图插件

1、下载Uni插件项目 在Uni官网下载Uni插件项目&#xff0c;并参考官网插件项目创建插件项目. 开发者须知 | uni小程序SDK 如果下载下来项目运行不了可以参考下面链接进行处理 UniApp原生插件制作_wangdaoyin2010的博客-CSDN博客 2、引入高德SDK 2.1 在高德官网下载对应SD…

redis — 基于Spring Boot实现redis延迟队列

1. 业务场景 延时队列场景在我们日常业务开发中经常遇到&#xff0c;它是一种特殊类型的消息队列&#xff0c;它允许把消息发送到队列中&#xff0c;但不立即投递给消费者&#xff0c;而是在一定时间后再将消息投递给消费者。延迟队列的常见使用场景有以下几种&#xff1a; 在…

学点Selenium玩点新鲜~,让分布式测试有更多玩法

前 言 我们都知道 Selenium 是一款在 Web 应用测试领域使用的自动化测试工具&#xff0c;而 Selenium Grid 是 Selenium 中的一大组件&#xff0c;通过它能够实现分布式测试&#xff0c;能够帮助团队简单快速在不同的环境中测试他们的 Web 应用。 分布式执行测试其实并不是一…

LSTM模型

目录 LSTM模型 LSTM结构图 LSTM的核心思想 细胞状态 遗忘门 输入门 输出门 RNN模型 LRNN LSTM模型 什么是LSTM模型 LSTM (Long Short-Term Memory)也称长短时记忆结构,它是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象.同时LS…

考研算法第46天: 字符串转换整数 【字符串,模拟】

题目前置知识 c中的string判空 string Count; Count.empty(); //正确 Count ! null; //错误c中最大最小宏 #include <limits.h>INT_MAX INT_MIN 字符串使用发运算将字符加到字符串末尾 string Count; string str "liuda"; Count str[i]; 题目概况 AC代码…

QT多屏显示程序

多屏显示的原理其实很好理解&#xff0c;就拿横向扩展来说&#xff1a; 计算机把桌面的 宽度扩展成了 w1&#xff08;屏幕1的宽度&#xff09; w2(屏幕2的宽度) 。 当一个窗口的起始横坐标 > w1&#xff0c;则 他就被显示在第二个屏幕上了。 drm设备可以多用户同时打开&am…

STM32定时器TIM控制

一、CubeMX的设置 1、新建工程&#xff0c;进行基本配置 2、配置定时器TIM2 1&#xff09;定时器计算公式&#xff1a;&#xff08;以下两条公式相同&#xff09; Tout ((ARR1) * PSC1)) / Tclk TimeOut ((Prescaler 1) * (Period 1)) / TimeClockFren Tout TimeOut&…

WordPress更换域名后-后台无法进入,网站模版错乱,css失效,网页中图片不显示。完整解决方案(含宝塔设置)

我在实际解决问题时用到了 【简单暴力解决方案】的《方法一&#xff1a;修改wp-config.php》 和 【简单暴力-且特别粗暴-的解决方案】 更换域名时经常遇到的几个问题&#xff1a; 1、更换域名后&#xff0c;后台无法进入 2、更换域名后&#xff0c;网站模版错乱&#xff0c;c…

第 4 章 链表(1)

4.1链表(Linked List)介绍 链表是有序的列表&#xff0c;但是它在内存中是存储如下 小结: 链表是以节点的方式来存储,是链式存储每个节点包含 data 域&#xff0c; next 域&#xff1a;指向下一个节点.如图&#xff1a;发现链表的各个节点不一定是连续存储.链表分带头节点的链…

Django之定时任务--apscheduler

Django--定时任务apscheduler的使用 apscheduler定时任务的使用1、安装包2、配置settings.py3、在manage.py的文件同级目录下创建文件scheduler.py4、在项目的urls.py中调用这个定时计划5、然后启动项目 python manage.py runserver,在admin中查看就能看到你的定时任务及执行的…

React 高阶组件(HOC)

React 高阶组件(HOC) 高阶组件不是 React API 的一部分&#xff0c;而是一种用来复用组件逻辑而衍生出来的一种技术。 什么是高阶组件 高阶组件就是一个函数&#xff0c;且该函数接受一个组件作为参数&#xff0c;并返回一个新的组件。基本上&#xff0c;这是从 React 的组成…

NVIDIA vGPU License许可服务器高可用全套部署秘籍

第1章 前言 近期遇到比较多的场景使用vGPU&#xff0c;比如Citrix 3D场景、Horizon 3D场景&#xff0c;还有AI等&#xff0c;都需要使用显卡设计研发等&#xff0c;此时许可服务器尤为重要&#xff0c;许可断掉会出现掉帧等情况&#xff0c;我们此次教大家部署HA许可服务器。 …

【腾讯云 Cloud Studio 实战训练营】在线 IDE 编写 canvas 转换黑白风格头像

关于 Cloud Studio Cloud Studio 是基于浏览器的集成式开发环境(IDE)&#xff0c;为开发者提供了一个永不间断的云端工作站。用户在使用Cloud Studio 时无需安装&#xff0c;随时随地打开浏览器就能在线编程。 Cloud Studio 作为在线IDE&#xff0c;包含代码高亮、自动补全、Gi…

基于Redis实现关注、取关、共同关注及消息推送(含源码)

微信公众号访问地址&#xff1a;基于Redis实现关注、取关、共同关注及消息推送(含源码) 推荐文章&#xff1a; 1、springBoot对接kafka,批量、并发、异步获取消息,并动态、批量插入库表; 2、SpringBoot用线程池ThreadPoolTaskExecutor异步处理百万级数据; 3、为什么引入Rediss…

程序员如何利用公网远程访问查询本地硬盘【内网穿透】

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《高效编程技巧》《cpolar》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 公网远程访问本地硬盘文件【内网穿透】 文章目录 公网远程访问本地硬盘文件【内网穿透】前言1. 下载cpolar和Everything软件1.…

小象课堂在线授课教育系统

此项目包含后端全部代码&#xff0c;前端包括后台和web界面的源码&#xff0c;数据库用的mysql,可当作课设或者毕设&#xff0c;还可写入自己的简历中 web界面展示&#xff1a; 前端后台界面展示&#xff1a; 用户管理 课程管理 内容配置 订单管理 系统管理 系统监控