昇思大模型平台打卡体验活动:项目4基于MindSpore实现Roberta模型Prompt Tuning

基于MindNLP的Roberta模型Prompt Tuning

本文档介绍了如何基于MindNLP进行Roberta模型的Prompt Tuning,主要用于GLUE基准数据集的微调。本文提供了完整的代码示例以及详细的步骤说明,便于理解和复现实验。

环境配置

在运行此代码前,请确保MindNLP库已经安装。本文档基于大模型平台运行,因此需要进行适当的环境配置,确保代码可以在相应的平台上运行。

模型与数据集加载

在本案例中,我们使用 roberta-large 模型并基于GLUE基准数据集进行Prompt Tuning。GLUE (General Language Understanding Evaluation) 是自然语言处理中的标准评估基准,包括多个子任务,如句子相似性匹配、自然语言推理等。Prompt Tuning是一种新的微调技术,通过插入虚拟的“提示”Token在模型的输入中,以微调较少的参数达到较好的性能。

import mindspore
from tqdm import tqdm
from mindnlp import evaluate
from mindnlp.dataset import load_dataset
from mindnlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
from mindnlp.core.optim import AdamW
from mindnlp.transformers.optimization import get_linear_schedule_with_warmup
from mindnlp.peft import (
    get_peft_model,
    PeftType,
    PromptTuningConfig,
)

1. 定义训练参数

首先,定义模型名称、数据集任务名称、Prompt Tuning类型、训练轮数等基本参数。

batch_size = 32
model_name_or_path = "roberta-large"
task = "mrpc"
peft_type = PeftType.PROMPT_TUNING
num_epochs = 20

2. 配置Prompt Tuning

在Prompt Tuning的配置中,选择任务类型为"SEQ_CLS"(序列分类任务),并定义虚拟Token的数量。虚拟Token即为插入模型输入中的“提示”Token,通过这些Token的微调,使得模型能够更好地完成下游任务。

peft_config = PromptTuningConfig(task_type="SEQ_CLS", num_virtual_tokens=10)
lr = 1e-3

3. 加载Tokenizer

根据模型类型选择padding的侧边,如果模型为GPT、OPT或BLOOM类模型,则从序列左侧填充(padding),否则从序列右侧填充。

if any(k in model_name_or_path for k in ("gpt", "opt", "bloom")):
    padding_side = "left"
else:
    padding_side = "right"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)
if getattr(tokenizer, "pad_token_id") is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

4. 加载数据集

通过MindNLP加载GLUE数据集,并打印样本以便确认数据格式。在此示例中,我们使用GLUE的MRPC(Microsoft Research Paraphrase Corpus)任务,该任务用于句子匹配,即判断两个句子是否表达相同的意思。

datasets = load_dataset("glue", task)
print(next(datasets['train'].create_dict_iterator()))

5. 数据预处理

为了适配MindNLP的数据处理流程,我们定义了一个映射函数 MapFunc,用于将句子转换为 input_idsattention_mask,并对数据进行padding处理。

from mindnlp.dataset import BaseMapFunction

class MapFunc(BaseMapFunction):
    def __call__(self, sentence1, sentence2, label, idx):
        outputs = tokenizer(sentence1, sentence2, truncation=True, max_length=None)
        return outputs['input_ids'], outputs['attention_mask'], label

def get_dataset(dataset, tokenizer):
    input_colums=['sentence1', 'sentence2', 'label', 'idx']
    output_columns=['input_ids', 'attention_mask', 'labels']
    dataset = dataset.map(MapFunc(input_colums, output_columns),
                          input_colums, output_columns)
    dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'attention_mask': (None, 0)})
    return dataset

train_dataset = get_dataset(datasets['train'], tokenizer)
eval_dataset = get_dataset(datasets['validation'], tokenizer)

6. 设置评估指标

我们使用 evaluate 模块加载评估指标(accuracy 和 F1-score)来评估模型的性能。

metric = evaluate.load("./glue.py", task)

7. 加载模型并配置Prompt Tuning

加载 roberta-large 模型,并根据配置进行Prompt Tuning。可以看到,微调的参数量仅为总参数量的0.3%左右,节省了大量计算资源。

model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

模型微调(Prompt Tuning)

在Prompt Tuning中,训练过程中仅微调部分参数(主要是虚拟Token相关的参数),相比于传统微调而言,大大减少了需要调整的参数量,使得模型能够高效适应下游任务。

1. 优化器与学习率调整

使用 AdamW 优化器,并设置线性学习率调整策略。

