大模型微调---Prompt-tuning微调

目录

    • 一、前言
    • 二、Prompt-tuning实战
      • 2.1、下载模型到本地
      • 2.2、加载模型与数据集
      • 2.3、处理数据
      • 2.4、Prompt-tuning微调
      • 2.5、训练参数配置
      • 2.6、开始训练
    • 三、模型评估
    • 四、完整训练代码

一、前言

Prompt-tuning通过修改输入文本的提示(Prompt)来引导模型生成符合特定任务或情境的输出,而无需对模型的全量参数进行微调。
在这里插入图片描述
Prompt-Tuning 高效微调只会训练新增的Prompt的表示层,模型的其余参数全部固定,其核心在于将下游任务转化为预训练任务

在这里插入图片描述
新增的 Prompt 内容可以分为 Hard PromptSoft Prompt 两类:

  • Soft prompt 通常指的是一种较为宽泛或模糊的提示,允许模型在生成结果时有更大的自由度,通常用于启发模型进行创造性的生成;
  • Hard prompt 是一种更为具体和明确的提示,要求模型按照给定的信息生成精确的结果,通常用于需要模型提供准确答案的任务;

Soft Prompt 在 peft 中一般是随机初始化prompt的文本内容,而 Hard prompt 则一般需要设置具体的提示文本内容;

对于不同任务的Prompt的构建示例如下:
在这里插入图片描述

例如,假设我们有兴趣将英语句子翻译成德语。我们可以通过各种不同的方式询问模型,如下图所示。

1)“Translate the English sentence ‘{english_sentence}’ into German: {german_translation}”
2)“English: ‘{english sentence}’ | German: {german translation}”
3)“From English to German:‘{english_sentence}’-> {german_translation}”

上面说明的这个概念被称为硬提示调整

软提示调整(soft prompt tuning)将输入标记的嵌入与可训练张量连接起来,该张量可以通过反向传播进行优化,以提高目标任务的建模性能。

例如下方伪代码:

# 定义可训练的软提示参数
# 假设我们有 num_tokens 个软提示 token,每个 token 的维度为 embed_dim
soft_prompt = torch.nn.Parameter(
    torch.rand(num_tokens, embed_dim)  # 随机初始化软提示向量
)

# 定义一个函数,用于将软提示与原始输入拼接
def input_with_softprompt(x, soft_prompt):
    # 假设 x 的维度为 (batch_size, seq_len, embed_dim)
    # soft_prompt 的维度为 (num_tokens, embed_dim)
    # 将 soft_prompt 在序列维度上与 x 拼接
    # 拼接后的张量维度为 (batch_size, num_tokens + seq_len, embed_dim)
    x = concatenate([soft_prompt, x], dim=seq_len)
    return x

# 将包含软提示的输入传入模型
output = model(input_with_softprompt(x, soft_prompt))

  1. 软提示参数:

使用 torch.nn.Parameter 将随机初始化的向量注册为可训练参数。这意味着在训练过程中,soft_prompt 中的参数会随梯度更新而优化。

  1. 拼接输入:

函数 input_with_softprompt 接收原始输入 x(通常是嵌入后的 token 序列)和 soft_prompt 张量。通过 concatenate(伪代码中使用此函数代指张量拼接操作),将软提示向量沿着序列长度维度与输入拼接在一起。

  1. 传递给模型:

将包含软提示的输入张量传给模型,以引导模型在执行特定任务(如分类、生成、QA 等)时更好地利用这些可训练的软提示向量。

二、Prompt-tuning实战

预训练模型与分词模型——Qwen/Qwen2.5-0.5B-Instruct
数据集——lyuricky/alpaca_data_zh_51k

2.1、下载模型到本地

# 下载数据集
dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
ds = load_dataset("./data/alpaca_data", split="train")

# 下载分词模型
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
tokenizer.save_pretrained("./local_tokenizer_model")

#下载与训练模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径
    torch_dtype="auto",
    low_cpu_mem_usage=True,
    cache_dir="./local_model_cache"  # 指定本地缓存目录
)

2.2、加载模型与数据集

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")

# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")

# 记载模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",
    torch_dtype="auto",
    device_map="cuda:0")

