模型微调入门介绍一

  备注:模型微调系列的博客部分内容来源于极客时间大模型微调训练营素材,撰写模型微调一系列博客,主要是期望把训练营的内容内化成自己的知识,我自己写的这一系列博客除了采纳部分训练营的内容外,还会扩展细化某些具体细节知识点。

  模型微调大致会有下面5大步骤,其中数据下载主要用transformers库中的datasets来完成,数据预处理部分会用到tokenizer对象。本篇博客会重点介绍数据加载和数据预处理部分,剩余的三个步骤会通过一个简单的例子来简要介绍,后面会有专门的博客来介绍超参数如何设置和结果评估等内容。

数据下载

  datasets 是由 Hugging Face 提供的一个 Python 库,用于访问和使用大量自然语言处理(NLP)数据集。该库旨在使研究人员和开发人员能够轻松地获取、处理和使用各种 NLP 数据集,从而促进自然语言处理模型的研究和开发。datasets提供的常用function如下图所示:

load_dataset(name, split=None):用于加载指定名称的数据集。可以通过 split 参数指定加载数据集的特定拆分(如 "train"、"validation"、"test" 等)。
list_datasets():列出所有可用的数据集名称。
load_metric(name):加载指定名称的评估指标,用于评估模型性能,后面会有专门的一篇博客进行介绍。
load_from_disk(path) 和 save_to_disk(path, data):用于从磁盘加载数据集或将数据集保存到磁盘。
shuffle(seed=None):用于对数据集进行随机洗牌。可以通过 seed 参数指定随机数生成器的种子。
train_test_split(test_size=0.2, seed=None):用于将数据集拆分为训练集和测试集。

数据预处理

清洗数据

  在进行数据预处理的时候,通常需要分析是否需要进行数据清洗。例如,如果原始数据中存在一些特殊符号需要进行清理,通常会自定义清理方法对原始数据进行清洗。具体demo code如下图所示,具体的clean_text方法需要结合具体的数据进行自定义。

import re
import string

def clean_text(text):
    # 将文本转换为小写
    text = text.lower()
    # 去除标点符号
    text = text.translate(str.maketrans("", "", string.punctuation))
    # 去除数字
    text = re.sub(r'\d+', '', text)
    # 去除多余的空格
    text = re.sub(r'\s+', ' ', text).strip()
    # 处理缩写词,这里只是一个简单的示例
    text = re.sub(r"won't", "will not", text)
    text = re.sub(r"can't", "can not", text)
    # 添加更多的缩写词处理..
    return text

# 示例文本
raw_text = "Hello, how are you? This is an example text with some numbers like 123 and punctuations!!!"

# 进行文本清理
cleaned_text = clean_text(raw_text)

# 输出结果
print("Original Text:")
print(raw_text)
print("\nCleaned Text:")
print(cleaned_text)

Tokenzier进行数据预处理

 除了数据清洗,在做数据预处理的时候,通常会调用tokenizer的方法进行填充、截断等预处理,那么tokenizer具体提供了哪些参数呢?初始化tokenizer对象时,主要有以下参数:

max_length:控制分词后的最大序列长度。文本将被截断或填充以适应这个长度。
truncation:控制是否对文本进行截断,以适应 max_length。可以设置为 True(默认)或 False。
padding:控制是否对文本进行填充,以适应 max_length。
return_tensors:控制返回的结果是否应该是 PyTorch 或 TensorFlow 张量。可以设置为 'pt'、'tf' 或 None(默认)。
add_special_tokens:控制是否添加特殊令牌,如 [CLS]、[SEP] 或 [MASK]。可以设置为 True(默认)或 False。
is_split_into_words:控制输入文本是否已经是分好词的形式。如果设置为 True,分词器将跳过分词步骤。可以设置为 False 或 True(默认)。
return_attention_mask:控制是否返回 attention mask,指示模型在输入序列中哪些标记是有效的。可以设置为 True 或 False(默认)。
return_offsets_mapping:控制是否返回标记的偏移映射,即每个标记在原始文本中的起始和结束位置。可以设置为 True 或 False(默认)。
return_token_type_ids:控制是否返回用于区分文本段的 token type ids。可以设置为 True 或 False(默认)

 以下面的demo code为例,当设置padding=“max_length”后,如果内容长度低于10,会对内容进行自动填充。tokenizer对象返回一个字典类型,包含inputs_ids,token_type_ids,attention_mask。其中inputs_ids是真正的对输入文本的编码,attention_mask用于标记哪些是真正的输入文本转换的内容,哪些是填充内容,标记为0的即为填充内容。

