lora微调过程

import os
import pickle
from transformers import AutoModelForCausalLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType


device = "cuda:0"

#1.创建lora微调基本的配置
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

 2.通过调用get_peft_model方法包装基础的Transformer模型

#通过调用get_peft_model方法包装基础的Transformer模型
model = AutoModelForCausalLM("/root/paddlejob/workspace/llama-2-7b-chat")
model = get_peft_model(model, peft_config)

下面是lora微调的模型结构,可以看到多了两个矩阵,一个降维一个升维 

3.训练

# optimizer and lr scheduler
'''len(train_dataloader) 是训练数据集中的批次数量,num_epochs 是训练过程中的迭代次数。因此,len(train_dataloader) * num_epochs 表示整个训练过程中的总迭代次数,即总共要遍历训练数据集的批次数'''
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),)

#training and evaluation
model = model.to('cuda:0')
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to('cuda:0') for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss = total_loss + loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
   for step, batch in enumerate(eval_dataloader):
        batch = {k: v.to('cuda:0') for k, v in bacth.items()}
        with torch.nograd():
            outputs = model(**bacth)
        loss = outputs.loss
        eval_loss = eval_loss + loss.detach().float()
        eval_preds.extend(
            tokenizer.bacth_decode(torch.argmax(outputs.logits, -1), skip_special_tokens=True)
            )
    
    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
        

'''
在 Python 中,** 运算符用于将字典解包为关键字参数传递给函数或方法。在 PyTorch 中,model(**batch) 中的 **batch 将字典 batch 中的键值对作为关键字参数传递给模型的方法(通常是前向传播方法)。

具体来说,**batch 将字典 batch 中的每个键值对解包为一组关键字参数。例如,如果 batch 字典包含键值对 {'input_ids': tensor1, 'attention_mask': tensor2},那么 model(**batch) 实际上就等价于 model(input_ids=tensor1, attention_mask=tensor2)。

这种方式可以方便地将字典中的数据传递给函数或方法,并且使代码更加简洁和易读。
'''

'''
loss.detach().float() 的作用是将计算图中的 loss 张量分离出来并转换为浮点数类型。

具体来说:

loss.detach() 会创建一个新的张量,其值与 loss 相同,但不再跟踪梯度信息。这样做是因为在训练过程中,我们通常只需要保存当前步骤的损失值,而不需要其相关的计算图和梯度信息。
.float() 将张量转换为浮点数类型。这是因为通常情况下,损失值是作为浮点数来计算和累加的。
所以,total_loss += loss.detach().float() 的作用就是将当前步骤的损失值添加到总损失值中,保证总损失值是一个浮点数。
'''

'''
在使用 PyTorch 进行梯度下降优化时,optimizer.zero_grad() 的作用是将模型参数的梯度归零,以便进行新一轮的梯度计算和更新。这是因为在 PyTorch 中,每次调用 .backward() 方法都会累积梯度,而不是覆盖之前的梯度。因此,在每次迭代更新参数之前,需要先将之前的梯度清零,以免影响当前迭代的梯度计算。

简而言之,optimizer.zero_grad() 用于初始化梯度,确保每次迭代都是基于当前 batch 的梯度计算和参数更新,而不会受到之前迭代的影响。
'''

'''

outputs.logits 是模型生成的原始输出,通常是一个三维张量,其中包含了模型对于每个词汇的得分(未经过 softmax 处理)。在语言模型中,这个张量的维度通常是 (batch_size, sequence_length, vocab_size),其中 batch_size 表示批量大小,sequence_length 表示每个序列的长度,vocab_size 表示词汇表的大小。
在生成文本任务中,outputs.logits 的每个元素表示模型在当前位置生成每个词汇的得分。通常,需要对这些得分进行 softmax 处理以获得每个词汇的概率分布,然后根据概率分布进行采样或选择最高概率的词汇作为模型生成的下一个词。

torch.argmax 是 PyTorch 库中的一个函数,用于返回张量中指定维度上的最大值的索引。具体而言,对于一个输入张量,torch.argmax(input, dim=None, keepdim=False) 函数将返回指定维度 dim 上最大值的索引。如果不指定 dim,则默认返回整个张量中最大值的索引。
例如,对于一个形状为 (batch_size, seq_length, vocab_size) 的张量,torch.argmax(outputs.logits, -1) 将返回在 vocab_size 维度上每个位置上的最大值对应的索引,即得分最高的词的索引。

这行代码的作用是将模型输出的 logits (对应每个词的得分)经过 torch.argmax 函数找到得分最高的词的索引,然后使用 tokenizer 对这些索引进行解码,将索引转换为对应的词,并通过 skip_special_tokens=True 参数去除特殊标记(如 [CLS], [SEP] 等)。最终得到的是模型生成的文本内容。
'''

 4.模型保存

