Mindspore 公开课 - GPT

GPT Task

在模型 finetune 中,需要根据不同的下游任务来处理输入,主要的下游任务可分为以下四类:

  • 分类(Classification):给定一个输入文本,将其分为若干类别中的一类,如情感分类、新闻分类等;
  • 蕴含(Entailment):给定两个输入文本,判断它们之间是否存在蕴含关系(即一个文本是否可以从另一个文本中推断出来);
  • 相似度(Similarity):给定两个输入文本,计算它们之间的相似度得分;
  • 多项选择题(Multiple choice):给定一个问题和若干个答案选项,选择最佳的答案。

在这里插入图片描述
我们使用IMDb数据集,通过finetune GPT进行情感分类任务。

IMDb数据集是一个常用的情感分类数据集,其中包含50,000条影评文本,其中25,000条用作训练数据,另外25,000条用作测试数据。每个样本都有一个二元标签,表示影评的情感是正面还是负面。

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp import load_dataset
from mindnlp.transforms import PadTransform, GPTTokenizer

from mindnlp.engine import Trainer, Evaluator
from mindnlp.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp.metrics import Accuracy

数据预处理

# 加载数据集
imdb_train, imdb_test = load_dataset('imdb', shuffle=False)

通过load_dataset加载IMDb数据集后,我们需要对数据进行如下处理:

  • 将文本内容进行分词,并映射为对应的数字索引;
  • 统一序列长度:超过进行截断,不足通过占位符进行补全;
  • 按照分类任务的输入要求,在句首和句末分别添加Start与Extract占位符(此处用 <bos>与 <eos>表示);
  • 批处理。
import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False):
    """数据集预处理"""
    def pad_sample(text):
        if len(text) + 2 >= max_seq_len:
            return np.concatenate(
                [np.array([tokenizer.bos_token_id]), text[: max_seq_len-2], np.array([tokenizer.eos_token_id])]
            )
        else:
            pad_len = max_seq_len - len(text) - 2
            return np.concatenate( 
                [np.array([tokenizer.bos_token_id]), text,
                 np.array([tokenizer.eos_token_id]),
                 np.array([tokenizer.pad_token_id] * pad_len)]
            )

    column_names = ["text", "label"]
    rename_columns = ["input_ids", "label"]

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    dataset = dataset.map(operations=[tokenizer, pad_sample], input_columns="text")
    dataset = dataset.rename(input_columns=column_names, output_columns=rename_columns)
    dataset = dataset.batch(batch_size)

    return dataset

加载 GPT tokenizer,并添加上述使用到的 <bos>, <eos>, <pad>占位符。

gpt_tokenizer = GPTTokenizer.from_pretrained('openai-gpt')

special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

由于IMDb数据集本身不包含验证集,我们手动将其分割为训练和验证两部分,比例取0.7, 0.3。

imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

模型训练

同BERT课件中的情感分类任务实现,这里我们依旧使用了混合精度。另外需要注意的一点是,由于在前序数据处理中,我们添加了3个特殊占位符,所以在token embedding中需要调整词典的大小(vocab_size + 3)。

from mindnlp.models import GPTForSequenceClassification
from mindnlp._legacy.amp import auto_mixed_precision

model = GPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)
model.pad_token_id = gpt_tokenizer.pad_token_id
model.resize_token_embeddings(model.config.vocab_size + 3)
model = auto_mixed_precision(model, 'O1')

loss = nn.CrossEntropyLoss()
optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)

metric = Accuracy()

ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='sentiment_model', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', auto_load=True)

trainer = Trainer(network=model, train_dataset=dataset_train,
                  eval_dataset=dataset_val, metrics=metric,
                  epochs=3, loss_fn=loss, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb], jit=True)

trainer.run(tgt_columns="label")

模型评估

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="label")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/323959.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式软件工程师面试题——2025校招社招通用(十八)

说明&#xff1a; 面试群&#xff0c;群号&#xff1a; 228447240面试题来源于网络书籍&#xff0c;公司题目以及博主原创或修改&#xff08;题目大部分来源于各种公司&#xff09;&#xff1b;文中很多题目&#xff0c;或许大家直接编译器写完&#xff0c;1分钟就出结果了。但…

pandas列与列之间的计算

1.简单计算 &#xff08;1&#xff09;首先导入pandas模块并读取数据 import pandas as pd adress"D:/pandas练习文件/计算.xlsx" datapd.read_excel(adress) &#xff08;2&#xff09; 计算销售额 销售额售价*销售数量 data["销售额"]data["售价&…

Unet系列网络解析

Unet UNet最早发表在2015的MICCAI上&#xff0c;到2020年中旬的引用量已经超过了9700多次&#xff0c;估计现在都过万了&#xff0c;从这方面看足以见得其影响力。当然&#xff0c;UNet这个基本的网络结构有太多的改进型&#xff0c;应用范围已经远远超出了医学图像的范畴。我…

Oracle常见操作

知识点1&#xff1a;格式化日期 select to_char(sysdate,yyyy-MM-dd HH:mm:ss) as time from dual;运行截图&#xff1a; 知识点2&#xff1a;解锁用户 alter user test account unlock;知识点3&#xff1a;修改密码 alter user test identified by test2;知识点4&#xff…

JVM-Arthas高效的监控工具

一、arthas介绍 3.选择监控哪个进程 4.进入具体进程 二、arthas的基础命令与基本操作 1.查询包含Java的系统属性&#xff1a; 命令&#xff1a;sysprop |grep java 1.查询不含Java的系统属性&#xff1a; 命令&#xff1a;sysprop | grep -v java 3.打印历史命令 命令&#…

vector容器解决杨辉三角

