混合精度训练说明

什么是混合精度训练?混合精度训练有什么用?
这里总结一下。
本文总结自kapathy的build gpt2

通常在训练过程中,model里面的数据默认都是torch.float32类型,
也就是用32bit的float型数据表示变量。
比如特征提取中提取的特征,描述子,网络的参数等,显示是torch.float32.

32bit的float带来的影响是比如拿它和int8比,占用内存会比较多,计算量会比较大。
所以训练时也就会比较慢,我们可以考虑丢弃一些精度来达到性能的提升。

看一下GPU支持的类型,比如nv a100吧,

在这里插入图片描述
可以看到实际上GPU是支持到float64的,但实际上模型训练时不需要那么高的精度,也会导致训练速度变慢,所以默认一般是float32.

精度下调的性能提升

float32最大可达到19.5TFLOPS,意味着可达到每秒19.5 trillion次float操作。
那么如果我们降低一些精度,比如降到TF32(见上图),每秒就可以达到156TFLOPS的性能,差不多达到了8倍性能提升。
如果进一步降低精度,到BFLOAT16, 就可以达到16倍性能提升。
注:右边*的性能指的是用稀疏化。

int8一般用于推理而不是训练。因为int8是均匀的空间,而训练时activation, weight都是正态分布的, 训练时要用float。

另一方面, 精度降低之后占的内存少,易于搬运,这涉及到memory bandwidth和模型的memory. 可以参考图中的GPU memory bandwidth.
解释一下memory bandwidth, 一般情况下data需要搬到GPU上然后运算,运算完再搬回去(涉及到GPU内存),但是受限于这个bandwidth, 明明有多余算力,但是data还没搬进来,就需要等,导致利用度不够高。然而如果你降低了精度,数据占的内存就会变小,一次就可以搬运更多,那么每次参与计算的就会更多,在这个搬运上也会提升性能。

小结一下,适当降低精度会让计算量减少从而提升性能,另一方面,会让数据占的内存减少,每次搬运的数据更多,从而在有限的memory bandwidth上达到性能提升。

图上有个名词叫tensor core, 现在介绍一下什么是tensor core.
它是a100中的一个instruction, 它做的事情是4x4的矩阵乘法。
在这里插入图片描述

矩阵的乘法会broke up成这些4x4的矩阵乘。
比如在transformer中很多linear layer就需要矩阵乘法,特别是最后一层的classify layer是一个大矩阵乘。
矩阵乘就通过tensor core来加速。

不同精度的数据

TF32
在这里插入图片描述
通过这个图可以看到TF32和FP32表示的范围是一样的,最左边的sign是符号位,中间8位range表示数字可表达的范围,它们是一样的,区别就是TF32舍弃了一些小数点位。整个只有19bit, 而不是32bit。
这些是在硬件上完成的,pytorch代码上是不可见的。

为什么叫混合精度呢,是因为你的input还是fp32, output还是fp32, 但是在内部计算上,后面的bit被舍弃了,为提升性能降低了一些精度。所以结果会类似是一个近似结果,但你会几乎看不出差别。
虽然说明书上写的用TF32会有8x性能提升,但实际上那只是矩阵乘的时候用了TF32, 其他部分仍然是FP32, 另外还受限于memory bandwidth, 所以实际上大概率是达不到的。

但是注意一点,这是a100支持的TF32,有的GPU可能不支持。

说了这么多,到底怎么用TF32训练呢。
只需要一行code. 用到了torch.set_float32_matmul_precision, 有兴趣的可以查看官方文档。它有一些参数,“highest”, ”high“,”medium“等, 其中"high"就是TF32. 默认是"highest", 也就是float32.

torch.set_float32_matmul_precision('high')

#model = GPT(..)
#model.to(device)
#training process

BFLOAT16
每个float只有16bit.
它和FP16有什么区别呢,还是看上面的图,它的range和float32是一样的,比FP16表示的范围要宽,只不过精度进一步被cut.
另外当你用FP16训练时,由于它表示的范围比FP32要小,所以你还要做gradient scaling操作。用BF16不需要做gradient scaling.

具体怎么用BF16呢,你可以参考torch.autocast, 但你不需要考虑gradient scaling的部分。
官方文档上有说,autocasting时,不要在model或input处call half() or bfloat16()。
你只能在forward和计算loss处用BF16.
具体如下, 只需要加一句torch.autocast:

model = GPT(..)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(iterations):
     x, y = train_loader.next_batch()
     x, y = x.to(device), y.to(device)
     optimizer.zero_grad()
     #用BF16
     with torch.autocast(device_type=device, dtype=torch.bfloat16):
         logits, loss = model(x, y)
     loss.backward()
     optimizer.step()

但是你看transformer里面的embedding table的weight, 仍然是float32, 具体哪些模块能cast到BF16,你可以参考官方文档里面的CUDA Ops that can autocast to float16. 需要矩阵乘的时候会cast, 很多操作仍然会保持float32不变,比如norm。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/940871.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

draw.io 导出svg图片插入word后模糊(不清晰 )的解决办法

通常我们将图片从draw.io导出为svg格式后插入word, 会发现字体不清晰,特别是使用宋体时,折腾了半天,得到如下办法: 方法1: 在draw.io中导出pdf文件,使用 PDF转SVG转换器 - SVGConverter 将其转换为svg, 完美呈现。 …

ARM学习(38)多进程多线程之间的通信方式

ARM学习(38)ARM学习(38)多进程多线程之间的通信方式 一、问题背景 笔者在调试模拟器的时候,碰到进程间通信的问题,一个进程在等另外一个进程ready的时候,迟迟等不到,然后通过调试发现,另外一个进程变量已经变化了,但是当前进程变量没变化,需要了解进程间通信的方式…

【动手学运动规划】 5.2 数值优化基础:梯度下降法,牛顿法

朕四季常服, 不过八套. — 大明王朝1566 道长 🏰代码及环境配置:请参考 环境配置和代码运行! 上一节我们介绍了数值优化的基本概念, 让大家对最优化问题有了基本的理解. 那么对于一个具体的问题, 我们应该如何求解呢? 这一节我们将介绍几个基本的求解…

24-12-22 pytorch学习 基础知识 帝乡明日到,犹自梦渔樵。

文章目录 pytorch学习 基础知识pytorch学习(1) Tensors1.1 初始化Tensor1.2 Tensor 的属性1.3 Tensors 的操作1.4 与 NumPy 的桥梁1.4.1 Tensor 到 NumPy 数组1.4.2 NumPy 数组 到 Tensor pytorch学习(2) 数据集和数据加载器2.1 加载一个数据集2.2 迭代和可视化数据集2.3 为你的…

Linux网络功能 - 服务和客户端程序CS架构和简单web服务示例

By: fulinux E-mail: fulinux@sina.com Blog: https://blog.csdn.net/fulinus 喜欢的盆友欢迎点赞和订阅! 你的喜欢就是我写作的动力! 目录 概述准备工作扫描服务端有那些开放端口创建客户端-服务器设置启动服务器和客户端进程双向发送数据保持服务器进程处于活动状态设置最小…

M3D: 基于多模态大模型的新型3D医学影像分析框架,将3D医学图像分析从“看图片“提升到“理解空间“的层次,支持检索、报告生成、问答、定位和分割等8类任务

M3D: 基于多模态大模型的新型3D医学影像分析框架,将3D医学图像分析从“看图片“提升到“理解空间“的层次,支持检索、报告生成、问答、定位和分割等8类任务 论文大纲理解1. 确认目标2. 分析过程(目标-手段分析)核心问题拆解 3. 实…

【102. 二叉树的层序遍历 中等】

题目: 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 示例 1: 输入:root [3,9,20,null,null,15,7] 输出:[[3],[9,20],[15,7]] 示例…

第四届电气工程与控制科学

重要信息 官网:www.ic2ecs.com 时间:2024年12月27-29日 简介 第四届电气工程与控制科学定于2024年12月27-29日在中国南京召开。主要围绕“电气工程“、”控制科学“、”机械工程“、”自动化”等主题展开,旨在为从电…

监控易在汽车制造行业信息化运维中的应用案例

引言 随着汽车制造行业的数字化转型不断深入,信息化类IT软硬件设备的运行状态监控、故障告警、报表报告以及网络运行状态监控等成为了企业运维管理的关键环节。监控易作为一款全面、高效的信息化运维管理工具,在汽车制造行业中发挥着重要作用。本文将结合…

大模型+安全实践之春天何时到来?

