深度学习:GPT-1的MindSpore实践

GPT-1简介

GPT-1(Generative Pre-trained Transformer)是2018年由Open AI提出的一个结合预训练和微调的用于解决文本理解和文本生成任务的模型。它的基础是Transformer架构,具有如下创新点:

  • NLP领域的迁移学习:通过最少的任务专项数据,利用预训练模型出色地完成具体的下游任务。
  • 语言建模作为预训练任务:使用无监督学习和大规模的文本语料库来训练模型
  • 为具体任务微调:采用预训练模型来适应监督任务

和BERT类似,GPT-1同样采取pre-train + fine-tune的思路:先基于大量未标注语料数据进行预训练, 后基于少量标注数据进行微调。但GPT-1在预训练任务思路和模型结构上与BERT有所差别。

GPT-1的目标是在预训练的过程中根据现有的所有词元,预测下一个词元。这个任务被称为“自回归语言建模”。

一个简单的例子:

输入序列为:“The sun rises in the”

训练数据的原句子为:“The sun rises in the east”

所以我们的目标输出为:“east”

将输入序列输入GPT模型,GPT根据输入预测下一个词元(“east”)在语料库中的概率分布

正确词元“east”作为一个“伪标签”来帮助模型训练

模型架构

GPT主要使用Transformer Decoder架构,但因为没有Encoder,所以在Transformer Decoder的基础上移除了计算Encoder与Decoder间注意力分数的Multi-Head Attention Layer。

Masked Multi-HeadSelf-Attention

Masked Multi-Head Self-Attention 是Multi-Head Attetion的变种。 最大的不同来自于MMSA的掩码机制,掩码机制防止模型通过观测未来的词元以进行“作弊”。

一个掩码词元<mask>被用于注意力分数矩阵,所以当前词元只能注意到序列中自己和自己之前的词元。未来的次元的注意力分数将被设为0以确保其在Softmax步骤后的实际贡献为0。

为什么掩码机制非常重要?

对于自回归任务,模型必须线性地生成词元,不能基于未来的信息预测下一个词元。

损失函数

GPT使用Cross-Entropy Loss作为损失函数:\mathcal{L} = - \sum_{t=1}^N \log P(w_t | w_1, w_2, \dots, w_{t-1})

交叉熵损失是这项任务的理想选择,因为它通过测量预测的概率分布与真实分布的距离来惩罚不正确的预测。它自然适于处理多类分类任务,其中模型从大量词汇表中选择一个标记。

模型输入

GPT-1的输入同样为句子或句子对,并添加Special Tokens。

  • [BOS]:表示句子的开始,(论文中给出的token表示为[START]),添加到序列最前;
  • [EOS]:表示序列的结束,(论文中的给出的[EXTRACT]),添加到序列最后,在进行分类任务时,会将 该special token对应的输出接入输出层;我们也可以理解为该token可以学习到整个句子的语义信息;
  • [SEP]:用于间隔句子对中的两个句子;
GPT Embedding 同样分为三类:token Embedding、Position Embedding、Segment Embedding

 

GPT-1模型具体参数

模型架构

  • 12个Transformer Decoder Block
  • hidden_size为768(模型输入和输出的向量纬度)
  • 注意力头数为12
  • FFN维度为3072
  • 词表(Vocab)大小为40000
  • 序列长度为512(上下文窗口长度)

训练过程

  • Adam优化器,超参数为:0.9, 0.99
  • 学习率:最大学习率:2.5x10e-4 使用2000步作为热身,随后线性衰退
  • 批大小:64
  • 梯度剪裁:1.0
  • Dropout率:0.1

训练过程

100000步,大约花费8张NVIDIA V100 GPU训练30天,共有117M参数。使用Xavier初始化,权重衰退为0.01。 

下游任务 

GPT按照生成式的逻辑统一了下游任务的应用模板,使用最后一个token([EOS]or[EXTRACT])对应的hidden state,输出到额外的输出层中,进行分类标签预测。
任务包括:文本分类(情感分类、新闻分类)、文本蕴含(根据前提推出假设)、文本语义相似度、多类选择(在多个next token中进行选择)

