Trl SFT: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)

目录

一、环境

  1.1、环境安装

  1.2、安装flash atten

二、代码

  2.1、bash脚本 

  2.2、utils.py 注释与优化

  2.3、train.py 注释与优化

  2.4、模型/参数相关

    2.4.1、量化后的模型

      2.4.1.1 量化后模型结构

      2.4.1.2 量化后模型layers

    2.4.2、参数

     2.4.2.1 training args

     2.4.2.2 peft args

     2.4.2.3 model args

三、Trl 库

  3.1、SFTTrainer

  3.2、其他的代码

    3.2.1、datasets.map 使用 load_from_cache_file = False 方便调试​​​​​​​​​​​​​​

四、小结

  4.1、在SFTTrainer初始化peft模型时,为什么 开启了 QLoRA + FSDP / DS-Zero3 后不使用prepare_model_for_kbit_training 和 peft_module_casting_to_bf16 ,prepare_model_for_kbit_training 和 peft_module_casting_to_bf16 做了什么?QLoRA + FSDP / DS-Zero3 未开启offload​​​​​​​​​​​​​​模型加载后model为什么在cpu上?

  4.2、bfloat16和float16的区别

五、Trl 其他Trainer注释笔记

  5.1、DPOTrainer笔记​​​​​​​​​​​​​​

 5.2、... 


  • 项目地址

peft/examples/sft at main · huggingface/peft · GitHub🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning. - peft/examples/sft at main · huggingface/pefticon-default.png?t=N7T8https://github.com/huggingface/peft/tree/main/examples/sft

  • 文档

https://huggingface.co/docs/peft/accelerate/deepspeedicon-default.png?t=N7T8https://huggingface.co/docs/peft/accelerate/deepspeed

一、环境

系统:ubuntu 
cuda版本:12.1
torch版本:2.2.0
python版本:3.10

conda 虚拟环境中 cuda版本
cuda:12.1  # 确保与"外界"cuda一致

  1.1、环境安装

pip install -r ...

    第一种

git+https://github.com/huggingface/transformers
git+https://github.com/huggingface/accelerate
git+https://github.com/huggingface/peft
git+https://github.com/huggingface/trl
git+https://github.com/huggingface/datatrove.git
unsloth[conda]@git+https://github.com/unslothai/unsloth.git
deepspeed
PyGithub
# flash-attn 单独安装
huggingface-hub
evaluate
datasets
bitsandbytes
einops
wandb
tensorboard
tiktoken
pandas
numpy
scipy
matplotlib
sentencepiece
nltk
xformers
hf_transfer

     第二种

absl-py==2.1.0
accelerate==0.30.0
aiohttp==3.9.4
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes==0.43.1
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.1
cryptography==42.0.5
cycler==0.12.1
datasets==2.18.0
datatrove==0.0.1
deepspeed==0.14.0
Deprecated==1.2.14
dill==0.3.8
docker-pycreds==0.4.0
docstring_parser==0.16
einops==0.7.0
evaluate==0.4.1
filelock==3.13.4
# flash-attn==2.5.7
# flash-attn 需要手动安装, 安装之前需要先保证:
# 第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致
# 第二 安装好 c++ g++ ninja
# 第三 参考官方命令: https://github.com/Dao-AILab/flash-attention
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.43
grpcio==1.62.1
hf_transfer==0.1.6
hjson==3.1.0
huggingface-hub==0.22.2
humanize==4.9.0
idna==3.7
Jinja2==3.1.3
joblib==1.4.0
kiwisolver==1.4.5
loguru==0.7.2
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.3
ninja==1.11.1.1
nltk==3.8.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pandas==2.2.2
peft==0.10.1
pillow==10.3.0
pip==23.3.1
protobuf==3.20.3
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==15.0.2
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.7.0
pydantic_core==2.18.1
PyGithub==2.3.0
Pygments==2.17.2
PyJWT==2.8.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
responses==0.18.0
rich==13.7.1
safetensors==0.4.2
scipy==1.13.0
sentencepiece==0.2.0
sentry-sdk==1.45.0
setproctitle==1.3.3
setuptools==68.2.2
shtab==1.7.1
six==1.16.0
smmap==5.0.1
sympy==1.12
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tiktoken==0.6.0
tokenizers==0.15.2
torch==2.2.2
tqdm==4.66.2
transformers==4.40.0
triton==2.2.0
trl==0.8.3
typing_extensions==4.11.0
tyro==0.8.3
tzdata==2024.1
unsloth==2024.4
urllib3==2.2.1
wandb==0.16.6
Werkzeug==3.0.2
wheel==0.43.0
wrapt==1.16.0
xformers==0.0.25.post1
xxhash==3.4.1
yarl==1.9.4

  1.2、安装flash atten