2.3、处理数据

"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    if example["output"] is not None:
        response = tokenizer(example["output"] + tokenizer.eos_token)
    else:
        return
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)

2.4、Prompt-tuning微调

soft Prompt

# Soft Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化

Hard Prompt

# Hard Prompt
prompt = "下面是一段人与机器人的对话。"
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text=prompt,
                            num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),
                            tokenizer_name_or_path="../local_tokenizer_model")

加载peft配置

peft_model = get_peft_model(model, config)

print(peft_model.print_trainable_parameters())

在这里插入图片描述
可以看到要训练的模型相比较原来的全量模型要少很多

2.5、训练参数配置

# 配置模型参数
args = TrainingArguments(
    output_dir="chatbot",   # 训练模型的输出目录
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    num_train_epochs=1,
)

2.6、开始训练

# 创建训练器
trainer = Trainer(
    args=args,
    model=model,
    train_dataset=tokenized_ds,
    data_collator=DataCollatorForSeq2Seq(tokenizer_model, padding=True )
)
# 开始训练
trainer.train()

可以看到 ,损失有所下降

在这里插入图片描述

三、模型评估

# 模型推理
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model = AutoModelForCausalLM.from_pretrained("../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775", low_cpu_mem_usage=True)
peft_model = PeftModel.from_pretrained(model=model, model_id="./chatbot/checkpoint-643")
peft_model = peft_model.cuda()

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")
ipt = tokenizer_model("Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: ", return_tensors="pt").to(
    peft_model.device)
print(tokenizer_model.decode(peft_model.generate(**ipt, max_length=128, do_sample=True)[0], skip_special_tokens=True))

print("-----------------")
#预训练的管道流
# 构建prompt
ipt = "Human: {}\n{}".format("我们如何在日常生活中减少用水?", "").strip() + "\n\nAssistant: "
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer_model)
output = pipe(ipt, max_length=256, do_sample=True, truncation=True)
print(output)

训练了一轮,感觉效果不大,可以增加训练轮数试试
在这里插入图片描述

四、完整训练代码


from datasets import load_dataset
from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model, PeftModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM, TrainingArguments, \
    DataCollatorForSeq2Seq, Trainer

# 下载数据集
# dataset_file = load_dataset("lyuricky/alpaca_data_zh_51k", split="train", cache_dir="./data/alpaca_data")
# ds = load_dataset("./data/alpaca_data", split="train")
# print(ds[0])

# 下载分词模型
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
# Save the tokenizer to a local directory
# tokenizer.save_pretrained("./local_tokenizer_model")

#下载与训练模型
# model = AutoModelForCausalLM.from_pretrained(
#     pretrained_model_name_or_path="Qwen/Qwen2.5-0.5B-Instruct",  # 下载模型的路径
#     torch_dtype="auto",
#     low_cpu_mem_usage=True,
#     cache_dir="./local_model_cache"  # 指定本地缓存目录
# )

#加载分词模型
tokenizer_model = AutoTokenizer.from_pretrained("../local_tokenizer_model")

# 加载数据集
ds = load_dataset("../data/alpaca_data", split="train[:10%]")

# 记载模型
model = AutoModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path="../local_llm_model/models--Qwen--Qwen2.5-0.5B-Instruct/snapshots/7ae557604adf67be50417f59c2c2f167def9a775",
    torch_dtype="auto",
    device_map="cuda:0")


# 处理数据
"""
并将其转换成适合用于模型训练的输入格式。具体来说,
它将原始的输入数据(如用户指令、用户输入、助手输出等)转换为模型所需的格式,
包括 input_ids、attention_mask 和 labels。
"""
def process_func(example, tokenizer=tokenizer_model):
    MAX_LENGTH = 256
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["Human: " + example["instruction"], example["input"]]).strip() + "\n\nAssistant: ")
    if example["output"] is not None:
        response = tokenizer(example["output"] + tokenizer.eos_token)
    else:
        return
    input_ids = instruction["input_ids"] + response["input_ids"]
    attention_mask = instruction["attention_mask"] + response["attention_mask"]
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }


# 分词
tokenized_ds = ds.map(process_func, remove_columns=ds.column_names)