基于MindSpore微调GPT-1进行情感分类

# #安装mindnlp 0.4.0套件
# !pip install mindnlp
# !pip uninstall soundfile -y
# !pip install download
# !pip install jieba
# !pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

import os

import mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn

from mindnlp.dataset import load_dataset

from mindnlp.engine import Trainer

# loading dataset
imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

imdb_train.get_dataset_size()

import numpy as np

def process_dataset(dataset, tokenizer, max_seq_len=512, batch_size=4, shuffle=False):
    is_ascend = mindspore.get_context('device_target') == 'Ascend'
    def tokenize(text):
        if is_ascend:
            tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)
        else:
            tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['attention_mask']

    if shuffle:
        dataset = dataset.shuffle(batch_size)

    # map dataset
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'attention_mask'])
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # batch dataset
    if is_ascend:
        dataset = dataset.batch(batch_size)
    else:
        dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                             'attention_mask': (None, 0)})

    return dataset

from mindnlp.transformers import OpenAIGPTTokenizer
# tokenizer
gpt_tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')

# add sepcial token: <PAD>
special_tokens_dict = {
    "bos_token": "<bos>",
    "eos_token": "<eos>",
    "pad_token": "<pad>",
}
num_added_toks = gpt_tokenizer.add_special_tokens(special_tokens_dict)

#为方便体验流程,把原本数据集的十分之一拿出来体验训练和评估,
imdb_train, _ = imdb_train.split([0.1, 0.9], randomize=False)

# split train dataset into train and valid datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

dataset_train = process_dataset(imdb_train, gpt_tokenizer, shuffle=True)
dataset_val = process_dataset(imdb_val, gpt_tokenizer)
dataset_test = process_dataset(imdb_test, gpt_tokenizer)

# load GPT sequence classification model and set class=2
from mindnlp.transformers import OpenAIGPTForSequenceClassification  # Import the GPT model for sequence classification
from mindnlp import evaluate  # Import the evaluation module from MindNLP
import numpy as np  # Import NumPy for numerical operations

# Set up the GPT model for sequence classification with 2 output labels (binary classification).
model = OpenAIGPTForSequenceClassification.from_pretrained('openai-gpt', num_labels=2)

# Set the padding token ID in the model configuration to match the tokenizer's padding token ID.
model.config.pad_token_id = gpt_tokenizer.pad_token_id

# Resize the token embedding layer to account for any added tokens (e.g., special tokens).
model.resize_token_embeddings(model.config.vocab_size + 3)

from mindnlp.engine import TrainingArguments  # Import training arguments for model training configuration.

# Define training arguments.
training_args = TrainingArguments(
    output_dir="gpt_imdb_finetune",  # Directory to save model checkpoints and outputs.
    evaluation_strategy="epoch",  # Evaluate the model at the end of each epoch.
    save_strategy="epoch",  # Save model checkpoints at the end of each epoch.
    logging_strategy="epoch",  # Log metrics and progress at the end of each epoch.
    load_best_model_at_end=True,  # Automatically load the best model (based on evaluation metrics) at the end of training.
    num_train_epochs=1.0,  # Number of training epochs (default is 1 for quick experimentation).
    learning_rate=2e-5  # Learning rate for the optimizer.
)

# Load the accuracy metric for evaluation.
metric = evaluate.load("accuracy")

# Define a function to compute metrics during evaluation.
def compute_metrics(eval_pred):
    logits, labels = eval_pred  # Unpack predictions (logits) and true labels.
    predictions = np.argmax(logits, axis=-1)  # Convert logits to class predictions using argmax.
    return metric.compute(predictions=predictions, references=labels)  # Compute accuracy metric.

# Initialize the Trainer class with the model, training arguments, datasets, and metric computation function.
trainer = Trainer(
    model=model,  # The GPT model to be fine-tuned.
    args=training_args,  # Training configuration arguments.
    train_dataset=dataset_train,  # Training dataset (must be preprocessed and tokenized).
    eval_dataset=dataset_val,  # Validation dataset for evaluation.
    compute_metrics=compute_metrics  # Metric computation function for evaluation.
)

