【AI知识点】批归一化(Batch Normalization)

批归一化(Batch Normalization,BN) 是一种用于加速神经网络训练并提高模型稳定性的方法,通过在每一层对神经网络中的激活值进行标准化,使得每一层的输入保持在一个稳定的分布中,从而缓解梯度消失和梯度爆炸的问题,并加快训练过程。


1. 为什么需要批归一化?

在神经网络训练过程中,尤其是深度神经网络,层与层之间的参数不断更新,这导致网络中的每一层的输入分布会发生变化。这种现象被称为内部协变量偏移(Internal Covariate Shift)。它会导致训练变得更加困难,因为每一层的输入分布不稳定,会使得模型需要不断适应新的数据分布,从而影响模型的训练速度。

为了解决这个问题,批归一化被引入。批归一化通过将每一层的激活值标准化为均值为 0、方差为 1 的分布,使得每一层的输入数据保持相对稳定的分布,从而使得网络可以更快地学习和收敛。


2. 批归一化的基本步骤

批归一化的过程主要包括以下几个步骤:

  1. 计算批次的均值
    对于每一层的输入(例如激活值) x x x,计算其在当前 mini-batch 中的均值:
    μ B = 1 m ∑ i = 1 m x i \mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i μB=m1i=1mxi
    其中 m m m 是 mini-batch 的样本数量, x i x_i xi 是第 i i i 个样本的输入。

  2. 计算批次的方差
    接下来计算 mini-batch 中输入的方差:
    σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 \sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2 σB2=m1i=1m(xiμB)2

  3. 对输入进行标准化
    使用批次均值 μ B \mu_B μB 和方差 σ B 2 \sigma_B^2 σB2 对每个输入 x i x_i xi 进行标准化处理:
    x ^ i = x i − μ B σ B 2 + ϵ \hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} x^i=σB2+ϵ xiμB
    其中, ϵ \epsilon ϵ 是一个很小的常数,用来避免除以零的情况。

  4. 尺度变换和偏移
    为了保持网络的表达能力,批归一化还会引入可学习的参数 γ \gamma γ β \beta β,用于对标准化后的结果进行线性变换:
    y i = γ x ^ i + β y_i = \gamma \hat{x}_i + \beta yi=γx^i+β
    其中, γ \gamma γ 是缩放参数, β \beta β 是偏移参数。这一步保证了即使数据经过归一化后,网络仍然能够恢复原始的表示能力。


3. 批归一化的整体公式

结合上面几步,批归一化的整体公式可以表示为:
y i = γ ⋅ x i − μ B σ B 2 + ϵ + β y_i = \gamma \cdot \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} + \beta yi=γσB2+ϵ xiμB+β

在这个公式中, x i x_i xi 是神经网络层中的输入, μ B \mu_B μB σ B 2 \sigma_B^2 σB2 是当前 mini-batch 的均值和方差, γ \gamma γ β \beta β 是可学习的参数,而 ϵ \epsilon ϵ 是一个防止除零的小常数。


4. 批归一化的位置

下图展示了批归一化在神经网络中的位置

在这里插入图片描述
图片来源:https://medium.com/@abheerchrome/batch-normalization-explained-1e78f7eb1e8a

下图对比了批归一化网络(Batch Normalized Network)标准神经网络(Standard Network) 在前向传播过程中的区别。


图片来源:https://gradientscience.org/batchnorm/

  • x x x 是输入特征, W W W 是权重矩阵, y = W ⋅ x y = W \cdot x y=Wx 是通过神经网络层(隐藏层)计算得到的线性输出(也就是未经过激活函数的输出)。
  • 在标准网络中,这个 y y y 直接用于损失计算 L ^ \hat{\mathcal{L}} L^ (或者进入激活函数),但在批归一化网络中,会先对 y y y 进行标准化处理。批归一化后的值 z z z 再用于损失计算 L ^ \hat{\mathcal{L}} L^(或者进入激活函数)。

5. 训练与测试时的区别

