bert 相似度任务训练完整版

任务

之前写了一个相似度任务的版本:bert 相似度任务训练简单版本,faiss 寻找相似 topk-CSDN博客

相似度用的是 0,1,相当于分类任务,现在我们相似度有评分,不再是 0,1 了,分数为 0-5,数字越大代表两个句子越相似,这一次的比较完整,评估,验证集,相似度模型都有了。

数据集

链接:https://pan.baidu.com/s/1B1-PKAKNoT_JwMYJx_zT1g 
提取码:er1z 
原始数据好几千条,我训练数据用了部分 2500 条,验证,测试 300 左右,使用 cpu 也用了好几个小时

train.py

import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, AdamW, get_cosine_schedule_with_warmup
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np


# 设备选择
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = 'cpu'


# 定义文本相似度数据集类
class TextSimilarityDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_len=128):
        self.data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f.readlines():
                text1, text2, similarity_score = line.strip().split('\t')
                inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=max_len)
                inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=max_len)
                self.data.append({
                    'input_ids1': inputs1['input_ids'],
                    'attention_mask1': inputs1['attention_mask'],
                    'input_ids2': inputs2['input_ids'],
                    'attention_mask2': inputs2['attention_mask'],
                    'similarity_score': float(similarity_score),
                })

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def cosine_similarity_torch(vec1, vec2, eps=1e-8):
    dot_product = torch.mm(vec1, vec2.t())
    norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)
    norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)
    similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)
    return similarity_scores


# 定义模型,这里我们不仅计算两段文本的[CLS] token的点积,而是整个句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):
    def __init__(self, pretrained_model):
        super(BertSimilarityModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout层以防止过拟合

    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        embeddings1 = self.dropout(self.bert(input_ids=input_ids1, attention_mask=attention_mask1)['last_hidden_state'])
        embeddings2 = self.dropout(self.bert(input_ids=input_ids2, attention_mask=attention_mask2)['last_hidden_state'])

        # 计算两个文本向量的余弦相似度
        embeddings1 = torch.mean(embeddings1, dim=1)
        embeddings2 = torch.mean(embeddings2, dim=1)

        similarity_scores = cosine_similarity_torch(embeddings1, embeddings2)

        # 映射到[0, 5]评分范围
        normalized_similarities = (similarity_scores + 1) * 2.5
        return normalized_similarities.unsqueeze(1)


# 自定义损失函数,使用Smooth L1 Loss,更适合处理回归问题
class SmoothL1Loss(torch.nn.Module):
    def __init__(self):
        super(SmoothL1Loss, self).__init__()

    def forward(self, predictions, targets):
        diff = predictions - targets
        abs_diff = torch.abs(diff)
        quadratic = torch.where(abs_diff < 1, 0.5 * diff ** 2, abs_diff - 0.5)
        return torch.mean(quadratic)


def train_model(model, train_loader, val_loader, epochs=3, model_save_path='../output/bert_similarity_model.pth'):
    model.to(device)
    criterion = SmoothL1Loss()  # 使用自定义的Smooth L1 Loss
    optimizer = AdamW(model.parameters(), lr=5e-5)  # 调整初始学习率为5e-5
    num_training_steps = len(train_loader) * epochs
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.1*num_training_steps, num_training_steps=num_training_steps)  # 使用带有warmup的余弦退火学习率调度

    best_val_loss = float('inf')
    for epoch in range(epochs):
        model.train()
        for batch in train_loader:
            input_ids1 = batch['input_ids1'].to(device)
            attention_mask1 = batch['attention_mask1'].to(device)
            input_ids2 = batch['input_ids2'].to(device)
            attention_mask2 = batch['attention_mask2'].to(device)
            similarity_scores = batch['similarity_score'].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)
            loss = criterion(outputs, similarity_scores.unsqueeze(1))
            loss.backward()
            optimizer.step()
            scheduler.step()

        # 验证阶段
        model.eval()
        with torch.no_grad():
            val_loss = 0
            total_val_samples = 0

            for batch in val_loader:
                input_ids1 = batch['input_ids1'].to(device)
                attention_mask1 = batch['attention_mask1'].to(device)
                input_ids2 = batch['input_ids2'].to(device)
                attention_mask2 = batch['attention_mask2'].to(device)
                similarity_scores = batch['similarity_score'].to(device)

                val_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)
                val_loss += criterion(val_outputs, similarity_scores.unsqueeze(1)).item()
                total_val_samples += len(similarity_scores)

            val_loss /= len(val_loader)
            print(f'Epoch {epoch + 1}, Validation Loss: {val_loss:.4f}')

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), model_save_path)


