Pyhon 大模型常见的微调方式,LLMs常见的Finetune方式;chatglm3微调实战;大模型微调通俗易懂总结

一、 LLMs微调

        微调(Fine-tuning)是指在一个已经训练好的神经网络模型基础上,使用额外的数据集或调整超参数,以实现特定任务的训练过程。在微调中,通常会固定预训练模型的大部分参数,只调整最后几层或特定层的参数,以适应新的任务或数据。这种方法通常可以加快模型的收敛速度,提高模型在特定任务上的表现。

(1)Adapter Tuning

Adapter Tuning是一种在预训练模型上进行微调的方法,它通过添加轻量级的适配器模块来保留预训练模型的大部分参数,同时只微调适配器模块的参数。适配器模块通常是一个小的神经网络层,用于在不破坏预训练模型的情况下对其进行调整以适应特定任务。这种方法相对于直接微调整个预训练模型来说更加高效,因为适配器模块的参数数量很少,所以可以在保留预训练模型参数的同时快速适应新任务。Adapter Tuning可以帮助节省计算资源,加快训练速度,并且在一些任务上取得了很好的效果。

        模型结构如上图左侧所示, 微调时冻结预训练模型的主体,由Adapter模块学习特定下游任务的知识。其中,Adapter模块结构如上图右侧所示,包含两个前馈层和一个中间层,第一个前馈层和中间层起到一个降维的作用,后一个前馈层和中间层起到升维的作用。

        Adapter调优的参数量大约为LM参数的3.6%。

(2) Prefix Tuning

        Prefix Tuning是一种用于微调通用预训练语言模型以适应特定任务的技术。在传统的微调方法中,我们会直接将任务文本作为输入传递给预训练模型,并通过调整整个模型的参数来适应任务。然而,Prefix Tuning采用一种不同的方法,它在输入序列的前面添加一个特定的前缀,以引导模型生成特定的输出。

        具体来说,当使用Prefix Tuning时,首先需要设计一个适当的前缀,这个前缀通常包含与任务相关的信息,比如问题描述、指令等。然后,在输入序列前面添加这个前缀,形成一个带有任务相关信息的完整输入序列。接着,将这个带有前缀的序列输入给预训练模型进行生成,模型会在生成时考虑到前缀中的任务提示信息,从而生成与任务相关的输出。

        在训练过程中,Prefix Tuning通过最小化任务目标序列与生成序列之间的距离来调整模型的参数,使得模型可以更好地适应特定任务。这种方法可以提高模型在特定任务上的性能,同时避免了重新训练整个模型的需求,从而节省了大量的计算资源和时间。

实验结果表明:

(1)在完整的数据集上,Prefix-Tunning和Fine-Tuning在table-to-text上的结果是comparable的,而在summarization任务上,prefix-tuning的效果略有下降。但在low-data settings和unseen topics的情况下,Prefix-Tuning的效果更佳。

(2)与Adapter-Tuning相比,Trefix-Tuning在相同的表现下只需调节更少的参数量。

(3)不同的前缀长度有不一样的性能表现,在一定程度上长度越长,prefix的效果越明显,但也可能出现降低的问题。实验表明,prefix长度对推理速度影响不大,因为prefix上的attention是并行计算的。

        Prefix Tuning参数规模约为LM模型整体规模的0.1%。