optimizer = AdamW(params=model.parameters(), lr=lr)

# Instantiate scheduler
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0.06 * (len(train_dataset) * num_epochs),
    num_training_steps=(len(train_dataset) * num_epochs),
)

2. 训练逻辑定义

训练步骤如下:

  1. 构建正向计算函数 forward_fn
  2. 定义梯度计算函数 grad_fn
  3. 定义每一步的训练逻辑 train_step
  4. 遍历数据集进行训练和评估,在每个 epoch 结束时,计算评估指标。
def forward_fn(**batch):
    outputs = model(**batch)
    loss = outputs.loss
    return loss

grad_fn = mindspore.value_and_grad(forward_fn, None, tuple(model.parameters()))

def train_step(**batch):
    loss, grads = grad_fn(**batch)
    optimizer.step(grads)
    return loss

for epoch in range(num_epochs):
    model.set_train()
    train_total_size = train_dataset.get_dataset_size()
    for step, batch in enumerate(tqdm(train_dataset.create_dict_iterator(), total=train_total_size)):
        loss = train_step(**batch)
        lr_scheduler.step()

    model.set_train(False)
    eval_total_size = eval_dataset.get_dataset_size()
    for step, batch in enumerate(tqdm(eval_dataset.create_dict_iterator(), total=eval_total_size)):
        outputs = model(**batch)
        predictions = outputs.logits.argmax(axis=-1)
        predictions, references = predictions, batch["labels"]
        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    eval_metric = metric.compute()
    print(f"epoch {epoch}:", eval_metric)

在每个 epoch 后,程序输出当前模型的评估指标(accuracy 和 F1-score)。从结果中可以看到,模型的准确率和 F1-score 会随着训练的进展逐渐提升。
7797b4532920b53cb41371e07cfa81c6.png
7797b4532920b53cb41371e07cfa81c6.png

总结

本案例通过Prompt Tuning技术,在Roberta模型上进行了微调以适应GLUE数据集任务。通过控制微调参数量,Prompt Tuning展示了较强的高效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/914787.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL】数据库表连接简明解释

