⌈ 传知代码 ⌋ 基于BERT的语义分析实现

💛前情提要💛

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~


📌导航小助手📌

  • 💡本章重点
  • 🍞一. 概述
  • 🍞二. 语义分类
  • 🍞三. 实现原理
  • 🍞四. 核心逻辑
  • 🍞五.实现方式&演示效果
  • 🫓总结


💡本章重点

  • 基于BERT的语义分析实现

🍞一. 概述

在之前的文章中,我们介绍了BERT模型。BERT作为一种预训练语言模型,它具有很好的兼容性,能够运用在各种下游任务中,本文的主要目的是利用数据集来对BERT进行训练,从而实现一个语义分类的模型。


🍞二. 语义分类

语义分类是自然语言处理任务中的一种,包含文本分类、情感分析

文本分类

文本分类是指给定文本a,将文本分类为n个类别中的一个或多个。常见的应用包括文本话题分类,情感分类,具体的分类方向有有二分类,多分类和多标签分类。
文本分类可以采用传统机器学习方法(贝叶斯,svm等)和深度学习方法(fastText,TextCNN等)实现。
举例而言,对于一个对话数据集,我们可以用1、2、3表示他们的话题,如家庭、学校、工作等,而文本分类的目的,则是把这些文本的话题划分到给定的三种类别中。

情感分类

情感分析是自然语言处理中常见的场景,比如商品评价等。通过情感分析,可以挖掘产品在各个维度的优劣。情感分类其实也是一种特殊的文本分类,只是他更聚焦于情感匹配词典。

举例而言,情感分类可以用0/1表示负面评价/正面评价,例子如下:

0,不好的,319房间有故臭味。要求换房说满了,我是3月去的。在路上认识了一个上海人,他说他退房前也住的319,也是一股臭味。而且这个去不掉,特别是晚上,很浓。不知道是厕所的还是窗外的。服务一般,门前有绿皮公交去莫高窟,不过敦煌宾馆也有,下次住敦煌宾馆。再也不住这个酒店了,热水要放半个小时才有。

1,不错的酒店,大堂和餐厅的环境都不错。但由于给我的是一间走廊尽头的房间,所以房型看上去有点奇怪。客厅和卧室是连在一起的,面积偏小。服务还算到位,总的来说,性价比还是不错的。

本文将以情感二分类为例,实现如何利用BERT进行语义分析。


🍞三. 实现原理

首先,基于BERT预训练模型,能将一个文本转换成向量,作为模型的输入。

在BERT预训练模型的基础上,新增一个全连接层,将输入的向量通过训练转化成一个tensor作为输出,其中这个tensor的维度则是需要分类的种类,具体的值表示每个种类的概率。例如:

[0.25,0.75] 

指代的是有0.25的概率属于第一类,有0.75的概率属于第二类,因此,理论输出结果是把该文本分为第二类。


🍞四. 核心逻辑

pre_deal.py

import csv
import random
from datasets import load_dataset

def read_file(file_path):
    csv_reader = csv.reader(open(file_path, encoding='UTF-8'))
    num = 0
    data = []
    for row in csv_reader:
        if num == 0:
            num = 1
            continue
        comment_data = [row[1], int(row[0])]
        if len(comment_data[0]) > 500:
            text=comment_data[0]
            sub_texts, start, length = [], 0, len(text)
            while start < length:
                piecedata=[text[start: start + 500], comment_data[1]]
                data.append(piecedata)
                start += 500
        else:
             data.append(comment_data)
    random.shuffle(data)
    return data

对输入的csv文件进行处理,其中我们默认csv文件的格式是[label,text],将用于训练的内容读取出来,转化为numpy格式,其中,如果遇到有些文本过长(超过模型的输入),将其截断,分为多个文本段来输入。在最后,会通过shuffle函数进行打乱。

train.py

train.py定义了几个函数,用于训练。

首先是Bertmodel类,定义了基于Bert的训练模型:

class Bertmodel(nn.Module):
    def __init__(self, output_dim, model_path):
        super(Bertmodel, self).__init__()
        # 导入bert模型
        self.bert = BertModel.from_pretrained(model_path)
        # 外接全连接层
        self.layer1 = nn.Linear(768, output_dim)

    def forward(self, tokens):
        res = self.bert(**tokens)
        res = self.layer1(res[1])
        res = res.softmax(dim=1)
        return res