# start training
trainer.train()

trainer.evaluate(dataset_test)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/923515.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

过滤条件包含 OR 谓词,如何进行查询优化——OceanBase SQL 优化实践

这篇博客涉及两个点&#xff0c;一个是 “OR Expansion 改写”&#xff0c;另一个是 “基于代价的改写”。 背景 在写SQL查询时&#xff0c;难以避免在过滤条件中使用 OR 谓词&#xff0c;但其往往会导致索引利用效率下降的问题 。本文将分享如何通过查询改写的2种方式进行优化…

C语言函数递归经典题型——汉诺塔问题

一.汉诺塔问题介绍 Hanoi&#xff08;汉诺&#xff09;塔问题。古代有一个梵塔&#xff0c;塔内有3个座A、B、C&#xff0c;开始时&#xff21;座上有64个盘子&#xff0c;盘子大小不等&#xff0c;大的在下&#xff0c;小的在上。有一个老和尚想把这64个盘子从&#xff21;座移…

【Python】九大经典排序算法:从入门到精通的详解(冒泡排序、选择排序、插入排序、归并排序、快速排序、堆排序、计数排序、基数排序、桶排序)

文章目录 1. 冒泡排序&#xff08;Bubble Sort&#xff09;2. 选择排序&#xff08;Selection Sort&#xff09;3. 插入排序&#xff08;Insertion Sort&#xff09;4. 归并排序&#xff08;Merge Sort&#xff09;5. 快速排序&#xff08;Quick Sort&#xff09;6. 堆排序&…

lua除法bug

故事背景&#xff0c;新来了一个数值&#xff0c;要改公式。神奇的一幕出现了&#xff0c;公式算出一个非常大的数。排查是lua有一个除法bug,1除以大数得到一个非常大的数。 function div(a, b)return tonumber(string.format("%.2f", a/b)) end print(1/73003) pri…

STM32 USART串口发送+接收

单片机学习&#xff01; 目录 前言 一、串口发送配置步骤 二、详细步骤 2.1 RCC开启USART和GPIO时钟 2.2 GPIO初始化 2.3 配置USART 2.4 开启USART 2.5 总初始化代码 三、接收数据 3.1 查询方法 3.2 中断方法 3.2.1 中断配置 3.2.2 接收函数 总结 前言 上篇博文介…

网络安全事件管理

一、背景 信息化技术的迅速发展已经极大地改变了人们的生活&#xff0c;网络安全威胁也日益多元化和复杂化。传统的网络安全防护手段难以应对当前繁杂的网络安全问题&#xff0c;构建主动防御的安全整体解决方案将更有利于防范未知的网络安全威胁。 国内外的安全事件在不断增…

AIGC--AIGC与人机协作:新的创作模式

AIGC与人机协作&#xff1a;新的创作模式 引言 人工智能生成内容&#xff08;AIGC&#xff09;正在以惊人的速度渗透到创作的各个领域。从生成文本、音乐、到图像和视频&#xff0c;AIGC使得创作过程变得更加快捷和高效。然而&#xff0c;AIGC并非完全取代了人类的创作角色&am…

【数据结构实战篇】用C语言实现你的私有队列

&#x1f3dd;️专栏&#xff1a;【数据结构实战篇】 &#x1f305;主页&#xff1a;f狐o狸x 在前面的文章中我们用C语言实现了栈的数据结构&#xff0c;本期内容我们将实现队列的数据结构 一、队列的概念 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端…

RHCSA作业

课后练习 将整个 /etc 目录下的文件全部打包并用 gzip 压缩成/back/etcback.tar.gz [rootlocalhost ~]# tar -czvf /back/etcback.tar.gz -C / etc 使当前用户永久生效的命令别名&#xff1a;写一个命令命为hello,实现的功能为每输入一次hello命令&#xff0c;就有hello&#…

java:拆箱和装箱,缓存池概念简单介绍

