昇思25天学习打卡营第12天|LLM-基于MindSpore实现的BERT对话情绪识别

打卡

目录

打卡

预装环境

BERT

任务说明

数据集

数据加载和数据预处理:process_dataset 函数

模型构建与训练

运行示例

模型验证

模型推理

自定义推理数据集

运行结果示例

代码


预装环境

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install mindnlp

pip show mindspore
pip show mindnlp

BERT

BERT 全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型,例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。

BERT模型的主要创新点都在pre-train方法上,即用了 Masked Language Model Next Sentence Prediction 两种方法分别捕捉词语和句子级别的representation

CASE 1 : 在用 Masked Language Model 方法训练BERT的时候随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况

1)80%的单词直接用[Mask]替换

2)10%的单词直接替换成另一个新的单词

3)10%的单词保持不变

CASE 2 : 因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了 Next Sentence Prediction 预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句

BERT预训练之后,会保存它的 Embedding table 12 层Transformer权重(BERT-BASE)或 24 层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

任务说明

对话情绪识别(Emotion Detection,简称EmoTect),专注于识别智能对话场景中用户的情绪,针对智能对话场景中的用户文本,自动判断该文本的情绪类别并给出相应的置信度,情绪类型分为积极、消极、中性。 对话情绪识别适用于聊天、客服等多个场景,能够帮助企业更好地把握对话质量、改善产品的用户交互体验,也能分析客服服务质量、降低人工质检成本。

数据集

  • 来源:百度飞桨团队
  • 数据集说明:已标注的、经过分词预处理的机器人聊天数据集。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本。文件为 utf8 编码。
wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
tar xvf emotion_detection.tar.gz

数据加载和数据预处理:process_dataset 函数

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理

import numpy as np

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    column_names = ["label", "text_a"]
    
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    def tokenize_and_pad(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset


from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

dataset_train = process_dataset(
                    SentimentDataset("data/train.tsv"), 
                    tokenizer)
dataset_val = process_dataset(
                    SentimentDataset("data/dev.tsv"), 
                    tokenizer)
dataset_test = process_dataset(
                    SentimentDataset("data/test.tsv"), 
                    tokenizer, 
                    shuffle=False)

部分执行结果演示,看到词表大小为 21128 ,模型维度长 512 ,右侧截断,一共有5种特殊的token,其中训练、验证、测试集数据分别有302、34、33个。

模型构建与训练

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练.

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision


### 模型配置
# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')

### 优化算法选择
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)


### 评估指标
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])


### 模型训练
# start training
trainer.run(tgt_columns="labels")

运行示例

loss降低到了0.0663,精度达到了 0.9917。非常不错。

 

模型验证

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, 
                      eval_dataset=dataset_test, 
                      metrics=metric)
evaluator.run(tgt_columns="labels")

模型推理


from mindspore import Tensor

dataset_infer = SentimentDataset("data/infer.tsv")

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)



for label, text in dataset_infer:
    predict(text, label)

自定义推理数据集

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")
predict("起开 我要开始发功了")

运行结果示例

代码

import os
import numpy as np
from mindspore import Tensor
import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, context

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
from mindnlp.transformers import BertTokenizer
from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision


# prepare dataset
class SentimentDataset:
    """Sentiment Dataset"""

    def __init__(self, path):
        self.path = path
        self._labels, self._text_a = [], []
        self._load()

    def _load(self):
        with open(self.path, "r", encoding="utf-8") as f:
            dataset = f.read()
        lines = dataset.split("\n")
        for line in lines[1:-1]:
            label, text_a = line.split("\t")
            self._labels.append(int(label))
            self._text_a.append(text_a)

    def __getitem__(self, index):
        return self._labels[index], self._text_a[index]

    def __len__(self):
        return len(self._labels)
    
    

def process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):
    """
    这个函数 process_dataset 用于处理文本数据集,包括文本的标记化、填充、类型转换和批处理,并针对不同的设备(Ascend或非Ascend)进行了不同的处理策略。
    source: 数据集的来源,可以是文件路径或数据生成器。
    tokenizer: 用于将文本转换为数字表示的标记器。
    max_seq_len: 最大序列长度,默认为64。
    batch_size: 批处理大小,默认为32。
    shuffle: 是否打乱数据集,默认为True。
    """
    ## 检查当前设备目标是否为Ascend(华为的AI处理器),并将结果存储在变量 is_ascend 中。
    is_ascend = mindspore.get_context('device_target') == 'Ascend'

    # 定义一个列表 column_names,包含数据集中列的名称。
    column_names = ["label", "text_a"]
    
    # 创建一个 GeneratorDataset 对象,它从 source 生成数据集,指定列名,并根据 shuffle 参数决定是否打乱数据。
    dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)
        
    # transforms
    type_cast_op = transforms.TypeCast(mindspore.int32)
    ## 定义一个内部函数 tokenize_and_pad,用于对文本进行标记化和填充。
    def tokenize_and_pad(text):
        ## 在 tokenize_and_pad 函数内部,根据 is_ascend 的值来决定是否对文本进行填充和截断。如果设备是Ascend,则使用 max_length 填充策略和截断。函数返回标记化后的 input_ids 和 attention_mask。
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text)
        return tokenized['input_ids'], tokenized['attention_mask']
    # map dataset
    ## 使用 map 操作将 tokenize_and_pad 函数应用到数据集的 “text_a” 列上,并将输出列命名为 ‘input_ids’ 和 ‘attention_mask’。
    dataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])
    ## 使用 map 操作将 type_cast_op 应用于数据集的 “label” 列,并将输出列命名为 ‘labels’。
    dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')
    # batch dataset
    ## 根据 is_ascend 的值,决定使用普通的批处理还是填充后的批处理。如果设备是Ascend,则直接进行批处理;否则,使用 padded_batch 来确保每个批次中的序列长度一致。
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})

    return dataset    

def predict(text, label=None):
    label_map = {0: "消极", 1: "中性", 2: "积极"}

    text_tokenized = Tensor([tokenizer(text).input_ids])
    logits = model(text_tokenized)
    predict_label = logits[0].asnumpy().argmax()
    info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"
    if label is not None:
        info += f" , label: '{label_map[label]}'"
    print(info)
    
    
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)
print(f"train dataset columns name: ", dataset_train.get_col_names())

print(next(dataset_train.create_tuple_iterator()))


##### 模型构建
# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)  ### 3个标签输出
model = auto_mixed_precision(model, 'O1')
## 优化器
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)
print(f"模型结构输出 model:", model)
print(f"优化器选择 optimizer:", optimizer)

####### 模型训练配置项
metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])


######### 模型训练 
# start training
trainer.run(tgt_columns="labels")

########## 模型评估
evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")


#############
dataset_infer = SentimentDataset("data/infer.tsv")

for label, text in dataset_infer:
    predict(text, label)

print(predict("家人们咱就是说一整个无语住了 绝绝子叠buff"))
print(predict("起开 我要开始发功了"))

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/801880.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用ant+jmeter如何生成html报告

一、安装ant 下载apache-ant,放到对应路径下,然后进行环境变量配置。系统变量的CLASSPATH添加E:\Installation Pack\eclipse\apache-ant-1.9.6\lib;用户变量的PATH添加:E:\Installation Pack\eclipse\apache-ant-1.9.6\bin。打开cmd&#xf…

持续集成02--Linux环境更新/安装Java新版本

前言 在持续集成/持续部署(CI/CD)的旅程中,确保开发环境的一致性至关重要。本篇“持续集成02--Linux环境更新/安装Java新版本”将聚焦于如何在Linux环境下高效地更新或安装Java新版本。Java作为广泛应用的编程语言,其版本的更新对…

前端框架入门之Vue的模版语法与数据单向绑定 数据双向绑定

目录 vue的模版语法 数据绑定 vue的模版语法 关于模版这个概念 root容器里面被称为模版 我们的语法分为插值语法和插值语法 这样就是实现了插值语法 接下来我们实现指令语法 首先我们写一个a标签 链一个超链接上去 <h1>指令语法</h1><a href"https:/…

【46 Pandas+Pyecharts | 当当网畅销图书榜单数据分析可视化】

文章目录 &#x1f3f3;️‍&#x1f308; 1. 导入模块&#x1f3f3;️‍&#x1f308; 2. Pandas数据处理2.1 读取数据2.2 查看数据信息2.3 去除重复数据2.4 书名处理2.5 提取年份 &#x1f3f3;️‍&#x1f308; 3. Pyecharts数据可视化3.1 作者图书数量分布3.2 图书出版年份…

openeuler 终端中文显示乱码、linux vim中文乱码

1、解决终端乱码 网上很多教程试了都不生效&#xff0c;以下方法有效&#xff1a; 确认终端支持中文显示&#xff1a; echo $LANG 输出应该包含 UTF-8&#xff0c;例如 en_US.UTF-8。如果不是&#xff0c;您可以通过以下命令设置为 UTF-8&#xff1a; export LANGzh_CN.UTF-8…

昇思25天学习打卡营第12天|Vision Transformer图像分类、SSD目标检测

Vision Transformer&#xff08;ViT&#xff09;简介 近些年&#xff0c;随着基于自注意&#xff08;Self-Attention&#xff09;结构的模型的发展&#xff0c;特别是Transformer模型的提出&#xff0c;极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩…

【数据结构(邓俊辉)学习笔记】高级搜索树02——B树

文章目录 1. 大数据1.1 640 KB1.2 越来越大的数据1.3 越来越小的内存1.4 一秒与一天1.5 分级I/O1.6 1B 1KB 2. 结构2.1 观察体验2.2 多路平衡2.3 还是I/O2.4 深度统一2.5 阶次含义2.6 紧凑表示2.7 BTNode2.8 BTree 3. 查找3.1 算法过程3.2 操作实例3.3 算法实现3.4 主次成本3.…

django踩坑(四):终端输入脚本可正常执行,而加入crontab中无任何输出

使用crontab执行python脚本时&#xff0c;有时会遇到脚本无法执行的问题。这是因为crontab在执行任务时使用的环境变量与我们在终端中使用的环境变量不同。具体来说&#xff0c;crontab使用的环境变量是非交互式(non-interactive)环境变量&#xff0c;而终端则使用交互式(inter…

