从零学习大模型(九)-----P-Tuning(下)

代码展示P-Tuning的全过程

import torch
from torch import nn
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# 1. 数据准备
dataset = load_dataset("imdb")

# 2. 构建提示
def add_prompt(examples):
    examples['text'] = ["这段文本的情感是:'{}'".format(text) for text in examples['text']]
    return examples

dataset = dataset.map(add_prompt)

# 3. 模型选择
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 4. 添加可训练的嵌入向量
class PromptEmbedding(nn.Module):
    def __init__(self, prompt_length, embedding_dim):
        super(PromptEmbedding, self).__init__()
        self.prompt_embedding = nn.Parameter(torch.randn(prompt_length, embedding_dim))

    def forward(self, x):
        prompt = self.prompt_embedding.unsqueeze(0).repeat(x.size(0), 1, 1)  # 扩展到batch大小
        return torch.cat((prompt, x), dim=1)

# 定义新模型
class P_Tuning_BERT(nn.Module):
    def __init__(self, base_model, prompt_length):
        super(P_Tuning_BERT, self).__init__()
        self.base_model = base_model
        self.prompt_embedding = PromptEmbedding(prompt_length, base_model.bert.config.hidden_size)

    def forward(self, input_ids, attention_mask=None, labels=None):
        # 获取原始的输入嵌入
        embeddings = self.base_model.bert.embeddings(input_ids)
        # 添加prompt嵌入
        embeddings = self.prompt_embedding(embeddings)
        outputs = self.base_model.bert(inputs_embeds=embeddings, attention_mask=attention_mask)
        logits = self.base_model.classifier(outputs[1])  # 只取池化输出
        return (logits,)

# 设置P-Tuning模型
prompt_length = 5  # Prompt的长度
p_tuning_model = P_Tuning_BERT(model, prompt_length)

# 冻结原模型参数
for param in p_tuning_model.base_model.parameters():
    param.requires_grad = False

# 5. 数据预处理
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# 6. 微调过程
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
)

trainer = Trainer(
    model=p_tuning_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
)

# 7. 训练模型
trainer.train()

# 8. 测试模型
trainer.evaluate()