安装 flash atten 和 deepspeed 前,需要保证:

  • 第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致
  • 第二 安装好 c++ g++ ninja (c++ g++ Ninjia 安装版本过低后续安装可能会失败)
  • 第三 参考官方命令: GitHub - Dao-AILab/flash-attention: Fast and memory-efficient exact attentionFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/Dao-AILab/flash-attention
1. 安装 c++ g++
sudo apt-get update
sudo apt-get install build-essential

2. 安装 Ninja
sudo apt-get install ninja-build

3. 安装flash atten
    参考上面官方命令:
    pip install packaging
    pip install flash-attn --no-build-isolation          ----- flash atten 编译过程需要一定的时间,需要等待

二、代码

peft/examples/sft at main · huggingface/peft · GitHub🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning. - peft/examples/sft at main · huggingface/pefticon-default.png?t=N7T8https://github.com/huggingface/peft/tree/main/examples/sft

  2.1、bash脚本 

PYTHONPATH=$PWD
export PYTHONPATH
echo "当前bash执行目录: $PWD, 已经将PYTHONPATH设置为: $PYTHONPATH"


# --resume_from_checkpoint dir   表示trainer从dir恢复ckpt
# 注释掉: 与wandb 不能共存
# 2>&1 | tee -a examples/sft/qlora_ds_zero3_log.out
accelerate launch --config_file "examples/sft/configs/deepspeed_config_z3_qlora.yaml"  examples/sft/train.py \
    --seed 100 \
    --model_name_or_path "/workspace/Llama-2-7b-chat-hf" \
    --dataset_name "smangrul/ultrachat-10k-chatml" \
    --chat_template_format "chatml" \
    --add_special_tokens False \
    --append_concat_token False \
    --splits "train,test" \
    --max_seq_len 2048 \
    --num_train_epochs 2 \
    --logging_steps 5 \
    --log_level "info" \
    --logging_strategy "steps" \
    --evaluation_strategy "epoch" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 10 \
    --bf16 True \
    --packing True \
    --learning_rate 1e-4 \
    --lr_scheduler_type "cosine" \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --max_grad_norm 1.0 \
    --output_dir "/workspace/output/llama-sft-qlora-dsz3" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --use_flash_attn True \
    --gradient_checkpointing True \
    --use_reentrant True \
    --dataset_text_field "content" \
    --use_peft_lora True \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --lora_target_modules "all-linear" \
    --use_4bit_quantization True \
    --use_nested_quant True \
    --bnb_4bit_compute_dtype "bfloat16" \
    --bnb_4bit_quant_storage_dtype "bfloat16" \
    --resume_from_checkpoint /workspace/output/llama-sft-qlora-dsz3/checkpoint-100 \
    2>&1 | tee -a examples/sft/qlora_ds_zero3_log.out

    # 上传至 hub 的参数
    # --push_to_hub \
    # --hub_private_repo True \
    # --hub_strategy "every_save" \

  2.2、utils.py 注释与优化

import os
from enum import Enum

import torch
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

from peft import LoraConfig

# DEFAULT_CHATML_CHAT_TEMPLATE是一个用于格式化聊天消息的jinja2模板字符串
# jinja2是一种流行的Python模板引擎,它允许在模板中嵌入Python代码,使模板更加动态和可编程
# 在这个模板中,{% for message in messages %} 是一个jinja2的for循环语句,用于遍历messages列表中的每个消息
# {
   {'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
# 这一部分定义了每条消息的格式化方式,包括:
#   1. <|im_start|>: 一个特殊标记,表示消息角色(如user、system或assistant)的开始
#   2. message['role']: 当前消息的角色,如user、system或assistant
#   3. \n: 换行符,用于在角色和消息内容之间添加新行
#   4. message['content']: 当前消息的实际内容
#   5. <|im_end|>: 一个特殊标记,表示消息内容的结束
#   6. \n: 换行符,用于在每条消息之后添加新行
# {% if loop.last and add_generation_prompt %}{
   {'<|im_start|>assistant\n' }}{% endif %}