批归一化的行为在训练阶段测试阶段是不同的:

  • 训练阶段:每个 mini-batch 内的数据被标准化,使用 mini-batch 的均值和方差进行归一化。
  • 测试阶段:由于测试时无法使用 mini-batch 的均值和方差(因为测试是单独进行的),因此在训练过程中,通常会维护一个全局的均值和方差(通过对所有 mini-batch 的均值和方差进行指数加权平均计算得出)。在测试时,使用这个全局均值和方差进行归一化,而不是使用单个 mini-batch 的均值和方差。

6. 批归一化的优点

  • 加快收敛速度:批归一化能够稳定输入分布,从而加快模型的收敛速度。在实际应用中,批归一化常常使得模型在更少的迭代次数内达到同样甚至更好的效果。
  • 缓解梯度消失/爆炸问题:通过将数据标准化,批归一化可以有效防止梯度消失或梯度爆炸的问题,这对于训练深层神经网络尤其重要。
  • 允许使用更大的学习率:在梯度下降过程中,批归一化减少了权重更新的波动,因此可以使用更大的学习率,从而进一步加速模型的训练。
  • 一定的正则化效果:在一定程度上,批归一化对每个 mini-batch 的操作引入了噪声,这种噪声类似于 Dropout 的作用,能够减少模型过拟合。

7. 批归一化的缺点

  • 对小批量数据效果较差:批归一化依赖于 mini-batch 内的均值和方差。当 mini-batch 的大小较小时,均值和方差可能无法很好地代表整体数据分布,从而影响归一化效果。
  • 引入额外的计算开销:批归一化会增加额外的计算量,特别是在进行大量卷积操作或多层网络时,这可能会对训练时间造成一定影响。
  • 在某些模型中的表现不稳定:批归一化虽然通常提高了模型的稳定性,但在某些极端情况下,特别是序列模型(如 RNN)中,其表现可能不如其他正则化技术(如 Layer Normalization 和 Group Normalization)。

8. 批归一化和其变体的比较

图示

由于批归一化对小批量数据和序列模型效果不佳,一些变体技术被提出,下面这张图形象的解释了几种归一化方法的差别:

在这里插入图片描述

图片来源:https://arxiv.org/abs/1803.08494

这个图展示了四种不同的归一化方法在特征图张量上的操作方式。每个子图展示了一个三维的特征图张量,其中:

  • N 代表 batch 维度(样本数量);
  • C 代表通道(channel)维度,每个通道代表一个特征;
  • H, W 代表空间维度(即图像的高度和宽度)。

蓝色区域代表在归一化过程中使用相同均值和方差的像素点或区域,不同的归一化方法在归一化时对不同维度的数据进行标准化。

Batch Norm 批归一化

  • 批次内所有样本的单个通道上进行归一化。
  • 适用场景:适合大批量数据,适用于大多数神经网络模型,特别是在卷积神经网络(CNN)和全连接网络(FCN)中广泛使用。

Layer Norm 层归一化

  • 单个样本的所有通道上归一化。
  • 适用场景:适合变长序列模型。主要用于循环神经网络(RNN)、自注意力模型(如 Transformer)等序列模型。

Instance Norm 实例归一化

  • 单个样本的单个通道上归一化。
  • 适用场景:适合单样本输入的场景。主要用于图像生成任务,特别是在图像风格转换任务中效果较好。

Group Norm 组归一化

  • 单个样本的多个通道(按照通道分组)上归一化。
  • 适用场景:适合小批量数据。适用于卷积神经网络(CNN)中的小批量训练,以及 mini-batch 太小无法使用批归一化的情况。特别是在计算资源有限的情况下表现出色。

9. 批归一化的实际应用

在深度学习模型中,批归一化几乎已经成为标准组件之一。特别是在卷积神经网络(CNN)和全连接神经网络(FCNN)中,批归一化的使用能显著提高模型训练速度和性能。

