使用LORA微调RoBERTa

模型微调是指在一个已经训练好的模型的基础上,针对特定任务或者特定数据集进行再次训练以提高性能的过程。微调可以在使其适应特定任务时产生显着的结果。

RoBERTa(Robustly optimized BERT approach)是由Facebook AI提出的一种基于Transformer架构的预训练语言模型。它是对Google提出的BERT(Bidirectional Encoder Representations from Transformers)模型的改进和优化。

“Low-Rank Adaptation”(低秩自适应)是一种用于模型微调或迁移学习的技术。一般来说我们只是使用LORA来微调大语言模型,但是其实只要是使用了Transformers块的模型,LORA都可以进行微调,本文将介绍如何利用🤗PEFT库,使用LORA提高微调过程的效率。

LORA可以大大减少了可训练参数的数量,节省了训练时间、存储和计算成本,并且可以与其他模型自适应技术(如前缀调优)一起使用,以进一步增强模型。

但是,LORA会引入额外的超参数调优层(特定于LORA的秩、alpha等)。并且在某些情况下,性能不如完全微调的模型最优,这个需要根据不同的需求来进行测试。

首先我们安装需要的包:

 !pip install transformers datasets evaluate accelerate peft

数据预处理

 import torch
 from transformers import RobertaModel, RobertaTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
 from peft import LoraConfig, get_peft_model
 from datasets import load_dataset
 
 
 
 peft_model_name = 'roberta-base-peft'
 modified_base = 'roberta-base-modified'
 base_model = 'roberta-base'
 
 dataset = load_dataset('ag_news')
 tokenizer = RobertaTokenizer.from_pretrained(base_model)
 
 def preprocess(examples):
     tokenized = tokenizer(examples['text'], truncation=True, padding=True)
     return tokenized
 
 tokenized_dataset = dataset.map(preprocess, batched=True,  remove_columns=["text"])
 train_dataset=tokenized_dataset['train']
 eval_dataset=tokenized_dataset['test'].shard(num_shards=2, index=0)
 test_dataset=tokenized_dataset['test'].shard(num_shards=2, index=1)
 
 
 # Extract the number of classess and their names
 num_labels = dataset['train'].features['label'].num_classes
 class_names = dataset["train"].features["label"].names
 print(f"number of labels: {num_labels}")
 print(f"the labels: {class_names}")
 
 # Create an id2label mapping
 # We will need this for our classifier.
 id2label = {i: label for i, label in enumerate(class_names)}
 
 data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")

训练

我们训练两个模型,一个使用LORA,另一个使用完整的微调流程。这里可以看到LORA的训练时间和训练参数的数量能减少多少

以下是使用完整微调

 
 training_args = TrainingArguments(
     output_dir='./results',
     evaluation_strategy='steps',
     learning_rate=5e-5,
     num_train_epochs=1,
     per_device_train_batch_size=16,
 )
 

然后进行训练:

 def get_trainer(model):
       return  Trainer(
           model=model,
           args=training_args,
           train_dataset=train_dataset,
           eval_dataset=eval_dataset,
           data_collator=data_collator,
       )
 full_finetuning_trainer = get_trainer(
     AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label),
 )
 
 full_finetuning_trainer.train()

下面看看PEFT的LORA

 model = AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label)
 
 peft_config = LoraConfig(task_type="SEQ_CLS", inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1)
 peft_model = get_peft_model(model, peft_config)
 
 print('PEFT Model')
 peft_model.print_trainable_parameters()
 
 peft_lora_finetuning_trainer = get_trainer(peft_model)
 
 peft_lora_finetuning_trainer.train()
 peft_lora_finetuning_trainer.evaluate()

可以看到

模型参数总计:125,537,288,而LORA模型的训练参数为:888,580,我们只需要用LORA训练~0.70%的参数!这会大大减少内存的占用和训练时间。

在训练完成后,我们保存模型:

 tokenizer.save_pretrained(modified_base)
 peft_model.save_pretrained(peft_model_name)

最后测试我们的模型

 from peft import AutoPeftModelForSequenceClassification
 from transformers import AutoTokenizer
 
 # LOAD the Saved PEFT model
 inference_model = AutoPeftModelForSequenceClassification.from_pretrained(peft_model_name, id2label=id2label)
 tokenizer = AutoTokenizer.from_pretrained(modified_base)
 
 
 def classify(text):
   inputs = tokenizer(text, truncation=True, padding=True, return_tensors="pt")
   output = inference_model(**inputs)
 
   prediction = output.logits.argmax(dim=-1).item()
 
   print(f'\n Class: {prediction}, Label: {id2label[prediction]}, Text: {text}')
   # return id2label[prediction]
 
 classify( "Kederis proclaims innocence Olympic champion Kostas Kederis today left hospital ahead of his date with IOC inquisitors claiming his ...")

 classify( "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.")

