MoE模型性能还能更上一层楼?一次QLoRA微调实践

Fine-Tuning Mixtral 8x7B with QLoRA:Enhancing Model Performance 🚀

编者按:最近,混合专家(Mixture of Experts,MoE)这种模型设计策略展现出了卓越的语言理解能力,如何在此基础上进一步提升 MoE 模型的性能成为业界热点。

本文作者使用一种名为 QLoRA 的方法,通过量化和 LoRA 技术对 MoE 模型 Mixtral-8x7B 进行微调,以期大幅提高其性能。

作者详细阐明这种方法的诸多优势,包括显著增强 MoE 模型的理解生成能力、计算效率更高等。文中还逐步介绍了使用 QLoRA 微调 Mixtral-8x7B 的全过程。

本文探索了使用 QLoRA 推动 MoE 模型的性能改进这一技术方案。期待未来更多关于 MoE 模型的性能改进方案出现!

一、简介

目前整个业界都希望经过优化的模型能够表现出卓越的性能,这一追求不断推动着自然语言理解(natural language understanding)的发展。Mixtral-8x7B Mixture of Experts(MoE)模型就是其中之一,该模型在各种基准测试(benchmarks)中表现出优于同类产品的性能,尤其是优于 Llama 2 70B。

本教程采用一种名为 QLoRA 的创新方法对 Mixtral-8x7B 模型进行微调,该方法结合了量化(quantization)和 LoRA(Local Representation Adaptation)技术。期望通过这两种技术的结合来进一步增强Mixtral-8x7B模型的能力。

image.png

Source: Mixtral[1]

二、相关定义

● Mixtral 8x7B:一种混合专家模型,因其架构设计在自然语言处理任务中表现出色而闻名。

● QLoRA:Quantization 和 LoRA 技术相结合的缩写。量化涉及降低模型权重的精度,从而优化内存使用并加快计算速度。LoRA 可调整模型中的局部表征,增强模型对特定上下文的理解。

三、优势

● 增强性能:使用 QLoRA 对 Mixtral 8x7B 进行微调,可提高其性能,从而更好地理解和生成各种领域的文本。

● 能效比高:量化的整合降低了内存需求和计算复杂度,使模型更节省资源。

● 针对垂直领域进行微调:通过微调,该模型可针对特定任务进行定制,从而提高其在特定领域的准确性和相关性。

四、代码实现说明

本教程在 Notebook 环境中(译者注:使用Jupyter notebook 或白海IDP自研notebook)使用 Python。整个过程包括使用 "bitsandbytes "库加载 4 位精度的大型 Mixtral 模型。随后,在训练阶段使用 Hugging Face 的 PEFT 库实现 LoRA。

4.1 步骤 1:安装相关库

# You only need to run this once per machine, even if you stop/restart it
!pip install --upgrade pip
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets matplotlib

4.2 步骤 2:设置 Accelerator

from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

4.3 步骤 3:使用Weights & Biases追踪性能指标

!pip install -q wandb -U

import wandb, os
wandb.login()

wandb_project = "viggo-finetune"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

4.4 步骤 4:加载数据集

from datasets import load_dataset

dataset_name = "databricks/databricks-dolly-15k"

train_dataset = load_dataset(dataset_name, split="train[0:800]")
eval_dataset = load_dataset(dataset_name, split="train[800:1000]")

4.5 步骤 5:加载基础模型

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mixtral-8x7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")

# Tokenization 
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

def tokenize(prompt):
    result = tokenizer(prompt)
    result["labels"] = result["input_ids"].copy()
    return result

def generate_and_tokenize_prompt(data_point):
    full_prompt = f"""Given a question and some additional context, provide an answer

    ### Target sentence:
    Question: {data_point['instruction']}
    Additional Context: {f"Here is some context: {data_point['context']}" if len(data_point["context"]) > 0 else ""}
    Response: [/INST] {data_point['response']}</s>"""

    tokenized_prompt = tokenizer(full_prompt)
    return tokenized_prompt

tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)

untokenized_text = tokenizer.decode(tokenized_train_dataset[1]['input_ids']) 
print(untokenized_text)

# Output
<s> Given a question and some additional context, provide an answer

    ### Target sentence:
    Question: Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?
    Additional Context: 
    Response: [/INST] The name of the third daughter is Alice</s></s>

4.6 步骤 6:获取数据集中各个样本长度的分布情况

import matplotlib.pyplot as plt

def plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset):
    lengths = [len(x['input_ids']) for x in tokenized_train_dataset]
    lengths += [len(x['input_ids']) for x in tokenized_val_dataset]
    print(len(lengths))

    # Plotting the histogram
    plt.figure(figsize=(10, 6))
    plt.hist(lengths, bins=20, alpha=0.7, color='blue')
    plt.xlabel('Length of input_ids')
    plt.ylabel('Frequency')
    plt.title('Distribution of Lengths of input_ids')
    plt.show()

plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)

image.png

Source: Image created by Author

4.7 步骤 7:在数据的左侧添加 padding ,以减少内存的使用

max_length = 320 # This was an appropriate max length for my dataset

# redefine the tokenize function and tokenizer

tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,  
    add_bos_token=True,  
)
tokenizer.pad_token = tokenizer.eos_token


def tokenize(prompt):
    result = tokenizer(
        prompt,
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)

untokenized_text = tokenizer.decode(tokenized_train_dataset[4]['input_ids']) 
print(untokenized_text)

# Output
<s> Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.
    This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute'].
    The attributes must be one of the following: ['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating', 'genres', 'player_perspective', 'has_multiplayer', 'platforms', 'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier']

    ### Target sentence:
    When did Virgin Australia start operating?
    Here is some context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.
    [/INST] Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s></s>
plot_data_lengths(tokenized_train_dataset, tokenized_val_dataset)

image.png

Source: Image created by Author

4.8 步骤 8:设置 LoRA

from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "w1",
        "w2",
        "w3",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

# Apply the accelerator. You can comment this out to remove the accelerator.
model = accelerator.prepare_model(model)

# Output
trainable params: 120350720 || all params: 23602952192 || trainable%: 0.5098968934945001

4.9 步骤 9:进行训练

import transformers
from datetime import datetime

if torch.cuda.device_count() > 1: # If more than 1 GPU
    model.is_parallelizable = True
    model.model_parallel = True

project = "databricks-dolly-finetune"
base_model_name = "mixtral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

tokenizer.pad_token = tokenizer.eos_token

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_val_dataset,
    args=transformers.TrainingArguments(
        output_dir=output_dir,
        warmup_steps=5,
        per_device_train_batch_size=1,
        gradient_checkpointing=True,
        gradient_accumulation_steps=4,
        max_steps=500,
        learning_rate=2.5e-5, 
        logging_steps=25,
        fp16=True, 
        optim="paged_adamw_8bit",
        logging_dir="./logs",        # Directory for storing logs
        save_strategy="steps",       # Save the model checkpoint every logging step
        save_steps=50,                # Save checkpoints every 50 steps
        evaluation_strategy="steps", # Evaluate the model every logging step
        eval_steps=50,               # Evaluate and save checkpoints every 50 steps
        do_eval=True,                # Perform evaluation at the end of training
        report_to="wandb",           # Comment this out if you don't want to use weights & baises
        run_name=f"{run_name}-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"          # Name of the W&B run (optional)
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

4.10 步骤 10:使用训练完毕的模型

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mixtral-8x7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,  # Mixtral, same as before
    quantization_config=bnb_config,  # Same quantization config as before
    device_map="auto",
    trust_remote_code=True,
    use_auth_token=True
)

eval_tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    add_bos_token=True,
    trust_remote_code=True,
)
from peft import PeftModel

ft_model = PeftModel.from_pretrained(base_model, "mixtral-databricks-dolly-finetune/checkpoint-100")
eval_prompt = """Given a question and some additional context, provide an answer

### Target sentence:
Question: When was Tomoaki Komorida born?
Here is some context: Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.

### Response:
"""

model_input = eval_tokenizer(eval_prompt, return_tensors="pt").to("cuda")

ft_model.eval()

with torch.no_grad():
    print(eval_tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True))

Given a question and some additional context, provide an answer

### Target sentence:
Question: When was Tomoaki Komorida born?
Here is some context: Komorida was born in Kumamoto Prefecture on July 10, 1981. After graduating from high school, he joined the J1 League club Avispa Fukuoka in 2000. Although he debuted as a midfielder in 2001, he did not play much and the club was relegated to the J2 League at the end of the 2001 season. In 2002, he moved to the J2 club Oita Trinita. He became a regular player as a defensive midfielder and the club won the championship in 2002 and was promoted in 2003. He played many matches until 2005. In September 2005, he moved to the J2 club Montedio Yamagata. In 2006, he moved to the J2 club Vissel Kobe. Although he became a regular player as a defensive midfielder, his gradually was played less during the summer. In 2007, he moved to the Japan Football League club Rosso Kumamoto (later Roasso Kumamoto) based in his local region. He played as a regular player and the club was promoted to J2 in 2008. Although he did not play as much, he still played in many matches. In 2010, he moved to Indonesia and joined Persela Lamongan. In July 2010, he returned to Japan and joined the J2 club Giravanz Kitakyushu. He played often as a defensive midfielder and center back until 2012 when he retired.

