昇思25天学习打卡营第11天|基于MindSpore通过GPT实现情感分类

学AI还能赢奖品?每天30分钟,25天打通AI任督二脉 (qq.com)

基于MindSpore通过GPT实现情感分类

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 该案例在 mindnlp 0.3.1 版本完成适配,如果发现案例跑不通,可以指定mindnlp版本,执行`!pip install mindnlp==0.3.1`
!pip install mindnlp
!pip install jieba
%env HF_ENDPOINT=https://hf-mirror.com
import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']
imdb_train.get_dataset_size()

加载IMDB数据集。将IMDB数据集分为训练集和测试集。IMDB (Internet Movie Database) 数据集包含来自著名在线电影数据库 IMDB 的电影评论。每条评论都被标注为正面(positive)或负面(negative),因此该数据集是一个二分类问题,也就是情感分类问题。

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset

定义数据预处理函数。这个函数输入参数为数据集、分词器(GPT Tokenizer)以及一些可选参数,如最大序列长度、批量大小和是否打乱数据。预处理包括将文本转换为模型可以理解的输入格式(如input_ids和attention_mask),并将标签转换为整数类型。

from mindnlp.transformers import GPTTokenizer
# tokenizer
gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

加载GPT分词器并增加特殊标记。

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

将训练集划分为训练集和验证集。

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

用 process_dataset 函数对训练集、验证集和测试集进行处理,得到相应的数据集对象。

next(dataset_train.create_tuple_iterator())
from mindnlp.transformers import GPTForSequenceClassification
from mindspore.experimental.optim import Adam

# set bert config and define parameters for training
model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.config.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)

optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='gpt_imdb_finetune_best', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_train, metrics=metric,
                  epochs=1, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb],
                  jit=False)

导入 GPTForSequenceClassification 模型和 Adam 优化器。设置GPT模型的配置信息,包括pad_token_id和词汇表大小。使用Adam优化器对模型的可训练参数进行优化(从这里没有看出是更新部分参数,还是全部参数,有可能是部分参数。通常会改变最后一层分类器的权重和偏置,其他层的权重被冻结不变或者只微小更新些许参数。)。

Accuracy作为评价指标。

定义回调函数用于保存检查点:

   - CheckpointCallback:用于定期保存模型权重,save_path 指定了保存路径,ckpt_name保存文件的前缀,epochs=1 每个epoch保存一次,keep_checkpoint_max=2 表示最多保留2个检查点文件。
   - BestModelCallback:用于保存验证集上表现最好的模型,auto_load=True表示在训练结束后自动加载最优模型的权重。

创建 Trainer 对象,传入以下参数:
      - network:要训练的模型。
      - train_dataset:训练数据集。
      - eval_dataset:验证数据集。
      - metrics:评估指标。
      - epochs:训练轮数。
      - optimizer:优化器。
      - callbacks:回调函数列表,包括检查点保存和最佳模型保存。
      - jit:是否启用JIT编译,这里设置为False。

trainer.run(tgt_columns="labels")

通过 Trainer 的 run 方法启动训练,指定了训练过程中的目标标签列为 "labels"。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

创建 Evaluator 对象,传入以下参数:
      - network:要评估的模型。
      - eval_dataset:测试数据集。
      - metrics:评估指标。

用MindSpore通过GPT实现情感分类(Sentiment Classification)的示例。首先加载了IMDB影评数据集,并将其划分为训练集、验证集和测试集。然后使用GPTTokenizer对文本进行了标记化和转换。接下来,使用GPTForSequenceClassification构建了情感分类模型,并定义了优化器和评估指标。使用Trainer进行模型的训练,并设置了保存检查点的回调函数。训练完成后,通过Evaluator对测试集进行评估,输出分类准确率。通过对IMDB影评数据集进行训练和评估,模型可以自动进行情感分类,识别出正面或负面情感。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/755663.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Mysql常用SQL:日期转换成周_DAYOFWEEK(date)

有时候需要将查询出来的日期转换成周几&#xff0c;Mysql本身语法就是支持这种转换的&#xff0c;就是DAYOFWEEK()函数 语法格式&#xff1a;DAYOFWEEK(date) &#xff08;date&#xff1a;可以是指定的具体日期&#xff08; 如2024-06-29 &#xff09;&#xff0c;也可以是日期…

一个项目学习IOS开发---创建一个IOS开发项目

前提&#xff1a; 由于IOS开发只能在MacOS上开发&#xff0c;所以黑苹果或者购买一台MacBook Pro是每个IOS开发者必备的技能或者工具之一 Swift开发工具一般使用MacOS提供的Xcode开发工具 首先Mac Store下载Xcode工具 安装之后打开会提醒你安装IOS的SDK&#xff0c;安装好之…

媒体宣发套餐的概述及推广方法-华媒舍

在今天的数字化时代&#xff0c;对于产品和服务的宣传已经变得不可或缺。媒体宣发套餐作为一种高效的宣传方式&#xff0c;在帮助企业塑造品牌形象、扩大影响力方面扮演着重要角色。本文将揭秘媒体宣发套餐&#xff0c;为您呈现一条通往成功的路。 1. 媒体宣发套餐的概述 媒体…

使用Tailwindcss之后,vxe-table表头排序箭头高亮消失的问题解决

环境 vue2.7.8 vxe-table3.5.9 tailwindcss/postcss7-compat2.2.17 postcss7.0.39 autoprefixer9.8.8 问题 vxe-table 表格表头 th 的排序箭头在开启正序或逆序排序时&#xff0c;会显示蓝色高亮来提示用户表格数据处在排序情况下。在项目开启运行了tailwindcss之后&#xff0…

Kafka入门-基础概念及参数