# 这一部分是一个jinja2的条件语句,当循环遍历到最后一条消息时,如果add_generation_prompt为True,
# 则会在最后一条消息后添加'<|im_start|>assistant\n'作为提示,表示需要模型生成助手的回复
# 这种模板格式化方式的目的是将原始的聊天记录转换为适合语言模型输入的格式,以便进行对话生成任务
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{
   {'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{
   {'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"


# DEFAULT_ZEPHYR_CHAT_TEMPLATE与DEFAULT_CHATML_CHAT_TEMPLATE类似,也是一个用于格式化聊天消息的jinja2模板
# 不同之处在于格式化方式和使用的特殊标记
# {% for message in messages %} 同样是一个用于遍历消息列表的for循环
# {% if message['role'] == 'user' %} 是一个条件语句,用于判断当前消息的角色是否为user
# 如果是user,则使用{
   { '<|user|>\n' + message['content'] + eos_token }}将消息格式化为:
#   1. <|user|>: 用户角色的特殊标记
#   2. \n: 换行符
#   3. message['content']: 消息内容
#   4. eos_token: 句尾标记,如</s>
# {% elif message['role'] == 'system' %} 是另一个条件分支,用于判断当前消息的角色是否为system
# 如果是system,则使用{
   { '<|system|>\n' + message['content'] + eos_token }}进行格式化
# {% elif message['role'] == 'assistant' %} 是第三个条件分支,用于判断当前消息的角色是否为assistant
# 如果是assistant,则使用{
   { '<|assistant|>\n'  + message['content'] + eos_token }}进行格式化
# {% if loop.last and add_generation_prompt %}\n{
   { '<|assistant|>' }}\n{% endif %}