除了上面的字段外,还可以设置是否返回tensor,是否添加特殊标记等。以下面的例子为例,在encode中添加了特殊标记,设置了返回张量,则返回的内容是tensor张量。

from transformers import BertTokenizer

# 初始化 BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# 定义文本
text = "Hello, how are you? I hope everything is going well."

# 使用 tokenizer.tokenize 进行分词
tokens = tokenizer.tokenize(text)
print("Tokens after tokenization:", tokens)

# 使用 tokenizer.encode 将文本编码成模型输入的标识符序列
input_ids = tokenizer.encode(text, max_length=15, truncation=True, padding="max_length", add_special_tokens=True)
print("Input IDs after encoding:", input_ids)

# 使用 tokenizer.decode 将模型输出的标识符序列解码为文本
decoded_text = tokenizer.decode(input_ids)
print("Decoded text:", decoded_text)

# 使用 tokenizer.encode_plus 获取详细的编码结果,包括 attention mask 和 token type ids
encoding_result = tokenizer.encode_plus(text, max_length=15, truncation=True, padding="max_length", add_special_tokens=True, return_tensors="pt")
print("Detailed encoding result:", encoding_result)

 打印出来的结果如下图所示:

在上面调用tokenizer的方法时,有直接调用encode,有调用encode_plus,还有直接初始化tokenizer对象,那么他们之间有什么区别么?

encode与encode_plus的区别

encode方法:该方法用于将输入文本编码转换为模型输入的整数序列(input IDs)。它只返回输入文本的编码结果。
使用场景: 适用于单一文本序列的编码,例如一个问题或一段文本。

encode_plus方法:该方法除了生成整数序列(input IDs)外,还会生成注意力掩码(attention mask)、段落标记(segment IDs)等其他有用的信息,通常用于训练和评估中。返回一个字典,包含编码后的各种信息。
使用场景: 适用于处理多个文本序列,例如一个问题和一个上下文文本。

encode_plus与直接调用tokenizer对象本质上无区别:在 Hugging Face Transformers 库中,直接调用 tokenizer 对象和调用 tokenizer.encode 方法的本质是相同的,都是为了将文本转换为模型可接受的输入标识符序列。这两种方式实际上等效,都是通过 tokenizer 对象的编码方法完成的。

数据处理的具体例子

 在数据预处理过程中,不同的数据类型预处理的步骤不同,以huggingface中的squad数据集和yelp_review_full数据集为例,squad是从上下文context中寻找question的答案。yelp_review_full数据集是对一系列评论以及评论的分数数据。squad用于训练问答系统模型,yelp_review_full用于训练文本分类、情感分类模型。

squad数据集

yelp_review_full数据集

 下面以yelp_review_full为例子,看看如何完成数据预处理与模型微调训练。下面的代码是加载yelp_review_full的数据完成模型的微调。在数据预处理部分,调用tokenizer对象,将truncation设置为true,以及设置了padding="max_length".没有复杂的预处理过程。

from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
import evaluate

# 1. 加载YelpReviewFull数据集
dataset = load_dataset("yelp_review_full")

# 2. 选择并加载BERT模型和标记器
model_name = "bert-base-uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=5)  # num_labels=5表示5种分类任务
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 3. 对原始数据进行标记化
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 4. 定义训练参数
training_args = TrainingArguments(
    output_dir="./yelp_review_model",  # 保存微调模型的目录
    per_device_train_batch_size=8,      # 每个设备的训练批次大小
    evaluation_strategy="steps",        # 在每个 steps 后进行评估
    eval_steps=500,                     # 每 500 个 steps 进行一次评估
    save_steps=500,                     # 每 500 个 steps 保存一次模型
    num_train_epochs=3,                 # 微调的轮数
    logging_dir="./logs"               # 保存训练日志的目录
)

# 5. 定义compute_metrics函数计算准确度
metric = evaluate.load("accuracy")
def compute_metrics(p):
    preds = p.predictions.argmax(axis=1)
    return metric.compute(predictions=preds, references=p.label_ids)