模型评估

我们还需要对PEFT模型的性能与完全微调的模型的性能进行对比,看看这种方式有没有性能的损失

 from torch.utils.data import DataLoader
 import evaluate
 from tqdm import tqdm
 
 metric = evaluate.load('accuracy')
 
 def evaluate_model(inference_model, dataset):
 
     eval_dataloader = DataLoader(dataset.rename_column("label", "labels"), batch_size=8, collate_fn=data_collator)
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
     inference_model.to(device)
     inference_model.eval()
     for step, batch in enumerate(tqdm(eval_dataloader)):
         batch.to(device)
         with torch.no_grad():
             outputs = inference_model(**batch)
         predictions = outputs.logits.argmax(dim=-1)
         predictions, references = predictions, batch["labels"]
         metric.add_batch(
             predictions=predictions,
             references=references,
         )
 
     eval_metric = metric.compute()
     print(eval_metric)
     

首先是没有进行微调的模型,也就是原始模型

 evaluate_model(AutoModelForSequenceClassification.from_pretrained(base_model, id2label=id2label), test_dataset)

accuracy: 0.24868421052631579‘

下面是LORA微调模型

 evaluate_model(inference_model, test_dataset)

accuracy: 0.9278947368421052

最后是完全微调的模型:

 evaluate_model(full_finetuning_trainer.model, test_dataset)

accuracy: 0.9460526315789474

总结

我们使用PEFT对RoBERTa模型进行了微调和评估,可以看到使用LORA进行微调可以大大减少训练的参数和时间,但是在准确性方面还是要比完整的微调要稍稍下降。

本文代码:

https://avoid.overfit.cn/post/26e401b70f9840dab185a6a83aac06b0

作者:Achilles Moraites

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/385490.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

单片机学习笔记---DS18B20温度传感器

目录 DS18B20介绍 模拟温度传感器的基本结构 数字温度传感器的应用 引脚及应用电路 DS18B20的原理图 DS18B20内部结构框图 暂存器内部 单总线介绍 单总线电路规范 单总线时序结构 初始化 发送一位 发送一个字节 接收一位 接收一个字节 DS18B20操作流程 指令介…

L2-002 链表去重

一、题目 二、解题思路 结构体数组的下标表示该节点的地址,value 表示该节点的值,next 表示下一个结点的地址。result1 表示去重后的链表的节点的地址,result2 表示被删除的链表的节点的地址。 flag 表示节点对应的值是否出现过,…

第二节:轻松玩转书生·浦语大模型趣味Demo

参考教程:https://github.com/InternLM/tutorial/blob/main/helloworld/hello_world.md InternLM-Chat-7B 智能对话 Demo 终端运行 web demo 运行 1.首先启动服务: cd /root/code/InternLM streamlit run web_demo.py --server.address 127.0.0.1 --…

Python爬虫之文件存储#5

爬虫专栏:http://t.csdnimg.cn/WfCSx 文件存储形式多种多样,比如可以保存成 TXT 纯文本形式,也可以保存为 JSON 格式、CSV 格式等,本节就来了解一下文本文件的存储方式。 TXT 文本存储 将数据保存到 TXT 文本的操作非常简单&am…

Stable Diffusion 模型下载:DreamShaper XL(梦想塑造者 XL)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十 下载地址 模型介绍 DreamShaper 是一个分格多样的大模型,可以生成写实、原画、2.5D 等…

给定长度为n的数组b,求对于任意1<=l<=r<=n, 求b[i] + b[j] + b[k] - (r - l) 的最大值(l<=i, j, k<=r)

题目 思路: #include <bits/stdc++.h> using namespace std; #define int long long #define pb push_back #define fi first #define se second #define lson p << 1 #define rson p << 1 | 1 const int maxn = 1e6 + 5, inf = 1e18 + 5, maxm = 4e4 + 5,…

leetcode(双指针)11.盛最多水的容器(C++详细解释)DAY9

文章目录 1.题目示例提示 2.解答思路3.实现代码结果 4.总结 1.题目 给定一个长度为 n 的整数数组 height 。有 n 条垂线&#xff0c;第 i 条线的两个端点是 (i, 0) 和 (i, height[i]) 。 找出其中的两条线&#xff0c;使得它们与 x 轴共同构成的容器可以容纳最多的水。 返回…

利用pandas读取MongoDB库中的数据

下方代码的主要目的是从MongoDB数据库中获取数据&#xff0c;并使用pandas库将其转换为DataFrame。 # codingutf-8 from pymongo import MongoClient import pandas as pd# 创建MongoDB客户端连接 client MongoClient() # 选择数据库douban中的集合tv1 collection client[do…

java之jvm详解