prompt = "下面是一段人与机器人的对话。"

# prompt-tuning
# Soft Prompt
# config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=10) # soft_prompt会随机初始化
# Hard Prompt
config = PromptTuningConfig(task_type=TaskType.CAUSAL_LM, prompt_tuning_init=PromptTuningInit.TEXT,
                            prompt_tuning_init_text=prompt,
                            num_virtual_tokens=len(tokenizer_model(prompt)["input_ids"]),
                            tokenizer_name_or_path="../local_tokenizer_model")

peft_model = get_peft_model(model, config)

print(peft_model.print_trainable_parameters())

# 训练参数

args = TrainingArguments(
    output_dir="./chatbot",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    logging_steps=10,
    num_train_epochs=1
)

# 创建训练器
trainer = Trainer(model=peft_model, args=args, train_dataset=tokenized_ds,
                  data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer_model, padding=True))


# 开始训练
trainer.train()




本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/939174.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

spring学习(spring-bean实例化(实现FactoryBean规范)(延迟实例化bean))

目录 一、spring容器实例化bean的几种方式。 (1)无参构造与有参构造方法。 (2)静态工厂与实例工厂。 (3)实现FactoryBean规范。 二、spring容器使用实现FactoryBean规范方式实现bean实例化。 (1…

【Java】mac安装Java17(JDK17)

文章目录 下载java17一、安装二、环境变量 下载java17 官网下载:https://www.oracle.com/java/technologies/downloads 提示:以下是本篇文章正文内容,下面案例可供参考 一、安装 直接安装后,安装完后目录会存放在下面目录下 /…

mysql免安装版配置教程

一、将压缩包解压至你想要放置的文件夹中,注意:绝对路径中要避免出现中文 二、在解压目录下新建my.ini文件,已经有的就直接覆盖 my.ini文件内容 [mysqld] # 设置3306端口 port3306 # 设置mysql的安装目录 basedirD:\\tools\\mysql-8.1.0-win…

web:pc端企业微信登录-vue版

官方文档:developer.work.weixin.qq.com/document/pa… 不需要调用ww.register,直接调用ww.createWWLoginPanel即可创建企业微信登录面板 - 文档 - 企业微信开发者中心 (qq.com) 引入 //通过 npm 引入 npm install wecom/jssdk import * as ww from we…

基于Spring Boot的无可购物网站系统

一、系统背景与意义 随着互联网的快速发展,电子商务已经成为人们日常生活的重要组成部分。构建一个稳定、高效、可扩展的电商平台后端系统,对于满足用户需求、提升用户体验、推动业务发展具有重要意义。Spring Boot作为当前流行的Java开发框架&#xff…

YOLOv11改进,YOLOv11添加DLKA-Attention可变形大核注意力,WACV2024 ,二次创新C3k2结构

摘要 作者引入了一种称为可变形大核注意力 (D-LKA Attention) 的新方法来增强医学图像分割。这种方法使用大型卷积内核有效地捕获体积上下文,避免了过多的计算需求。D-LKA Attention 还受益于可变形卷积,以适应不同的数据模式。 理论介绍 大核卷积(Large Kernel Convolu…

fpga系列 HDL:Quartus II PLL (Phase-Locked Loop) IP核 (Quartus II 18.0)

在 Quartus II 中使用 PLL (Phase-Locked Loop) 模块来将输入时钟分频或倍频,并生成多个相位偏移或频率不同的时钟信号: 1. 生成 PLL 模块 在 Quartus II 中: 打开 IP Components。 file:///C:/intelFPGA_lite/18.0/quartus/common/help/w…

JAVA:代理模式(Proxy Pattern)的技术指南

1、简述 代理模式(Proxy Pattern)是一种结构型设计模式,用于为其他对象提供一种代理,以控制对这个对象的访问。通过代理模式,我们可以在不修改目标对象代码的情况下扩展功能,满足特定的需求。 设计模式样例:https://gitee.com/lhdxhl/design-pattern-example.git 2、什…

3D Gaussian Splatting for Real-Time Radiance Field Rendering-简洁版

1. 研究背景与问题 传统的3D场景表示方法,如网格和点云,适合GPU加速的光栅化操作,但缺乏灵活性。而基于神经辐射场(NeRF)的表示方式,尽管质量高,但需要高成本的训练和渲染时间。此外&#xff0…

Unity性能优化---使用SpriteAtlas创建图集进行批次优化

在日常游戏开发中,UI是不可缺少的模块,而在UI中又使用着大量的图片,特别是2D游戏还有很多精灵图片存在,如果不加以处理,会导致很高的Batches,影响性能。 比如如下的例子: Batches是9&#xff0…

计算机网络知识点全梳理(三.TCP知识点总结)

目录 TCP基本概念 为什么需要TCP 什么是TCP 什么是TCP链接 如何唯一确定一个 TCP 连接 TCP三次握手 握手流程 为什么是三次握手,而不是两次、四次 为什么客户端和服务端的初始序列号 ISN 不同 既然 IP 层会分片,为什么 TCP 层还需要 MSS TCP四…

找出1000以内的所有回文数

找出1000以内的所有回文数 方法概述检查回文数的方法伪代码C代码实现代码解析运行结果在计算机科学中,回文数是一种具有对称性质的数,即从左向右读和从右向左读都是相同的。例如,121、1331、12321都是回文数。本文将利用数据结构、C语言和算法的知识来编写一个程序,找出100…

Go-FastDFS文件服务器一镜到底使用Docker安装

本文章介绍一镜到底安装go-fastdfs并配置数据文件到linux 由于国内镜像无法安装go-fastdfs:国内环境已经把docker官方的网站给封闭了 我们需要从国外的一台服务器,下载images镜像,然后进行转发加载到国内服务器 当然我也给你们准备了docke…

ArcGIS计算土地转移矩阵

在计算土地转移矩阵时,最常使用的方法就是在ArcGIS中将土地利用栅格数据转为矢量,然后采用叠加分析计算,但这种方法计算效率低。还有一种方法是采用ArcGIS中的栅格计算器,将一个年份的地类编号乘以个100或是1000再加上另一个年份的…

软路由系统 --- VMware安装与配置OpenWRT

目录: OpenWrt安装与配置 一、下载OpenWRT映像二、img转换vmdk格式三、创建虚拟机(OpenWRT)四、启动系统(OpenWRT)五、配置网络信息(LAN、WAN)六、登录OpenWRT后台管理界面 一、下载OpenWRT映像 OpenWrt映…

利用notepad++删除特定关键字所在的行

1、按组合键Ctrl H,查找模式选择 ‘正则表达式’,不选 ‘.匹配新行’ 2、查找目标输入 : ^.*关键字.*\r\n (不保留空行) ^.*关键字.*$ (保留空行)3、替换为:(空) 配置界面参考下图: ​​…

Moretl开箱即用日志采集

永久免费: 至Gitee下载 使用教程: Moretl使用说明 使用咨询: 用途 定时全量或增量采集工控机,电脑文件或日志. 优势 开箱即用: 解压直接运行.不需额外下载.管理设备: 后台统一管理客户端.无人值守: 客户端自启动,自更新.稳定安全: 架构简单,兼容性好,通过授权控制访问. 架…

Java中的全局异常捕获与处理

引言 Java编程中,异常情况是不可避免的,优秀的异常处理机制是保证程序稳定性和可靠性的重要因素,当程序出现异常的时候,如果没有进行适当的处理,可能导致程序崩溃,丢失数据等严重问题。全局异常捕获与处理可…

基于Spring Boot的小区车辆管理系统

一、系统背景与目的 随着城市化进程的加快,小区内的车辆数量急剧增加,车辆管理问题日益凸显。传统的车辆管理方式存在效率低、易出错、信息不透明等问题。为了解决这些问题,基于Spring Boot的小区车辆管理系统应运而生。该系统旨在通过信息化…

Mapbox-GL 的源码解读的一般步骤

Mapbox-GL 是一个非常优秀的二三维地理引擎,随着智能驾驶时代的到来,应用也会越来越广泛,关于mapbox-gl和其他地理引擎的详细对比(比如CesiumJS),后续有时间会加更。地理首先理解 Mapbox-GL 的源码是一项复…