(3)Prompt Tuning

        Prompt Tuning(提示微调)是一种用于微调预训练语言模型的技术,它专注于通过设计和修改提示文本来改善模型在特定任务上的性能。在传统的微调方法中,我们直接将任务文本作为输入传递给预训练模型,并通过调整模型的参数来适应任务。但是,Prompt Tuning采用了一种不同的策略:通过精心设计的提示文本来设置任务的起始点,以引导模型生成更准确和相关的输出。

        Prompt Tuning的关键思想是在输入序列的开头添加一个提示文本。这个提示文本可以包含问题描述、类别信息、指令或任何对任务有帮助的文本片段。这个提示文本的设计需要充分考虑任务特定的上下文和领域知识,以引导模型产生合适的回答或输出。

        在训练过程中,Prompt Tuning使用了自监督学习的方法。首先,通过生成模型的预测输出,可以根据提示文本和一些部分观察到的目标答案来操作和生成伪造的监督信号。然后,通过最大化这些伪造信号的似然性,微调模型的参数。这个过程可以在大规模的未标注数据上进行,而不需要人工标注的成本。通过Prompt Tuning,可以在不重新训练整个模型的情况下,通过微调提示文本来提高模型的性能和适应特定任务。这种方法在多种自然语言处理任务中取得了显著的成功,包括问答系统、摘要生成、文本分类等。

 

  • Prompt 长度影响:模型参数达到一定量级时,Prompt 长度为1也能达到不错的效果,Prompt 长度为20就能达到极好效果。
  • Prompt初始化方式影响:Random Uniform 方式明显弱于其他两种,但是当模型参数达到一定量级,这种差异也不复存在。
  • 预训练的方式:LM Adaptation 的方式效果好,但是当模型达到一定规模,差异又几乎没有了。
  • 微调步数影响:模型参数较小时,步数越多,效果越好。同样随着模型参数达到一定规模,zero shot 也能取得不错效果。
  • 当参数达到100亿规模与全参数微调方式效果无异。

(4)P-Tuning -v1

 

P-Tuning 提出将 Prompt 转换为可以学习的 Embedding 层,只是考虑到直接对 Embedding 参数进行优化会存在这样两个挑战:

  • Discretenes: 对输入正常语料的 Embedding 层已经经过预训练,而如果直接对输入的 prompt embedding进行随机初始化训练,容易陷入局部最优。
  • Association:没法捕捉到 prompt embedding 之间的相关关系。

作者在这里提出用 MLP + LSTM 的方式来对 prompt embedding 进行一层处理:

P-tuning 依然是固定 LLM 参数,利用多层感知机和 LSTM 对 Prompt 进行编码,编码之后与其他向量进行拼接之后正常输入 LLM。注意,训练之后只保留 Prompt 编码之后的向量即可,无需保留编码器。

(5) P-Tuning v2

        P-Tuning v2是Prompt-Tuning的改进版本,是一种用于微调预训练语言模型的方法。该方法由OpenAI提出,旨在提高模型在特定任务上的性能。

        P-Tuning v2的核心创新是引入了一个自适应的Prompt Encoder,通过对输入样本进行编码,并生成动态提示,从而使模型能够根据不同任务的需求自动调整提示语。这种自适应的方法消除了人工设计和微调提示文本的需要,提高了模型在各种任务中的适应性和性能。相比 Prompt Tuning 和 P-tuning 的方法, P-tuning v2 方法在多层加入了 Prompts tokens 作为输入,带来两个方面的好处:带来更多可学习的参数(从 P-tuning 和 Prompt Tuning 的0.1%增加到0.1%-3%),同时也足够 parameter-efficient;加入到更深层结构中的 Prompt 能给模型预测带来更直接的影响。

具体而言,P-Tuning v2包含以下关键步骤:

  1. 自适应Prompt生成:模型通过Prompt Encoder对输入样本进行编码,结合任务信息生成动态的提示。这样,模型能够根据不同的输入样本自动调整提示,提高了模型的灵活性和泛化能力。

  2. 微调:生成的动态提示被输入到模型中,模型在训练过程中根据提示指导进行微调,以适应特定任务的要求。这有助于提高模型在该任务上的性能表现。

  3. 高效性:相较于传统的Prompt-Tuning方法,P-Tuning v2减少了对提示文本的手动设计工作,提高了效率。模型能够更快地适应不同任务,并表现出更好的性能。

(6) LoRA(Low-Rank Adaptation)

        LoRA(Low-Rank Adaptation)是一种用于模型微调的技术,旨在通过将低秩矩阵附加到预训练模型的嵌入矩阵中,来提高模型在特定任务上的性能。这种方法主要用于在大规模语言模型(如BERT、GPT等)上进行微调,以适应特定任务的需求。

        在LoRA微调中,预训练的模型架构保持不变,但会添加一个额外的低秩矩阵参数层用于微调。这个低秩矩阵通常是一个小型的矩阵,其目的是在不增加过多参数的情况下,提供更多与微调任务相关的信息。通过在微调过程中同时训练嵌入矩阵和低秩矩阵,LoRA能够更好地适应特定任务的特征。

        LoRA微调的优势在于可以在保持预训练模型参数不变的情况下,针对特定任务进行快速有效的微调。这种方法可以提高模型的性能并加速微调过程,尤其适用于任务数据集相对较小的情况下。 