def collate_to_tensors(batch):
    '''把数据处理为模型可用的数据,不同任务可能需要修改一下,'''
    input_ids1 = torch.tensor([example['input_ids1'] for example in batch])
    attention_mask1 = torch.tensor([example['attention_mask1'] for example in batch])
    input_ids2 = torch.tensor([example['input_ids2'] for example in batch])
    attention_mask2 = torch.tensor([example['attention_mask2'] for example in batch])
    similarity_score = torch.tensor([example['similarity_score'] for example in batch])

    return {'input_ids1': input_ids1, 'attention_mask1': attention_mask1, 'input_ids2': input_ids2,
            'attention_mask2': attention_mask2, 'similarity_score': similarity_score}


# 加载数据集和预训练模型
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')

# 加载数据并创建
train_data = TextSimilarityDataset('../data/STS-B/STS-B.train - 副本.data', tokenizer)
val_data = TextSimilarityDataset('../data/STS-B/STS-B.valid - 副本.data', tokenizer)
test_data = TextSimilarityDataset('../data/STS-B/STS-B.test - 副本.data', tokenizer)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_to_tensors)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_to_tensors)
test_loader = DataLoader(test_data, batch_size=32, collate_fn=collate_to_tensors)

optimizer = AdamW(model.parameters(), lr=2e-5)

# 开始训练
train_model(model, train_loader, val_loader)

# 加载最佳模型进行测试
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))
test_loss = 0
total_test_samples = 0

with torch.no_grad():
    for batch in test_loader:
        input_ids1 = batch['input_ids1'].to(device)
        attention_mask1 = batch['attention_mask1'].to(device)
        input_ids2 = batch['input_ids2'].to(device)
        attention_mask2 = batch['attention_mask2'].to(device)
        similarity_scores = batch['similarity_score'].to(device)

        test_outputs = model(input_ids1, attention_mask1, input_ids2, attention_mask2)
        test_loss += torch.nn.functional.mse_loss(test_outputs, similarity_scores.unsqueeze(1)).item()
        total_test_samples += len(similarity_scores)

test_loss /= len(test_loader)
print(f'Test Loss: {test_loss:.4f}')

predit.py

这个脚本是用来看看效果的,直接传入两个文本,使用训练好的模型来计算相似度的

import torch
from transformers import BertTokenizer, BertModel


def cosine_similarity_torch(vec1, vec2, eps=1e-8):
    dot_product = torch.mm(vec1, vec2.t())
    norm1 = torch.norm(vec1, 2, dim=1, keepdim=True)
    norm2 = torch.norm(vec2, 2, dim=1, keepdim=True)
    similarity_scores = dot_product / (norm1 * norm2.t()).clamp(min=eps)
    return similarity_scores


