昇思MindSpore学习总结十七 —— 基于MindSpore通过GPT实现情感分类

 1、要求

2、导入了一些必要的库和模块

        以便在使用MindSpore和MindNLP进行深度学习任务时能使用各种功能,比如数据集处理、模型训练、评估和回调功能。

import os  # 导入操作系统相关功能的模块,如文件和目录操作

import mindspore  # 导入MindSpore库,这是一个深度学习框架
from mindspore.dataset import text, GeneratorDataset, transforms  # 从MindSpore的数据集模块导入处理文本、生成数据集和变换功能
from mindspore import nn  # 从MindSpore库导入神经网络模块

from mindnlp.dataset import load_dataset  # 从MindNLP库导入加载数据集的功能

from mindnlp._legacy.engine import Trainer, Evaluator  # 从MindNLP库的旧版本引擎模块导入训练器和评估器
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback  # 导入旧版本引擎模块的回调功能,用于检查点保存和最佳模型保存
from mindnlp._legacy.metrics import Accuracy  # 从MindNLP库的旧版本指标模块导入准确率指标

 3、加载IMDB数据集

        并将其分为训练集和测试集。load_dataset函数会返回一个包含数据集各个部分的字典,然后你可以通过键 'train''test' 来访问相应的数据。

imdb_ds = load_dataset('imdb', split=['train', 'test'])  # 加载IMDB数据集,并将数据集分为训练集和测试集,返回一个包含两个部分的字典

imdb_train = imdb_ds['train']  # 从字典中提取训练集数据
imdb_test = imdb_ds['test']  # 从字典中提取测试集数据

4、获取训练集数据集大小

get_dataset_size() 用于返回数据集中包含的样本数量。这个方法的返回值通常是一个整数,表示训练集中有多少个样本。 

imdb_train.get_dataset_size()  # 获取训练集数据集中样本的数量

 5、定义一个用于处理数据集的函数 process_dataset

        将输入文本数据进行tokenization,并根据设备类型选择不同的批处理方式。如果需要,还可以对数据集进行打乱和批处理。

import numpy as np  # 导入NumPy库,用于数值计算和数组操作

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    # 定义处理数据集的函数,接受数据集、tokenizer、最大序列长度、批量大小和是否打乱数据集作为参数

    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    # 检查当前设备是否为Ascend(华为的深度学习处理器),根据设备类型选择不同的tokenizer处理方式

    def tokenize(text):
        # 定义tokenize函数,用于对文本进行tokenization
        if is_ascend:
            # 如果在Ascend设备上,使用'padding'和'truncation'进行tokenization
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            # 否则只进行'truncation'和设置最大长度
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']
        # 返回tokenized的'input_ids'和'attention_mask'字段

    if shuffle:
        dataset = dataset.shuffle(batch_size)
        # 如果设置了shuffle参数为True,则对数据集进行打乱

    # 对数据集应用map操作
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    # 使用tokenize函数处理数据集中的"text"列,并生成新的列'input_ids'和'attention_mask'

    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # 将数据集中的"label"列转换为mindspore.int32类型,并生成新的列"labels"

    # 根据设备类型选择批处理方式
    if is_ascend:
        dataset = dataset.batch(batch_size)
        # 如果在Ascend设备上,直接进行批处理
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})
        # 否则使用padded_batch进行批处理,设置填充信息(input_ids和attention_mask的填充值)

    return dataset  # 返回处理后的数据集

6、初始化一个GPT模型的tokenizer

加载预训练的GPT模型,然后添加了一些特殊标记,如句子开始标记、句子结束标记和填充标记。

from mindnlp.transformers import GPTTokenizer
# 从MindNLP库中导入GPTTokenizer,用于加载和处理GPT模型的tokenizer

gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')
# 使用预训练的GPT tokenizer进行初始化,指定预训练模型名称为'openai-gpt'

# add special token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",  # 句子开始标记
    "eos_token": "<eos>",  # 句子结束标记
    "pad_token": "<pad>",  # 填充标记
}
# 定义一个字典,用于添加特定的特殊标记到tokenizer中

num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)
# 将定义的特殊标记添加到tokenizer中,并返回添加的标记数量

 7、划分训练集、验证集

将一个训练数据集 imdb_train 按照 70% 和 30% 的比例分割成两个数据集:一个用于训练 (imdb_train),另一个用于验证 (imdb_val)。

# split train dataset into train and valid datasets
# 将训练数据集拆分成训练集和验证集

imdb_train, imdb_val = imdb_train.split([0.7, 0.3])
# 将 imdb_train 数据集按 70% 和 30% 的比例分割成两个数据集
# imdb_train 包含 70% 的数据,用于继续训练
# imdb_val 包含 30% 的数据,用于验证模型