(7) 其他方法:AdaLoRA

        AdaLoRA(Adaptive Low-Rank Adaptation)是LoRA(Low-Rank Adaptation)的一个改进版本,旨在进一步优化模型微调过程,以更好地适应特定任务。在AdaLoRA中,与传统的LoRA不同,它引入了自适应机制来动态地调整低秩矩阵的大小和结构,以使模型在微调过程中更加灵活和高效。这种自适应机制可以根据微调任务的需求,自动确定最佳的低秩矩阵参数,以更好地解决特定任务的挑战。具体来说,AdaLoRA通过引入自适应的稀疏正交约束来调整低秩矩阵的结构,以提高模型的泛化能力和表达能力。这种方法能够在微调过程中更好地平衡模型的复杂性和性能,从而更好地适应不同的任务需求

        具体做法如下:调整增量矩分配。AdaLoRA将关键的增量矩阵分配高秩以捕捉更精细和任务特定的信息,而将较不重要的矩阵的秩降低,以防止过拟合并节省计算预算。以奇异值分解的形式对增量更新进行参数化,并根据重要性指标裁剪掉不重要的奇异值,同时保留奇异向量。由于对一个大矩阵进行精确SVD分解的计算消耗非常大,这种方法通过减少它们的参数预算来加速计算,同时,保留未来恢复的可能性并稳定训练。

二、LLMs微调实践(chatglm3为例)

ChatGLM3 是智谱AI和清华大学 KEG 实验室联合发布的新一代对话预训练模型。ChatGLM3-6B 是 ChatGLM3 系列中的开源模型,在保留了前两代模型对话流畅、部署门槛低等众多优秀特性的基础上,ChatGLM3-6B 引入了如下特性:

  1. 更强大的基础模型: ChatGLM3-6B 的基础模型 ChatGLM3-6B-Base 采用了更多样的训练数据、更充分的训练步数和更合理的训练策略。在语义、数学、推理、代码、知识等不同角度的数据集上测评显示,ChatGLM3-6B-Base 具有在 10B 以下的基础模型中最强的性能
  2. 更完整的功能支持: ChatGLM3-6B 采用了全新设计的 Prompt 格式,除正常的多轮对话外。同时原生支持工具调用(Function Call)、代码执行(Code Interpreter)和 Agent 任务等复杂场景。
  3. 更全面的开源序列: 除了对话模型 ChatGLM3-6B 外,还开源了基础模型 ChatGLM3-6B-Base、长文本对话模型 ChatGLM3-6B-32K。以上所有权重对学术研究完全开放,在填写问卷进行登记后亦允许免费商业使用

环境安装

首先需要下载本仓库:

git clone https://github.com/THUDM/ChatGLM3
cd ChatGLM3

然后使用 pip 安装依赖:

pip install -r requirements.txt
  • transformers 库版本应该 4.30.2 以及以上的版本 ,torch 库版本应为 2.0 及以上的版本,以获得最佳的推理性能。
  • 为了保证 torch 的版本正确,请严格按照 官方文档 的说明安装。
  • gradio 库版本应该为 3.x 的版本。

注意哦: transformers 版本如果较低,很可能出现模型的参数加载不全(甚至不支持模型),如果参数加载不全虽然不影响推理,但是很影响推理效果哦

本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。

如果将模型下载到了本地,本文和代码中的 THUDM/chatglm3-6b 字段均应替换为相应地址以从本地加载模型。

运行示例需要 python>=3.9,除基础的 torch 依赖外,示例代码运行还需要依赖

pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed

多轮对话格式

多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 loss_mask 从而在一遍计算中为多轮回复计算 loss

数据格式和预处理

对于数据文件,样例采用如下格式

如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。

[
  {
    "conversations": [
      {
        "role": "system",
        "content": "<system prompt text>"
      },
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }, 
       // ... Muti Turn
      {
        "role": "user",
        "content": "<user prompt text>"
      },
      {
        "role": "assistant",
        "content": "<assistant response text>"
      }
    ]
  }
  // ...
]

