为什么要梯度累积

文章目录

    • 梯度累积
      • 什么是梯度累积
      • 如何理解理解梯度累积
        • 梯度累积的工作原理
      • 梯度累积的数学原理
        • 梯度累积过程
        • 如何实现梯度累积
      • 梯度累积的可视化

梯度累积

什么是梯度累积

随着深度学习模型变得越来越复杂,模型的训练通常需要更多的计算资源,特别是在训练期间需要更多的内存。在训练深度学习模型时,在硬件资源有限的情况下,很难使用大批量数据进行有效学习。大批量数据通常可以带来更好的梯度估计,但同时也需要大量的内存。

梯度累积是一种巧妙的技术,它允许在不增加内存需求的情况下,有效地使用更大的批量数据来训练深度学习模型。

如何理解理解梯度累积

梯度累积本质上涉及将大批量划分为较小的子批量,并在这些子批量上累积计算出的梯度。这一过程模拟了使用较大批量训练的情况。

梯度累积的工作原理

以下是梯度累积过程的逐步分解:

  1. 分而治之:将你的硬件无法处理的大批量划分为更小的、可管理的子批量。
  2. 累积梯度:不是在处理每个子批量后更新模型参数,而是在几个子批量上累积梯度。
  3. 参数更新:在处理了预定义数量的子批量后,使用累积的梯度来更新模型参数。

这种方法使得模型能够利用大批量的稳定性和收敛性,而不必提高内存成本。

梯度累积的数学原理

在这里插入图片描述

梯度累积过程

在深度学习模型中,一个完整的前向和反向传播过程如下:

  • 前向传播:数据通过神经网络,层层处理后得到预测结果。

  • 损失计算:使用损失函数计算预测结果与实际值之间的差异。以平方误差损失函数为例:

    L ( θ ) = 1 2 ( h ( x k ) − y k ) 2 L(\theta) = \frac{1}{2} (h(x_k) - y_k)^2 L(θ)=21(h(xk)yk)2

    这里 L ( θ ) L(\theta) L(θ) 表示损失函数, θ \theta θ 代表模型参数, h ( x k ) h(x_k) h(xk) 是对输入 x k x_k xk 的预测输出, y k y_k yk 是对应的真实输出。

  • 反向传播:计算损失函数相对于模型参数的梯度(对上式求导):

    ∇ θ L ( θ ) = ( h ( x k ) − y k ) ⋅ ∇ θ h ( x k ) \nabla_\theta L(\theta) = (h(x_k) - y_k) \cdot \nabla_\theta h(x_k) θL(θ)=(h(xk)yk)θh(xk)

  • 梯度累积:在传统的训练过程中,每完成一个批次的数据处理后就会更新模型参数。而在梯度累积中,梯度不是立即用来更新参数,而是累加多个小批次的梯度:

    G = ∑ i = 1 n ∇ θ L i ( θ ) G = \sum_{i=1}^{n} \nabla_{\theta} L_i(\theta) G=i=1nθLi(θ)

    这里 G G G 是累积梯度, L i ( θ ) L_i(\theta) Li(θ) 是第 i i i 个batch的损失函数。

  • 参数更新:累积足够的梯度后,使用以下公式更新参数:

    θ = θ − η ⋅ G \theta = \theta - \eta \cdot G θ=θηG
    其中 l r lr lr 是学习率,用于控制更新的步长。

如何实现梯度累积

以下是在 PyTorch 中实现梯度累积的示例:

# 模型定义
model = ...
optimizer = ...

# 累积步骤数
accumulation_steps = 4

for epoch in range(num_epochs):
    optimizer.zero_grad()
    for i, (inputs, labels) in enumerate(dataloader):
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()

        # 只有在处理足够数量的子批量后才更新参数
        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    # 如果批量大小不是累积步数的倍数,确保在每个epoch结束时更新
    if (i + 1) % accumulation_steps != 0:
        optimizer.step()
        optimizer.zero_grad()

这个例子中,accumulation_steps 定义了在参数更新前需要累积的batch数量。

梯度累积的可视化

为了更好地理解梯度累积的影响,可视化可以非常有帮助。以下是一个例子,说明如何在神经网络中可视化梯度流,以监控梯度是如何被累积和应用的:

import matplotlib.pyplot as plt

# 绘制梯度流动的函数
def plot_grad_flow(named_parameters):
    ave_grads = []
    layers = []
    for n, p in named_parameters:
        if (p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean())
    plt.plot(ave_grads, alpha=0.3, color="b")
    plt.hlines(0, 0, len(ave_grads)+1, linewidth=1, color="k")
    plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical")
    plt.xlim(xmin=0, xmax=len(ave_grads))
    plt.xlabel("层")
    plt.ylabel("平均梯度")
    plt.title("网络中的梯度流")
    plt.grid(True)
    plt.show()

# 在训练过程中或训练后调用此函数以可视化梯度流
plot_grad_flow(model.named_parameters())

参考资料:

  1. Gradient Accumulation Algorithm

  2. Performing gradient accumulation with 🤗 Accelerate

  3. 梯度累加(Gradient Accumulation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/604555.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深度学习笔记_10YOLOv8系列自定义数据集实验

1、mydaya.yaml 配置方法 # 这里分别指向你训练、验证、测试的文件地址,只需要指向图片的文件夹即可。但是要注意图片和labels名称要对应 # 训练集、测试集、验证机文件路径,可以是分类好的TXT文件,也可以直接是图片文件夹路径 train: # t…

Litedram仿真验证(四):AXI接口完成板级DDR3读写测试(FPGA-Artix7)

目录 日常唠嗑一、仿真中遗留的问题二、板级测试三、工程获取及交流 日常唠嗑 接上一篇Litedram仿真验证(三):AXI接口完成仿真(FPGA/Modelsim)之后,本篇对仿真后的工程进行板级验证。 本次板级验证用到的开…

学成在线 - 第3章任务补偿机制实现 + 分块文件清理

7.9 额外实现 7.9.1 任务补偿机制 问题:如果有线程抢占了某个视频的处理任务,如果线程处理过程中挂掉了,该视频的状态将会一直是处理中,其它线程将无法处理,这个问题需要用补偿机制。 单独启动一个任务找到待处理任…

Layer1 公链竞争破局者:Sui 生态的全面创新之路

随着 Sui 生态逐渐在全球范围内树立起声望,并通过与 Revolut 等前沿金融科技平台合作,推广区块链教育与应用,Sui 生态的未来发展方向已成为业界瞩目的焦点。如今,Sui 的总锁定价值已攀升至 5.93 亿美元,充分展示了其在…

分布式架构的演技进过程

最近看了一篇文章,觉得讲的挺不错,就借机给大家分享一下。 早期应用:早期的应用比较简单,访问人数有限,大部分的开发单机就能完成。 分离模型:在业务发展后,用户数量逐步上升,服务器的性能出现瓶颈;就需要将应用和数据分开存储,避免相互抢占资源。 缓存模式:随着系…

历代著名画家作品赏析-东晋顾恺之

中国历史朝代顺序为:夏朝、商朝、西周、东周、秦朝、西楚、西汉、新朝、玄汉、东汉、三国、曹魏、蜀汉、孙吴、西晋、东晋、十六国、南朝、刘宋、南齐、南梁、南陈、北朝、北魏、东魏、北齐、西魏、北周、隋,唐宋元明清,近代。 一、东晋著名…

现身说法暑期三下乡社会实践团一个好的投稿方法胜似千军万马

作为一名在校大学生,去年夏天我有幸参与了学院组织的暑期大学生三下乡社会实践活动,这段经历不仅让我深入基层,体验了不一样的生活,更是在新闻投稿的实践中,经历了一次从传统到智能的跨越。回忆起那段时光,从最初的邮箱投稿困境,到后来智慧软文发布系统的高效运用,每一步都刻印…

顺序栈的操作

归纳编程学习的感悟, 记录奋斗路上的点滴, 希望能帮到一样刻苦的你! 如有不足欢迎指正! 共同学习交流! 🌎欢迎各位→点赞 👍 收藏⭐ 留言​📝既然选择了远方,当不负青春…

什么是DDoS攻击?DDoS攻击的原理是什么?

一、DDoS攻击概念 DDoS攻击又叫“分布式拒绝服务”(Distributed DenialofService)攻击,它是一种通过控制大量计算机、物联网终端或网络僵尸(Zombie)来向目标网站发送大量请求,从而耗尽其服务器资源,导致正常用户无法访…

WEB基础--JDBC操作数据库

使用JDBC操作数据库 使用JDBC查询数据 五部曲:建立驱动,建立连接,获取SQL语句,执行SQL语句,释放资源 建立驱动 //1.加载驱动Class.forName("com.mysql.cj.jdbc.Driver"); 建立连接 //2.连接数据库 Stri…

【3dmax笔记】026:挤出和壳修改器的使用

文章目录 一、修改器二、挤出三、壳 一、修改器 3ds Max中的修改器是一种强大的工具,用于创建和修改复杂的几何形状。这些修改器可以改变对象的形状、大小、方向和位置,以生成所需的效果。以下是一些常见的3ds Max修改器及其功能: 挤出修改…

Google Earth Engine谷歌地球引擎计算遥感影像在每个8天间隔内的多年平均值

本文介绍在谷歌地球引擎(Google Earth Engine,GEE)中,求取多年时间中,遥感影像在每1个8天时间间隔内的多年平均值的方法。 本文是谷歌地球引擎(Google Earth Engine,GEE)系列教学文章…

神经网络案例实战

🔎我们通过一个案例详细使用PyTorch实战 ,案例背景:你创办了一家手机公司,不知道如何估算手机产品的价格。为了解决这个问题,收集了多家公司的手机销售数据:这些数据维度可以包括RAM、存储容量、屏幕尺寸、…

JavaScript数字分隔符

● 如果现在我们用一个很大的数字,例如2300000000,这样真的不便于我们进行阅读,我们希望用千位分隔符来隔开它,例如230,000,000; ● 下面我们使用_当作分隔符来尝试一下 const diameter 287_266_000_000; console.log(diameter)…

论文分享[cvpr2018]Non-local Neural Networks非局部神经网络

论文 https://arxiv.org/abs/1711.07971 代码https://github.com/facebookresearch/video-nonlocal-net 非局部神经网络 motivation:受计算机视觉中经典的非局部均值方法[4]的启发,非局部操作将位置的响应计算为所有位置的特征的加权和。 非局部均值方法 NLM&#…

Python实现Chiikawa

写在前面 哈?呀哈!本期小编给大家素描版Chiikawa! 主人公当然是我们可爱的吉伊、小八以及乌萨奇啦~ Chiikawa小小可爱 《Chiikawa》是一部来自日本的超萌治愈系漫画与动画作品,由作者秋田祯信创作。"Chiikawa"这个名字…

【Kolmogorov-Arnold网络 替代多层感知机MLPs】KAN: Kolmogorov-Arnold Networks

KAN: Kolmogorov-Arnold Networks 论文地址 代码地址 知乎上的讨论(看一下评论区更正) Abstract Inspired by the Kolmogorov-Arnold representation theorem, we propose Kolmogorov-Arnold Networks (KANs) as promising alternatives to Multi-Layer…

支持LLM的Markdown笔记;ComfyUI-HiDiffusion图片生成和对图像进行高质量编辑

✨ 1: ComfyUI-HiDiffusion ComfyUI-HiDiffusion是一个为HiDiffusion技术使用而定制的节点。HiDiffusion技术是专门用于在计算机视觉和图像处理中生成和改进图片质量的先进算法。该技术通常应用于图像的超分辨率、去噪、风格转换等方面。 ComfyUI-HiDiffusion的主要特点包含提…

Julia 语言环境安装与使用

1、Julia 语言环境安装 安装教程:https://www.runoob.com/julia/julia-environment.html Julia 安装包下载地址为:https://julialang.org/downloads/。 安装步骤:注意(勾选 Add Julia To PATH 自动将 Julia 添加到环境变量&…

(五)JSP教程——response对象

response对象主要用于动态响应客户端请求(request),然后将JSP处理后的结果返回给客户端浏览器。JSP容器根据客户端的请求建立一个默认的response对象,然后使用response对象动态地创建Web页面、改变HTTP标头、返回服务器端地状态码…