# 这一部分与DEFAULT_CHATML_CHAT_TEMPLATE类似,当遍历到最后一条消息时,如果add_generation_prompt为True,
# 则会添加'<|assistant|>\n'作为提示,表示需要模型生成助手的回复
# 总的来说,这种格式化方式将原始聊天记录转换为适合语言模型输入的形式,但使用了不同的特殊标记
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{
   { '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{
   { '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{
   { '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{
   { '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

# ZephyrSpecialTokens是一个继承自str和Enum的枚举类
# 它定义了Zephyr聊天格式中使用的各种特殊标记,如用户标记、助手标记、系统标记等
# 枚举类的好处是可以将一组相关的常量组织在一起,并提供更好的可读性和类型安全性
# 每个特殊标记都被定义为一个类属性,其值为对应的字符串形式
# 例如,user = "<|user|>"表示用户标记的字符串形式为"<|user|>"
class ZephyrSpecialTokens(str, Enum):
    user = "<|user|>"
    assistant = "<|assistant|>"
    system = "<|system|>"
    eos_token = "</s>"      # 句尾标记,表示一个句子或序列的结束
    bos_token = "<s>"       # 句首标记,表示一个句子或序列的开始
    pad_token = "<pad>"     # 填充标记,用于将序列填充至指定长度

    # list方法是一个类方法,它返回一个列表,包含了该枚举类中所有特殊标记的字符串形式
    # 这个方法常用于初始化分词器(tokenizer)时,将这些特殊标记添加到词表中
    @classmethod
    def list(cls):
        return [c.value for c in cls]

# ChatmlSpecialTokens与ZephyrSpecialTokens类似,也是一个定义了Chatml聊天格式中使用的特殊标记的枚举类
# 不同之处在于具体的特殊标记字符串形式
# 例如,user标记在Chatml格式中为"<|im_start|>user",而在Zephyr格式中为"<|user|>"
class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

# create_datasets函数用于创建训练和测试数据集
# 参数包括:
#   tokenizer: 用于对文本进行分词(tokenization)和编码(encoding)的分词器对象
#   data_args: 包含数据相关配置的参数对象,如数据集名称、切分方式等
#   training_args: 包含训练相关配置的参数对象
#   apply_chat_template (bool): 是否应用聊天模板对数据进行预处理,默认为False
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):
    # preprocess是一个内部函数,用于对数据样本进行预处理
    # 它接受一个字典样本作为输入,其中"messages"键对应一个列表,列表中的每个元素都是一个对话(conversation)
    def preprocess(samples):
        batch = []     # 初始化一个空列表,用于存储预处理后的对话
        # TODO 修改源码
        batch_tokens = []
        # 遍历样本中的每个对话
        for conversation in samples["messages"]:
            # 对每个对话应用tokenizer.apply_chat_template方法进行预处理
            # tokenize=False表示不执行分词操作,只进行格式化
            # https://huggingface.co/docs/transformers/main/zh/chat_templating
            # TODO 对源码进行修改
            chat_tmp = tokenizer.apply_chat_template(conversation, tokenize=False)
            batch.append(chat_tmp)
            chat_tmp_tokens = tokenizer.tokenize(chat_tmp)
            batch_tokens.append(chat_tmp_tokens)
        # 返回一个字典,其中"content"键对应预处理后的对话列表
        return {"content": batch, "content_tokens":batch_tokens}

    raw_datasets = DatasetDict()   # 初始化一个空的DatasetDict对象,用于存储数据集
    # 遍历data_args.splits指定的数据集切分(如train、test等)
    for split in data_args.splits.split(","):
        try:
            # Try first if dataset on a Hub repo, 首先尝试从Hugging Face Hub上加载指定的数据集
            dataset = load_dataset(data_args.dataset_name, split=split)
        except DatasetGenerationError:
            # If not, check local dataset, 如果从Hub上加载失败,则尝试从本地磁盘加载数据集
            dataset = load_from_disk(os.path.join(data_args.dataset_name, split))

        # 根据切分类型,将数据集存入raw_datasets的对应键值中
        if "train" in split:
            raw_datasets["train"] = dataset
        elif "test" in split:
            raw_datasets["test"] = dataset
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")

    # 如果apply_chat_template为True,则对数据集应用preprocess函数进行预处理
    if apply_chat_template:
        raw_datasets = raw_datasets.map(
            preprocess,
            batched=True,         # 表示对样本进行批处理,提高效率
            remove_columns=raw_datasets["train"].column_names,
            # TODO 新增代码, 取消缓存, 用于调试
            load_from_cache_file = False
        )

    train_data = raw_datasets["train"]  # 获取训练数据集
    valid_data = raw_datasets["test"]   # 获取测试数据集

    # TODO 只有主进程打印
    if training_args.local_rank == 0 or training_args.local_rank == -1:
        print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")  # 打印数据集大小
        print(f"A sample of train dataset: {train_data[0]}")  # 打印训练数据集的第一个样本

    return train_data, valid_data


# create_and_prepare_model函数用于创建和准备模型
# 参数包括:
#   args: 包含模型相关配置的参数对象,如模型名称、是否使用量化等
#   data_args: 包含数据相关配置的参数对象,如最大序列长度等
#   training_args: 包含训练相关配置的参数对象,如是否使用梯度检查点等
def create_and_prepare_model(args, data_args, training_args):
    if args.use_unsloth:
        # 如果使用Unsloth库(一种用于加速语言模型的库),则导入FastLanguageModel类
        from unsloth import FastLanguageModel
    bnb_config = None    # 初始化BitsAndBytesConfig为None,用于量化配置
    quant_storage_dtype = None   # 初始化量化存储数据类型为None

    # 检查是否为分布式训练且使用Unsloth库,如果是则抛出NotImplementedError
    # 因为当前版本的Unsloth不支持分布式训练
    if (
        torch.distributed.is_available()
        and torch.distributed.is_initialized()
        and torch.distributed.get_world_size() > 1
        and args.use_unsloth
    ):
        raise NotImplementedError("Unsloth is not supported in distributed training")

    # 如果使用4位量化,则设置计算数据类型和量化存储数据类型
    if args.use_4bit_quantization:
        # 获取指定的计算数据类型, getattr 会将字符串 bfloat16 ---> torch.bfloat16
        compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
        # 获取指定的量化存储数据类型, getattr 会将字符串 bfloat16 ---> torch.bfloat16
        quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

        # 创建BitsAndBytesConfig对象,用于配置量化相关参数
        # BitsAndBytesConfig是一个用于管理量化配置的类,可以指定量化类型、计算数据类型、存储数据类型等
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.use_4bit_quantization,          # 是否使用4位量化
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,     # 4位量化的类型, 如 nf4
            bnb_4bit_compute_dtype=compute_dtype,             # 计算数据类型
            bnb_4bit_use_double_quant=args.use_nested_quant,  # 是否使用双量化
            # TODO Qlora + zero3 修改的代码
            bnb_4bit_quant_storage=quant_storage_dtype,       # 量化存储数据类型
        )

        # 如果计算数据类型为float16且使用4位量化,则打印GPU是否支持bfloat16的提示
        if compute_dtype == torch.float16 and args.use_4bit_quantization:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                print("=" * 80)
                print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
                print("=" * 80)
        # 如果使用8位量化,则创建相应的BitsAndBytesConfig对象
        elif args.use_8bit_quantization:
            bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)

    # 如果使用Unsloth库
    if args.use_unsloth:
        # Load model, 使用FastLanguageModel.from_pretrained方法加载模型, 传入模型名称路径、最大序列长度、是否使用4位量化等参数
        model, _ = FastLanguageModel.from_pretrained(
            model_name=args.model_name_or_path,
            max_seq_length=data_args.max_seq_length,
            dtype=None,
            load_in_4bit=args.use_4bit_quantization,
        )
    else: # 如果不使用Unsloth库,则使用AutoModelForCausalLM.from_pretrained方法加载模型
        # TODO Qlora + zero3 修改的代码
        # 如果指定了quant_storage_dtype且是浮点数类型,则使用quant_storage_dtype, 否则使用默认的torch.float32
        torch_dtype = (
            quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
        )
        # 使用AutoModelForCausalLM.from_pretrained方法加载语言模型, 传入模型路径、量化配置、是否信任远程代码、注意力实现方式和数据类型等参数
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            quantization_config=bnb_config,
            trust_remote_code=True,
            # 注意力实现方式,flash_attention_2或eager
            attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
            # TODO Qlora + zero3 修改的代码
                # 注意 torch_dtype 对于 AutoModelForCausalLM 与 bnb_4bit_quant_storage 数据类型相同。就是这样。其他所有事情都由 Trainer 和 TRL 处理。
            torch_dtype=torch_dtype,
        )

    peft_config = None      # 初始化PEFT配置为None
    chat_template = None    # 初始化聊天模板为None
    # 如果使用PEFT LoRA且不使用Unsloth库,则创建LoraConfig对象
    # PEFT (Parameter-Efficient Fine-Tuning)是一种模型微调技术,可以在保持大部分模型参数不变的情况下,只微调一小部分参数
    # LoRA (Low-Rank Adaptation)是PEFT的一种实现,通过添加低秩矩阵来适应新任务
    if args.use_peft_lora and not args.use_unsloth:
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,         # LoRA的alpha参数,控制LoRA层的重要性
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias="none",                       # 是否对偏置项应用LoRA
            task_type="CAUSAL_LM",             # 任务类型,这里是因果语言模型
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
        )

    special_tokens = None   # 初始化特殊标记为None
    chat_template = None    # 初始化聊天模板为None
    # 根据args.chat_template_format参数,设置特殊标记和聊天模板
    if args.chat_template_format == "chatml":
        special_tokens = ChatmlSpecialTokens              # 使用Chatml格式的特殊标记
        chat_template = DEFAULT_CHATML_CHAT_TEMPLATE      # 使用Chatml聊天模板
    elif args.chat_template_format == "zephyr":
        special_tokens = ZephyrSpecialTokens            # 使用Zephyr格式的特殊标记
        chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE    # 使用Zephyr聊天模板

    # 如果特殊标记不为None
    if special_tokens is not None:
        # 使用AutoTokenizer.from_pretrained方法加载分词器
        # 设置填充标记、句首标记、句尾标记和其他特殊标记
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path,
            pad_token=special_tokens.pad_token.value,     # 填充标记
            bos_token=special_tokens.bos_token.value,     # 句首标记
            eos_token=special_tokens.eos_token.value,     # 句尾标记
            additional_special_tokens=special_tokens.list(),  # 其他特殊标记
            trust_remote_code=True,
        )
        tokenizer.chat_template = chat_template           # 设置聊天模板
        # make embedding resizing configurable?
        # 调整tokenizer的嵌入大小,使其能够容纳新增的特殊标记
        # pad_to_multiple_of=8用于对齐,提高GPU计算效率
        model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
    else:
        # 如果特殊标记为None,则直接加载分词器
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token     # 设置填充标记为句尾标记


    # 如果使用Unsloth库
    if args.use_unsloth:
        # Do model patching and add fast LoRA weights
        # 使用FastLanguageModel.get_peft_model方法对模型进行修补,并添加快速LoRA权重
        # 传入LoRA相关参数,如alpha、dropout、rank等,以及是否使用梯度检查点、随机种子和最大序列长度
        model = FastLanguageModel.get_peft_model(
            model,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
            use_gradient_checkpointing=training_args.gradient_checkpointing,
            random_state=training_args.seed,
            max_seq_length=data_args.max_seq_length,
        )

    return model, peft_config, tokenizer       # 返回模型、PEFT配置和分词器

  2.3、train.py 注释与优化