未经许可,不得转载。 文章目录 表连接表连接的类型内连接与外连接结合 WHERE 条件交叉连接(cross join)表连接 在关系型数据库中,建模是数据组织的核心难点。数据库建模需要将数据关系理清,构建出适合存储和查询的结构。 所谓“模型”包括实体(entity) 和关系(relati…

离线安装GDAL与MapServer:在银河麒麟V10上的快速指南

✅作者简介:2022年博客新星 第八。热爱国学的Java后端开发者,修心和技术同步精进。 🍎个人主页:Java Fans的博客 🍊个人信条:不迁怒,不贰过。小知识,大智慧。 ✨特色专栏&#xff1a…

17.UE5丰富怪物、结构体、数据表、构造函数

2-19 丰富怪物,结构体、数据表格、构造函数_哔哩哔哩_bilibili 目录 1.结构体和数据表格 2.在构造函数中初始化怪物 3.实现怪物是否游荡 1.结构体和数据表格 创建蓝图:结构体蓝图 在结构体蓝图中添加变量,如下所示,为了实现不…

Kafka 快速入门(一)

1.1安装部署 1.1.1 集群规划 bigdata01bigdata02bigdata03zookeeperzookeeperzookeeperkafkakafkakafka 1.1.2 集群部署 官方下载地址:http://kafka.apache.org/downloads.html 检查三台虚拟机的zk是否启动:zkServer.sh start 默认启动方式 1)解压…

零件图纸的技术要求及标注

1零件的技术要求 零件在加工、检验时的各项技术要求,通常是指表面粗糙度、尺寸公差、形状和位置公差,材料的热处理及表面处理等。 2尺寸公差与配合 1、零件的互换性&定义、作用 在按规定要求大量制造的零件或部件中,任取一个&#xff0…

Python 的 Pygame 库,编写简单的 Flappy Bird 游戏

Pygame 是一个用 Python 编写的开源游戏开发框架,专门用于编写 2D 游戏。它提供了丰富的工具和功能,使得开发者能够快速实现游戏中的图形渲染、声音播放、输入处理和动画效果等功能。Pygame 非常适合初学者和想要快速创建游戏原型的开发者。 Pygame 的主…

【缓存策略】你知道 Cache Aside(缓存旁路)这个缓存策略吗

👉博主介绍: 博主从事应用安全和大数据领域,有8年研发经验,5年面试官经验,Java技术专家,WEB架构师,阿里云专家博主,华为云云享专家,51CTO 专家博主 ⛪️ 个人社区&#x…

2024版最新kali linux 新手教程(非常详细)零基础入门到精通,收藏这篇就够了

您是否有兴趣使用 Kali Linux,但不知道从哪里开始?您来对地方了。 Kali Linux 是一个强大的渗透测试和道德黑客工具,提供许多工具和资源。 本 Kali Linux 教程将向您展示如何下载和安装它、解释桌面并强调您应该知道的关键领域。接下来&…

Android JNI 技术入门指南

引言 在Android开发中,Java是一种主要的编程语言,然而,对于一些性能要求较高的场景(如音视频处理、图像处理、计算密集型任务等),我们可能需要使用到C或C等语言来编写底层的高效代码。为了实现Java代码与C…

国标GB28181视频平台EasyCVR私有化部署视频平台对接监控录像机NVR时,录像机“资源不足”是什么原因?

EasyCVR视频融合云平台,是TSINGSEE青犀视频“云边端”架构体系中的“云平台”系列之一,是一款针对大中型项目设计的跨区域、网络化、视频监控综合管理系统平台,通过接入视频监控设备及视频平台,实现视频数据的集中汇聚、融合管理、…

HarmonyOS NEXT:模块化项目 ——修改应用图标+启动页等

涉及官方文档 应用配置文件应用/组件级配置图标资源规范 涉及到app.json5配置文件和module.json5配置文件 1、 icon和label的校验。 IDE从5.0.3.800版本开始,不再对module.json5中的icon和label做强制校验,因此module.json5与app.json5只需要选择其一…

产品经理晋级-Axure中继器+动态面板制作美观表格

步骤如下: 将你的表格(制作好的表格复制) 在工作页面中,添加动态面板,并把刚才复制的表格添加进来

java 面向对象高级

1.final关键字 class Demo{public static void main(String[] args) {final int[] anew int[]{1,2,3};// anew int[]{4,5,6}; 报错a[0]5;//可以,解释了final修饰引用性变量,变量存储的地址不能被改变,但地址所指向的对象的内容可以改变} }什…

计算机网络:运输层 —— 运输层端口号

文章目录 运输层端口号的分类端口号与应用程序的关联应用举例发送方的复用和接收方的分用 运输层端口号的分类 端口号只具有本地意义,即端口号只是为了标识本计算机网络协议栈应用层中的各应用进程。在因特网中不同计算机中的相同端口号是没有关系的,即…

echarts引入自定义字体不起作用问题记录

echarts引入自定义字体不起作用问题记录 1、问题描述 初始化界面字体不作用,当界面更新后字体样式正常显示 2、原因描述 这通常是由于字体文件加载延迟导致的。ECharts 在初始化时可能还没有加载完字体文件,因此无法正确应用字体样式 3、解决方案 …

AscendC从入门到精通系列(一)初步感知AscendC

1 什么是AscendC Ascend C是CANN针对算子开发场景推出的编程语言,原生支持C和C标准规范,兼具开发效率和运行性能。基于Ascend C编写的算子程序,通过编译器编译和运行时调度,运行在昇腾AI处理器上。使用Ascend C,开发者…

JavaScript——函数、事件与BOM对象

一、系统函数(JS中预置的函数) JS的预置函数在遇到非数字字符时会停止解析 parseInt 转整型 parseFloat 转浮点型 isNaN !isNaN("10") 检测是否纯数字 eval 把字符串转成算式并计算 1.parseInt(string, radix); 语法: string&#x…

Python学习从0到1 day28 Python 高阶技巧 ⑤ 多线程

若事与愿违,请相信,上天自有安排,允许一切如其所是 —— 24.11.12 一、进程、线程 现代操作系统比如Mac OS X,UNIX,Linux,Windows等,都是支持“多任务”的操作系统。 进程 进程:就…

OpenCV视觉分析之目标跟踪(11)计算两个图像之间的最佳变换矩阵函数findTransformECC的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 根据 ECC 标准 78找到两幅图像之间的几何变换(warp)。 该函数根据 ECC 标准 ([78]) 估计最优变换(warpMatri…

彻底解决单片机BootLoader升级程序失败问题

文章目录 1、引言2、MicroBoot:优雅的解决升级问题问题1:bootloader 在跳转到app前没有清理干净存在的痕迹问题2: 需要 APP 传递信息给 Bootloader问题3: APP单独运行没有问题,通过Bootloader跳转到APP运行莫名死机问题…