8、处理三个数据集:训练集、验证集和测试集。

process_dataset 函数的作用是对数据集进行预处理,包括标记化、清洗或其他转换操作。gpt_tokenizer 是用于将文本数据转换为模型可以理解的格式的标记化工具。数据集的打乱(shuffle=True)有助于防止模型训练中的过拟合和提升泛化能力。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
# 使用 process_dataset 函数处理 imdb_train 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# shuffle=True 表示对数据进行随机打乱,以提高训练效果
# 处理后的数据集存储在 dataset_train 中

dataset_val = process_dataset(imdb_val, gpt_tokenizer)
# 使用 process_dataset 函数处理 imdb_val 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# 处理后的数据集存储在 dataset_val 中
# 这里没有指定 shuffle 参数,默认情况下数据不会被打乱

dataset_test = process_dataset(imdb_test, gpt_tokenizer)
# 使用 process_dataset 函数处理 imdb_test 数据集
# gpt_tokenizer 用于对数据进行标记化处理
# 处理后的数据集存储在 dataset_test 中
# 这里也没有指定 shuffle 参数,默认情况下数据不会被打乱

9、从 dataset_train 中获取下一个数据样本

  • dataset_train.create_tuple_iterator():将 dataset_train 数据集转换为一个迭代器,这个迭代器返回的每一项都是一个元组(通常包含输入数据和对应的标签)。
  • next():从迭代器中获取下一个数据项。如果迭代器中没有更多的数据项,它将引发 StopIteration 异常。
next(dataset_train.create_tuple_iterator())
# 创建一个迭代器,用于遍历 dataset_train 数据集中的数据
# 使用 create_tuple_iterator() 方法将 dataset_train 转换为一个元组迭代器
# 然后调用 next() 函数从迭代器中获取下一个数据样本
# 这将返回数据集中的第一个数据项(通常是一个包含特征和标签的元组)

10、配置模型

配置一个 GPT 模型进行序列分类任务,并设置了训练过程中的优化器、评估指标、回调函数等。

  1. 模型定义和配置

    • GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2):从预训练的 GPT 模型加载,并配置为二分类任务。
    • model.config.pad_token_id = gpt_tokenizer.pad_token_id:设置模型的填充标记 ID。
    • model.resize_token_embeddings(model.config.vocab_size + 3):扩展模型的词汇表以适应新增的词汇。
  2. 优化器和学习率

    • optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5):使用 Adam 优化器,学习率设置为 0.00002。
  3. 评估指标

    • metric = Accuracy():使用准确率作为模型性能评估指标。
  4. 回调函数

    • CheckpointCallback:用于保存训练过程中的模型检查点。
    • BestModelCallback:用于保存和自动加载最佳模型检查点。
  5. 训练配置

    • Trainer:用于模型的训练,指定了模型、数据集、优化器、回调函数等参数。
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# set bert config and define parameters for training
# 设置模型配置和训练参数

# 创建一个 GPT 模型用于序列分类任务,num_labels=2 表示有两个分类标签
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)

# 配置模型的填充标记 ID 为 tokenzier 中的 pad_token_id
model.config.pad_token_id = gpt_tokenizer.pad_token_id

# 调整模型的词汇表大小,为模型词汇表增加 3 个新的词汇
model.resize_token_embeddings(model.config.vocab_size + 3)

# 使用 Adam 优化器,并设置学习率为 2e-5
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

# 定义准确率作为评估指标
metric = Accuracy()

# 定义回调函数以保存检查点
# ckpoint_cb 用于保存模型检查点
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)

# best_model_cb 用于保存最佳模型检查点,并自动加载最佳模型
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

# 创建一个 Trainer 对象用于训练模型
trainer = Trainer(
    network=model,                       # 训练的模型
    train_dataset=dataset_train,         # 训练数据集
    eval_dataset=dataset_train,          # 验证数据集(这里使用了训练集进行验证)
    metrics=metric,                      # 评估指标
    epochs=1,                            # 训练轮次
    optimizer=optimizer,                 # 优化器
    callbacks=[ckpoint_cb, best_model_cb], # 回调函数
    jit=False                            # 是否使用 JIT 编译
)

 11、trainer.run(tgt_columns="labels") 启动了模型的训练过程,并指定了目标列。

  • trainer.run(): 这个方法启动了模型的训练过程。它会根据之前配置的训练参数(如模型、优化器、数据集、回调函数等)开始训练。

  • tgt_columns="labels": 这是一个参数,指定了数据集中哪个列作为模型的目标列(即标签列)。在这里,"labels" 表示数据集中用于训练和验证的目标列是 labels。这个列的值用于计算损失函数并进行模型的优化。