import os
import sys
import torch
from dataclasses import dataclass, field
from typing import Optional

import torch.distributed
from transformers import HfArgumentParser, TrainingArguments, set_seed, Seq2SeqTrainingArguments
from trl import SFTTrainer    # SFTTrainer用于序列到序列(Sequence-to-Sequence)的语言模型微调训练
from utils import create_and_prepare_model, create_datasets  # 自定义的实用函数,用于创建和准备模型、数据集

# TODO 新增代码, wandb 与 bash 重定向 log.out 冲突, 关闭掉
os.environ["WANDB_DISABLED"] = "true" # 关闭 wandb

# Define and parse arguments. 定义ModelArguments数据类,用于指定模型相关参数
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    # 指定预训练语言模型的路径或在Hugging Face模型库中的标识符, 这允许您使用您选择的任何预训练模型,如GPT-2、GPT-3、BERT等
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    # 指定聊天数据的格式,有以下选项:
    # 1) chatml: 使用Anthropic的chatml格式,例如: <human>: 你好 \n<assistant>: 你好,很高兴与你交谈。
    # 2) zephyr: 使用Pretrained.AI的zephyr格式,例如: Human: 你好 \nAssistant: 你好,很高兴与你交谈。 
    # 3) none: 如果数据集已经格式化为聊天模板,则设置为none
    # 这个参数可以帮助您灵活地处理不同格式的聊天数据
    chat_template_format: Optional[str] = field(
        default="none",
        metadata={
            "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
        },
    )
    lora_alpha: Optional[int] = field(default=16)    # lora_alpha控制LoRA层的重要性,典型值为16或32
    lora_dropout: Optional[float] = field(default=0.1)  # lora_dropout设置LoRA层的dropout率,用于防止过拟合
    # lora_r指定LoRA低秩矩阵的秩(rank),较低的秩可以进一步减少参数量,但可能会影响性能, 秩越低,模型越压缩,但可能会导致性能下降
    lora_r: Optional[int] = field(default=64)
    # lora_target_modules指定应用LoRA的模块列表
    # 默认值包括注意力层的线性投影(q_proj, k_proj, v_proj, o_proj)和前馈神经网络层(down_proj, up_proj, gate_proj)
    # 也可以设置为"all-linear"以应用LoRA到所有线性层
    # 通过选择性地应用LoRA,可以在性能和参数量之间进行权衡
    lora_target_modules: Optional[str] = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
    )
    # use_nested_quant指定是否启用嵌套量化(nested quantization), 嵌套量化可以将4位量化模型进一步量化,从而进一步减小模型大小和内存占用,但可能会影响精度
    # 即 双量化
    use_nested_quant: Optional[bool] = field(
        default=False,
        metadata={"help": "Activate nested quantization for 4bit base models"},
    )
    # bnb_4bit_compute_dtype指定4位量化模型的计算数据类型,例如float16或bfloat16, 使用较低的计算精度可以提高计算速度,但可能会影响模型精度
    bnb_4bit_compute_dtype: Optional[str] = field(
        default="float16",
        metadata={"help": "Compute dtype for 4bit base models"},
    )
    # bnb_4bit_quant_storage_dtype指定4位量化模型的量化存储数据类型,如uint8或float16或bf16, 使用较低的存储精度可以减小模型大小,但可能会影响模型精度
    # 您需要权衡模型大小和精度的平衡
    bnb_4bit_quant_storage_dtype: Optional[str] = field(
        default="uint8",
        metadata={"help": "Quantization storage dtype for 4bit base models"},
    )
    # bnb_4bit_quant_type指定4位量化类型,包括fp4或nf4(normal float量化,一种新型的数据格式),信息论中表示nf4的效果可能会更好
    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4",
        metadata={"help": "Quantization type fp4 or nf4"},
    )
    # use_flash_attn指定是否启用Flash注意力(Flash attention)
    # Flash注意力是一种高效的注意力实现,可以通过内存优化和并行计算提高训练速度
    use_flash_attn: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash attention for training."},
    )
    # use_peft_lora指定是否启用PEFT (Parameter-Efficient Fine-Tuning) LoRA
    use_peft_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables PEFT LoRA for training."},
    )
    # use_8bit_quantization指定是否将模型加载为8位量化版本
    use_8bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 8bit."},
    )
    # use_4bit_quantization指定是否将模型加载为4位量化版本, 4位量化可以将模型大小减小到原始大小的1/4,从而进一步节省内存和加快计算,但可能会显著影响精度
    use_4bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 4bit."},
    )
    # use_reentrant是梯度检查点(Gradient Checkpointing)的一个参数, 梯度检查点可以通过重新计算激活值来节省内存,但会增加一些计算开销
    # use_reentrant指定是否使用可重入(reentrant)的梯度检查点实现,可能会进一步节省内存, 这个参数可以帮助在内存占用和计算开销之间进行权衡
    use_reentrant: Optional[bool] = field(
        default=False,
        metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
    )
    # use_unsloth指定是否使用Unsloth库进行训练
    # Unsloth是一个优化库,可以通过内存优化和并行计算加速PEFT LoRA的训练过程
    # 这个参数可以帮助您进一步提高训练效率
    use_unsloth: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables UnSloth for training."},
    )