一、题目描述 118. 杨辉三角 给定一个非负整数 numRows&#xff0c;生成「杨辉三角」的前 numRows 行。 在「杨辉三角」中&#xff0c;每个数是它左上方和右上方的数的和。 示例 1: 输入: numRows 5 输出: [[1],[1,1],[1,2,1],[1,3,3,1],[1,4,6,4,1]]示例 2: 输入: numRo…

008 Linux_文件在磁盘上的管理

前言 本文将会向你介绍文件在磁盘上是如何被管理的 磁盘的物理存储结构 系统中的文件分为被打开的文件&#xff0c;和没被打开的文件&#xff0c;被打开的文件在内存中进行管理&#xff0c;而没有被打开的文件则在磁盘中进行管理 同样地&#xff0c;也需要对这些没被打开的文…

Vulnhub靶机:driftingblues 5

一、介绍 运行环境&#xff1a;Virtualbox 攻击机&#xff1a;kali&#xff08;10.0.2.15&#xff09; 靶机&#xff1a;driftingblues5&#xff08;10.0.2.21&#xff09; 目标&#xff1a;获取靶机root权限和flag 靶机下载地址&#xff1a;https://www.vulnhub.com/entr…

实验 2:语法分析程序的设计——以LL(1)为例(更适合njupt体质的宝宝)

实验 2&#xff1a;语法分析程序的设计——以LL(1)为例 PS&#xff1a;源码私聊 1.实验软件环境 ​ 开发平台&#xff1a;Windows 11 ​ 编程语言&#xff1a;C ​ 编译器版本&#xff1a;GCC 11.20 C20 ​ IDE&#xff1a;VS Code 2.实验原理描述 求 F i r s t First Fi…

基于springboot数码论坛系统源码和论文

网络的广泛应用给生活带来了十分的便利。所以把数码论坛与现在网络相结合&#xff0c;利用java技术建设数码论坛系统&#xff0c;实现数码论坛的信息化。则对于进一步提高数码论坛发展&#xff0c;丰富数码论坛经验能起到不少的促进作用。 数码论坛系统能够通过互联网得到广泛…

如何判断售卖的医疗器械产品是二类还是三类

售卖医疗器械需关注产品本身是否为一、二、三类医疗器械。第一类医疗器械为一般项目的经营范围无需取得备案或许可证即可销售。第二类医疗器械产品需办理第二类医疗器械的备案方可销售。第三类医疗器械需取得医疗器械经营许可证且许可证上的经营范围需与销售的产品对应方可销售…

IIC协议

文章目录 IIC介绍通信距离通信速度主从方式通信方式物理结构 IIC协议空闲状态开始信号、结束信号和应答信号向从机发送数据的过程读取从机数据的过程数据有效性协议一帧构成 IIC介绍 IIC(Inter&#xff0d;Integrated Circuit)总线是一种由 PHILIPS 公司开发的两线式串行总线。…

ArkUI-X跨平台已至,何需其它!

运行环境 DevEco Studio&#xff1a;4.0Release OpenHarmony SDK API10 开发板&#xff1a;润和DAYU200 自从写了一篇ArkUI-X跨平台的文章之后&#xff0c;好多人都说对这个项目十分关注。 那么今天我们就来完整的梳理一下这个项目。 1、ArkUI-X 我们之前可能更多接触的…

class_5:在c++中一个类包含另一个类的对象叫做组合

#include <iostream> using namespace std;class Wheel{ public://成员数据string brand; //品牌int year; //年限//真正的成员函数void printWheelInfo(); //声明成员函数 };void Wheel::printWheelInfo() {cout<<"我的轮胎品牌是&#xff1a;"<…

MySQL之多表查询

1.创建student和score表 CREATE TABLE student ( id INT(10) NOT NULL UNIQUE PRIMARY KEY , name VARCHAR(20) NOT NULL , sex VARCHAR(4) , birth YEAR, department VARCHAR(20) , address VARCHAR(50) ); 创建score表。SQL代码如下&#xff1a; CREATE TABLE sc…

数据库多表查询练习题

二、多表查询 1. 创建 student 和 score 表 CREATE TABLE student ( id INT ( 10 ) NOT NULL UNIQUE PRIMARY KEY , name VARCHAR ( 20 ) NOT NULL , sex VARCHAR ( 4 ) , birth YEAR , department VARCHAR ( 20 ) , address VARCHAR ( 50 ) ); 创建 s…

02章【JAVA编程基础】

变量与标识符 变量 数学名词&#xff1a;变数或变量&#xff0c;是指没有固定的值&#xff0c;可以改变的数。变量以非数字的符号来表达&#xff0c;一般用拉丁字母。变量是常数的相反。变量的用处在于能一般化描述指令的方式。计算机解释&#xff1a;变量就是系统为程序分配…

IOS自动化测试元素定位

一、元素属性介绍 1、元素属性 2、查看各定位方式执行效率 二、iOS常用定位方法 1、accessibility_id 2、class_name 3、Xpath 4、ios_class_chain(类型链) 5、ios_predicate(谓词) 一个页面最基本组成单元是元素&#xff0c;想要定位一个元素&#xff0c;我们需…

德语B2 SampleAcademy

德语B2 SampleAcademy 一, Zweigliedrige Konnektion1, entweder... oder2, nicht nur... sondern auch3,sowohl... als auch4,einerseits... andererseits5, zwar...aber6, weder...noch7,je...desto/umso 一, Zweigliedrige Konnektion discontinuous conjunctions 1,…

基础面试题整理4

1.mybatis的#{}和${}区别 #{}是预编译处理&#xff0c;${}是字符串替换#{}可以防止SQL注入&#xff0c;提高安全性 2.mybatis隔离级别 读未提交 READ UNCOMMITED&#xff1a;读到了其他事务中未提交的数据&#xff0c;造成"脏读","不可重复读","幻读&…