1.基本数据类型及其包装类&#xff1a; 举例子&#xff1a; Integer i 10; //装箱int n i; //拆箱 概念&#xff1a; 装箱就是自动将基本数据类型转换为包装器类型&#xff1b; 拆箱就是自动将包装器类型转换为基本数据类型&#xff1b; public class Main {public s…

如何选择最适合企业的ETL解决方案?

在今天的大数据时代&#xff0c;企业的数据管理和处理变得愈发重要。企业也越来越依赖于数据仓库和数据湖来提取、转换和加载&#xff08;ETL&#xff09;关键业务信息。一个高效、灵活的ETL解决方案不仅能提升数据处理能力&#xff0c;还能为企业决策提供有力支持。然而&#…

[网鼎杯 2020 朱雀组]phpweb 详细题解(反序列化绕过命令执行)

知识点: call_user_func() 函数 反序列化魔术方法 find命令查找flag 代码审计 打开题目,弹出上面的提示,是一个警告warning,而且页面每隔几秒就会刷新一次,根据warning中的信息以及信息中的时间一直在变,可以猜测是date()函数一直在被调用 查看源代码发现一些信息,但是作用…

数字图像处理(2):Verilog基础语法

&#xff08;1&#xff09;Verilog常见数据类型&#xff1a; reg型、wire型、integer型、parameter型 &#xff08;2&#xff09;Verilog 常见进制&#xff1a;二进制&#xff08;b或B&#xff09;、十进制&#xff08;d或D&#xff09;、八进制&#xff08;o或O&#xff09;、…

c++:面向对象三大特性--继承

面向对象三大特性--继承 一、继承的概念及定义&#xff08;一&#xff09;概念&#xff08;二&#xff09;继承格式1、继承方式2、格式写法3、派生类继承后访问方式的变化 &#xff08;三&#xff09;普通类继承&#xff08;四&#xff09;类模板继承 二、基类和派生类的转换&a…

数据结构 (12)串的存储实现

一、顺序存储结构 顺序存储结构是用一组连续的存储单元来存储串中的字符序列。这种存储方式类似于线性表的顺序存储结构&#xff0c;但串的存储对象仅限于字符。顺序存储结构又可以分为定长顺序存储和堆分配存储两种方式。 定长顺序存储&#xff1a; 使用静态数组存储&#xff…

在线绘制Nature Communication同款双色、四色火山图,突出感兴趣的基因

导读&#xff1a;火山图通常使用三种颜色分别表示显著上调&#xff0c;显著下调和不显著。通过为特定的数据点添加另一种颜色&#xff0c;可以创建双色或四色火山图&#xff0c;从而更直观地突出感兴趣的数据点。 《Nature Communication》文章“Molecular and functional land…

2024赣ctf-web -wp

1.你到底多想要flag??? 首先来解决第一关&#xff1a; 先了解一下stripos&#xff08;&#xff09;&#xff1b; 并且此函数处理数组返回false。而且pre_match同样遇见数组是返回false&#xff08;解释一下正则 i&#xff1a;这是正则表达式的修饰符&#xff0c;代表“不区…

计算机毕业设计Python+大模型美食推荐系统 美食可视化 美食数据分析大屏 美食爬虫 美团爬虫 机器学习 大数据毕业设计 Django Vue.js

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

Linux 查看内核日志的方法

文章目录 1. dmesg 命令一. 介绍内核环形缓冲区的特点 二. 主要功能三. dmesg 使用 2. 查看kmsg文件/dev/kmsg 的用途使用 /dev/kmsg与 dmesg 的关系 3. 内核日志消息的打印行为 1. dmesg 命令 一. 介绍 dmesg&#xff08;display message 或 display driver message 的缩写&…

Perforce SAST专家详解:自动驾驶汽车的安全与技术挑战,Klocwork、Helix QAC等静态代码分析成必备合规性工具

自动驾驶汽车安全吗&#xff1f;现代汽车的软件包含1亿多行代码&#xff0c;支持许多不同的功能&#xff0c;如巡航控制、速度辅助和泊车摄像头。而且&#xff0c;这些嵌入式系统中的代码只会越来越复杂。 随着未来汽车的互联程度越来越高&#xff0c;这一趋势还将继续。汽车越…