常见的应用场景包括:

  • 图像分类:批归一化常用于卷积层之后,以保证卷积输出的稳定性,避免梯度爆炸。这种应用显著提高了像 ResNet、VGG 等经典图像分类模型的训练速度和准确性。
  • 生成对抗网络(GANs):GAN 模型中的 Generator 和 Discriminator 都需要稳定训练,批归一化能帮助平衡两者的训练。
  • 深度神经网络中的每一层:现代的神经网络模型几乎在每一层都使用批归一化。通常,批归一化会被放置在全连接层或卷积层之后,非线性激活函数之前。这种放置方式能最大程度地稳定激活值,防止模型在深度训练中失去学习能力。

10. 批归一化的总结

批归一化(Batch Normalization) 是一种重要的神经网络正则化方法,它通过标准化每一层的输入来加速神经网络的训练过程并提高模型的稳定性。其主要优势包括减少梯度消失和梯度爆炸、加快收敛速度,并提供一定的正则化效果,降低过拟合风险。批归一化在卷积神经网络和全连接神经网络中非常流行,几乎是现代深度学习模型中的标准组件。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/888429.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Chromium 搜索引擎功能浅析c++

地址栏输入:chrome://settings/searchEngines 可以看到 有百度等数据源,那么如何调整其顺序呢,此数据又存储在哪里呢? 1、浏览器初始化搜索引擎数据来源在 components\search_engines\prepopulated_engines.json // Copyright …

机器学习-支撑向量机SVM

Support Vector Machine 离分类样本尽可能远 Soft Margin SVM scikit-learn中的SVM 和kNN一样,要做数据标准化处理! 涉及距离! 加载数据集 import numpy as np import matplotlib.pyplot as plt from sklearn import datasetsiris datas…

Debezium日常分享系列之:Debezium 3.0.0.Final发布

Debezium日常分享系列之:Debezium 3.0.0.Final发布 Debezium 核心的变化需要 Java 17基于Kafka 3.8 构建废弃的增量信号字段的删除每个表的详细指标 MariaDB连接器的更改版本 11.4.3 支持 MongoDB连接器的更改MongoDB sink connector MySQL连接器的改变MySQL 9MySQL…

【图论】迪杰特斯拉算法

文章目录 迪杰特斯拉算法主要特点基本思想算法步骤示例 实现迪杰斯特拉算法基本步骤算法思路 总结 迪杰特斯拉算法 迪杰特斯拉算法是由荷兰计算机科学家艾兹赫尔迪杰特斯拉(Edsger W. Dijkstra)在1956年提出的,用于解决单源最短路径问题的经…

命令行py脚本——Linux下方便快捷地运行*.py脚本