# 9. 应用模型
def predict(text):
    p_tuning_model.eval()
    inputs = tokenizer("这段文本的情感是:'{}'".format(text), return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = p_tuning_model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
    logits = outputs[0]
    predicted_class = torch.argmax(logits, dim=-1)
    return "积极" if predicted_class.item() == 1 else "消极"

# 测试应用
print(predict("这家餐厅的服务很好。"))

P-Tuning的实验结果

文本分类任务

  • P-Tuning:在小数据集上,F1-score为0.85,训练时间为1小时。
  • 全参数微调:在相同数据集上,F1-score为0.88,训练时间为3小时,但在验证集上过拟合(F1-score为0.80)。

对话生成任务

  • P-Tuning:生成的回复自然性评分为4.2/5,训练时间为2小时。
  • 全参数微调:生成的回复自然性评分为4.5/5,但训练时间为5小时。

P-Tuning的优点

1. 计算效率高

  • 参数更新少:P-Tuning仅更新与提示相关的嵌入向量,减少了训练过程中需要优化的参数数量。这意味着在同样的计算资源下,可以更快速地进行实验和模型调整。

2. 减少过拟合风险

  • 冻结预训练模型的参数:通过冻结大部分模型参数,P-Tuning降低了在小数据集上过拟合的风险。对于数据量有限的任务,P-Tuning能够更好地泛化。

3. 灵活性和适应性强

  • 任务适应性:可以通过简单地调整提示内容来适应不同的任务,无需修改整个模型架构。这使得在多任务场景中,P-Tuning能够快速切换和调整。
  • Prompt设计自由:研究者可以根据具体任务设计不同的提示,以探索对模型性能的影响。这种灵活性允许在多个任务之间共享同一模型,而只需修改提示。

4. 易于实现和部署

  • 实现简单:相较于全参数微调,P-Tuning的实现更加简便,尤其是在不需要重新训练整个模型的情况下。只需在输入中添加提示即可。
  • 资源需求低:由于只更新部分参数,P-Tuning对计算资源的需求较低,可以在较小的硬件上进行训练和部署。

5. 在小数据集上的表现良好

  • 数据效率高:P-Tuning特别适用于小数据集场景,在这些场景中,训练整个模型可能导致性能下降,而P-Tuning可以利用预训练的知识,有效提升模型的性能。

6. 提升模型的可解释性

  • 可解释性增强:由于P-Tuning强调了提示的作用,研究者可以更清晰地理解模型如何通过特定提示来做出不同的决策。这对于分析模型的行为和结果非常有帮助。

7. 迁移学习效果好

  • 知识迁移:P-Tuning能够有效地利用预训练模型中存储的知识,通过适当的提示,将这种知识迁移到新任务中。这使得在许多下游任务中,P-Tuning能够实现与全参数微调相当甚至更好的性能。

P-Tuning的局限性

1. 提示设计的依赖性

  • 提示的有效性:P-Tuning的性能高度依赖于提示的设计和选择。不同的提示可能会导致模型产生不同的预测结果。如果提示设计不当,可能会影响模型的理解和预测能力。
  • 提示选择的挑战:设计有效的提示需要领域知识和经验,这对于非专业人士来说可能是一个挑战。

2. 学习到的提示嵌入的复杂性

  • 提示嵌入的可解释性:虽然P-Tuning提供了一定的可解释性,但学习到的提示嵌入的具体意义和如何影响模型决策可能仍然不够清晰。研究者可能难以解读这些嵌入的具体作用。
  • 相似性问题:不同任务或数据集可能会导致提示嵌入相似性较高,导致模型在迁移到新任务时表现不佳。

3. 数据集和任务的限制

  • 适用性问题:P-Tuning在小数据集上表现良好,但在大规模和复杂任务中,可能无法完全发挥预训练模型的潜力。在某些情况下,全参数微调可能仍然是更优的选择。
  • 数据分布差异:如果训练和测试数据的分布差异较大,P-Tuning的效果可能受到影响,特别是如果提示未能充分捕捉任务的关键特征。

4. 对训练资源的需求

  • 额外的训练时间:尽管P-Tuning的训练参数较少,但学习提示嵌入仍然需要一定的训练时间和计算资源。在资源有限的情况下,可能仍需权衡使用全参数微调与P-Tuning的选择。

5. 任务特定性

  • 领域适应性:某些领域的特定任务可能不适合使用P-Tuning,尤其是在需要高度专业化的知识和上下文理解的情况下。全参数微调可能更好地适应这些特定的领域。

6. 模型性能的极限

  • 性能瓶颈:由于只更新部分参数,P-Tuning在某些情况下可能无法突破预训练模型的性能极限。在需要极高性能的任务中,全参数微调可能更能挖掘模型的潜力。

P-Tuning的未来发展方向

1. 大规模模型的适应性

  • 模型架构的调整:为适应更大规模的模型,P-Tuning可以通过调整提示嵌入的维度和数量来保持与模型的对齐。这意味着需要为每个新任务设计适当的嵌入结构。
  • 分层提示:对于大型模型,可以设计分层的提示结构,允许在不同层次上进行信息传递,从而使模型更有效地利用提示信息。

2. 多任务学习

  • 共享提示嵌入:在多任务设置中,可以设计共享的提示嵌入,以便在不同任务之间传递信息。这有助于提高模型的训练效率,并减少为每个任务单独训练提示的需求。
  • 动态提示调整:利用动态生成的提示来适应不同任务的需求。通过实时分析任务特征,生成适合特定任务的提示,从而增强模型的适应性。

3. 增强训练方法

  • 自适应学习率:为不同任务的提示嵌入设置不同的学习率,以便更好地适应每个任务的特性。这可以通过监控每个任务的性能来动态调整学习率。
  • 数据增强:结合数据增强技术,在训练过程中引入多样化的训练样本,从而提高模型在新任务和大规模数据集上的泛化能力。

4. 集成方法

  • 与其他技术结合:将P-Tuning与其他微调技术(如LoRA、Adapter等)结合使用,可以进一步提升模型的性能。这些技术可以帮助在不大幅增加模型参数的情况下,增强模型对新任务的适应性。
  • 知识蒸馏:通过知识蒸馏技术,将大型模型的知识迁移到较小的模型中,同时利用P-Tuning进行微调,可以在资源有限的情况下实现较好的性能。

5. 任务定制化

  • 针对性任务提示设计:针对特定任务或领域设计专门的提示嵌入,以确保它们能有效捕捉任务特征。这可能包括对领域特定的语言和上下文的理解。
  • 领域适应性:在特定领域(如医疗、法律等)中,通过细化提示以增强对领域术语和上下文的理解,提升模型在特定领域任务上的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/904404.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

中间件安全(三)

本文仅作为学习参考使用,本文作者对任何使用本文进行渗透攻击破坏不负任何责任。 前言: 本文主要讲解apache命令执行漏洞(cve_2021_41773)。 靶场链接:Vulfocus 漏洞威胁分析平台 一,漏洞简介。 cve_2021_41773漏洞…

双十一我都入手了啥大件?这几款超值好物分享给你

​马上就到一年一度的“双11”大促,简单与大家分享,最近自己买过或者是看好的生活好物。以数码为主,平常的一点生活会提及一些。 耳机党必备,听歌不伤耳朵!——南卡OE MIX开放式耳机 一句话推荐:百元旗舰…

PHP海外矿物矿机理财投资源码-金融理财投资源码

PHP海外矿物矿机理财投资源码/金融理财投资源码 海外矿物矿机理财投资源码 测试不错,可以做其他产品理财,功能都没啥太大问题

持续更新...记录

一、Random类 1、构造方法: ①有参:通过指定种子数进行创建 (使用相同的种子数创建多个Random对象,这些对象生成的随机数序列将完全相同‌) (适用于需要可重复生成相同随机数序列的场景,如科学…

Redis项目中应用

1. Redis简介 Redis是一个基于内存的key-value结构数据库。Redis 是互联网技术领域使用最为广泛的存储中间件。 官网:https://redis.io 中文网:Redis中文网 2. Redis下载与安装 2.1 Redis下载 Redis安装包分为windows版和Linux版: Wind…

cursor连接远程jupyter

cursor的步骤跟vscode应该是基本一样的,主要需要两个插件,一个是remote-ssh,另一个是jupyter 第一步 首先连接远程的ssh,因为我已经新建好了,所以直接选207,没有连接过的就选Add New SSH Host&#xff…

9款热门CRM客户关系管理系统大盘点

在当今竞争激烈的商业环境中,客户关系管理(CRM)系统已成为企业不可或缺的工具。CRM系统不仅帮助企业管理客户信息,还能提高销售效率、改善客户服务、增强客户满意度。本文将为您盘点9款热门的CRM客户关系管理系统,并重…

IMX6ULL裸机-汇编_反汇编_机器码

程序处理的4个步骤 我们编写的C程序是不能直接在ARM等平台上运行的,必须经过一系列的程序处理才可以,我们的第一个LED程序涉及两个文件:start.S、main.c,它们的处理过程如下: 对于汇编程序,经过汇编之后&a…

【Unity】游戏UI中添加粒子特效导致穿层问题的解决

这里介绍一下简易的ui系统中,添加粒子特效导致的穿层问题 首先是在ui界面中添加粒子特效预制体,这个时候,控制这个粒子显示层级的有两个方面 上图中,如果你的Sorting Layer ID的值(Layer排序)是大于当前C…

SAP 根据不同生产版本创建销售预测简介

SAP 根据不同生产版本创建销售预测简介 业务场景前台操作1、创建BOM2、创建工艺路线3、创建生产版本4、创建销售预测5、调整销售预测6、查看物料需求业务场景 很多工厂一个物料可能会存在多个BOM,当有多个BOM存在的情况下就会存在多个生产版本,当创建计划独立需求的时候,系…

STM32 RTC 驱动代码(解决了使用HAL库函数导致的复位或者掉电后导致RTC年月日日期清零的问题)

问题背景:在RTC中断里面使用HAL库HAL_RTC_GetDate()和HAL_RTC_GetTime()来获取RTC时间日期。 源码如下图: 问题描述:单片机断电或者复位后的时分秒的时间可以接上,但年月日的日期就会被清零。如图: 导致问题的根本原因…

SpringBoot3+SpringSecurity6基于若依系统整合自定义登录流程

SpringBoot3SpringSecurity6基于若依系统整合自定义登录流程 问题背景 在做项目时遇到了要对接统一认证的需求,但是由于框架的不兼容性(我们项目是springboot3,jdk17,springsecurity6.1.5)等因素,不得不使…

Mount Image Pro,在取证安全的环境中挂载和访问镜像文件内容

天津鸿萌科贸发展有限公司从事数据安全服务二十余年,致力于为各领域客户提供专业的数据恢复、数据备份解决方案与服务,并针对企业面临的数据安全风险,提供专业的相关数据安全培训。 天津鸿萌科贸发展有限公司是 GetData 公司数据恢复与取证工…

PHP合成图片,生成海报图,poster-editor使用说明

之前写过一篇使用Grafika插件生成海报图的文章,但是当我再次使用时,却发生了错误,回看Grafika文档,发现很久没更新了,不兼容新版的GD,所以改用了intervention/image插件来生成海报图。 但是后来需要对海报…

React 前端框架全面教程:从入门到进阶

React 前端框架全面教程:从入门到进阶 引言 在现代前端开发中,React 作为一款流行的 JavaScript 库,以其组件化、声明式的特性和强大的生态系统,成为了开发者的首选。无论是构建单页应用(SPA)还是复杂的用…

基于Python的自然语言处理系列(42):Token Classification(标注分类)

在本篇文章中,我们将探讨如何进行 Token Classification(标注分类),这是一类为句子中的每个 token(词或子词)分配标签的任务。该任务可以解决很多问题,例如命名实体识别(NER&#xf…

用Pyhon写一款简单的益智类小游戏——2048

文字版——代码及讲解 代码—— import random# 初始化游戏棋盘 def init_board():return [[0] * 4 for _ in range(4)]# 在棋盘上随机生成一个2或4 def add_new_tile(board):empty_cells [(i, j) for i in range(4) for j in range(4) if board[i][j] 0]if empty_cells:i,…

『Linux学习笔记』如何在 Ubuntu 22.04 上安装和配置 VNC

『Linux学习笔记』如何在 Ubuntu 22.04 上安装和配置 VNC 文章目录 一. 『Linux学习笔记』如何在 Ubuntu 22.04 上安装和配置 VNC1. 介绍 二. 参考文献 一. 『Linux学习笔记』如何在 Ubuntu 22.04 上安装和配置 VNC 如何在 Ubuntu 22.04 上安装和配置 VNC 1. 介绍 虚拟网络计算…

【Java】方法的使用 —— 语法要求、方法的重载和签名、方法递归

目录 1. 方法基础知识 1.1 方法的概念 1.2 语法格式 * 注意事项【与C不同】 1.3 return —— 返回值的严格检查【比C语言严格】 2. 形参与实参的关系 3. 方法重载 3.1 什么是方法重载?为什么要方法重载? 3.2 方法重载的规则 4. 方法签名 5. 递…

HT7178 带输出关断的20V,14A全集成同步升压转换器

1、特点 输入电压范围VpIN:2.7V-20V 输出电压范围VouT:4.5V-20V 可编程峰值电流:14A 高转换效率: 95%(VPIN7.2V, VoUT 16V, IouT3A) 94%(VPIN12V,VoUT18V,IoUT4A) 90%(VPIN3.3, VoUT-9V,IOUT3A) 轻载条件下两种调制方式:脉频调制(PFM)和 强制脉宽调试(PWM) 集成输出关断的栅极…