small_train_data=tokenized_datasets["train"].shuffle(seed=42).select(range(5000))
small_test_data=tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
# 6. 定义Trainer对象
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_data,
    eval_dataset=small_test_data,
#     train_dataset=tokenized_datasets["train"],
#     eval_dataset=tokenized_datasets["test"],
    compute_metrics=compute_metrics,   # 使用定义的compute_metrics函数
)

# 7. 微调BERT模型
trainer.train()

# 8. 输出评估结果
results = trainer.evaluate()
print("Results:", results)

因为只选取了部分数据进行训练,正确率是0.632.训练结果如下图所示:

对于用于训练问答系统模型的squad数据,预处理步骤会多一些,所以会在下一篇博客中做专门的介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/267097.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【贪心】单源最短路径Python实现

文章目录 [toc]问题描述Dijkstra算法Dijkstra算法应用示例时间复杂性Python实现 个人主页:丷从心 系列专栏:贪心算法 问题描述 给定一个带权有向图 G ( V , E ) G (V , E) G(V,E),其中每条边的权是非负实数,给定 V V V中的一个…

Packet Tracer -使用 Ping 和 Traceroute测试 网络的连通性

地址分配表 目标 第 1 部分:测试和恢复 IPv4 连通性 第 2 部分:测试和恢复 IPv6 连通性 场景 本练习中存在连通性方面的问题。除了 收集和记录有关网络的信息,您还需要找出 问题,并实施可行的解决方案来恢复网络的连通性。 注意…

c语言:求1/2+2/3+3/4+……n-1/n的和|练习题

一、题目 求1/22/33/4……n-1/n的和 如图&#xff1a; 二、思路分析 1、1/2、2/3、3/4……可以用(i/i1) 2、设置一个函数&#xff0c;求数的相加之和 三、代码截图【带注释】 四、源代码【带注释】 #include <stdio.h> int main() { int num; printf("输入…

【RabbitMQ】RabbitMQ详解(二)

RabbitMQ详解 死信队列死信来源消息TTL过期队列达到最大长度消息被拒绝 RabbitMQ延迟队列TTL的两种设置队列设置TTL消息设置TTL 整合SrpingBoot队列TTL延时队列TTL优化Rabbtimq插件实现延迟队列 死信队列 先从概念解释上搞清楚这个定义&#xff0c;死信&#xff0c;顾名思义就…

S7项目EMS输送线操作

C型钩装置是支撑轨道的挂件,通过和轨道配合可以组成寓任意输送网络。并且可以拆卸和调整。 轨道是承载重物并供载物车行走的部件,它是通过连接装置(支撑件)悬于辅梁或房架上。它分直轨和弯轨两种形式,与道岔配合,能组合成生产工艺所需的任意输送网络。 道岔是载物车沿 轨…

域内定位个人PC的三种方式(1)

会话搜集 在cmd下调用query session命令可以获得当前环境下的windows会话 NetSessionEnum 这个函数不允许直接查询是谁登陆&#xff0c;但是它允许查询是谁在访问此工作站的网络资源时所创建的网络会话&#xff0c;从而知道来自何处&#xff0c;此函数不需要高权限即可查询 第…

UE和Android互相调用

ue和android互调 这两种方式都是在UE打包的Android工程之上进行的。 一、首先是UE打包Android&#xff0c;勾选下面这项 如果有多个场景需要添加场景 工程文件在这个路径下 然后可以通过Android Studio打开&#xff0c;选择gradle打开 先运行一下&#xff0c;看看是否可以发布…

简述用C++实现SIP协议栈

SIP&#xff08;Session Initiation Protocol&#xff0c;会话初始协议&#xff09;是一个基于文本的应用层协议&#xff0c;用于创建、修改和终止多媒体会话&#xff08;如语音、视频、聊天、游戏等&#xff09;中的通信。SIP协议栈是实现SIP协议的一组软件模块&#xff0c;它…

【数据库系统概论】第3章-关系数据库标准语言SQL(1)

文章目录 3.1 SQL概述3.2 学生-课程数据库3.3 数据定义3.3.1 数据库定义3.3.2 模式的定义3.3.3 基本表的定义3.3.4 索引的建立与删除3.3.5 数据字典 3.1 SQL概述 动词 分类 三级模式 3.2 学生-课程数据库 3.3 数据定义 3.3.1 数据库定义 创建数据库 tips&#xff1a;[ ]表…