### Response:
Tomoaki Komorida was born on July 10, 1981.

五、结论

利用 QLoRA 对 Mixtral-8x7B 模型进行微调是自然语言处理 (NLP) 领域的一个重要进展,它将模型性能提升到了新的高度。这一缜密的过程融合了量化和 LoRA 等前沿技术,为超越基准(benchmarks)提供了一条稳健的途径,甚至在各种评估指标上超越了强大的 Llama 2 70B 模型。

本教程的核心在于使用QLoRA进行微调,利用bitsandbytes以4位精度实例化模型,并运用Hugging Face 🤗的PEFT库。该指南不仅概述了微调方法,还揭示了实践过程中可能遇到的问题,如OutOfMemory errors,为用户提供了精确的解决途径。

从本质上讲,该教程并非是一个技术指南,更像一个倡导模型微调最佳实践的指引。它倡导协作式微调,请邀请其他研究人员和从业者一同踏上推动语言理解模型发展的旅程。

前沿技术、详细的指导以及合作共赢的态度使得该教程对于NLP社区来说是一个非常重要且不可或缺的资源,期望能够引导 NLP 社区进一步提高模型性能,丰富理解能力。

Resources:

● Mixtral-8x7b[2]

● Thanks to Harper Carroll[2]

文中链接

[1]https://mistral.ai/news/mixtral-of-experts/

[2]https://huggingface.co/blog/mixtral

[3]https://twitter.com/HarperSCarroll

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/313270.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SPDK中常用的性能测试工具

本文主要介绍磁盘性能评估的方法&#xff0c;针对用户态驱动Kernel与SPDK中各种IO测试工具的使用方法做出总结。其中fio是一个常用的IO测试工具&#xff0c;可以运行在Linux、Windows等多种系统之上&#xff0c;可以用来测试本地磁盘、网络存储等的性能。为了和SPDK的fio工具相…

【RabbitMQ】RabbitMQ高级:如何保证消息可靠性

目录 概述异常捕获机制事务机制持久化存储机制发送端确认机制概述开启发布确认的方法单个发布确认批量发布确认异步发布确认 消费端确认机制消息限流消息幂等性处理 概述 前面学习了如何简单使用RabbitMQ&#xff0c;在实际使用RabbitMQ时&#xff0c;我们还需要考虑很多&…

Vue.js设计与实现阅读-2

Vue.js设计与实现阅读-2 1、前言2、框架设计的核心要素2、1 提升用户体验2、2 控制代码体积2、3 Tree-Shaking2、4 特性开关2、5 错误处理 1、前言 上一篇我们了解到了 命令式和声明式的区别&#xff0c;前者关注过程&#xff0c;后者关注结果了解了虚拟dom存在的意义&#x…

数据库SELECT语句

文章目录 一、检索数据二、排序检索三、过滤数据四、数据过滤4.1 组合WHERE子句1. AND操作符2. OR操作符3. 计算次序 4.2 IN操作符4.3 NOT操作符 五、用通配符过滤LIKE操作符1. 百分号&#xff08;%&#xff09;通配符2. 下划线&#xff08;_&#xff09;通配符 使用通配符的技…

盈利之道:下单前的必问之问

投资者在过去的交易经历中&#xff0c;通常都会面临所谓的“交易低谷”。交易低谷是指在交易过程中难以实现盈利或可能导致进一步亏损的阶段。这种面临损失或没有盈利的时期可能发生在任何人身上&#xff0c;无论是由于市场变化、投资者策略调整还是其他原因。为了应对这种情况…

CSS基础方法——引入方式、属性、基础选择器

CSS 主要用于设置 HTML 页面中的文本样式&#xff08;字体、大小、颜色、对齐方式……&#xff09;、图片样式&#xff08;宽高、边框样式、边距……&#xff09;以及版面的布局和外观显示样式。 1、CSS引入方式 行内样式 写在标签中&#xff0c;通常不使用&#xff0c;只做…

优惠券兑换码生成需求——事务同步回调问题分析

前段时间收到一个优惠券兑换码的需求&#xff1a;管理后台针对一个优惠券发起批量生成兑换码&#xff0c;这些兑换码可以导出分发到各个合作渠道&#xff08;比如&#xff1a;抖音、京东等&#xff09;&#xff0c;用户通过这些渠道获取到兑换码之后&#xff0c;再登录到我司研…

提升测试效率,轻松并行运行测试——探秘Pytest插件pytest-xdist

在软件开发中&#xff0c;测试是确保代码质量的重要一环。然而&#xff0c;随着项目规模的增大&#xff0c;测试用例的数量也随之增多&#xff0c;测试的执行时间可能成为一个瓶颈。为了解决这个问题&#xff0c;Pytest提供了丰富的插件生态系统&#xff0c;其中 pytest-xdist …