JVM内存结构 程序计数器 Program Counter Register程序计数器(寄存器) 程序计数器在物理层上是通过寄存器实现的 作用&#xff1a;记住下一条jvm指令的执行地址特点 是线程私有的(每个线程都有属于自己的程序计数器)不会存在内存溢出 虚拟机栈(默认大小为1024kb) 每个线…

每日五道java面试题之java基础篇(七)

第一题. HashMap和HashTable有什么区别&#xff1f;其底层实现是什么&#xff1f; 区别 &#xff1a; HashMap⽅法没有synchronized修饰&#xff0c;线程⾮安全&#xff0c;HashTable线程安全&#xff1b;HashMap允许key和value为null&#xff0c;⽽HashTable不允许 底层实现…

【软件工程导论】实验六——建立系统对象模型(自助点餐系统)

需求描述 自助点餐系统是一站式解决预约订桌、点餐、上菜、收银等一系列餐厅经营问题的系统。 顾客在系统中填写个人信息、联系方式等信息进行用户注册。进入系统后顾客可根据餐桌特点、人数、可约时间等信息进行餐桌的预订与选择。就餐时&#xff0c;根据系统提供的菜单进行…

python基于flask的网上订餐系统769b9-django+vue

课题主要分为两大模块&#xff1a;即管理员模块和用户模块&#xff0c;主要功能包括个人中心、用户管理、菜品类型管理、菜品信息管理、留言反馈、在线交流、系统管理、订单管理等&#xff1b; 如果用户想要交换信息&#xff0c;他们需要满足双方交换信息的需要。由于时间有限…

与本地渲染相比,云渲染有哪些优势?渲染100邀请码1a12

与本地渲染相比&#xff0c;云渲染有以下几个优势&#xff1a; 1、速度快 云渲染可以利用分布式计算和并行处理技术&#xff0c;将一个大型渲染任务分割成多个小任务&#xff0c;分配给不同服务器同时执行&#xff0c;从而缩短渲染时间。2、质量高 云渲染能提供更大、更精致和…

ASCII码和EASCII码对照表

ASCII ASCII&#xff0c;是American Standard Code for Information Interchange的缩写&#xff0c; 是基于拉丁字母的一套电脑编码系统。它主要用于显示现代英语。ASCII的局限在于只能显示26个基本拉丁字母、阿拉伯数字和英式标点符号&#xff0c;因此只能用于显示现代美国英语…

服务器出现问题该怎么办?

在我们日常使用服务器的过程中&#xff0c;经常会有遇到服务器出现各种各样问题&#xff0c;服务器出错的原因有很多种&#xff0c;常见的包括系统问题、软件问题、硬件问题和网络问题。今天德迅云安全就来介绍几种比较常见的情况。 一、 服务器出现蓝屏、死机可能的原因&#…

Netty应用(十) 之 自定义编解码器 自定义通信协议

目录 25.自定义编解码器 25.1 自定义编解码器编码 25.2 自定义编解码器的总结和补充 26.自定义通信协议 26.1 关于通信协议的关注点 26.2 自定义通信协议的格式 26.3 编解码 25.自定义编解码器 有了上面这个大体框架的流程之后&#xff0c;我们来聊一个非常特殊的&#x…

用脑想问题还是用心驱动脑?

昨天回答了几个朋友的问题&#xff0c;我发现提问题的人很少&#xff0c;这让我想起之前讲的小妞子的故事&#xff0c;我问了她好几个月的同一句话&#xff1a;你有问题吗&#xff1f; 结果她很反感&#xff0c;嘿嘿。其实吧&#xff0c;我讲的很多东西都是实的&#xff0c;反而…

【新手必看】解决GitHub打不开问题,亲测有效

&#x1f44b; Hi, I’m 货又星&#x1f440; I’m interested in …&#x1f331; I’m currently learning …&#x1f49e; I’m looking to collaborate on …&#x1f4eb; How to reach me … README 目录&#xff08;持续更新中&#xff09; 各种错误处理、爬虫实战及模…

冰雪遮盖着伏尔加河

三套车 - 杨洪基词&#xff1a;李幼客 曲&#xff1a;彼得格鲁波基 冰雪遮盖着伏尔加河 冰河上跑着三套车 有人在唱着忧郁的歌 唱歌的是那赶车的人小伙子你为什么忧愁 为什么低着你的头是谁叫你这样伤心 问他的是那乘车的人 你看吧这匹可怜的老马 它跟我走遍天涯可恨那财主要把…

MQTT的学习与应用

文章目录 一、什么是MQTT二、MQTT协议特点三、MQTT应用领域四、安装Mosquitto五、如何学习 MQTT 一、什么是MQTT MQTT&#xff08;Message Queuing Telemetry Transport&#xff09;是一种轻量级的消息传输协议&#xff0c;设计用于在低带宽、不稳定的网络环境中进行高效的通信…