# 定义模型,这里我们不仅计算两段文本的[CLS] token的点积,而是整个句向量的余弦相似度
class BertSimilarityModel(torch.nn.Module):
    def __init__(self, pretrained_model):
        super(BertSimilarityModel, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.dropout = torch.nn.Dropout(p=0.1)  # 引入Dropout层以防止过拟合

    def forward(self, input_ids1, attention_mask1, input_ids2, attention_mask2):
        '''如果是用来预测,forward 会被禁用'''
        pass


# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
model = BertSimilarityModel('../bert-base-chinese')
model.load_state_dict(torch.load('../output/bert_similarity_model.pth'))  # 请确保路径正确
model.eval()  # 设置模型为评估模式


def calculate_similarity(text1, text2):
    # 对输入文本进行编码
    inputs1 = tokenizer(text1, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
    inputs2 = tokenizer(text2, padding='max_length', truncation=True, max_length=128, return_tensors='pt')

    # 计算相似度
    with torch.no_grad():
        embeddings1 = model.bert(**inputs1.to('cpu'))['last_hidden_state'][:, 0]
        embeddings2 = model.bert(**inputs2.to('cpu'))['last_hidden_state'][:, 0]
        similarity_score = cosine_similarity_torch(embeddings1, embeddings2).item()

    # 映射到[0, 5]评分范围(假设训练时有此步骤)
    normalized_similarity = (similarity_score + 1) * 2.5

    return normalized_similarity


# 示例
text1 = "瑞典驻利比亚班加西领事馆发生汽车炸弹袭击,无人员伤亡"
text2 = "汽车炸弹击中瑞典驻班加西领事馆,无人受伤。"
similarity = calculate_similarity(text1, text2)
print(f"两个句子的相似度为:{similarity:.2f}")

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/426795.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ChatGPT最新功能“Text To Speech (TTS,文本转语音)”详细解读!

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;所以创建了“AI信息Gap”这个公众号&#xff0c;专注于分享AI全维度知识…

Windows环境MySQL全量备份+增量备份

一、环境准备 1.1.安装MySQL 在进行MySQL数据库备份和还原操作时&#xff0c;必须先提前安装好MySQL环境&#xff0c;且MySQL服务已成功开启 如果没有安装MySQL环境&#xff0c;可以参考博客&#xff1a;http://t.csdnimg.cn/h8bHl 如果已成功安装MySQL环境&#xff0c;打开…

Orbit 使用指南 02 | 在场景中生成原始对象| Isaac Sim | Omniverse

如是我闻&#xff1a; Orbit使用指南02将 深入探讨如何使用Python代码在Orbit中向场景生成各种对象&#xff08;或原始对象&#xff09;。一起探索如何生成地面平面、灯光、基本图形形状以及来自USD文件的网格。前置知识&#xff1a;如何生成空白场景&#xff0c;Orbit 使用指…

VUE实现Office文档在线编辑,支持doc/docx、xls/xlsx、ppt/pptx、pdf等

1.微软提供的在线Office预览&#xff08;只能预览&#xff0c;不能编辑&#xff09; https://view.officeapps.live.com/op/view.aspx?src服务器上文档地址&#xff08;http开头&#xff09; 2.国内在线Office方案&#xff1a; 腾讯文档、石墨文档、飞书 优势&#xff1a;跨…

paimon取消hive转filesystem

目录 概述实践关键配置spark sql 结束 概述 公司上一版本保留了 hive &#xff0c;此版优化升级后&#xff0c;取消 hive。 实践 关键配置 同步数据时&#xff0c;配置如下&#xff0c;将形成两个库 # ods库 CREATE CATALOG paimon WITH (type paimon,warehouse hdfs:///d…

CentOS配网报错:network is unreachable

常用命令&#xff1a; 打开&#xff1a; cd /etc/sysconfig/network-scripts/ 修改&#xff1a; vim ifcfg-ens33 打开修改&#xff1a; vim /etc/sysconfig/network-scripts/ifcfg-ens33 保存&#xff1a; 方法1&#xff1a;ESCZZ&#xff08;Z要大写&#xff09; 方…

熔断降级 spring事务

如果有事务处理&#xff0c;会先把事务的自动提交给关闭

Apache Flink连载(三十七):Flink基于Kubernetes部署(7)-Kubernetes 集群搭建-3

🏡 个人主页:IT贫道-CSDN博客 🚩 私聊博主:私聊博主加WX好友,获取更多资料哦~ 🔔 博主个人B栈地址:豹哥教你学编程的个人空间-豹哥教你学编程个人主页-哔哩哔哩视频 目录

32单片机基础:PWM驱动舵机,直流电机

PWM驱动舵机 接线图如上图所示。注意&#xff0c;舵机的5V 线不能接到面包板上的正极&#xff0c;面包板上的正极只有3.3V,是STM32提供的&#xff0c;所以要接到STLINK的5V, 我们如何驱动舵机呢&#xff1f;由之前我们介绍原理知道&#xff0c;要输出如下图对应的PWM波形才行…

202209 青少年软件编程等级考试Scratch二级真题

第 1 题 【 单选题 】 数字&#xff1a;1&#xff0c;2&#xff0c;3&#xff0c;4&#xff0c;6&#xff0c;9&#xff0c;13&#xff0c;19&#xff0c;28&#xff0c;...的下一项是多少&#xff1f; A&#xff1a;37 B&#xff1a;39 C&#xff1a;41 D&#xff1a;47 …

爱奇艺2023年营收319亿元:完善服务价值感知,重构影视新生态

近日&#xff0c;爱奇艺&#xff08;NASDAQ:IQ&#xff09;发布截至2023年12月31日未经审计的第四季度和全年财报&#xff0c;这份财报被外界评价为“爱奇艺交出的年度最佳业绩”。 财报显示&#xff0c;爱奇艺全年总营收319亿元&#xff0c;同比增长10%&#xff1b;非美国通用…

模拟器抓HTTP/S的包时如何绕过单向证书校验(XP框架)

模拟器抓HTTP/S的包时如何绕过单向证书校验&#xff08;XP框架&#xff09; 逍遥模拟器无法激活XP框架来绕过单向的证书校验&#xff0c;如下图&#xff1a; ​​ 解决办法&#xff1a; 安装JustMePlush.apk安装Just Trust Me.apk安装RE管理器.apk安装Xposedinstaller_逍遥64位…

Java SE:反射

反射作用 获取字节码文件里面的所有信息&#xff0c;包括构造方法、成员、成员方法&#xff0c;以及修饰他们的修饰符、类型和方法的返回值等等&#xff0c;只要是类里面的内容都能获取&#xff0c;获取之后可以动态的调用方法&#xff0c;动态的创建对象 获取类字节码文件对象…

vue3中的基本语法

目录 基础素材 vue3的优化 使用CompositionAPI理由 1. reactive() 函数 2. ref() 函数 2.1. ref的使用 2.2. 在 reactive 对象中访问 ref 创建的响应式数据 3. isRef() 函数 4. toRefs() 函数 5. computed() 5.1. 通过 set()、get()方法创建一个可读可写的计算属性 …

“耳机党”注意了!你的耳机,用对了吗?

文章目录 &#x1f4d6; 介绍 &#x1f4d6;&#x1f3e1; 什么是“3个60”原则&#xff1f; &#x1f3e1;&#x1f4d2; 如何遵循“3个60”原则&#xff1f; &#x1f4d2;&#x1f4dd; 控制音量&#x1f4dd; 适时休息&#x1f4dd; 关注外界声音 &#x1f4d6; 介绍 &…

深度学习目标检测】二十二、基于深度学习的肺炎检测系统-含数据集、GUI和源码(python,yolov8)

肺炎尽管很常见&#xff0c;但准确诊断是一项困难的任务。它要求训练有素的专家对胸部X光片进行检查&#xff0c;并通过临床病史&#xff0c;生命体征和实验室检查进行确认。肺炎通常表现为胸部X光片上一个或多个区域的阴影(opacity)增加。但是&#xff0c;由于肺部有许多其他状…

足球青训俱乐部|基于Springboot的足球青训俱乐部管理系统设计与实现(源码+数据库+文档)

足球青训俱乐部管理系统目录 目录 基于Springboot的足球青训俱乐部管理系统设计与实现 一、前言 二、系统设计 1、系统架构设计 三、系统功能设计 1、管理员登录界面 2、公告信息管理界面 3、学员管理界面 4、商品信息管理界面 5、课程安排管理界面 四、数据库设计…

机器学习:主成分分析笔记

主成分分析&#xff08;Principal Component Analysis&#xff0c;PCA&#xff09;是一种无监督的机器学习算法&#xff0c;通常用于高维数据的降维、提取主要特征、数据降噪和可视化。PCA的基本思想是将原始数据的多个变量转换为少数几个相互独立的变量&#xff08;即主成分&a…

上海亚商投顾:深成指震荡涨超1% 两市成交连续破万亿

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 一.市场情绪 沪指3月1日震荡反弹&#xff0c;深成指、创业板指午后涨超1%。充电桩概念股集体走强&#xff0c;英可瑞、欧陆…

Stable Video文本生成视频公测地址——Scaling Latent Video Diffusion Models to Large Datasets

近期&#xff0c;Stability AI发布了首个开放视频模型——"Stable Video"&#xff0c;该创新工具能够将文本和图像输入转化为生动的场景&#xff0c;将概念转换成动态影像&#xff0c;生成出电影级别的作品&#xff0c;旨在满足广泛的视频应用需求&#xff0c;包括媒…