#save model
peft_model_id = f"{model_name_path}_{peft_config.peft_type}_{peft_config.peft.task_type}"
model.save_pretrained(peft_model_id)

5.模型训练的其余部分无需更改,当模型训练完成后,保存高效微调的模型权重部分以供模型推理 

#加载微调后的权重
from peft import PeftModel, PeftConfig

config = PeftConfig.from_pretrained(peft_model_id)
##加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
##加载peft模型
model = PeftModel.from_pretrained(model, peft_model_id)

##加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokennizer.pad_token = tokenizer.eos_token

6.加载微调后的权重文件,并进行推理 

#利用微调后的模型进行推理
##tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

##模型推理
outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=10,
    eos_token_id=3
)

##tokenizer解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/529644.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

记一次SQL优化

问题描述: 原本执行此查询,需要占用546G内存数据, 但经过与实施人员沟通,以及对于业务的排查 (精简SQL,站在业务的角度优化SQL) 去掉排序功能(运维,及生产人员可接受&am…

HarmonyOS 开发-应用异常处理案例

介绍 本示例介绍了通过应用事件打点hiAppEvent获取上一次应用异常信息的方法,主要分为应用崩溃、应用卡死以及系统查杀三种。 效果图预览 使用说明: 点击构建应用崩溃事件,3s之后应用退出,然后打开应用进入应用异常页面&#x…

Maven与Jave web结构

Maven 简介 https://www.liaoxuefeng.com/wiki/1252599548343744/1255945359327200 java web module web目录 –src 应用程序源代码和测试程序代码的根目录 –main –java  应用程序源代码目录     --package1     --class1     --class2 –resources  应用…

P8707 [蓝桥杯 2020 省 AB1] 走方格

原题链接:[蓝桥杯 2020 省 AB1] 走方格 - 洛谷 目录 1.题目描述 2.思路分析 3.代码实现 1.题目描述 2.思路分析 题目大意:现在有个人站在第 1 行第 1 列,要走到第 i 行第 j 列(每次只能向右或者向下走)&#xff0…

Linux操作系统(六):文件系统组件

参考资料:阿秀的笔记 文件系统 1. 文件系统的基本组成2. 文件的使用3.文件如何存储3.1 目录怎么存储 4.Linux继承于Unix系统的Unix文件实现方式4.1 Linux Ext 2/3 文件系统4.2 Linux Ext 4 文件系统4.3 磁盘空闲空间的管理机制4.3.1 空闲表法4.3.2 空闲链表法4.3.3…

js语法---简单理解promise

promise语法结构 创建一个promise对象 let p new Promise(function(resolve,reject){// 执行的操作...// 判断操作的结果并执行对应的回调函数if(){resolve()}else{reject()} } 以上实例化了一个promise对象,其中包含了一个参数function,这个函数会在…

从二维数组到一维数组——探索01背包问题的动态规划优化

文章目录 题目前知背包问题 二维dp数组一、思路二、解题方法三、Code 一维dp数组一、思路二、解题方法三、Code 总结 本文将继续上一篇博客爬楼梯之后继续讲解同样用到了动态规划的 01背包问题 在解决动态规划问题时,我们经常面临着空间复杂度的挑战。01背包问题是…

书生·浦语大模型-第三节课笔记/作业

笔记 作业 原版 prompt控制节奏,实现类似关键词检索、主题、信息抽取等功能注意这里根据llm返回的topic (prompt: 告诉我这句话的主题,直接说主题不要解释)进行召回检索(CacheRetriever), 并再次让大模型判断query与返回的检索的相关程度. 如果本地检索…

【工具-工具指南】

项目-开发工具 ■ 编辑器■ Xmind ■ UI交互设计■ AxureRP9 ■ 项目管理■ boardmix■ excalidraw ■ Markdown■ MarkText■ Typora■ Ulysses■ Notable■ VNote■ Mou■ Bears■ Notion■ 有道云■ 印象笔记 ■ 硬件画图■ AD■ Allegro■ PADS■ Eagle■ Altium■ Fritzin…

保研线性代数复习4

一.范数(Norms) 1.什么是范数? 范数是一个向量空间V的函数,每一个属于向量空间V的向量x都匹配了一个实数(它的长度): 2.范数的性质? 齐次性: 正定性: 三…

SpringBoot整合MyBatis四种常用的分页方式

目录 方式1 一、准备工作 1. 创建表结构 2. 导入表数据 3. 导入pom.xml依赖 4. 配置application.yml文件 5. 创建公用的实体类 项目结构 2. 创建controller层 3. 创建service层 4. 创建mapper层 5. 创建xml文件 6. 使用postman进行测试,测试结果如下…

第6章 6.1.1 文本格式化 sprintf函数(MATLAB入门课程)

sprintf函数源自 C 语言标准库中的同名函数,这个函数在 C 语言中用于创建格式化的字符串,且使用频率非常高。作为一门高级编程语言,MATLAB借鉴了 C 语言和其他编程语言中的许多特性和命名惯例。在MATLAB中,sprintf函数主要有两种用…

学习记录14-运算放大器2

目录 前言 一、理想放大器 二、虚断 二、虚短 虚短的两个使用条件 1.虚短概念 2.如果我们将运放的同相端和反相端颠倒会怎样呢? 总结 前言 主要讲述运算放大器的虚短虚断 一、理想放大器 如果没有基础或只是想简单了解,可以看我前一篇文章&am…

数学基础:常见函数图像

来自: https://www.desmos.com/calculator/l3u8133jwj?langzh-CN 推荐: https://www.shuxuele.com/index.html 一、三角函数 1.1 正弦 sin(x) 1.2 余弦 cos(x) 1.3 正切 tan(x) 1.4 余切 cot(x) 1.5 正弦余弦综合 1.6 正切余切综合 二、指数对数

【数据结构与算法】力扣 19. 删除链表的倒数第 N 个结点

题目描述 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 示例 1: 输入: head [1,2,3,4,5], n 2 输出: [1,2,3,5]示例 2: 输入: head [1], n 1 输出: []示例…

[方案实操|数据技术]数据要素十大创新模式(3):深数所-数据交易动态合规体系

“ 推动数据要素更好发展,政策创新是前提,数据质量管理是基础,安全和隐私保护是关键,合规性遵循是条件,数据共享和交易平台是手段。” 数据要素十大创新模式系列文章。 [方案实操|数据技术]数据要素十大创新模式(1)&a…

uniapp 表单使用Uview校验 包括城市选择器

<view><!-- 注意&#xff0c;如果需要兼容微信小程序&#xff0c;最好通过setRules方法设置rules规则 --><u--form labelPosition"left" :model"model1" :rules"rules" ref"uForm" labelWidth"174"><u…

生产端消息可靠性保证: 确认(Confirm)机制

1.PostConstruct注解 PostConstruct注解是Java EE规范中的一部分&#xff0c;主要用于标记在一个Bean初始化完成后需要执行的方法。这个注解由JSR-250定义&#xff0c;并且在Spring框架以及其他遵循Java EE标准的应用服务器中广泛支持。 功能与用途&#xff1a;初始化方法,当…

扫描IP开放端口该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开

扫描IP开放端口该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开 #/bin/bash #该脚本用于对特定目标主机进行常见端口扫描(加载端口字典)或者指定端口扫描,判断目标主机开放来哪些端口 #用telnet方式 IP$1 #IP119.254.3.28 #获得IP的前…

UML学习

UML(Unified Modeling Language)&#xff1a;统一建模语言&#xff0c;提供了一套符号和规则来帮助分析师和设计师表达系统的架构、行为和交互 类图&#xff1a;描绘类、接口之间的关系(继承、实现、关联、依赖等)以及类的内部结构(属性和方法)&#xff0c;直观展现系统的静态…