请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能

如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。

[
   {
      "tools": [
         // available tools, format is not restricted
      ],
      "conversations": [
         {
            "role": "system",
            "content": "<system prompt text>"
         },
         {
            "role": "user",
            "content": "<user prompt text>"
         },
         {
            "role": "assistant",
            "content": "<assistant thought to text>"
         },
         {
            "role": "tool",
            "name": "<name of the tool to be called",
            "parameters": {
               "<parameter_name>": "<parameter_value>"
            },
            "observation": "<observation>"
            // don't have to be string
         },
         {
            "role": "assistant",
            "content": "<assistant response to observation>"
         },
         // ... Muti Turn
         {
            "role": "user",
            "content": "<user prompt text>"
         },
         {
            "role": "assistant",
            "content": "<assistant response text>"
         }
      ]
   }
   // ...
]
  • 关于工具描述的 system prompt 无需手动插入,预处理时会将 tools 字段使用 json.dumps(..., ensure_ascii=False) 格式化后插入为首条 system prompt。

  • 每种角色可以附带一个 bool 类型的 loss 字段,表示该字段所预测的内容是否参与 loss 计算。若没有该字段,样例实现中默认对 systemuser 不计算 loss,其余角色则计算 loss

  • tool 并不是 ChatGLM3 中的原生角色,这里的 tool 在预处理阶段将被自动转化为一个具有工具调用 metadata 的 assistant 角色(默认计算 loss)和一个表示工具返回值的 observation 角色(不计算 loss)。

  • 目前暂未实现 Code interpreter的微调任务。

  • system 角色为可选角色,但若存在 system 角色,其必须出现在 user 角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 system 角色。

作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 ToolAlpaca 数据集,并使用

./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"

将数据集处理成上述格式。在这里,我们有意将工具处理成了了 list[str] 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。

微调模型

以下脚本提供了微调模型的参考方式。

./scripts/finetune_ds_multiturn.sh  # 全量微调
./scripts/finetune_pt_multiturn.sh  # P-Tuning v2 微调

部署

我们更新了 ChatGLM3 的综合 Demo,使其可以部署微调后的模型 checkpoint。

对于全量微调,可以使用以下方式进行部署

cd ../composite_demo
MODEL_PATH="path to finetuned model checkpoint" TOKENIZER_PATH="THUDM/chatglm3-6b" streamlit run main.py

对于 P-Tuning v2 微调,可以使用以下方式进行部署 

cd ../composite_demo
MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit run main.py

输入输出格式

对于输入-输出格式,样例采用如下输入格式

[
  {
    "prompt": "<prompt text>",
    "response": "<response text>"
  }
  // ...

预处理时,不会拼接任何角色标识符。作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 Google Drive 或者 Tsinghua Cloud 下载处理好的 AdvertiseGen 数据集,将解压后的 AdvertiseGen 目录放到本目录下。

./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"

来下载和将数据集处理成上述格式。

微调模型

以下脚本提供了微调模型的参考方式。

./scripts/finetune_ds.sh  # 全量微调
./scripts/finetune_pt.sh  # P-Tuning v2 微调

推理验证

对于输入输出格式的微调,可使用 inference.py 进行基本的推理验证。

python inference.py \
    --pt-checkpoint "path to p-tuning checkpoint" \
    --model THUDM/chatglm3-6b 
python inference.py \
    --tokenizer THUDM/chatglm3-6b \
    --model "path to finetuned model checkpoint" 

提示

微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为如下:

Sanity Check >>>>>>>>>>>>>
         '[gMASK]':  64790 ->   -100
             'sop':  64792 ->   -100
      '<|system|>':  64794 ->   -100
                '':  30910 ->   -100
              '\n':     13 ->   -100
          'Answer':  20115 ->   -100
             'the':    267 ->   -100
       'following':   1762 ->   -100
                  ...
            'know':    683 ->   -100
             'the':    267 ->   -100
        'response':   3010 ->   -100
         'details':   3296 ->   -100
               '.':  30930 ->   -100
   '<|assistant|>':  64796 ->   -100
                '':  30910 ->  30910
              '\n':     13 ->     13
               'I':    307 ->    307
            'need':    720 ->    720
              'to':    289 ->    289
             'use':    792 ->    792
                  ...
<<<<<<<<<<<<< Sanity Check

字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 loss_mask 是否符合预期。若不符合,可能需要调整代码或数据。 

参考显存用量:

参数解释: 

PRE_SEQ_LEN=128     ---这是一个环境变量,代表序列的预设长度为128
LR=2e-2    ---代表学习率为0.02
NUM_GPUS=1     ---使用GPU的数量,为1
MAX_SOURCE_LEN=1024    ---输入序列的最大长度
MAX_TARGET_LEN=128     ---目标序列的最大长度
DEV_BATCH_SIZE=1    ---每个batch样本数量
GRAD_ACCUMULARION_STEPS=32    ---在进行一次参数更新之前,要进行的梯度累积步骤的数量
MAX_STEP=1000    ---训练步数的最大数量,一个batch一次
SAVE_INTERVAL=500    ---保存模型检查点的步数间隔

假设2000条数据 

 epoch计算:数据划分为batch:2000 /(DEV_BATCH_SIZE * GRAD_ACCUMULARION_STEPS * NUM_GPUS) = 62.5

epoch = 1000/62.5  = 16

所以epoch为16

基础模型推理代码

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True, device='cuda')
model = model.eval()
response, history = model.chat(tokenizer, "你好", history=[])
print(response)
###你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。
response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
print(response)
###

