bert 相似度任务训练,简单版本

目录

任务

代码

train.py

predit.py

数据


任务

使用 bert-base-chinese 训练相似度任务,参考:微调BERT模型实现相似性判断 - 知乎

参考他上面代码,他使用的是 BertForNextSentencePrediction 模型,BertForNextSentencePrediction 原本是设计用于下一个句子预测任务的。在BERT的原始训练中,模型会接收到一对句子,并试图预测第二个句子是否紧跟在第一个句子之后;所以使用这个模型标签(label)只能是 0,1,相当于二分类任务了

但其实在相似度任务中,我们每一条数据都是【text1\ttext2\tlabel】的形式,其中 label 代表相似度,可以给两个文本打分表示相似度,也可以映射为分类任务,0 代表不相似,1 代表相似,他这篇文章利用了这种思想,对新手还挺有用的。

现在我搞了一个招聘数据,里面有办公区域列,处理过了,每一行代表【地址1\t地址2\t相似度】

只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

利用这数据微调,没有使用验证数据集,就最后使用测试集来看看效果。

代码

train.py

import json
import torch
from transformers import BertTokenizer, BertForNextSentencePrediction
from torch.utils.data import DataLoader, Dataset


# 能用gpu就用gpu
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

bacth_size = 32
epoch = 3
auto_save_batch = 5000
learning_rate = 2e-5


# 准备数据集
class MyDataset(Dataset):
    def __init__(self, data_file_paths):
        self.texts = []
        self.labels = []
        # 分词器用默认的
        self.tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
        # 自己实现对数据集的解析
        with open(data_file_paths, 'r', encoding='utf-8') as f:
            for line in f:
                text1, text2, label = line.split('\t')
                self.texts.append((text1, text2))
                self.labels.append(int(label))

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text1, text2 = self.texts[idx]
        label = self.labels[idx]
        encoded_text = self.tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        return encoded_text, label


# 训练数据文件路径
train_dataset = MyDataset('../data/train.txt')

# 定义模型
# num_labels=5 定义相似度评分有几个
model = BertForNextSentencePrediction.from_pretrained('../bert-base-chinese', num_labels=6)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
train_loader = DataLoader(train_dataset, batch_size=bacth_size, shuffle=True)
trained_data = 0
batch_after_last_save = 0
total_batch = 0
total_epoch = 0

for epoch in range(epoch):
    trained_data = 0
    for batch in train_loader:
        inputs, labels = batch
        # 不知道为啥,出来的数据维度是 (batch_size, 1, 128),需要把第二维去掉
        inputs['input_ids'] = inputs['input_ids'].squeeze(1)
        inputs['token_type_ids'] = inputs['token_type_ids'].squeeze(1)
        inputs['attention_mask'] = inputs['attention_mask'].squeeze(1)
        # 因为要用GPU,将数据传输到gpu上
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(**inputs, labels=labels)
        loss, logits = outputs[:2]
        loss.backward()
        optimizer.step()
        trained_data += len(labels)
        trained_process = float(trained_data) / len(train_dataset)
        batch_after_last_save += 1
        total_batch += 1
        # 每训练 auto_save_batch 个 batch,保存一次模型
        if batch_after_last_save >= auto_save_batch:
            batch_after_last_save = 0
            model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
            print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))
        print("训练进度:{:.2f}%, loss={:.4f}".format(trained_process * 100, loss.item()))
    total_epoch += 1
    model.save_pretrained(f'../output/cn_equal_model_{total_epoch}_{total_batch}.pth')
    print("保存模型:cn_equal_model_{}_{}.pth".format(total_epoch, total_batch))

训练好后的文件,输出的最后一个文件夹才是效果最好的模型:

predit.py

import torch
from transformers import BertTokenizer, BertForNextSentencePrediction


tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertForNextSentencePrediction.from_pretrained('../output/cn_equal_model_3_171.pth')

with torch.no_grad():
    with open('../data/test.txt', 'r', encoding='utf8') as f:
        lines = f.readlines()
        correct = 0
        for i, line in enumerate(lines):
            text1, text2, label = line.split('\t')
            encoded_text = tokenizer(text1, text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
            outputs = model(**encoded_text)
            res = torch.argmax(outputs.logits, dim=1).item()
            print(text1, text2, label, res)
            if str(res) == label.strip('\n'):
                correct += 1
            print(f'{i + 1}/{len(lines)}')
        print(f'acc:{correct / len(lines)}')

可以看到还是较好的学习了我数据特征:只要两文本中有一个地址相似我就作为相似,标签为 1,否则 0

数据

链接:https://pan.baidu.com/s/1Cpr-ZD9Neakt73naGdsVTw 
提取码:eryw 
链接:https://pan.baidu.com/s/1qHYjXC7UCeUsXVnYTQIPCg 
提取码:o8py 
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/421096.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一文讲透:可视化大屏中3D元素的融入和使用方法

在可视化大屏中,3D元素融入的越来越多,贝格前端工场经常接到这类项目,很多老铁认为加个3D效果很easy,其实不然,工序非常复杂,总结如下。 一、什么是3D技术 三维展示(3D展示)是指使用…

gpt生成器,批量gpt文章生成器

GPT(生成式预训练模型)生成器软件在当今的数字化时代扮演着越来越重要的角色,它们通过人工智能技术,可以自动生成各种类型的文章内容,为用户提供了无限的创作可能性。本文将介绍6款不同的GPT生成器软件,并介…

退休教师40年教龄补贴多少钱

那些默默奉献了四十年的老教师,他们退休后能得到多少补贴?今天,就让我们一起揭开这层面纱,看看教师退休金背后的故事。 教师这份职业,不仅仅是传授知识那么简单。它更代表着一种责任,一种对下一代无尽的关爱…

什么是微前端

微前端是一种web应用构建方式。 微前端在2016年ThoughtWorks Technology Radar正式被提出。微服务这个被广泛应用于服务端的技术范式扩展到前端领域。现代的前端应用的发展趋势正在变得越来越富功能化,富交互化,也就是SPA应用;这样越来越复杂…

word文档空格不能有下划线【笔记】

word文档空格不能有下划线 2024-3-1 21:20:24 推荐 word下划线打不出来了,是怎么回事? 问题 字后面打不出来下划线 操作 1.点击文件 左上角,点击“文件”。 2.点击选项 鼠标下滑,点击“选项”。 3.点击常规与保存 点击“…

CY8C42(未知.UDB模块使用)

开发UDB模块要用到verilog了 虽然官方给出了图形配置,但是完全看不懂。 没办法,我先去学FPGA了,去买矿卡了。 后面等学一点FPGA再来更新吧,可能PSOC里面有些内容会随机更新吧。 官方给了一份文档,链接在这&#xff0…

使用el-form之表单校验自动定位到报错位置问题,,提升用户体验

需求描述 由于需要填写的表单项太多,提交的时候校验不通过, 如果没填写的表单项在最上面,用户看不到不知道发生了啥, 所以需要将页面滚动定位到第一个报错的表单项位置,提升用户体验实现步骤 1. 给form表单添加ref …

智慧楼宇的心脏:E6000物联网主机

智慧楼宇是指通过全面覆盖的感知设备和互联网技术,为建筑提供高效、舒适、安全、环保、可持续的智能化服务。 在科技快速发展的今天,智慧楼宇已经不再是遥不可及的梦想。而在这个梦想成真的过程中,物联网主机扮演着至关重要的角色。它如同智慧…

加密与安全_深入了解Hmac算法(消息认证码)

文章目录 PreHMAC概述常见的Hmac算法Code随机的key的生成 KeyGeneratorHmacMD5用Hmac算法取代原有的自定义的加盐算法 HmacMD5 VS MD5HmacSHA256 Pre 加密与安全_深入了解哈希算法中我们提到, 存储用户的哈希口令时,要加盐存储,目的就在于抵…

手写数字识别(慕课MOOC人工智能之模式识别)

问题:手写数字识别 数据集 数据集链接请点击我 代码 %mat2vector.m function [data_] mat2vector(data,num)[row,col,~] size(data);data_zeros(num,row*col);for page 1:numfor rows 1:rowfor cols1:coldata_(page,((rows-1)*colcols)) im2double(data(rows,cols…

机器人与AGI会撞出什么火花?

真正的科技变革是不是就要来临了?各方大佬都开始布局机器人,对于普通人的就业会造成什么影响? ​ 优牛企讯-企业动态信息监控专家 在优牛企讯-企业动态监控专家搜索可知,全国目前的机器人公司已经达到了26401家,近一年…

浅谈 Linux 孤儿进程和僵尸进程

文章目录 前言孤儿进程僵尸进程 前言 本文介绍 Linux 中的 孤儿进程 和 僵尸进程。 孤儿进程 在 Linux 中,就是父进程已经结束了,但是子进程还在运行,这个子进程就被称作 孤儿进程。 需要注意两点: 孤儿进程最终会进入孤儿院…

实战打靶集锦-026-Gaara

文章目录 1. 主机发现2. 端口扫描3. 服务枚举4. 服务探查5. 突破边界6. 提权6.1 枚举系统信息6.2 查看passwd文件6.3 查看定时任务6.4 查看可执行文件6.5 查看家目录6.6 Linpeas提权 7. 获取flag写在最后 靶机地址:https://download.vulnhub.com/gaara/Gaara.ova 1…

第六节:Vben Admin权限-后端控制方式

系列文章目录 第一节:Vben Admin介绍和初次运行 第二节:Vben Admin 登录逻辑梳理和对接后端准备 第三节:Vben Admin登录对接后端login接口 第四节:Vben Admin登录对接后端getUserInfo接口 第五节:Vben Admin权限-前端控制方式 文章目录 系列文章目录前言一、角色权限(后端…

7. 构建简单 IPv6 网络

7.1 实验介绍 7.1.1 关于本实验 IPv6(Internet Protocol Version 6)也被称为IPng(IP Next Generation)。它是Internet工程任务组IETF(Internet Engineering Task Force)设计的一套规范,是IPv4…

python水表识别图像识别深度学习 CNN

python水表识别,图像识别深度学习 CNN,Opencv,Keras 重点:项目和文档是本人近期原创所作!程序可以将水表图片里面的数据进行深度学习,提取相关信息训练,lw1.3万字重复15%,可以直接上交那种&…

神经网络之万能定理python-pytorch实现,可以拟合任意曲线

神经网络之万能定理python-pytorch实现,可以拟合任意曲线 博主,这几天一直在做这个曲线拟合的实验,讲道理,网上可能也有很多这方面的资料,但是博主其实试了很多,效果只能对一般的曲线还行,稍微…

Azure AD使用SAML登录OpenSearch

Azure AD的配置如下: 登录 Global Azure 控制台,选择并切换至 Azure Active Directory 服务,在左侧的侧边栏选择 Enterprise applications,All Applications 页面点击 New application. 选择 Create your own application&#x…

多行业万能预约门店小程序源码系统 支持多门店预约小程序 带完整的安装代码包以及搭建教程

随着消费者对于服务体验要求的不断提升,门店预约系统成为了许多行业提升服务质量、提高运营效率的重要工具。然而,市面上的预约系统往往功能单一,无法满足多行业、多场景的个性化需求。下面,小编集合了多年的行业经验和技术积累&a…

【数据结构与算法】动态规划法解题20240302

这里写目录标题 一、198. 打家劫舍1、动态规划五部曲 二、213. 打家劫舍 II 一、198. 打家劫舍 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间…