# 定义DataTrainingArguments数据类,用于指定数据集和数据处理相关参数
@dataclass
class DataTrainingArguments:
    # 指定要使用的数据集名称或路径,默认为OpenAssistant Guanaco数据集
    dataset_name: Optional[str] = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )

    # packing指定是否使用数据集打包(packing)
    # 数据集打包可以将多个样本打包为一个更长的序列,从而提高训练效率, 这个参数可以帮助您在训练速度和内存占用之间进行权衡
    packing: Optional[bool] = field(
        default=False,
        metadata={"help": "Use packing dataset creating."},
    )

    # dataset_text_field指定数据集中作为input文本的字段名, 这个参数可以帮助您灵活地处理不同格式的数据集
    dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
    # max_seq_length指定输入序列的最大长度,超出部分将被截断, 这个参数可以帮助您在训练速度、内存占用和模型性能之间进行权衡
    max_seq_length: Optional[int] = field(default=512)

    # append_concat_token指定在打包数据集时,是否在每个样本的末尾追加一个连接标记(如<eos>),这个参数可以帮助您控制数据集的格式,从而影响模型的输出
    append_concat_token: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
    )

    # add_special_tokens指定在打包数据集时,是否由分词器(tokenizer)添加特殊标记(如<bos>和<eos>), 这个参数可以帮助您控制数据集的格式,从而影响模型的输出
    add_special_tokens: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
    )
    # splits指定要从数据集中使用的数据分割,如train、test或val,多个分割用逗号分隔, 这个参数可以帮助您灵活地使用数据集的不同部分进行训练和评估
    splits: Optional[str] = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )

# TODO 新增代码, 打印模型的是否参与训练的参数名和数据类型
def print_model_allarguments_name_dtype(model):
    for n,v in model.named_parameters():
        if v.requires_grad:
            print(f"trainable model arguments: {n} - {v.dtype} - {v.shape} - {v.device}")
        else:
            print(f"not trainable model arguments: {n} - {v.dtype} - {v.shape} - {v.device}")


def main(model_args, data_args, training_args):
    # Set seed for reproducibility
    set_seed(training_args.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/548445.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【随笔】Git 基础篇 -- 拉取数据 git pull(二十八)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

GPT的使用

个人笔记&#xff08;整理不易&#xff0c;有帮助点个赞&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_weixin_42717928的博客-CSDN博客 个人随笔&#xff1a;工作总结随笔_8、以前工作中都接触过哪些类型的测试文档-CSDN博客 网站sms-activate.or…

详解电源测试系统自定义报告模板功能:如何轻松实现数据导出

在NSAT-8000电源测试系统内&#xff0c;数据一般分为三级架构&#xff1a;原始数据、数据报告和数据分析。数据报告可以直接展示出电源模块的各项测试数据和测试结果&#xff0c;帮助用户评估电源性能&#xff0c;为电源的优化提升提供数据支持。 系统的记录报告板块展示着历史…

RocketMQ 10 面试题FAQ

RocketMQ 面试FAQ 说说你们公司线上生产环境用的是什么消息中间件? 为什么要使用MQ&#xff1f; 因为项目比较大&#xff0c;做了分布式系统&#xff0c;所有远程服务调用请求都是同步执行经常出问题&#xff0c;所以引入了mq 解耦 系统耦合度降低&#xff0c;没有强依赖…

DeiT:训练ImageNet仅用4卡不到3天的平民ViT | ICML 2021

论文基于改进训练配置以及一种新颖的蒸馏方式&#xff0c;提出了仅用ImageNet就能训练出来的Transformer网络DeiT。在蒸馏学习时&#xff0c;DeiT以卷积网络作为teacher&#xff0c;能够结合当前主流的数据增强和训练策略来进一步提高性能。从实验结果来看&#xff0c;效果很不…

基于FMC接口的Kintex-7 XC7K325T PCIeX4 3U PXIe接口卡

基于FMC接口的Kintex-7 XC7K325T PCIeX4 3U PXIe接口卡 一、板卡概述 本板卡基于Xilinx公司的FPGAXC7K325T-2FFG900 芯片&#xff0c;pin_to_pin兼容FPGAXC7K410T-2FFG900 &#xff0c;支持PCIeX8、64bit DDR3容量2GByte&#xff0c;HPC的FMC连接器&#xff0c;板卡支持PXI…

html基础——CSS

在HTML中&#xff0c;CSS的作用是用于控制网页的样式&#xff0c;包括字体、颜色、背景、布局等方面的设计。通过一个样例来说明CSS的作用&#xff1a; 如下是一个名为global.css的CSS文件&#xff1a; .C1{font-size: 10px;color: blue;border:1px solid red;height: 200px;…

【Redis 神秘大陆】009 案例实践进阶

九、案例实践&进阶方案 9.1 本地缓存组件选型 使用缓存组件时需要重点关注集群方式、集群、缓存命中率。 需要关注集群组建方式、缓存统计&#xff1b;还需要考虑缓存开发语言对缓存的影响&#xff0c;如对于JAVA开发的缓存需要考虑GC的影响&#xff1b;最后还要特别关注…

vue快速入门(二十六)生命周期钩子函数

注释很详细&#xff0c;直接上代码 上一篇 新增内容 生命周期钩子函数的解析生命周期函数效果演示 源码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevic…

【原创教程】海为PLC与RS-WS-ETH-6传感器的MUDBUS_TCP通讯

一、关于RS-WS-ETH-6传感器的准备工作 要完成MODBUS_TCP通讯,我们必须要知道设备的IP地址如何分配,只有PLC和设备的IP在同一网段上,才能建立通讯。然后还要选择TCP的工作模式,来建立设备端和PC端的端口号。接下来了解设备的报文格式,方便之后发送报文完成数据交互。 1、…

【Altium Designer 20 笔记】PCB层

Top Overlay & Bottom Overlay (顶部丝印层和底部丝印层)&#xff1a; 用于标记元件、连接和其他重要信息。丝印层是 PCB 表面的一层&#xff0c;上面印上文字、图标或标记。 Top Solder & Bottom Solder (顶部阻焊层和底部阻焊层)&#xff1a; 阻焊层、开窗层、绿油层…

【电控笔记2.3】速度回路+系统延迟

2.3.1速度回路pi控制器设计 pi伯德图近似设计(不考虑延时理想情况下) Tl:负载转矩 PI控制器的转折频率:Ki/Kp

金融数字化能力成熟度指引

1 范围 本文件提出了金融数字化能力成熟度模型、成熟度计算方法&#xff0c;明确了不同维度金融数字化转型能力 相应的分档要求。 本文件适用于金融机构衡量金融科技应用和数字化转型发展水平&#xff0c;检视自身数字化发展优势与短板&#xff0c; 加快数字化转型&#xff0c…

Fatal error in launcher: Unable to create process using【解决方案】

拷贝python 项目到其他电脑以后&#xff0c;执行pip list 命令时报如下错误&#xff1a; Fatal error in launcher: Unable to create process using ‘“d:\python37\python.exe” “C:\Python\Scripts\pip.exe” list’: ??? 解决方法&#xff1a; 先试这条&#xff1a; …

什么是One-Class SVM

1. 简介 单类支持向量机&#xff0c;简称One-Class SVM(One-Class Support Vector Machine)&#xff0c;用于异常检测和离群点检测(无监督学习&#xff0c;其他svm属于有监督的)&#xff0c;可以在没有大量异常样本的情况下有效地检测异常。其目标是通过仅使用正常数据来建模&a…

Gin框架小结

Gin 简介 Gin是一个轻量级的Web框架&#xff0c;用于构建高性能的Go语言Web应用程序。提供了路由管理、中间件支持、参数绑定和验证、错误处理、静态文件服务等功能。 Gin框架解决了什么问题和痛点 1.golang http 标准库本身提供了比较简单的路由注册能力&#xff0c;只支持…

企业内部知识库:帮助你提高工作效率的好帮手

在现代企业中&#xff0c;知识和信息是一种无形资产&#xff0c;对企业的成长至关重要。员工之间有效地共享知识&#xff0c;可以大幅提高工作效率和团队的整体执行力。为了实现这一点&#xff0c;越来越多的企业开始构建自己的内部知识库&#xff0c;为员工提供一个集中的信息…

华为服务Fellow、首席项目管理专家,华为H5M项目管理标准制定主导者孙虎受邀为PMO大会演讲嘉宾

全国PMO专业人士年度盛会 华为服务Fellow、首席项目管理专家&#xff0c;华为H5M项目管理标准制定主导者孙虎先生受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾&#xff0c;演讲议题为“落地项目管理标准&#xff0c;打赢班长的战争”。大会将于5月25-26日在北京举办&am…

液晶触摸屏中应用的电容式触摸芯片

随着多媒体信息查询的与日俱增&#xff0c;人们越来越多地谈到触摸屏&#xff0c;因为触摸屏不仅适用于中国多媒体信息查询的国情&#xff0c;而且触摸屏具有坚固耐用、反应速度快、节省空间、易于交流等许多优点。利用这种技术&#xff0c;用户只要用手指轻轻地碰计算机显示屏…

PCL 高斯滤波(C++详细过程版)

目录 一、概述二、代码实现三、结果展示1、滤波前2、滤波后3、对比PCL 高斯滤波(C++详细过程版)由CSDN点云侠原创,爬虫自重。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫。 一、概述 高斯滤波在PCL里有现成的调用函数,具体算法原理和实现代码见: