使用DPO微调大模型Qwen2详解

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。但传统的RLHF比较复杂,且还需要奖励模型,故DPO方法被提出,其将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。
且huggingface的trl库已经集成了dpo,使用起来非常方便。

本次以QWEN2(蹭热点),为例进行训练,特别注意一点,qwen没有bos_token,要设置一下,不然dpo train时会报错。
分别介绍单轮对话的DPO多轮对话的DPO,对应的数据集分别如下(均在huggingface):

  • 单轮:lvwerra/stack-exchange-paired
  • 多轮:trl-internal-testing/hh-rlhf-helpful-base-trl-style

通过DPO微调模型大概可以简单的分为两个步骤:
1、将数据处理成所需格式。
2、使用DPOTrainer进行训练

两种形式的dpo代码已集成至github上的大模型训练框架,支持框架中的deepspeed(多卡,单卡)或者python(单卡)启动模式,相应的lora、qlora也支持。并做了详细的使用解释及代码位置说明,可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

项目包括一个每个人都可以以此为基础构建自己的开源大模型训练框架流程、支持主流模型使用deepspeed进行Lora、Qlora、DPO等训练、主流模型的chat template模版、以及一些tricks的从零实现模块。欢迎大家star 共同学习!:

单轮对话构建DpoDataset

标准的DpoDataset数据集,最终的数据集对象应包含这3个条目。条目应命名为:

  • prompt
  • chosen
  • rejected

官方示例

单轮官方示例如下:

dpo_dataset_dict = {
    "prompt": [
        "hello",
        "how are you",
        "What is your name?",
        "What is your name?",
        "Which is the best programming language?",
        "Which is the best programming language?",
        "Which is the best programming language?",
    ],
    "chosen": [
        "hi nice to meet you",
        "I am fine",
        "My name is Mary",
        "My name is Mary",
        "Python",
        "Python",
        "Java",
    ],
    "rejected": [
        "leave me alone",
        "I am not fine",
        "Whats it to you?",
        "I dont have a name",
        "Javascript",
        "C++",
        "C++",
    ],
}

多轮示例为上述提到的数据集,大家可以大概看一下是长这个样子:
在这里插入图片描述

从头开始构建

比较简单的方式是套用官方给的示例,如下所示,只需要将数据集映射为上述我们提到的prompt、chosen、rejected格式,此时传递给DPOTrainer的数据是未编码之前的,DPOTrainer中会自动的给我们进行编码。注意下面并没有添加对应模型的chat template,根据不同模型的template可以在return_prompt_and_responses中自行添加即可。

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)


dpo_trainer = DPOTrainer(
    model, # 经 SFT 的基础模型
    model_ref, # 一般为经 SFT 的基础模型的一个拷贝
    beta=0.1, # DPO 的温度超参
    train_dataset=dataset, # 上文准备好的数据集
    tokenizer=tokenizer, # 分词器
    args=training_args, # 训练参数,如: batch size, 学习率等
)

为了便于我们理解数据处理细节及进行一些魔改操作,我们可以从头自己构建一个DpoDataset。
首先,深入DPOTrainer源码可以看到其数据处理操作主要是在tokenize_row函数,如下所示,
在这里插入图片描述
最终返回的是一个batch字典字段,代码部分如下所示:
在这里插入图片描述
在这里插入图片描述
最终返回的字段为:

dict(
            prompt_input_ids,
            prompt_attention_mask,
            chosen_input_ids,
            chosen_attention_mask,
            chosen_labels,
            rejected_input_ids,
            rejected_attention_mask,
            rejected_labels,
        )

主要的__getitem__代码如下所示:

    def __getitem__(self, item):
        data = self.data_list[item]
        data = json.loads(data)  # 将json格式转换为python字典
        prompt =  data['prompt']
        chosen = data['chosen']
        rejected = data['rejected']
        # 对prompt进行编码
        prompt = self.user_format.format(content=prompt, stop_token=self.tokenizer.eos_token)
        if self.system_format is not None:
            system = self.system
            if system is not None:
                system_text = self.system_format.format(content=system)
                input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
                prompt_input_ids = input_ids + self.tokenizer.encode(prompt, add_special_tokens=False)
        else:
            prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)



        # 进行回答的input id编码
        chosen = self.assistant_format.format(content=chosen, stop_token=self.tokenizer.eos_token)
        rejected = self.assistant_format.format(content=rejected, stop_token=self.tokenizer.eos_token)

        chosen_input_ids = self.tokenizer.encode(chosen, add_special_tokens=False)
        rejected_input_ids = self.tokenizer.encode(rejected, add_special_tokens=False)

        # 对最大长度进行截断
        longer_response_length = max(len(chosen_input_ids), len(rejected_input_ids))
        # keep end 对prompt截断
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            max_prompt_length = max(self.max_prompt_length, self.max_seq_length - longer_response_length)
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        # 如果还不符合则回答截断
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            chosen_input_ids = chosen_input_ids[: self.max_seq_length - len(prompt_input_ids)]
            rejected_input_ids = rejected_input_ids[: self.max_seq_length - len(prompt_input_ids)]

        chosen_labels = [-100] * len(prompt_input_ids) + chosen_input_ids
        chosen_input_ids = prompt_input_ids + chosen_input_ids
        rejected_labels = [-100] * len(prompt_input_ids) + rejected_input_ids
        rejected_input_ids = prompt_input_ids + rejected_input_ids
        assert len(chosen_labels) == len(chosen_input_ids)
        assert len(rejected_labels) == len(rejected_input_ids)

        inputs = dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=[1] * len(prompt_input_ids),
            chosen_input_ids=chosen_input_ids,
            chosen_attention_mask=[1] * len(chosen_input_ids),
            chosen_labels=chosen_labels,
            rejected_input_ids=rejected_input_ids,
            rejected_attention_mask=[1] * len(rejected_input_ids),
            rejected_labels=rejected_labels,
        )
        return inputs

适配DPOTrainer

构建完dataset后要适配DPOTrainer,可以看到其需要使用dataset进行一个map操作,这也就是DPOTrainer自动给我们处理数据的入口。
在这里插入图片描述
在我们自建的Dataset类中添加一个map函数映射会self即可

    def map(self, func, **kwargs):
        return self

多轮对话构建DpoDataset

多轮对话构建我们这里就不自己去写了,直接采用DPOTrainer中自带的数据处理即可。
部分代码如下所示:

        if tokenizer.chat_template is None:
            tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
        train_dataset = load_dataset(data_files=args.train_data_path, path='json')

        def process(row):
            row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
            row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
            return row

        train_dataset = train_dataset.map(process)
        train_dataset = train_dataset['train']
        return train_dataset

完整代码集成至github项目中,具体可参见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

开始Qwen2-8B 多轮和单轮DPO训练

使用DPOTrainer即可开始训练

trainer = DPOTrainer(
            model,
            ref_model=None,
            args=train_args,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config
        )
dpo_trainer.train()
dpo_trainer.save_model()

总结

两种形式的dpo代码已集成至github上的大模型训练框架,并做了详细的使用解释及代码位置说明,支持框架中的deepspeed(多卡,单卡)或者python(单卡)启动模式,相应的lora、qlora也支持。可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/697524.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OSPF LSA头部详解

LSA概述 LSA是OSPF的本质 , 对于网工来说能否完成OSPF的排错就是基于OSPF的LSDB掌握程度 . 其中1/2类LAS是负责区域内部的 类似于设备的直连路由 . 加上对端的设备信息 3 类LSA是区域间的 指的是Area0和其他Area的区域间关系 , 设计多区域的初衷就是避免大型OSPF环境LSA太多…

AMD在行动:揭示应用程序跟踪和性能分析的力量

AMD in Action: Unveiling the Power of Application Tracing and Profiling — ROCm Blogs 导言 Rocprof是一款强大的工具,设计用于分析和优化基于AMD ROCm平台上运行的HIP程序的性能,帮助开发者找出并解决性能瓶颈。Rocprof提供了多种性能数据&#x…

生成树协议(思科)

#交换设备 生成树协议(STP) 目的 1.理解生成树的原理 理解STP的选举过程 2.会配置STP 为什么只有交换机0的f0/1接口变成了阻塞状态? 在环形的交换网络中,如果所有的接口都通畅,会形成闭回路,造成网路风暴 一、STP…

【优选算法】字符串

一、相关编程题 1.1 最长公共前缀 题目链接 14. 最长公共前缀 - 力扣&#xff08;LeetCode&#xff09; 题目描述 算法原理 编写代码 // 解法一&#xff1a;两两比较 class Solution { public:string longestCommonPrefix(vector<string>& strs) {int k strs[0…

《QT实用小工具·七十》openssl+qt开发的P2P文件加密传输工具

1、概述 源码放在文章末尾 该项目实现了P2P的文件加密传输功能&#xff0c;具体包含如下功能&#xff1a; 1、 多文件多线程传输 2、rsaaes文件传输加密 3、秘钥随机生成 4、断点续传 5、跨域传输引导服务器 项目界面如下所示&#xff1a; 接收界面 发送界面 RSA秘钥生成…

CTF-PWN-kernel-UAF

文章目录 参考slub 分配器kmem_cache_cpukmem_cache_node[ ]冻结和解冻分配释放 fork绑核Kmalloc flag和slub隔离CISCN - 2017 - babydriver检查babtdriver_initstruct cdevalloc_chrdev_regioncdev_initownercdev_add_class_createdevice_create babyopenbabyreleasebabyreadb…

CleanMyMac2024最新免费电脑Mac系统优化工具

大家好&#xff0c;我是你们的好朋友——软件评测专家&#xff0c;同时也是一名技术博主。今天我要给大家种草一个超级实用的Mac优化工具——CleanMyMac&#xff01; 作为一个长期使用macOS的用户&#xff0c;我深知系统运行时间长了&#xff0c;缓存文件、日志、临时文件等都会…

数据库管理-第200期 身边的数据库从业者(20240610)

数据库管理200期 2024-06-10 数据库管理-第200期 身边的数据库从业者&#xff08;20240610&#xff09;首席-薛晓刚院长-施嘉伟邦德-王丁丁强哥-徐小强会长-吴洋灿神-熊灿灿所长-严少安探长-张震总结活动预告 数据库管理-第200期 身边的数据库从业者&#xff08;20240610&#…

**《Linux/Unix系统编程手册》读书笔记24章**

D 24章 进程的创建 425 24.1 fork()、exit()、wait()以及execve()的简介 425 . 系统调用fork()允许父进程创建子进程 . 库函数exit(status)终止进程&#xff0c;将进程占用的所有资源归还内核&#xff0c;交其进行再次分配。库函数exit()位于系统调用_exit()之上。在调用fo…

HTML开发的最主要的三种框架及Python实现

一、介绍 HTML本身是一种标记语言&#xff0c;用于构建网页的结构。然而&#xff0c;当谈到HTML开发框架时&#xff0c;通常指的是那些提供了额外的功能和工具&#xff0c;以帮助开发者更高效地构建网页和应用程序的框架。有三种流行的HTML开发框架&#xff1a; Bootstrap 简介…

基于JSP技术的网络视频播放器

你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果有相关需求&#xff0c;文末可以找到我的联系方式。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;IDEA/Eclipse、Navicat、Maven 系统展示 首页 管理员界面 用户界…

网络分析(ArcPy)

一.前言 GIS中的网络分析最重要的便是纠正拓扑关系&#xff0c;建立矫正好的网络数据集&#xff0c;再进行网络分析&#xff0c;一般大家都是鼠标在arcgis上点点点&#xff0c;今天说一下Arcpy来解决的方案&#xff0c;对python的要求并不高,具体api参数查询arcgis帮助文档即可…

渗透测试模拟实战(二)-BlueCMS平台

渗透测试 渗透测试是维护网络安全的重要组成部分&#xff0c;可以帮助组织识别并修复潜在的安全漏洞&#xff0c;减少被恶意攻击的风险。然而&#xff0c;进行渗透测试时必须遵守法律和道德规范&#xff0c;确保所有活动都在授权范围内进行。 环境部署&#xff1a; study2016、…

逆序队专题

逆序对的定义是&#xff0c;在一个数组中&#xff0c;对于下标 ( i ) 和 ( j )&#xff08;其中 ( i < j )&#xff09;&#xff0c;如果 ( a[i] > a[j] )&#xff0c;则称 ((a[i], a[j])) 为数组的一个逆序对。 换句话说&#xff0c;逆序对就是在数组中前面的元素大于后…

分布式事务AP控制方案(上)

分布式事务控制方案 本篇文章给出一种要求高可用性&#xff08;AP思想&#xff09;的分布式事务控制方案 下篇新鲜出炉&#xff1a;点我查看 分布式事务控制方案1、业务背景2、本地消息表的设计3、对消息表的操作4、任务调度5、任务流程控制的抽象类6、课程发布的实现类7、总…

【C++】C++ QT实现Huffman编码器与解码器(源码+课程论文+文件)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

Vue17-条件渲染

一、使用v-show属性做条件渲染 控制元素的显示和隐藏 v-show里面也能是表达式&#xff0c;只要表达式的值是boolean就行。 或者 当时结构还在&#xff1a; 二、使用v-if属性做条件渲染 结构也不在了 三、示例 方式一&#xff1a; 方式二&#xff1a; 当元素有很高的切换频率&am…

机器学习实验----支持向量机(SVM)实现二分类

目录 一、介绍 (1)解释算法 (2)数据集解释 二、算法实现和代码介绍 1.超平面 2.分类判别模型 3.点到超平面的距离 4.margin 间隔 5.拉格朗日乘数法KKT不等式 (1)介绍 (2)对偶问题 (3)惩罚参数 (4)求解 6.核函数解决非线性问题 7.SMO (1)更新w (2)更新b 三、代…

我在得物的这两年

写在前面 这篇文章非常简单&#xff0c;和大家简单聊聊我在得物的这两年&#xff0c;也是从学生到社会人的这两年。 我是2022年的6月加入得物实习&#xff0c;负责某个业务中台的后端研发&#xff0c;那一年我21岁&#xff0c;还在读大三&#xff0c;还在迷茫未来是读研还是工…

nw.js 如何调用activeX控件 (控件是C++编写的dll文件)

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…