大语言模型的工程技巧(二)——混合精度训练

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
混合精度训练的示例请参考如下链接:regression2chatgpt/ch11_llm/gpt2_lora_optimum.ipynb

本文将讨论如何利用混合精度训练(Mixed Precision Training)来减少内存的开销,特别是GPU内存的开销。这在大语言模型的训练当中是非常重要的。关于GPU的计算可以参考

  • 大语言模型的工程技巧(一)——GPU计算

关于大语言模型的讨论请参考:

  • 理解大语言模型(二)——从零开始实现GPT-2

内容大纲

  • 相关说明
  • 一、概述
  • 二、什么是混合精度训练?
  • 三、算法细节
  • 四、代码实现

一、概述

在人工智能领域,反向传播算法(计算参数梯度的算法)是非常重要的。而在进行反向传播计算时,必须将经过膨胀的计算图存储在内存中(如果使用GPU运算,那么将存储在GPU的专用内存中)。然而,这种存储量相当庞大,在整个计算图的存储结构中,数值存储占据了最大的比例。这些数值包括各个节点的计算结果(来自向前传播的输出),以及相应的梯度(这些梯度是来自反向传播的结果)。虽然梯度累积技术可以通过分解计算图来限制计算图的膨胀,从而降低内存的使用,但面对庞大的模型时,即便是单个数据点的计算图,其所需的内存都是巨大的。例如,大语言模型的参数数量可能高达数十亿甚至上百亿。

二、什么是混合精度训练?

为了解决这个具有挑战性的问题,需要采取额外的优化策略来降低内存的使用。在深入探讨这些策略之前,我们需要更详细地了解数字在计算机中的存储方式。一般而言,数值计算结果使用32位浮点数(需要4字节来存储,使用32位的二进制的方式表示)存储。这种存储方式被称为单精度浮点数。那么,如果使用16位二进制数表示一个数值,会产生什么影响呢?

这种方法的好处之一是能够立即减少所需的存储空间,同时提升计算速度。然而,这种方法也存在一个明显的缺陷,即能够表示的数值范围受限。为了便于讨论,下面以能够表示的最小正数为例。使用16位浮点数,能够表示的最小正数是 2 − 24 2^{-24} 224(相比之下,32位浮点数能够表示的最小正数为 2 − 149 2^{-149} 2149)。当实际的数值小于这个阈值时,计算机会错误地将其视作0,这就是浮点数下溢(Underflow)。

为了尽可能地减少这类错误的发生,可以混合精度训练(Mixed Precision Training)算法,顾名思义,它是指在模型训练过程中使用不同的数值精度来处理不同部分的计算。

三、算法细节

这一算法包含两个主要部分。

  1. 精度分层处理:在这种训练中,模型本身(模型参数)依然使用32位浮点数进行存储,参数更新过程也使用32位浮点数。在模型的向前传播和反向传播过程中,转而使用16位浮点数进行计算。具体情况如图1所示。

图1

图1

  1. 引入比例因子(Scale Factor):在数学上,要防止浮点数下溢是相当容易的,只需要将模型损失乘以一个较大的常数n,该常数也被称为比例因子。根据链式法则,这将导致所有节点的梯度都增大n倍。这种方法确保了梯度落入16位浮点数表示的范围,从而解决浮点数下溢问题。在使用这些梯度进行参数更新时,需要将引入的缩放移除,也就是将梯度除以n。将这个过程与精度分层处理相结合,如图2所示。

图2

图2

混合精度训练方法的优势在于,在保持适当的模型表示能力的同时,显著降低了内存开锁。通过将高精度的32位浮点数与16位浮点数的计算相结合,在不牺牲模型性能的前提下,显著减少内存需求,使计算机能够处理更大规模的模型和数据集。

四、代码实现

在实际应用中,PyTorch已经提供了相应的封装函数,分别是torch.cuda.amp.autocast和torch.cuda.amp.GradScaler。其中autocast实现的是第一部分——精度分层处理;GradScaler实现的是第二部分——引入比例因子。借助这两个工具,在优化算法中使用混合精度训练就变得很容易了。示意代码如下:

# 常规的模型训练实现
for epoch in range(0): 
    for input, target in zip(data, targets):
        # 启动混合精度训练
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)

        # 在触发反向传播之前,启动缩放因子
        scaler.scale(loss).backward()

        # 更新模型参数
        scaler.step(opt)
        scaler.update()

        opt.zero_grad()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635126.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue.js状态管理和服务端渲染

状态管理 vuejs状态管理的几种方式 组件内管理状态:通过data,computed等属性管理组件内部状态 父子组件通信:通过props和自定义事件实现父子组件状态的通信和传递 事件总线eventBus:通过new Vue()实例,实现跨组件通…

个人博客网站开发笔记2

文章目录 前言p2 hexo安装与使用安装 Nodejs安装 GitGit Bash的使用,代码克隆Clone p3 写作一级标题二级标题三级标题四级标题五级标题六级标题 前言 现在继续看教程 p2 hexo安装与使用 link 啊有点难受,开幕就是需要自己先安装Nodejs和Git&#xff…

git使用介绍

一、为什么做版本控制(git是版本控制工具) 为了保留之前所以的版本,以便回滚和修改 二、点击安装 三、基础操作 1、初步认识 想要让git对一个目录进行版本控制需要以下步骤: 进入要管理的文件夹进行初始化命令 git init管理…

el-table 组件实现 “合并单元格 + N行数据小计” 功能

目录 需求 - 要实现的效果初始代码代码升级(可供多个表格使用)CommonTable.vue 子组件 使用子组件1 - 父组件 - 图1~图3使用效果展示 使用子组件2 - 父组件 - 图4使用效果展示 注意【代码优化 - 解决bug】 需求 - 要实现的效果 父组件中 info 数据示例 …

Redis篇 浅谈分布式系统

分布式系统 一. 单机架构二.分布式系统引入三.引入更多的应用服务器四.读写分离五.引入缓存服务器六. 将数据库服务器拆分七.微服务架构 一. 单机架构 单机架构,就是用一台服务器,完成所有的工作. 这时候就需要我们引入分布式系统了. 分布式系统是什么含义呢?就是由一台主机服…

MySQL实战——主从异步复制搭建(一主一从)

一、搭建前的准备 主库 192.168.1.76 从库 192.168.1.77 二、搭建 1、编辑配置文件 vi /etc/my.cnf 主库 [mysqld] log-binmysql-bin server-id1 从库 [mysqld] server-id2 2、在主库创建复制用户 create user repl192.168.1.77 identified by repl123; grant replic…

9、QT—SQLite使用小记

前言 开发平台:Win10 64位 开发环境:Qt Creator 13.0.0 构建环境:Qt 5.15.2 MSVC2019 64位 sqlite版本:sqlite3 文章目录 一、Sqlite是什么二、sqlite使用步骤2.1 下载2.2 安装2.3 使用 三、Qt集成sqlite33.1 关键问题3.2 封装sql…

C#, PCANBasicd.dll库读写CAN设备数据

PCAN-Basic是一个简单的 PCAN 系统编程接口。 通过 PCAN-Basic Dll,可以将自己的应用程序连接到设备驱动程序和 PCAN 硬件,以与 CAN 总线进行通信。支持C、C++、C#、Delphi、JAVA、VB、Python等语言。 PCAN-Basic库和驱动下载地址 ​ ​https://www.peak-system.com/filead…

【C#】未能加载文件或程序集“CefSharp.Core.Runtime.dll”或它的某一个依赖项。找不到指定的模块。

欢迎来到《小5讲堂》 这是《C#》系列文章,每篇文章将以博主理解的角度展开讲解。 温馨提示:博主能力有限,理解水平有限,若有不对之处望指正! 目录 背景错误提示分析原因解决方法Chromium知识点相关文章 背景 最近在使…

LeetCode 131题详解:高效分割回文串的递归与动态规划方法

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容,和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣! 推荐:数据分析螺丝钉的首页 格物致知 终身学习 期待您的关注 导航: LeetCode解锁100…

Shell编程之条件判断语句

目录 一、条件判断 1、test命令 2、文件测试 3、整数值比较 4、字符串判断 5、逻辑测试 二、if语句 1、if单分支语句 2、双分支语句 3、多分之语句 4、case 分支语句 一、条件判断 Shell环境根据命令执行后的返回状态值(echo $?)来判断是否执行成…

力扣刷题---1748.唯一元素的和【简单】

题目描述 给你一个整数数组 nums 。数组中唯一元素是那些只出现 恰好一次 的元素。 请你返回 nums 中唯一元素的 和 。 示例 1: 输入:nums [1,2,3,2] 输出:4 解释:唯一元素为 [1,3] ,和为 4 。 示例 2:…

基于BERT的医学影像报告语料库构建

大模型时代,任何行业,任何企业的数据治理未来将会以“语料库”的自动化构建为基石。因此这一系列精选的论文还是围绕在语料库的建设以及自动化的构建。 通读该系列的文章,犹如八仙过海,百花齐放。非结构的提取无外乎关注于非结构…

电路笔记 :元器件焊接相关 酒精灯松香浴加热取芯片

记录一下只使用松香和小火源加热(如酒精灯、小蜡烛)从电路板中取芯片。 过程 多放松香 让松香淹没芯片尽量均匀加热,等芯片旁边的松香开始从芯片里冒细小的“泡泡”,就差不多了 注:这种方法也可以用于焊接&#xff0…

UBUNTU22.04无法安装nvidia-driver-550 依赖于 nvidia-dkms-550 (<= 550.54.15-1)

类似的报错信息,就是卡在了nvidia-dkms-550无法安装 Loading new nvidia-550.40.07 DKMS files… Building for 6.5.0-15-generic Building for architecture x86_64 Building initial module for 6.5.0-15-generic ERROR: Cannot create report: [Errno 17] File e…

VLAN创建及配置

V-- 虚拟 LAN ---局域网 ---地理覆盖范围较小的网络 MAN ---城域网 WAN ---广域网 VLAN ---虚拟局域网 --- 交换机和路由器协同工作后,将原先的一个广播域,逻辑上切分为多个 第一步:创建VLAN [Huawei]display vlan---查看VLAN信息 VID -- VLAN ID ----…

DNS域名解析与智能选路

要开始访问公网了!! 你在访问百度的时候,你也不知道百度的IP地址是啥,你只知道他的域名是baidu AD这台设备可以做入站的负载平衡,AD来选择你访问的时候是用联通网还是电信网,避免卡顿 pc并不会域名解析&…

[算法] 优先算法(二): 双指针算法(下)

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏:🍕 Collection与数据结构 (91平均质量分)https://blog.csdn.net/2301_80050796/category_12621348.html?spm1001.2014.3001.5482 🧀Java …

Python TCP编程简单实例

客户端:创建TCP链接时,主动发起连接的叫做客户端 服务端:接收客户端的连接 连接其他服务器 可以通过tcp连接其他服务器。 示例: import socket# 1.创建一个socket # 参数1:指定协议 AF_INET(ipv4&#…

ftp是什么,ftp能做什么,ftp有什么用 -----在Windows搭建ftp服务器

大家好,我是风屿,今天教大家如何从零开始搭建一台属于自己的ftp,本期教大家搭建Windows客户端的,后面是linux的 首先第一步要有一台联网的Windows电脑 1打开控制面板,找到程序,点击打开或关闭Windows功能…