【Map和Set】

目录 1&#xff0c;搜索树 1.1 概念 1.2 查找 1.3 插入 1.4 删除&#xff08;难点&#xff09; 1.5 性能分析 1.6 和 java 类集的关系 2&#xff0c;搜索 2.1 概念及场景 2.2 模型 3&#xff0c;Map 的使用 3.1 关于Map的说明 3.2 关于Map.Entry的说明 3.3 Map的…

JAVA-----异常处理

一、定义 在 Java 中&#xff0c;异常&#xff08;Exception&#xff09;是指程序在执行过程中遇到的不正常情况&#xff0c;这些情况可能导致程序无法继续执行或产生错误的结果。异常可以是 Java 标准库中提供的内置异常类&#xff0c;也可以是开发人员自定义的异常类。 二、…

小程序-2(WXML数据模板+WXSS模板样式+网络数据请求)

目录 1.WXML数据模板 数据绑定 事件绑定 小程序中常用的事件 事件对象的属性列表 target和currentTarget的区别 bindtap的语法格式 在事件处理事件中为data中的数据赋值 事件传参与数据同步 事件传参 bindinput的语法绑定事件 文本框和data的数据同步 条件渲染 w…

Spring Data Redis + Redis数据缓存学习笔记

文章目录 1 Redis 入门1.1 简介1.2 Redis服务启动与停止&#xff08;Windows&#xff09;1.2.1 服务启动命令1.2.2 客户端连接命令1.2.3 修改Redis配置文件1.2.4 Redis客户端图形工具 2. Redis数据类型2.1 五种常用数据类型介绍 3. Redis常用命令3.1 字符串操作命令3.2 哈希操作…

数据库的约束条件和用户管理

约束条件&#xff1a; 主键&#xff1a;主键约束 primary key 用于标识表中的主键列的值&#xff0c;而且这个值是全表当中唯一的&#xff0c;而且只不能为null 一个表只能有一个主键。 外键&#xff1a;用来建立表与表之间的关系。确保外键中的值于另一个表的主键值匹配&a…

golang AST语法树解析

1. 源码示例 package mainimport ("context" )// Foo 结构体 type Foo struct {i int }// Bar 接口 type Bar interface {Do(ctx context.Context) error }// main方法 func main() {a : 1 }2. Golang中的AST golang官方提供的几个包&#xff0c;可以帮助我们进行A…

集线器、交换机、路由器的区别,冲突域、广播域

冲突域 定义&#xff1a;同一时间内只能有一台设备发送信息的范围。 分层&#xff1a;基于OSI模型的第一层物理层。 广播域 定义&#xff1a;如果某个站点发出一个广播信号&#xff0c;所有能接受到这个信号的设备的范围称为一个广播域。 分层&#xff1a;基于OSI模型的第二…

为ppt中的文字配色

文字的颜色来源于ppt不可删去的图像的颜色 从各类搜索网站中搜索ppt如何配色&#xff0c;有如下几点&#xff1a; 1.可以使用对比色&#xff0c;表示强调。 2.可以使用近似色&#xff0c;使得和谐统一。 3.最好一张ppt中&#xff0c;使用的颜色不超过三种主要颜色。 但我想强调…

第二证券:电影暑期档持续升温 农机自动驾驶驶入快车道

农机自动驾驶打开驶入快车道 得益于农机补贴、土地流通、高标准农田制造等方针引导&#xff0c;叠加技术突围和用户降本增效的内生需求&#xff0c;我国正处于农业2.0向农业3.0的过渡阶段。其间农机自动驾驶系统是结束农业3.0&#xff08;即自动化&#xff09;的要害并迎来快速…

【瑞吉外卖 | day07】移动端菜品展示、购物车、下单

文章目录 瑞吉外卖 — day71. 导入用户地址簿相关功能代码1.1 需求分析1.2 数据模型1.3 代码开发 2. 菜品展示2.1 需求分析2.2 代码开发 3. 购物车3.1 需求分析3.2 数据模型3.3 代码开发 4. 下单4.1 需求分析4.2 数据模型4.3 代码开发 瑞吉外卖 — day7 移动端相关业务功能 —…

Java实验4

实验内容 考试题 要求在一个界面内至少显示5道选择题&#xff0c;每道题4个选项。题目从数据库读取。表结构自定义。 另有2个命令按钮&#xff0c;分别为“重新答题”&#xff08;全部选项及正确答题数清空&#xff09;和“提交”&#xff08;计算&#xff09;&#xff0c;在…

【芯片设计- RTL 数字逻辑设计入门 9.1 -- CRG模块】

请阅读【芯片设计 RTL 数字逻辑设计扫盲 】 转自&#xff1a;芯片设计基础 – CRG模块 文章目录 CRG模块CRG时钟系统CRG复位系统同步复位同步复位的优点同步复位的缺点 异步复位异步复位的优点异步复位的缺点 异步复位同步释放 CRG模块 CRG是芯片里的时钟和复位生成模块&#…