对于llama3.1 8B模型,FP32和BF16混合精度训练,用的是AdamW优化器,模型训练时占用显存分析

目录

为什么先不考虑激活值的显存占用

1. 模型参数

含义

计算

2. 梯度参数

含义

3. 优化器参数

含义

4. 较固定总显存占用

计算

详细解释

5. 激活值计算:

计算公式

插入数值

计算步骤

结论


显存主要被用在四个模块上:

  • 模型权重本身

  • 梯度

  • 优化器

  • 激活值

其中,

  • 静态:优化器状态、模型参数

  • 动态:激活值、梯度值

也就是说,我们其实没法特别准确的计算出我们实际运行时候的显存大小,如果在面试的时候,就可以忽略掉激活值的计算,梯度当做静态计算就好。

为什么先不考虑激活值的显存占用

在计算显存占用时,我们通常会区分模型参数、梯度参数和优化器状态的显存占用,以及激活值的显存占用。以下是具体原因:

  1. 模型参数、梯度参数和优化器状态

    • 这些部分的显存占用是相对固定的,取决于模型的大小和优化器的选择。
    • 在混合精度训练中,我们可以明确计算这些部分的显存占用。
  2. 激活值

    • 激活值的显存占用与批量大小(batch size)和序列长度(sequence length)密切相关,且在不同的训练任务和配置下变化较大。
    • 激活值的显存占用往往是动态的,取决于具体的训练过程和数据流动。

因此,在讨论显存占用时,我们通常会先计算固定部分(模型参数、梯度参数和优化器状态)的显存占用,而不考虑激活值的显存占用。这是因为激活值的显存占用是高度可变的,需要根据具体的训练配置进行动态调整。具体变化在最后简单介绍一下

在模型训练中,显存占用主要包括模型参数、梯度参数和优化器状态。对于LLaMA 3.1 8B模型,使用混合精度训练(FP32和BF16)和AdamW优化器时,显存占用的计算如下:

1. 模型参数

含义

模型参数是神经网络的权重和偏置等参数。对于8B参数的模型:

  • BF16(Brain Floating Point 16-bit):每个参数占用16位(2字节)
  • FP32(Floating Point 32-bit):每个参数占用32位(4字节)
计算

假设模型的所有参数都存储为BF16和FP32两种格式:

  • BF16:8B参数 * 2字节 = 16GB
  • FP32:8B参数 * 4字节 = 32GB

总的模型参数显存占用为: 16𝐺𝐵+32𝐺𝐵=48𝐺𝐵

2. 梯度参数

含义

梯度参数是用于反向传播更新模型参数的梯度值。在混合精度训练中,梯度通常以BF16格式存储:

  • BF16:8B参数 * 2字节 = 16GB

总的梯度参数显存占用为: 16𝐺𝐵

3. 优化器参数

含义

AdamW优化器需要存储额外的状态参数,包括一阶动量(momentum)和二阶动量(variance)。这些参数通常以FP32格式存储:

  • 一阶动量(FP32):8B参数 * 4字节 = 32GB
  • 二阶动量(FP32):8B参数 * 4字节 = 32GB

总的优化器参数显存占用为: 32𝐺𝐵+32𝐺𝐵=64𝐺𝐵

4. 较固定总显存占用

计算

不考虑激活值的情况下,总显存占用为: 48𝐺𝐵(模型参数)+16𝐺𝐵(梯度参数)+64𝐺𝐵(优化器参数)=128𝐺𝐵

详细解释
  1. 模型参数(48GB)

    • BF16:模型的所有参数以16位格式存储,占用16GB显存。
    • FP32:模型的所有参数以32位格式存储,占用32GB显存。
  2. 梯度参数(16GB)

    • BF16:用于反向传播的梯度参数以16位格式存储,占用16GB显存。
  3. 优化器参数(64GB)

    • 一阶动量(32GB):AdamW优化器的一阶动量参数以32位格式存储,占用32GB显存。
    • 二阶动量(32GB):AdamW优化器的二阶动量参数以32位格式存储,占用32GB显存。

总结来说,在LLaMA 3.1 8B模型的混合精度训练中,模型参数、梯度参数和优化器参数的显存占用分别为48GB、16GB和64GB,总计128GB,不考虑激活值的情况下。

5. 激活值计算:

要计算LLaMA 3.1 8B模型的激活值显存占用,我们需要知道以下信息:

  1. 批量大小(Batch Size, B)
  2. 序列长度(Sequence Length, S)
  3. 每层的输出维度(Hidden Size, d)
  4. 模型的层数(Number of Layers, L)
  5. 每个激活值元素的大小(Element Size, BF16为2字节,FP32为4字节)

假设以下典型配置(请注意,实际配置可能有所不同):

  • 批量大小 𝐵=32
  • 序列长度 𝑆=512
  • 每层的输出维度 𝑑=4096
  • 模型的层数 𝐿=80(假设LLaMA 3.1 8B有80层)
  • 使用BF16格式(每个元素2字节)
计算公式

激活值的总显存占用可以表示为: 显存占用=𝐵×𝑆×∑𝑖=1𝐿(𝑑𝑖×size_of_element)

对于具有相同输出维度 𝑑 的所有层,这个公式简化为: 显存占用=𝐵×𝑆×𝐿×𝑑×size_of_element

插入数值
  • 批量大小 𝐵=32
  • 序列长度 𝑆=512
  • 层数 𝐿=80
  • 每层输出维度 𝑑=4096
  • 每个元素大小(BF16) size_of_element=2字节

计算显存占用: 显存占用=32×512×80×4096×2字节

计算步骤
  1. 计算批量大小和序列长度的乘积: 32×512=16384

  2. 计算层数和输出维度的乘积: 80×4096=327680

  3. 将上述结果相乘并乘以元素大小: 16384×327680×2=10737418240字节

  4. 转换为GB: 10737418240字节=10GB

结论

在假设批量大小为32,序列长度为512,每层输出维度为4096,使用BF16格式的情况下,LLaMA 3.1 8B模型的激活值显存占用大约为10GB。

请注意,这只是一个估算,实际显存占用可能会因为其他因素(如模型具体架构、额外的缓存和中间结果存储等)有所不同。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/871875.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言基础(十一)

1、指针: C语言中的指针是一种非常重要的数据类型,可以直接访问和操作内存地址。指针存储变量的内存地址,而不是变量的值本身。通过使用指针,可以灵活地控制数据的存储和访问,实现复杂的数据结构如链表、树。 定义指…

Redis (day 3)

一、通过jedis连接数据库 1.首先导入依赖 <!-- https://mvnrepository.com/artifact/redis.clients/jedis --><dependency><groupId>redis.clients</groupId><artifactId>jedis</artifactId><version>5.1.0</version></de…

Mac系统安装Homebrew【已成功】

1、正常安装失败原因 1.1命令行安装失败 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)" 原因 没挂&#x1fa9c;&#xff0c;不过我挂了梯子安装很多次也还是失败&#xff0c;所以可能是网站原因 1.2、网…

MyBatis进阶-1-面向接口编程

通过 MyBatis 底层自动创建接口实现类&#xff0c;我们可以直接对接口的方法进行编程 若简单的 sql 语句可以使用注解的方式进行&#xff0c;复杂的查询建议使用 xml 文件编写语句 注解使用时直接在接口的方法上加上对应语句的注解即可&#xff0c;而使用 xml 需要在文件中的…

ES6解构赋值详解;全面掌握:JavaScript解构赋值的终极指南

目录 全面掌握&#xff1a;JavaScript解构赋值的终极指南 一、数组解构赋值 1、基本用法 2、跳过元素 3、剩余元素 4、默认值 二、对象解构赋值 1、基本用法 2、变量重命名 3、默认值 4、嵌套解构 三、复杂的嵌套结构解构 四、函数参数解构赋值 1、对象解构作为函…

Jenkins汉化配置详解

Window安装构建神器Jenkins Window安装构建神器Jenkins详细教程-CSDN博客DevOps&#xff0c;CI&#xff0c;CD&#xff0c;自动化简单介绍选择其他需要和Jenkins一起安装的服务&#xff0c;点击Next。https://blog.csdn.net/qq_37237487/article/details/141299623 登录进入J…

【机器学习】CNN的基本架构模块

&#x1f308;个人主页: 鑫宝Code &#x1f525;热门专栏: 闲话杂谈&#xff5c; 炫酷HTML | JavaScript基础 ​&#x1f4ab;个人格言: "如无必要&#xff0c;勿增实体" 文章目录 CNN的基本架构模块1. 引言2. 卷积层2.1 基本原理2.2 卷积层的特性2.3 卷积层的超…

SQL,解析 json

Google BigQuery数据库的data表存储了若干多层的Json串&#xff0c;其中一条形如&#xff1a; [{"active":true,"key":"key1","values":[{"active":true,"value":"value1"}]},{"active":tru…

Java巅峰之路---进阶篇---面向对象(二)

Java巅峰之路---进阶篇---面向对象&#xff08;二&#xff09; 多态介绍多态调用成员的特点多态的优势、弊端以及解决方案综合练习 包和final包的介绍使用其他类的规则&#xff08;导包&#xff09;final关键字final的用途常量 权限修饰符和代码块权限修饰符的介绍四个权限修饰…

Halo个人博客Docker部署结合内网穿透为本地站点配置公网地址远程访问

文章目录 前言1. Docker部署Halo1.1 检查Docker版本如果未安装Docker可参考已安装Docker步骤&#xff1a;1.2 在Docker中部署Halo 2. Linux安装Cpolar2.1 打开服务器防火墙2.2 安装cpolar内网穿透 3. 配置Halo个人博客公网地址4. 固定Halo公网地址 前言 本文主要介绍如何在Cen…

C#学习第二节课 ,伤害计算

伤害计算 我一直好奇游戏的伤害计算是怎么计算并输出的,这第二节课利用学过的初级语法,Console.WriteLine,Console.ReadLine(),以及基础变量,int,string 和if 判断 组合,来实现打印一下伤害计算吧! 老规矩 先上结果图 代码区域 namespace hello01 {internal class Program …

望繁信科技荣膺上海市浦东新区博士后创新实践基地称号

近日&#xff0c;上海望繁信科技有限公司&#xff08;简称“望繁信科技”&#xff09;凭借在大数据流程智能领域的卓越表现&#xff0c;成功入选上海市浦东新区博士后创新实践基地。这一荣誉不仅是对望繁信科技创新能力和技术实力的高度认可&#xff0c;也标志着公司在推动产学…

EasyCVR视频汇聚平台构建远程安防监控:5大亮点解析,助力安防无死角

随着科技的飞速发展&#xff0c;远程安防监控系统已经成为现代社会中不可或缺的一部分&#xff0c;无论是在小区、公共场所还是工业领域&#xff0c;安防监控都发挥着至关重要的作用。而EasyCVR作为一款功能强大的视频监控综合管理平台&#xff0c;其在构建远程安防监控系统方面…

Qt 学习第六天:页面布局

如何设计页面&#xff1f; 有个类似沙盒模式的玩法&#xff0c;Qt Widget Designer可以更好的帮助我们设计页面 点击.ui文件进入 右上方可以看到四种常见的布局&#xff1a; 四种布局 &#xff08;一&#xff09;水平布局horizontalLayout&#xff1a;QHBoxLayout H 是 hori…

算法之工程化内容(3)—— Docker常用命令

目录 1. 配置docker镜像加速 2. 创建镜像docker-name 3. 查看正在运行的镜像 4. 拉取镜像 5. 运行镜像 6. 停止/启动指定 id 的容器 7. 删除指定 id 的镜像/容器 8. docker发布和部署 (推荐教程&#xff1a;&#x1f69a; 发布和部署 - Docker 快速入门) 1. 配置docke…

【蓝桥杯集训100题】scratch时间计算 蓝桥杯scratch比赛专项预测编程题 集训模拟练习题第26题

目录 scratch时间计算 一、题目要求 编程实现 二、案例分析 1、角色分析 2、背景分析 3、前期准备 三、解题思路 1、思路分析 2、详细过程 四、程序编写 五、考点分析 六、推荐资料 1、入门基础 2、蓝桥杯比赛 3、考级资料 4、视频课程 5、python资料 scratc…

【网络】UDP和TCP之间的差别和回显服务器

文章目录 UDP 和 TCP 之间的差别有连接/无连接可靠传输/不可靠传输面向字节流/面向数据报全双工/半双工 UDP/TCP API 的使用UDP APIDatagramSocket构造方法方法 DatagramPacket构造方法方法 回显服务器&#xff08;Echo Server&#xff09;1. 接收请求2. 根据请求计算响应3. 将…

AdMob聚合平台

Google Admob产品介绍 Google给开发者提供了3款用于流量变现的产品&#xff0c;分别是AdMob&#xff0c;通过应用内广告帮助App开发者变现&#xff1b;AdSense&#xff0c;通过网站广告帮助所有者变现&#xff1b;Google Ads Manager&#xff0c;通过全面管理和优化广告资源&a…

[ICLR-24] LRM: Large Reconstruction Model for Single Image to 3D

[pdf | proj | code] 本文首次提出大型重建模型&#xff08;Large Reconstruction Model, LRM&#xff09;&#xff0c;实现5s内对单图物体的3D重建。在128张A100&#xff08;40G&#xff09;上训练30 epochs。 LRM包含三个部分&#xff0c;具体框架如下&#xff1a; 图片编码…

亚马逊测评号生存法则:如何抵御亚马逊封号风波?

距离黑五购物狂欢节还剩99天&#xff0c;相信各位商家都在紧锣密鼓的筹备相关事宜&#xff0c;然而&#xff0c;亚马逊的封号风波再次席卷而来。那如何在这场风暴中让亚马逊矩阵测评号安全航行亦或是脱颖而出呢&#xff1f;本文将给你一个答案&#xff0c;并帮助你的亚马逊店铺…