该模型由Bert和一个全连接层组成,最后经过softmax激活函数。

其次是一个评估函数,用来计算模型结果的准确性

def evaluate(net, comments_data, labels_data, device, tokenizer):
    ans = 0 # 输出结果
    i = 0
    step = 8 # 每轮一次读取多少条数据
    tot = len(comments_data)
    while i <= tot:
        print(i)
        comments = comments_data[i: min(i + step, tot)]
        tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)

        res = net(tokens_X)  # 获得到预测结果

        y = torch.tensor(labels_data[i: min(i + step, tot)]).reshape(-1).to(device=device)

        ans += (res.argmax(axis=1) == y).sum()
        i += step

    return ans / tot

原理就是,将文本转化为tokens,输入给模型,而后利用返回的结果,计算准确性

下面展示了开始训练的主函数,在训练的过程中,进行后向传播,储存checkpoints模型

def training(net, tokenizer, loss, optimizer, train_comments, train_labels, test_comments, test_labels,
                          device, epochs):
    max_acc = 0.5  # 初始化模型最大精度为0.5

    for epoch in tqdm(range(epochs)):
        step = 8
        i, sum_loss = 0, 0
        tot=len(train_comments)
        while i < tot:
            comments = train_comments[i: min(i + step, tot)]
            tokens_X = tokenizer(comments, padding=True, truncation=True, return_tensors='pt').to(device=device)

            res = net(tokens_X)

            y = torch.tensor(train_labels[i: min(i + step, len(train_comments))]).reshape(-1).to(device=device)

            optimizer.zero_grad()  # 清空梯度
            l = loss(res, y)  # 计算损失
            l.backward()  # 后向传播
            optimizer.step()  # 更新梯度

            sum_loss += l.detach()  # 累加损失
            i += step

        train_acc = evaluate(net, train_comments, train_labels)
        test_acc = evaluate(net, test_comments, test_labels)

        print('\n--epoch', epoch + 1, '\t--loss:', sum_loss / (len(train_comments) / 8), '\t--train_acc:', train_acc,
              '\t--test_acc', test_acc)

        # 保存模型参数,并重设最大值
        if test_acc > max_acc:
            # 更新历史最大精确度
            max_acc = test_acc

            # 保存模型
            max_acc = test_acc
            torch.save({
                'epoch': epoch,
                'state_dict': net.state_dict(),
                'optimizer': optimizer.state_dict()
            }, 'model/checkpoint_net.pth')

训练结果表示如下:

--epoch 0 	--train_acc: tensor(0.6525, device='cuda:1') 	--test_acc tensor(0.6572, device='cuda:1')

  0%|          | 0/20 [00:00<?, ?it/s]
  5%|| 1/20 [01:48<34:28, 108.88s/it]
 10%|| 2/20 [03:38<32:43, 109.10s/it]
 15%|█▌        | 3/20 [05:27<30:56, 109.20s/it]
 20%|██        | 4/20 [07:15<29:02, 108.93s/it]
 25%|██▌       | 5/20 [09:06<27:23, 109.58s/it]
 30%|███       | 6/20 [10:55<25:29, 109.26s/it]
 35%|███▌      | 7/20 [12:44<23:40, 109.28s/it]
 40%|████      | 8/20 [14:33<21:51, 109.29s/it]
 45%|████▌     | 9/20 [16:23<20:04, 109.49s/it]
 50%|█████     | 10/20 [18:13<18:15, 109.59s/it]
 55%|█████▌    | 11/20 [20:03<16:27, 109.72s/it]
 60%|██████    | 12/20 [21:52<14:35, 109.45s/it]
 65%|██████▌   | 13/20 [23:41<12:45, 109.35s/it]
 70%|███████   | 14/20 [25:30<10:54, 109.14s/it]
 75%|███████▌  | 15/20 [27:19<09:05, 109.03s/it]
 80%|████████  | 16/20 [29:07<07:15, 108.84s/it]
 85%|████████▌ | 17/20 [30:56<05:26, 108.86s/it]
 90%|█████████ | 18/20 [32:44<03:37, 108.75s/it]
 95%|█████████▌| 19/20 [34:33<01:48, 108.73s/it]
100%|██████████| 20/20 [36:22<00:00, 108.71s/it]
100%|██████████| 20/20 [36:22<00:00, 109.11s/it]