【数据结构入门精讲 | 第十七篇】一文讲清图及各类图算法

在上一篇中我们进行了的并查集相关练习&#xff0c;在这一篇中我们将学习图的知识点。 目录 概念深度优先DFS伪代码 广度优先BFS伪代码 最短路径算法&#xff08;Dijkstra&#xff09;伪代码 Floyd算法拓扑排序逆拓扑排序 概念 下面介绍几种在对图操作时常用的算法。 深度优先D…

uniapp自定义头部导航怎么实现?

一、在pages.json文件里边写上自定义属性 "navigationStyle": "custom" 二、在对应的index页面写上以下&#xff1a; <view :style"{ height: headheight px, backgroundColor: #24B7FF, zIndex: 99, position: fixed, top: 0px, width: 100% …

一起玩儿物联网人工智能小车(ESP32)——14. 用ESP32的GPIO控制智能小车运动起来(二)

摘要&#xff1a;本文主要讲解如何使用Mixly实现对单一车轮的运动控制。 下面就该用程序控制我们的小车轮子转起来了。打开Mixly软件&#xff0c;然后单击顶部“文件”菜单中的“新建”功能&#xff0c;我们来开启一个新程序的开发工作。 我们的工作同样是先从最简单的开始&am…

设计模式分类

不同设计模式的复杂程度、 细节层次以及在整个系统中的应用范围等方面各不相同。 我喜欢将其类比于道路的建造&#xff1a; 如果你希望让十字路口更加安全&#xff0c; 那么可以安装一些交通信号灯&#xff0c; 或者修建包含行人地下通道在内的多层互通式立交桥。 最基础的、 底…

视频编码码率控制

什么是码率控制 码率控制是编码器的一个重要模块&#xff0c;主要的作用就是用算法来控制编码器输出码流的大小。虽然它是编码器的一个非常重要的部分&#xff0c;但是它并不是编码标准的一部分&#xff0c;也就是说&#xff0c;标准并没有给码控设定规则。我们平时用的编码器…

50 个具有挑战性的概率问题 [04/50]:尝试直至首次成功

一、说明 你好&#xff0c;我最近对与概率相关的问题产生了兴趣。我偶然发现了 Frederick Mosteller 所著的《五十个具有挑战性的概率问题及其解决方案》这本书。我认为创建一个系列来讨论这些可能作为面试问题出现的迷人问题会很有趣。每篇文章仅包含 1 个问题&#xff0c;使其…

基于python的excel检查和读写软件

软件版本&#xff1a;python3.6 窗口和界面gui代码&#xff1a; class mygui:def _init_(self):passdef run(self):root Tkinter.Tk()root.title(ExcelRun)max_w, max_h root.maxsize()root.geometry(f500x500{int((max_w - 500) / 2)}{int((max_h - 300) / 2)}) # 居中显示…

Python学习路线 - Python语言基础入门 - Python基础综合案例 - 数据可视化 - 动态柱状图

Python学习路线 - Python语言基础入门 - Python基础综合案例 - 数据可视化 - 动态柱状图 基础柱状图构建案例效果通过Bar构建基础柱状图反转x和y轴数值标签在右侧 基础时间线柱状图绘制创建时间线创建时间线自动播放时间线设置主题 动态GDP柱状图绘制需求分析列表的sort方法带名…

分巧克力c语言

分析&#xff1a;分巧克力&#xff0c;把每一种大小列举出来&#xff0c;在对巧克力分解&#xff0c;在加上所以的分解块数&#xff0c;在和人数比较&#xff0c;如果够分&#xff0c;就保存这一次的结果&#xff0c;在增大巧克力&#xff0c;如果不够分了&#xff0c;就打印上…

「Verilog学习笔记」并串转换

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 串并转换操作是非常灵活的操作&#xff0c;核心思想就是移位。串转并就是把1位的输入放到N位reg的最低位&#xff0c;然后N位reg左移一位&#xff0c;在把1位输入放到左移后…

【并发设计模式】聊聊两阶段终止模式如何优雅终止线程

在软件设计中&#xff0c;抽象出了23种设计模式&#xff0c;用以解决对象的创建、组合、使用三种场景。在并发编程中&#xff0c;针对线程的操作&#xff0c;也抽象出对应的并发设计模式。 两阶段终止模式- 优雅停止线程避免共享的设计模式- 只读、Copy-on-write、Thread-Spec…