微调模型代码:finetune.py

#!/usr/bin/env python
# coding=utf-8
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Fine-tuning the library models for sequence to sequence.
"""
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
# Adapted from 


import logging
import os
import sys
import torch
import json
import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
from trainer import PrefixTrainer

from arguments import ModelArguments, DataTrainingArguments

from preprocess_utils import sanity_check, MultiTurnDataset, InputOutputDataset

logger = logging.getLogger(__name__)

def main():
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    # datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    config.pre_seq_len = model_args.pre_seq_len
    config.prefix_projection = model_args.prefix_projection

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    if model_args.ptuning_checkpoint is not None:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)

    if model_args.quantization_bit is not None:
        print(f"Quantized to {model_args.quantization_bit} bit")
        model = model.quantize(model_args.quantization_bit)
    if model_args.pre_seq_len is not None:
        # P-tuning v2
        model = model.half()
        model.transformer.prefix_encoder.float()
    else:
        # Finetune
        model = model.float()
    
    with open(data_args.train_file, "r", encoding="utf-8") as f:
        if data_args.train_file.endswith(".json"):
            train_data = json.load(f)
        elif data_args.train_file.endswith(".jsonl"):
            train_data = [json.loads(line) for line in f]

    if data_args.train_format == "multi-turn":
        train_dataset = MultiTurnDataset(
            train_data,
            tokenizer,
            data_args.max_seq_length,
        )
    elif data_args.train_format == "input-output":
        train_dataset = InputOutputDataset(
            train_data,
            tokenizer,
            data_args.max_source_length,
            data_args.max_target_length,
        )
    else:
        raise ValueError(f"Unknown train format: {data_args.train_format}")
    if training_args.local_rank < 1:
        sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
        pad_to_multiple_of=None,
        padding=False
    )

    # Initialize our Trainer
    trainer = PrefixTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        save_changed=model_args.pre_seq_len is not None
    )

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    model.gradient_checkpointing_enable()
    model.enable_input_require_grads()
    trainer.train(resume_from_checkpoint=checkpoint)
    trainer.save_model()  # Saves the tokenizer too for easy upload
    trainer.save_state()

if __name__ == "__main__":
    main()

 微调模型推理代码:inference.py

import argparse
from transformers import AutoConfig, AutoModel, AutoTokenizer
import torch
import os

parser = argparse.ArgumentParser()
parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path")
parser.add_argument("--model", type=str, default=None, help="main model weights")
parser.add_argument("--tokenizer", type=str, default=None, help="main model weights")
parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--max-new-tokens", type=int, default=128)

args = parser.parse_args()

if args.tokenizer is None:
    args.tokenizer = args.model

if args.pt_checkpoint:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
    config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=128)
    model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True)
    prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin"))
    new_prefix_state_dict = {}
    for k, v in prefix_state_dict.items():
        if k.startswith("transformer.prefix_encoder."):
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
else:
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
    model = AutoModel.from_pretrained(args.model, trust_remote_code=True)

model = model.to(args.device)

while True:
    prompt = input("Prompt:")
    inputs = tokenizer(prompt, return_tensors="pt")
    inputs = inputs.to(args.device)
    response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
    response = response[0, inputs["input_ids"].shape[-1]:]
    print("Response:", tokenizer.decode(response, skip_special_tokens=True))

参考文章:

大模型微调总结 - 知乎

大模型高效微调综述下: DiffPruning、BitFit、LoRa、AdaLoRA、MAM Adapters、UniPELT-CSDN博客

大模型微调技术(Adapter-Tuning、Prefix-Tuning、Prompt-Tuning(P-Tuning)、P-Tuning v2、LoRA)_nlp_渣渣崔-GitCode 开源社区 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/507377.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Windows安装TortoiseSVN客户端结合Cpolar实现公网提交文件到本地服务器

文章目录 前言1. TortoiseSVN 客户端下载安装2. 创建检出文件夹3. 创建与提交文件4. 公网访问测试 前言 TortoiseSVN是一个开源的版本控制系统&#xff0c;它与Apache Subversion&#xff08;SVN&#xff09;集成在一起&#xff0c;提供了一个用户友好的界面&#xff0c;方便用…

盲水印脚本安装说明_bwm、_bwmforpy

此工具需要python2/python3 脚本下载地址https://gitcode.com/chishaxie/BlindWaterMark/tree/master?utm_sourcecsdn_blog_hover 直接下载压缩包解压 在python里面添加两个库&#xff0c;python.exe目录上方输入cmd pip install opencv-python python.exe -m pip install …

docker部署实用的运维开发手册

下载镜像 docker pull registry.cn-beijing.aliyuncs.com/wuxingge123/reference:latestdocker-compose部署 vim docker-compose.yml version: 3 services:reference:container_name: referenceimage: registry.cn-beijing.aliyuncs.com/wuxingge123/reference:latestports:…

Gparted工具 初始化磁盘

Gparted工具 初始化磁盘 1、安装 没有此工具请先安装&#xff1a; yum install epel-release yum install gparted yum install yum-utils git gnome-common gcc-c yum-builddep gparted 2、打开Gparted工具&#xff0c;初始化磁盘 使用具有root权限的普通用户打开gparted&…

回溯算法|40.组合总和II

力扣题目链接 class Solution { private:vector<vector<int>> result;vector<int> path;void backtracking(vector<int>& candidates, int target, int sum, int startIndex, vector<bool>& used) {if (sum target) {result.push_back…

OSPF不规则区域以及OSPF的数据库和优化OSPF的LSA

OSPF的不规则区域 远离骨干非骨干区域不连续骨干-----区域水平分割 解决方案&#xff1a; 1.tunnel ---点到点GRE 在合法与非法ABR(在两个区域之间&#xff0c;但没有连到骨干area0)间建立隧道&#xff0c;然后将其宣告于OSPF协议中&#xff1b; 缺点&#xff1a;1、周期和…

Web应急响应

2024年护网将至&#xff0c;最近我将分享一些红蓝对抗的一些技巧&#xff0c;应急响应、信息收集相关的知识概念以及相关技巧。 目录 1. 黑客攻击流程 2. webshell流量特征 1.1.菜刀特征 1.2.冰蝎3.0 &#xff1a; 1.3.冰蝎2.0&#xff1a; 1.4.冰蝎3.11流量特征 1.5.蚁…

cocos使用playable ads adapter打包试玩广告报错RangeError: Invalid string length

前言 最近有做试玩广告的需求&#xff0c;引擎用的cocos&#xff0c;打包使用的playable ads adapter插件。不过最近打包遇到个奇怪的问题&#xff0c;就是通过插件打包报错RangeError: Invalid string length。因为之前也用空包和早期项目测试过都能顺利打包&#xff0c;经过…

数码管时钟--LABVIEW编程

一、程序的前面板 1.获取系统时钟&#xff0c;年月日&#xff0c;时分秒&#xff0c;用14个数码管显示。 2.闹钟设定小时和分钟。 二、程序的后面板 三、程序运行图 四、程序源码 源程序可以在百度网盘自行下载&#xff0c;地址链接见下方。 链接&#xff1a;https://pan.b…

健身运动蓝牙耳机什么牌子好?五大业内顶级优品推荐

在当下这个健身热潮席卷的时代&#xff0c;越来越多的人开始注重运动与健康&#xff0c;而音乐作为运动时的最佳伴侣&#xff0c;无疑为锻炼过程增添了不少乐趣。为了在运动时享受音乐&#xff0c;一款优质的健身运动蓝牙耳机显得尤为重要&#xff0c;市场上各大品牌纷纷推出自…

python对接百度云车牌识别

注册百度智能云&#xff0c;选择产品服务。 https://console.bce.baidu.com/ 每天赠送200次&#xff0c;做开发测试足够了。 在应用列表复制 AppID , API Key ,Secret Key 备用。 SDK下载地址 https://ai.baidu.com/sdk#ocr 下载SDK文件&#xff0c;解压&#xff0c;…

如何在Plesk面板备份网站

本周有一个客户&#xff0c;购买Hostease的Windows虚拟主机&#xff0c;咨询我们的在线客服&#xff0c;询问Windows虚拟主机Plesk面板是否提供备份功能。我们为用户提供教程&#xff0c;用户很快完成了数据备份。在此&#xff0c;我们分享这个操作教程&#xff0c;希望可以对您…

差点引爆全球的核弹,深度分析XZ-Utils供应链后门投毒事件

处心积虑的投毒者蛰伏三年多&#xff0c;精心选择对象&#xff0c;通过复杂的攻击手法、专业的技战术&#xff0c;一步步支起一张大网&#xff0c;企图掌控全球主流linux发行版&#xff0c;一旦成功他将可以随意侵入全球绝大多数的服务器&#xff0c;这将是足以引爆全球的核弹危…

AI技术创业:挖掘行业解决方案、智能产品服务及教育培训的无限机遇

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…

1 导入图片后 调整图片大小

导入图片 如下图&#xff0c;通过“文件 → 打开”在PS中导入一张图片&#xff0c;但是图片有点小 有三种改变大小的方法 1 只要部分图片&#xff0c;画布大小不变 方式&#xff1a;按住ctrlt&#xff0c;就会出现如图所示的选框 画布大小不变&#xff0c;但是拖动选框&…

吴恩达深度学习笔记:浅层神经网络(Shallow neural networks)3.9-3.11

目录 第一门课&#xff1a;神经网络和深度学习 (Neural Networks and Deep Learning)第三周&#xff1a;浅层神经网络(Shallow neural networks)3.9 神 经 网 络 的 梯 度 下 降 &#xff08; Gradient descent for neural networks&#xff09;3.10&#xff08;选修&#xff0…

使用Redis集合List实现消息队列

系列文章目录 文章目录 系列文章目录前言前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 Redis是一个key-value存储系统。和Memcached类似,它支持存储的value类型…

MySQL经验分享:Shell开发问题

背景 之前整理过Python连接使用MySQL的经验&#xff0c;链接如下&#xff1a; pymysql封装总结_pymysql封装类-CSDN博客 相比高级语言&#xff0c;Shell与MySQL开发使用相对会更麻烦一些&#xff1b;由于 shell是linux命令集的概称&#xff0c;是属于命令行的人机界面。Shel…

k8s 基础入门

1.namespace k8s中的namespace和docker中namespace是两码事&#xff0c;可以理解为k8s中的namespace是为了多租户&#xff0c;dockers中的namespace是为了网络、资源等隔离 2.deployment kubectl create #新建 kubectl aply #新建 更新 升级&#xff1a; 滚动升级&#x…

MS35774/MS35774A,低噪声 256 细分微步进电机驱动,可用在车灯随动,香氛机等领域

MS35774/MS35774A 是一款高精度、低噪声的两相步进 电机驱动芯片&#xff0c;芯片内置功率 MOSFET &#xff0c;长时间工作的平均电 流可以达到 1.4A &#xff0c;峰值电流 2A 。芯片集成了过温保护、欠压 保护、过流保护、短地保护、短电源保护功能。 主要特点 ◼ 2…