黑群晖6.x 7.x ABB Active Backup for Business 套件激活方法

注意事项&#xff1a; 要先下载安装好Active Backup for Business套件再操作。SN码在【控制面板】 - 【信息中心】 -【产品序列号】。建议复制到记事本内修改内容。群晖的https是默认的5001端口&#xff0c;如果你的https端口号换过请自行修改&#xff1a;5001 为当前的端口号…

spacedesk 变成黑白的分析

测试发现只要调整时间到2024 就会出现黑白而且是建立连接是才检测的&#xff0c;那么应该存在于R3部分的可能性大 IDA分析找到2024

[论文阅读]4DRadarSLAM: A 4D Imaging Radar SLAM System for Large-scale Environments

目录 1.摘要和引言&#xff1a; 2. 系统框架&#xff1a; 2.1 前端&#xff1a; 2.2 回环检测&#xff1a; 2.3 后端&#xff1a; 3.实验和分析&#xff1a; 4.结论 1.摘要和引言&#xff1a; 这篇论文介绍了一种名为“4DRadarSLAM”的新型4D成像雷达SLAM系统&#xff0…

RT-DETR优化:UNetv2多层次特征融合模块结合DualConv、GSConv

🚀🚀🚀本文改进:多层次特征融合(SDI)结合DualConv、GSConv模块等实现二次创新 🚀🚀🚀SDI 亲测在多个数据集能够实现涨点,同样适用于小目标检测 🚀🚀🚀RT-DETR改进创新专栏:http://t.csdnimg.cn/vuQTz 学姐带你学习YOLOv8,从入门到创新,轻轻松松搞定…

Vue、uniApp、微信小程序、Html5等实现数缓存

此文章带你实现前端缓存&#xff0c;利用时间戳封装一个类似于Redis可以添加过期时间的缓存工具 不仅可以实现对缓存数据设置过期时间&#xff0c;还可以自定义是否需要对缓存数据进行加密处理 工具介绍说明 对缓存数据进行非对称加密处理 对必要数据进行缓存&#xff0c;并…

太平洋产险海南分公司:春季爱车保养,就看这几点!

一年之计在于春&#xff0c;春天不仅是万物复苏的好时节&#xff0c;也是一年中非常适合汽车养护的季节。 刚刚过去的春节&#xff0c;汽车的使用频率大大增加&#xff0c;很多车主都准备对爱车进行一次全面保养。加上立春过后&#xff0c;天气渐暖&#xff0c;许多车主也计划开…

答题小程序源码系统:自带流量主广告位+视频激励广告 带完整的代码安装包以及搭建教程

随着互联网的迅速发展&#xff0c;各种应用程序层出不穷&#xff0c;而答题类小程序由于其独特的互动性和吸引力&#xff0c;成为了当前最热门的应用之一。答题小程序源码系统是一款基于微信小程序开发的源代码系统&#xff0c;它具有丰富的功能和灵活的定制性&#xff0c;可以…

搭建算法日志自检小系统

&#x1f952; 前言 目前演示的是一个工具&#xff0c;但如此&#xff0c;未来完成有潜力可以演变为一整套系统。 &#x1f451;现场人员自检失败表计点位教程V2.0 NOTE: 如果没有“logfiles-meter-tool“目录的请联系我们进行提供&#xff01; &#x1f447; 进入<dist>…

使用AutoDL云计算平台训练并测试Pytorch版本NeRF代码

文章目录 前言一、数据集及代码获取二、租用并设置服务器三、Pycharm远程开发四、训练并测试代码 前言 因为第一次在云服务器上跑代码&#xff0c;所以在这里记录一下。 一、数据集及代码获取 nerf-pytorch项目是 NeRF 的忠实 PyTorch 实现&#xff0c;它在运行速度提高 1.3 倍…

docker 利用特权模式逃逸并拿下主机

docker 利用特权模式逃逸并拿下主机 在溯源反制过程中&#xff0c;会经常遇到一些有趣的玩法&#xff0c;这里给大家分享一种docker在特权模式下逃逸&#xff0c;并拿下主机权限的玩法。 前言 在一次溯源反制过程中&#xff0c;发现了一个主机&#xff0c;经过资产收集之后&…

SSL证书与HTTPS的关系

SSL证书是一种数字证书&#xff0c;由权威的证书颁发机构颁发。它包含了一个公钥和有关证书所有者的一些信息&#xff0c;如名称、组织、邮箱等。SSL证书的主要作用是实现数据加密和身份验证&#xff0c;确保数据在传输过程中的安全性和完整性。 HTTPS是一种基于HTTP协议的安全…