引子:距《在大模型实践旅途中摸了下上帝的脚指头》一文发布近一年,2024年笔者继续全情投入在大模型+安全上,深度参与了一些应用实践,包括安全大模型首次大规模应用在国家级攻防演习、部分项目的POC直到项目落地,也推动了一些场景安全大模型应用从0到3的孵化上市。这一年也…

大小端存储的问题

请你用C语言写一个简单的程序&#xff0c;判断你使用的主机是大端存储还是小端存储 #include <stdio.h> int main(){int x 0x11223344;char *p (char *)&x;if(0x44 *p){printf("小端\n");}else if(0x11 *p){printf("大端\n");}return 0; }

山景BP1048增加AT指令,实现单片机串口控制播放音乐(一)

1、设计目的 山景提供的SDK是蓝牙音箱demo&#xff0c;用户使用ADC按键或者IR遥控器&#xff0c;进行人机交互。然而现实很多场景&#xff0c;需要和单片机通信&#xff0c;不管是ADC按键或者IR接口都不适合和单片机通信。这里设计个AT指令用来和BP1048通信。AT指令如下图所示…

EMC VMAX/DMX 健康检查方法

近期连续遇到2个由于对VMAX存储系统没有做及时的健康检查&#xff0c;出现SPS电池故障没有及时处理&#xff0c;然后同一pair就是同一对的另外一个SPS电池再次出现故障&#xff0c;然后存储系统保护性宕机vault&#xff0c;然后业务系统挂掉的案例。 开始之前&#xff0c;先纠…

51c大模型~合集94

我自己的原文哦~ https://blog.51cto.com/whaosoft/12897659 #D(R,O) Grasp 重塑跨智能体灵巧手抓取&#xff0c;NUS邵林团队提出全新交互式表征&#xff0c;斩获CoRL Workshop最佳机器人论文奖 本文的作者均来自新加坡国立大学 LinS Lab。本文的共同第一作者为上海交通大…

移动魔百盒中的 OpenWrt作为旁路由 安装Tailscale并配置子网路由实现在外面通过家里的局域网ip访问内网设备

移动魔百盒中的 OpenWrt作为旁路由 安装Tailscale并配置子网路由实现在外面通过家里的局域网ip访问内网设备 一、前提条件 确保路由器硬件支持&#xff1a; OpenWrt 路由器需要足够的存储空间和 CPU 性能来运行 Tailscale。确保设备架构支持 Tailscale 二进制文件&#xff0c;例…

Webpack学习笔记(4)

1.缓存 可以通过命中缓存降低网络流量&#xff0c;是网站加站速度更快。 然而在部署新版本时&#xff0c;不更改资源的文件名&#xff0c;浏览器可能认为你没有更新&#xff0c;所以会使用缓存版本。 由于缓存存在&#xff0c;获取新的代码成为问题。 接下来将配置webpack使…

java抽奖系统(八)

9. 抽奖模块 9.1 抽奖设计 抽奖过程是抽奖系统中最重要的核⼼环节&#xff0c;它需要确保公平、透明且⾼效。以下是详细的抽奖过程设计&#xff1a; 对于前端来说&#xff0c;负责控制抽奖的流程&#xff0c;确定中奖的人员 对于后端来说&#xff1a; 接口1&#xff1a;查询完…

VulnHub靶场渗透之:Gigachad

环境搭建 VulnHub是一个丰富的实战靶场集合&#xff0c;里面有许多有趣的实战靶机。 本次靶机介绍&#xff1a;http://www.vulnhub.com/entry/gigachad-1,657/ 下载靶机ova文件&#xff0c;导入虚拟机&#xff0c;启动环境&#xff0c;便可以开始进行靶机实战。 虚拟机无法分…

解决Apache/2.4.39 (Win64) PHP/7.2.18 Server at localhost Port 80问题

配置一下apache里面的配置文件&#xff1a;httpd.conf 和 httpd.vhosts.conf httpd.conf httpd-vhosts.conf 重启服务 展示&#xff1a; 浏览器中中文乱码问题&#xff1a;

.NET重点

B/S C/S什么语言 B/S&#xff1a; 浏览器端&#xff1a;JavaScript&#xff0c;HTML&#xff0c;CSS 服务器端&#xff1a;ASP&#xff08;.NET&#xff09;PHP/JSP 优势&#xff1a;维护方便&#xff0c;易于升级和扩展 劣势&#xff1a;服务器负担沉重 C/S java/.NET/…