一、Kafka术语 Kafka属于分布式的消息引擎系统&#xff0c;它的主要功能是提供一套完备的消息发布与订阅解决方案。可以为每个业务、每个应用甚至是每类数据都创建专属的主题。 Kafka的服务器端由被称为Broker的服务进程构成&#xff0c;即一个Kafka集群由多个Broker组成&#…

dledger原理源码分析系列(二)-心跳

简介 dledger是openmessaging的一个组件&#xff0c; raft算法实现&#xff0c;用于分布式日志&#xff0c;本系列分析dledger如何实现raft概念&#xff0c;以及dledger在rocketmq的应用 本系列使用dledger v0.40 本文分析dledger的心跳 关键词 Raft Openmessaging 心跳/…

Android14之RRO资源文件替换策略(二百二十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+AOSP…

.NET 一款利用内核驱动关闭AV/EDR的工具

01阅读须知 此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等&#xff08;包括但不限于&#xff09;进行检测或维护参考&#xff0c;未经授权请勿利用文章中的技术资料对任何计算机系统进行入侵操作。利用此文所提供的信息而造成的直接或间接后果和损失&#xf…

微服务 | Springboot整合GateWay+Nacos实现动态路由

1、简介 路由转发 执行过滤器链。 ​ 网关&#xff0c;旨在为微服务架构提供一种简单有效的统一的API路由管理方式。同时&#xff0c;基于Filter链的方式提供了网关的基本功能&#xff0c;比如&#xff1a;鉴权、流量控制、熔断、路径重写、黑白名单、日志监控等。 基本功能…

搜维尔科技:「研讨会」惯性动捕技术在工效学领域应用研讨会

Movella将于7月2日&#xff08;周二&#xff09;下午2点举行主题为惯性动捕技术在工效学领域应用的研讨会。来自Movella的伙伴赋能经理Jeffrey Muller作为嘉宾出席&#xff0c;届时主讲人将为大家带来Xsens惯性动捕技术在工效学领域的应用分享。同时&#xff0c;研讨会还邀请多…

最近写javaweb出现的一个小bug---前端利用 form 表单传多项数据,后端 Servlet 取出的各项数据均为空

目录&#xff1a; 一. 问题引入二 解决问题 一. 问题引入 近在写一个 java web 项目时&#xff0c;遇到一个让我头疼了晚上的问题&#xff1a;前端通过 post 提交的 form 表单数据可以传到后端&#xff0c;但当我从 Servlet 中通过 request.getParameter(“name”) 拿取各项数…

竞赛选题 python的搜索引擎系统设计与实现

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; python的搜索引擎系统设计与实现 &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&#xff1a;3分工作量&#xff1a;5分创新点&#xff1a;3分 该项目较为新颖&#xff…

如何用CSS样式实现一个优雅的渐变效果?

CSS渐变效果 CSS渐变&#xff08;Gradients&#xff09;是一种让两种或多种颜色平滑过渡的视觉效果&#xff0c;广泛应用于网页背景、按钮、边框等&#xff0c;以创造丰富的视觉体验。CSS提供了线性渐变&#xff08;Linear Gradients&#xff09;和径向渐变&#xff08;Radial…

【软件实施】软件实施概论

目录 软件实施概述定义主要工作软件项目的实施工作区别于一般的项目&#xff08;如&#xff1a;房地产工程项目&#xff09;软件实施的重要性挑战与对策软件项目实施的流程软件项目实施的周期 软件企业软件企业分类产品型软件企业业务特点产品型软件企业的分类产品型软件企业的…

web安全渗透测试十大常规项(一):web渗透测试之深入JAVA反序列化

渗透测试之PHP反序列化 1. Java反序列化1.1 FastJson反序列化链知识点1.2 FastJson反序列化链分析1.3.1 FastJson 1.2.24 利用链分析1.3.2 FastJson 1.2.25-1.2.47 CC链分析1.3.2.1、开启autoTypeSupport:1.2.25-1.2.411. Java反序列化 1.1 FastJson反序列化链知识点 1、为什…

【scau大数据原理】期末复习——堂测题

一、集群安装知识 启动集群的命令start-all.sh位于 Hadoop安装目录的sbin文件夹 目录下。 bin文件夹下包含常见的Hadoop,yarn命令&#xff1b;sbin命令下包含集群的启动、停止命令。 启动集群的命令start-all.sh包含 同时启动start-dfs.sh和start-yarn.sh 功能。…

JetBrains PyCharm 2024 mac/win版编程艺术,智慧新篇

JetBrains PyCharm 2024是一款功能强大的Python集成开发环境(IDE)&#xff0c;专为提升开发者的编程效率和体验而设计。这款IDE不仅继承了前代版本的优秀特性&#xff0c;还在多个方面进行了创新和改进&#xff0c;为Python开发者带来了全新的工作体验。 JetBrains PyCharm 20…

LED封装技术中SMD、COB和GOB的优缺点

在小间距LED显示屏的封装技术中&#xff0c;SMD、COB和GOB各有其优缺点&#xff0c;以下是对这些技术的详细分析&#xff1a; SMD&#xff08;Surface Mounted Devices&#xff09;表贴工艺技术 SMD技术是将LED灯珠焊接在电路板上的一种成熟技术&#xff0c;广泛应用于LED显示屏…

java+mysql通讯录管理

完整代码地址如果控制台打印出现乱码&#xff0c;进行如下设置

10分钟完成微信JSAPI支付对接过程-JAVA后端接口

引入架包 <dependency><groupId>com.github.javen205</groupId><artifactId>IJPay-WxPay</artifactId><version>${ijapy.version}</version></dependency>配置类 package com.joolun.web.config;import org.springframework.b…