--epoch 1 	--loss: tensor(1.2426, device='cuda:1') 	--train_acc: tensor(0.6759, device='cuda:1') 	--test_acc tensor(0.6789, device='cuda:1')

--epoch 2 	--loss: tensor(1.0588, device='cuda:1') 	--train_acc: tensor(0.8800, device='cuda:1') 	--test_acc tensor(0.8708, device='cuda:1')

--epoch 3 	--loss: tensor(0.8543, device='cuda:1') 	--train_acc: tensor(0.8988, device='cuda:1') 	--test_acc tensor(0.8887, device='cuda:1')

--epoch 4 	--loss: tensor(0.8208, device='cuda:1') 	--train_acc: tensor(0.9111, device='cuda:1') 	--test_acc tensor(0.8990, device='cuda:1')

--epoch 5 	--loss: tensor(0.8024, device='cuda:1') 	--train_acc: tensor(0.9206, device='cuda:1') 	--test_acc tensor(0.9028, device='cuda:1')

--epoch 6 	--loss: tensor(0.7882, device='cuda:1') 	--train_acc: tensor(0.9227, device='cuda:1') 	--test_acc tensor(0.9024, device='cuda:1')

--epoch 7 	--loss: tensor(0.7749, device='cuda:1') 	--train_acc: tensor(0.9288, device='cuda:1') 	--test_acc tensor(0.9036, device='cuda:1')

--epoch 8 	--loss: tensor(0.7632, device='cuda:1') 	--train_acc: tensor(0.9352, device='cuda:1') 	--test_acc tensor(0.9061, device='cuda:1')

--epoch 9 	--loss: tensor(0.7524, device='cuda:1') 	--train_acc: tensor(0.9421, device='cuda:1') 	--test_acc tensor(0.9090, device='cuda:1')

--epoch 10 	--loss: tensor(0.7445, device='cuda:1') 	--train_acc: tensor(0.9443, device='cuda:1') 	--test_acc tensor(0.9103, device='cuda:1')

--epoch 11 	--loss: tensor(0.7397, device='cuda:1') 	--train_acc: tensor(0.9480, device='cuda:1') 	--test_acc tensor(0.9128, device='cuda:1')

--epoch 12 	--loss: tensor(0.7321, device='cuda:1') 	--train_acc: tensor(0.9505, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')

--epoch 13 	--loss: tensor(0.7272, device='cuda:1') 	--train_acc: tensor(0.9533, device='cuda:1') 	--test_acc tensor(0.9140, device='cuda:1')

--epoch 14 	--loss: tensor(0.7256, device='cuda:1') 	--train_acc: tensor(0.9532, device='cuda:1') 	--test_acc tensor(0.9111, device='cuda:1')

--epoch 15 	--loss: tensor(0.7186, device='cuda:1') 	--train_acc: tensor(0.9573, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')

--epoch 16 	--loss: tensor(0.7135, device='cuda:1') 	--train_acc: tensor(0.9592, device='cuda:1') 	--test_acc tensor(0.9136, device='cuda:1')

--epoch 17 	--loss: tensor(0.7103, device='cuda:1') 	--train_acc: tensor(0.9601, device='cuda:1') 	--test_acc tensor(0.9128, device='cuda:1')

--epoch 18 	--loss: tensor(0.7091, device='cuda:1') 	--train_acc: tensor(0.9590, device='cuda:1') 	--test_acc tensor(0.9086, device='cuda:1')

--epoch 19 	--loss: tensor(0.7084, device='cuda:1') 	--train_acc: tensor(0.9626, device='cuda:1') 	--test_acc tensor(0.9123, device='cuda:1')

--epoch 20 	--loss: tensor(0.7038, device='cuda:1') 	--train_acc: tensor(0.9628, device='cuda:1') 	--test_acc tensor(0.9107, device='cuda:1')

最终训练结果,在训练集上达到了96.28%的准确率,在测试集上达到了91.07%的准确率

test_demo.py

这个函数提供了一个调用我们储存的checkpoint模型来进行预测的方式,将input转化为berttokens,而后输入给模型,返回输出结果。

input_text=['这里环境很好,风光美丽,下次还会再来的。']
Bert_model_path = 'xxxx'
output_path='xxxx'
device = torch.device('cpu')
checkpoint = torch.load(output_path,map_location='cpu')

model = Bertmodel(output_dim=2,model_path=Bert_model_path)
model.load_state_dict(checkpoint,False)
# print(model)
tokenizer = BertTokenizer.from_pretrained(Bert_model_path,model_max_length=512)

tokens_X = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt').to(device='cpu')
model.eval()
output=model(tokens_X)
print(output)
out = torch.unsqueeze(output.argmax(dim=1), dim=1)
result = out.numpy()
print(result)
if result[0][0]==1:
    print("positive")
else:
    print("negative")

🍞五.实现方式&演示效果

训练阶段

首先找到能够拿来训练的数据,运行pre_deal.py进行预处理,而后可以在main.py修改模型的相关参数,运行main.py开始训练。

这个过程,可能会收到硬件条件的影响,推荐使用cuda进行训练。如果实在训练不了,可以直接调用附件中对应的训练好的模型来进行预测。

测试阶段

运行test_demo.py,测试输入文本的分类结果

输入: input_text=['这里环境很好,风光美丽,下次还会再来的。']
输出: tensor([[0.3191, 0.6809]], grad_fn=<SoftmaxBackward0>)
[[1]]
结果:positive

🫓总结

综上,我们基本了解了“一项全新的技术啦” 🍭 ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读😆

后续还会继续更新💓,欢迎持续关注📌哟~

💫如果有错误❌,欢迎指正呀💫

✨如果觉得收获满满,可以点点赞👍支持一下哟~✨

【传知科技 – 了解更多新知识】

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/666575.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发接口媒体:【@ohos.multimedia.camera (相机管理)】

相机管理 说明&#xff1a; 开发前请熟悉鸿蒙开发指导文档&#xff1a; gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到。 本模块首批接口从API version 9开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 导入模块…

【LeetCode算法】第101题:对称二叉树

目录 一、题目描述 二、初次解答 三、官方解法 四、总结 一、题目描述 二、初次解答 1. 思路&#xff1a;递归判定左子树和右子树是否对称。用一个新函数sym来递归判定左子树和右子树是否对称。该函数细节&#xff1a;判定当前传入的两个根节点是否为空&#xff0c;若均为空…

React + Taro 项目 实际书写 感受

之前我总结了部分react 基础 根据官网的内容 以及Taro 框架的内容 今天我试着开始写了一下页面和开发 说一下我的感受 我之前写的是vue3 今天是第一次真正根据需求做页面开发 和逻辑功能 代码的书写 主体就是开发了这个页面 虽说这个页面 很简单 但是如果你要是第一次写 难说…

通过nginx解决跨域问题,并测试

*表示所有域名 # 测试域名server {listen 80;server_name chat.test.com;#配置根目录location / {proxy_pass http://127.0.0.1:3000;}location /api/ {# 设置允许跨域的域&#xff0c;* 表示允许任何域&#xff0c;也可以设置特定的域add_header Access-Control-Allow-Origin …

将三个字符串通过strcat连接起来并打印输出

将三个字符串通过strcat连接起来并打印输出 #include <stdio.h> #include <string.h> int main () { char a[10]"I", b[10]" am",c[10]" happy"; strcat(a,b); strcat(a,c); printf("%s",a); printf("\n"); re…

Linux基本命令的使用(ls cd touch)

一、Windows系统常见的文件类型 • 文本文件格式&#xff1a;txt、doc、pdf、html等。 • 图像文件格式&#xff1a;jpg、png、bmp、gif等。 • 音频文件格式&#xff1a;mp3、wav、wma等。 • 视频文件格式&#xff1a;mp4、avi、wmv、mov等。 • 压缩文件格式&#xff1a;zip…

配置华为路由器通过RADIUS对接安当ASP身份认证服务器以实现上网功能解决方案

当配置华为路由器通过RADIUS对接安当ASP身份认证服务器以实现上网功能时&#xff0c;以下是一个更详细的解决方案&#xff1a; 一、前期准备 1. 确认网络环境&#xff1a; 确保华为路由器与安当ASP身份认证服务器之间的网络连接稳定可靠。确定RADIUS协议所需的端口&#xff08…

【量算分析工具-贴地距离】GeoServer改造Springboot番外系列九

【量算分析工具-概述】GeoServer改造Springboot番外系列三-CSDN博客 【量算分析工具-水平距离】GeoServer改造Springboot番外系列四-CSDN博客 【量算分析工具-水平面积】GeoServer改造Springboot番外系列五-CSDN博客 【量算分析工具-方位角】GeoServer改造Springboot番外系列…

思科防火墙 网线连接的端口还是down 已配置 端口还是down

环境&#xff1a; 思科防火墙fpr-2100 isco Firepower 2100 系列防火墙是思科系统&#xff08;Cisco Systems&#xff09;推出的一款中端网络安全和防火墙设备。这一系列的产品主要针对中到大型企业的需求&#xff0c;提供高性能的威胁防护和网络流量管理功能。 问题描述&am…

【算法】MT2 棋子翻转

✨题目链接&#xff1a; MT2 棋子翻转 ✨题目描述 在 4x4 的棋盘上摆满了黑白棋子&#xff0c;黑白两色棋子的位置和数目随机&#xff0c;其中0代表白色&#xff0c;1代表黑色&#xff1b;左上角坐标为 (1,1) &#xff0c;右下角坐标为 (4,4) 。 现在依次有一些翻转操作&#…

【Linux】磁盘结构文件系统软硬链接动静态库

目录 一.磁盘结构 1、磁盘的物理结构 2、磁盘的存储结构 3、磁盘的逻辑结构 二.文件系统 1、对IO单位的优化 2、磁盘分区与分组 3、对分组的具体管理方法 4、文件操作 三.软硬链接 1、理解硬链接 2、理解软连接 3、理解.和.. 四、动静态库 1、什么是动静态库 2、…

HSViT: Horizontally Scalable Vision Transformer

论文链接&#xff1a;https://arxiv.org/pdf/2404.05196 代码链接&#xff1a;https://github.com/xuchenhao001/HSViT 根据文档内容&#xff0c;我梳理出以下大纲&#xff1a; 一、引言 ViT模型在计算机视觉领域受到广泛关注&#xff0c;但需要大规模数据集进行预训练才能取…

python绘制北京汽车流量热力图:从原理到实践

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、热力图绘制原理 三、热力图绘制实践 1. 数据准备 2. 地图组件选择 3. 数据…

【Python】解决Python报错:AttributeError: ‘function‘ object has no attribute ‘xxx‘

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

关于网络编程

目录 1、InetAdress类 2、Socket套接字 3、UDP数据报套接字编程 &#xff08;1&#xff09;DatagramSocket 类 &#xff08;2&#xff09;DatagramPacket类 &#xff08;3&#xff09;处理无连接问题 UdpEchoServer.java UdpEchoClient.java 4、TCP流套接字编程 &…

设计模式23——状态模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用&#xff0c;主要是下面的UML图可以起到大作用&#xff0c;在你学习过一遍以后可能会遗忘&#xff0c;忘记了不要紧&#xff0c;只要看一眼UML图就能想起来了。同时也请大家多多指教。 状态模式&#xff08;State&am…

Mysql基础教程(12):JOIN

MySQL JOIN 在 MySQL 中&#xff0c;JOIN 语句用于将数据库中的两个表或者多个表组合起来。 比如在一个学校系统中&#xff0c;有一个学生信息表和一个学生成绩表。这两个表通过学生 ID 字段关联起来。当我们要查询学生的成绩的时候&#xff0c;就需要连接两个表以查询学生信…

内网渗透-隧道搭建ssp隧道代理工具

内网渗透-隧道搭建&ssp隧道代理工具 目录 内网渗透-隧道搭建&ssp隧道代理工具spp隧道代理工具spp工作原理图cs上线主机spp代理通信服务端配置客户端配置CS配置设置CS生成木马的监听器配置CS监听上线的监听器生成木马 spp隧道搭建服务端配置客户端配置CS配置 内网穿透&a…

hive安装-本地模式

1.安装mysql&#xff08;参考文章&#xff1a;centos7.8安装Mysql8.4-CSDN博客&#xff09; 2.将mysql驱动拷贝到/opt/module/hive/lib目录下 &#xff08;直接windows通过finalShell上传&#xff09; 3./opt/module/hive/conf目录下新建hive-site.xml文件&#xff0c;进行配置…

QT6.2.4 MSVC2019 连接MySql5.7数据库,无驱动问题

1.下载 查询一下数据库驱动 qDebug()<<QSqlDatabase::drivers(); 结果显示&#xff0c;没有QMYSQL的驱动。 QList("QSQLITE", "QMARIADB", "QODBC", "QPSQL") MySql6.2.4驱动下载地址&#xff0c;如果是别的版本&#xff0c;…