命令行参数传递,shell批指令和命令别名。 (笔记模板由python脚本于2024年10月08日 12:25:54创建,本篇笔记适合喜欢python和Linux的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣…

Docker:快速部署

docker安装: ​‌​‬​⁠​‍‬​‍‬‬‌​‬‬‬​⁠​‍​​‌‬‌​​​​​​‌​​​​⁠​‍⁠‌安装Docker - 飞书云文档 (feishu.cn) docker命令解读 docker run -d \ > --name mysql \ > -p 33…

【bug】finalshell向远程主机拖动windows快捷方式导致卡死

finalshell向远程主机拖动windows快捷方式导致卡死 问题描述 如题,作死把桌面的快捷方式拖到了finalshell连接的服务器面板中,导致finalshell没有响应(小概率事件,有时会触发) 解决 打开任务管理器查看finalshell进…

SpringBoot Jar 包加密防止反编译

今天看到了一个说明jar包加密的实现方式,特意试了下效果,并下载了插件源码及实现源码查看了下子,感兴趣的可以在最后得到gitee地址。 SpringBoot 程序 Jar 包加密的方式,通过代码加密可以实现无法反编译。应用场景就是当需要把公司…

RK3568笔记六十四:SG90驱动测试

若该文为原创文章,转载请注明原文出处。 前面有测试过PWM驱动,现在使用两种方式来产生PWM驱动SG90,实现舵机旋转任意角度 方法一:使用硬件PWM 方法二:使用高精度定时器,GPIO模拟PWM. 一、PWM子系统框架 二、SG90控制方法 舵机的控制需要MCU产生一个周期为20ms的脉冲信号…

(Linux驱动学习 - 8).信号异步通知

一.异步通知简介 1.信号简介 信号类似于我们硬件上使用的“中断”,只不过信号是软件层次上的。算是在软件层次上对中断的一种模拟,驱动可以通过主动向应用程序发送信号的方式来报告自己可以访问了,应用程序获取到信号以后就可以从驱动设备中…

【JavaEE】【多线程】Thread类讲解

目录 Thread构造方法Thread 的常见属性创建一个线程获取当前线程引用终止一个线程使用标志位使用自带的标志位 等待一个线程线程休眠线程状态线程安全线程不安全原因总结解决由先前线程不安全问题例子 Thread构造方法 方法说明Thread()创建线程对象Thread(Runnable target)使用…

Web3 游戏周报(9.22 - 9.28)

回顾上周的区块链游戏概况,查看 Footprint Analytics 与 ABGA 最新发布的数据报告。 【9.22-9.28】Web3 游戏行业动态: Axie Infinity 将 Fortune Slips 的冷却时间缩短至 24 小时,从而提高玩家的收入。 Web3 游戏开发商 Darkbright Studios…

使用sponge+dtm快速搭建一个高性能的电商系统,秒杀抢购和订单架构的设计与实现

本文将展示如何使用 Sponge 框架快速创建一个简易版高性能电商系统,主要实现秒杀抢购和订单功能,并通过分布式事务管理器 DTM 来确保数据一致性。电商系统的架构图如下: 这是源码示例eshop,目录下包括了两个一样的代码示例&#x…

kafka-windows集群部署

kafka-windows集群部署目录 文章目录 kafka-windows集群部署目录前言一、复制出来四个kafka文件夹二、修改集群每个kafka的配置文件四、启动zookeeper,kafka集群 前言 部署本文步骤可以先阅读这一篇博客,这篇是关于单机kafka部署测试的。本文用到的文件…

Android 电源管理各个版本的变动和限制

由于Android设备的电池容量有限,而用户在使用过程中会进行各种高耗电操作,如网络连接、屏幕亮度调节、后台程序运行等,因此需要通过各种省电措施来优化电池使用‌,延长电池续航时间,提高用户体验,并减少因电…

基于LORA的一主多从监测系统_AHT20温湿度传感器

1)AHT20温湿度传感器 这个传感器,网上能找到的资料还是比较多的,我们使用的是HAL硬件i2c,相比于模拟i2c,我们不需要过于关注时序问题,我们只需要关心如何获取数据以及数据如何处理,下面以数据手…

Prometheus之Pushgateway使用

Pushgateway属于整个架构图的这一部分 The Pushgateway is an intermediary service which allows you to push metrics from jobs which cannot be scraped. The Prometheus Pushgateway exists to allow ephemeral and batch jobs to expose their metrics to Prometheus. S…

MATLAB APPdesigner中的日期选择器怎样实时显示时间

文章目录 1.问题描述2.代码设置代码示例解释 1.问题描述 我们在做MATLAB的时候,一般需要在APP界面中加上时间显示,像下图中的右上角,在组件中有日期选择器,但是这个并不是实时显示的,我们还需要自己进行设置。 2.代码…

(11)MATLAB莱斯(Rician)衰落信道仿真2

文章目录 前言一、莱斯衰落信道仿真模型二、仿真代码与结果1.仿真代码2.仿真结果画图 三、后续:四、参考文献: 前言 首先给出莱斯衰落信道仿真模型,该模型由直射路径分量和反射路径分量组成,其中反射路径分量由瑞利衰落信道模型构…

力扣之603.连续空余座位

文章目录 1. 603.连续空余座位1.1 题干1.2 准备数据1.3 思路分析1.4 解法1.5 结果截图 1. 603.连续空余座位 1.1 题干 表: Cinema ----------------- | Column Name | Type | ----------------- | seat_id | int | | free | bool | ----------------- Seat_id 是该表的自动递…