这种设置通常用于数据集中包含多个列的情况,其中一个列是模型训练的目标输出。在这个例子中,labels 列包含了分类任务中的标签。

trainer.run(tgt_columns="labels")
# 启动训练过程
# tgt_columns="labels" 指定了训练数据集中包含的目标列(标签列)
# 训练过程将使用这些目标列来计算损失和进行梯度更新

12、 设置并运行了模型的评估过程

  1. 创建 Evaluator 对象:

    • network=model:指定要评估的模型。
    • eval_dataset=dataset_test:指定用于评估的数据集。这里使用了 dataset_test 数据集进行评估。
    • metrics=metric:指定评估时使用的指标。在这里是准确率(Accuracy())。
  2. 运行评估:

    • evaluator.run(tgt_columns="labels"):启动模型的评估过程。
    • tgt_columns="labels":指定目标列(即数据集中用于计算评估指标的列)。在这里,"labels" 表示模型将使用这个列中的标签来计算准确率。

通过这种设置,Evaluator 对象会遍历 dataset_test 数据集,计算模型在测试集上的表现,并根据指定的评估指标(准确率)输出结果。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
# 创建一个 Evaluator 对象,用于评估模型
# 参数解释:
# network=model: 要评估的模型
# eval_dataset=dataset_test: 用于评估的数据集
# metrics=metric: 评估过程中使用的指标(这里是准确率)

evaluator.run(tgt_columns="labels")
# 启动评估过程
# tgt_columns="labels" 指定了数据集中包含的目标列(标签列)
# 评估过程中使用这些目标列来计算评估指标(如准确率)

打卡

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/843310.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTTPServer改进思路2(mudou库核心思想融入)

mudou网络库思想理解 Reactor与多线程 服务器构建过程中&#xff0c;不仅仅使用一个Reactor&#xff0c;而是使用多个Reactor&#xff0c;每个Reactor执行自己专属的任务&#xff0c;从而提高响应效率。 首先Reactor是一种事件驱动处理模式&#xff0c;其主要通过IO多路复用…

基于WebGoat平台的SQL注入攻击

目录 引言 一、安装好JAVA 二、下载并运行WebGoat 三、注册并登录WebGoat 四、模拟攻击 1. 第九题 2. 第十题 3. 第十一题 4. 第十二题 5. 第十三题 五、思考体会 1. 举例说明SQL 注入攻击发生的原因。 2. 从信息的CIA 三要素&#xff08;机密性、完整性、可用性&…

JAVA:Filer过滤器+案例:请求IP访问限制和请求返回值修改

JAVA&#xff1a;Filer过滤器 介绍 Java中的Filter也被称为过滤器&#xff0c;它是Servlet技术的一部分&#xff0c;用于在web服务器上拦截请求和响应&#xff0c;以检查或转换其内容。 Filter的urlPatterns可以过滤特定地址http的请求&#xff0c;也可以利用Filter对访问请求…

[数据分析]脑图像处理工具

###############ATTENTION&#xff01;############### 非常需要注意软件适配的操作系统&#xff01;有些仅适用于Linux&#xff0c;可以点进各自软件手册查看详情。 需要自行查看支持的影像模态。 代码库和软件我没有加以区分。 不是专门预处理的博客&#xff01;&#xf…

Richteck立锜科技电源管理芯片简介及器件选择指南

一、电源管理简介 电源管理组件的选择和应用本身的电源输入和输出条件是高度关联的。 输入电源是交流或直流&#xff1f;需求的输出电压比输入电压高或是低&#xff1f;负载电流多大&#xff1f;系统是否对噪讯非常敏感&#xff1f;也许系统需要的是恒流而不是稳压 (例如 LED…

Mac装虚拟机占内存吗 Mac用虚拟机装Windows流畅吗

如今&#xff0c;越来越多的Mac用户选择在他们的设备上安装虚拟机来运行不同的操作系统。其中&#xff0c;最常见的是使用虚拟机在Mac上运行Windows。然而&#xff0c;许多人担心在Mac上装虚拟机会占用大量内存&#xff0c;影响电脑系统性能。此外&#xff0c;有些用户还关心在…

k8s中部署nacos

1 部署nfs # 在k8s的主节点上执行 mkdir -p /appdata/download cd /appdata/download git clone https://github.com/nacos-group/nacos-k8s.git 将nacos部署到middleware的命名空间中 kubectl create namespace middleware cd /appdata/download/nacos-k8s # 创建角色 kub…

图论模型-迪杰斯特拉算法和贝尔曼福特算法★★★★

该博客为个人学习清风建模的学习笔记&#xff0c;部分课程可以在B站&#xff1a;【强烈推荐】清风&#xff1a;数学建模算法、编程和写作培训的视频课程以及Matlab等软件教学_哔哩哔哩_bilibili 目录 ​1图论基础 1.1概念 1.2在线绘图 1.2.1网站 1.2.2MATLAB 1.3无向图的…

ABAP打印WORD的解决方案

客户要求按照固定格式输出到WORD模板中&#xff0c;目前OLE和DOI研究了均不太适合用于这种需求。 cl_docx_document类可以将WORD转化为XML文件&#xff0c;利用替换字符串方法将文档内容进行填充同 时不破坏WORD现有格式。 首先需要将WORD的单元格用各种预定义的字符进行填充…

canvas:矢量点转栅格

案例描述 ArcGIS提供了“点转栅格”的工具,可以将矢量点转换为栅格数据,以下尝试基于canvas绘图技术,实现经纬度矢量点转换为canvas栅格数据,并在Cesium.js三维地图中进行渲染。 原始数据 转出栅格 案例分析 实现的关键点在于:如何将经纬度坐标与canvas画布坐标进…

【Vue3】工程创建及目录说明

【Vue3】工程创建及目录说明 背景简介开发环境开发步骤及源码 背景 随着年龄的增长&#xff0c;很多曾经烂熟于心的技术原理已被岁月摩擦得愈发模糊起来&#xff0c;技术出身的人总是很难放下一些执念&#xff0c;遂将这些知识整理成文&#xff0c;以纪念曾经努力学习奋斗的日…

Figma 中文版指南:获取和安装汉化插件

Figma是一种主流的在线团队合作设计工具&#xff0c;也是一种基于 Web 端的设计工具。在当今的设计时代&#xff0c;Figma 的使用满足了每个人的设计需求&#xff0c;不仅可以实现在线编辑&#xff0c;还可以方便日常管理&#xff0c;有效提高工作效率。然而&#xff0c;相信很…

Java查询ES报错 I/O 异常解决方法: Request cannot be executed; I/O reactor status: STOPPED

问题 ES Request cannot be executed; I/O reactor status: STOPPED 报错解决 在使用ES和SpringBoot进行数据检索时&#xff0c;在接口中第一次搜索正常。第二次在搜索时在控制台就会输出Request cannot be executed; I/O reactor status: STOPPED错误 原因 本文错误是因为在使…

51单片机14(独立按键实验)

一、按键介绍 1、按键是一种电子开关&#xff0c;使用的时候&#xff0c;只要轻轻的按下我们的这个按钮&#xff0c;按钮就可以使这个开关导通。 2、当松开这个手的时候&#xff0c;我们的这个开关&#xff0c;就断开开发板上使用的这个按键&#xff0c;它的内部结构&#xff…

用Java手写jvm之实现java -version的效果

写在前面 源码 。 本文来用纯纯的Java代码来实现java -version的效果&#xff0c;就像下面这样&#xff1a; 1&#xff1a;程序 这里输出类似这样的&#xff1a; java version "9" Java(TM) SE Runtime Environment (build 9181) Java HotSpot(TM) 64-Bit Serve…

突破•指针二

听说这是目录哦 复习review❤️野指针&#x1fae7;assert断言&#x1fae7;assert的神奇之处 指针的使用和传址调用&#x1fae7;数组名的理解&#x1fae7;理解整个数组和数组首元素地址的区别 使用指针访问数组&#x1fae7;一维数组传参的本质&#x1fae7;二级指针&#x…

filebeat,kafka,clickhouse,ClickVisual搭建轻量级日志平台

springboot集成链路追踪 springboot版本 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.6.3</version><relativePath/> <!-- lookup parent from…

python—爬虫爬取电影页面实例

下面是一个简单的爬虫实例&#xff0c;使用Python的requests库来发送HTTP请求&#xff0c;并使用lxml库来解析HTML页面内容。这个爬虫的目标是抓取一个电影网站&#xff0c;并提取每部电影的主义部分。 首先&#xff0c;确保你已经安装了requests和lxml库。如果没有安装&#x…

HTML零基础自学笔记(上)-7.18

HTML零基础自学笔记&#xff08;上&#xff09; 参考&#xff1a;pink老师一、HTML, Javascript, CSS的关系是什么?二、什么是HTML?1、网页&#xff0c;网站的概念2、THML的基本概念3、THML的骨架标签/基本结构标签 三、HTML标签1、THML标签介绍2、常用标签图像标签&#xff…

数据结构----算法复杂度

1.数据结构前言 数据是杂乱无章的&#xff0c;我们要借助结构将数据管理起来 1.1 数据结构 数据结构(Data Structure)是计算机存储、组织数据的⽅式&#xff0c;指相互之间存在⼀种或多种特定关系的数 据元素的集合。没有⼀种单⼀的数据结构对所有⽤途都有⽤&#xff0c;所…