Batch Normalization学习笔记

文章目录

  • 一、为何引入 Batch Normalization
  • 二、具体步骤
    • 1、训练阶段
    • 2、预测阶段
  • 三、关键代码实现
  • 四、补充
  • 五、参考文献

一、为何引入 Batch Normalization

  现在主流的卷积神经网络几乎都使用了批量归一化(Batch Normalization,BN)1,它是一种逐层归一化方法,可以对神经网络中任意的中间层进行归一化操作。我们可以从不同角度来理解为什么要引入 Batch Normalization:

① 训练时的误差表面(error surface) 可能会十分崎岖,使得做优化时容易陷入局部最优值或鞍点等。通常我们会使用各种算法如Adam等进行优化,那么能不能直接改误差表面的地貌,“把山铲平”,让它变得比较好训练呢?Batch Normalization 就是其中一个“把山铲平”的想法2。另外一个好处是,误差表面变得没那么崎岖后,我们在训练时便可以增大学习率,使得网络更快收敛。

② 对于典型的多层感知机或卷积神经网络,在训练时中间层中的变量可能具有更广的变化范围。也就是说,随着训练时间的推移,每一层的模型参数分布范围变化莫测(比如一个深层网络,反向传播更新参数时,顶层与最底层数据范围差异会比较大,因为最底层相当于通过链式法则乘了一堆偏导数,导致数据范围非常大或小):

在这里插入图片描述

变量分布中的不规则的偏移可能会阻碍网络的收敛,因此为了使各层拥有适当的数据范围,通过 Batch Normalization“强制性”地调整数据分布使其约束到更小的范围(标准正态分布),这样便可以使得训练更加稳定,且对于初始值的设置没那么敏感。调整之后示意图如下:

在这里插入图片描述

③ 深层的网络很复杂,容易过拟合。而 Batch Normalization可以作为一种隐形的正则化方法,减轻过拟合(因此有时候使用BN后,dropout显得没那么必要使用)。由于Batch Normalization是基于一个 mini batch的,因此在训练时,神经网络对一个样本的预测不仅和该样本自身相关,也和同一批次中的其他样本相关,这种选取批次的随机性,使得神经网络不会“过拟合”到某个特定样本,从而提高网络的泛化能力。

总而言之,Batch Normalization 的优点如下3

  • 不那么依赖初始值(对于初始值不用那么神经质)。
  • 可以使学习快速进行(可以增大学习率)。
  • 抑制过拟合(降低Dropout等的必要性)。


二、具体步骤

1、训练阶段

  在训练时,Batch Normalization会逐步对每个mini-batch进行归一化。具体步骤如下:

设一个mini-batch中有 m m m 个输入数据,记为集合 B = { x 1 , x 2 , ⋯   , x m } B=\{x_1,x_2,\cdots,x_m\} B={x1,x2,,xm},对该集合求均值 μ B \mu_B μB 和方差 σ B 2 \sigma_B^2 σB2
μ B ← 1 m ∑ i = 1 m x i \begin{aligned}\mu_B\leftarrow\frac{1}{m}\sum_{i=1}^mx_i\end{aligned} μBm1i=1mxi

σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \begin{aligned}\sigma_B^2\leftarrow\frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2\end{aligned} σB2m1i=1m(xiμB)2
接下来利用求得的均值和方差对输入数据进行归一化:
x ^ i ← x i − μ B σ B 2 + ε \hat{x}_i\leftarrow\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\varepsilon}} x^iσB2+ε xiμB
其中 ε \varepsilon ε 是一个微小值(如 10 e − 7 10e^{-7} 10e7 等),以防止出现除以0的情况。

于是便可以将输入数据转换为均值为0,方差为1的数据 { x ^ 1 , x ^ 2 , ⋯   , x ^ m } \left\{\hat{x}_1,\hat{x}_2,\cdots,\hat{x}_m\right\} {x^1,x^2,,x^m} 了。

  为了使得归一化不对网络的表示能力造成负面影响,再通过一个附加的缩放和平移变换改变新数据的取值区间(虽然归一化加快了训练速度和稳定性,但它改变了数据的原始分布。对于某些任务来说,直接使用归一化的数据可能会限制模型的表达能力,因此引入可以学习的超参数 γ \gamma γ β \beta β ,使得模型可以灵活地调整归一化后的数据分布,恢复其自由度):
y i ← γ x ^ i + β y_i\leftarrow\gamma\hat{x}_i+\beta yiγx^i+β

最后把上述所有处理插入到激活函数的前面即可(整个过程相当于一个BatchNorm层),示意图如下:

在这里插入图片描述


示意图二(其中 W W W 是全连接层, L ^ \widehat{\mathcal{L}} L 是损失函数)4

在这里插入图片描述


2、预测阶段

  在训练过程中,我们无法得知整个数据集来估计平均值和方差,所以只能根据每个小批次(mini-batch)的平均值和方差不断训练模型。 而在预测模式下,一般使用整个预测数据集的均值和方差(因为这时候已经经过完整的训练了,因此可以得知全局信息)。为了节省存储资源,实际中大多采用**移动平均(moving average)**的方式来计算全局的均值和方差。移动平均的计算过程如下式所示:
μ t o t a l = λ ∗ μ t o t a l + ( 1 − λ ) ∗ μ B σ t o t a l 2 = λ ∗ σ t o t a l 2 + ( 1 − λ ) ∗ σ B 2 \begin{aligned}\mu_{total}&=\lambda*\mu_{total}+(1-\lambda)*\mu_{\mathcal{B}}\\\sigma_{total}^2&=\lambda*\sigma_{total}^2+(1-\lambda)*\sigma_{\mathcal{B}}^2\end{aligned} μtotalσtotal2=λμtotal+(1λ)μB=λσtotal2+(1λ)σB2



三、关键代码实现

以动手学深度学习第二版5的代码为例(Pytorch):

import torch
from torch import nn
from d2l import torch as d2l


def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
    if not torch.is_grad_enabled():
        # 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
        X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
    else:
        assert len(X.shape) in (2, 4)
        if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
        # 训练模式下,用当前的均值和方差做标准化
        X_hat = (X - mean) / torch.sqrt(var + eps)
        # 更新移动平均的均值和方差
        moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
        moving_var = momentum * moving_var + (1.0 - momentum) * var
    Y = gamma * X_hat + beta  # 缩放和移位
    return Y, moving_mean.data, moving_var.data

解释几个可能的疑惑点:

  • 为什么分为全连接层和卷积层两种情况?

  全连接层和卷积层的批量规范化实现略有不同:当作用在全连接层时,实际上是作用在特征维;当作用在卷积层上时,实际上是作用在通道维(将通道维当成是卷积层的特征维)。

为什么作用在通道维?因为每个通道都有自己的拉伸参数偏移参数,并且都是标量。例如下图6所示:

在这里插入图片描述

上图各颜色通道中的像素值通常具有不同的分布和范围,这种不一致性可能会导致训练出错或网络不收敛等问题。因此需要通过Normalize操作,将每个通道的像素值标准化为均值为0、标准差为1的分布,使得所有通道的像素值范围和分布一致。



  • 为什么全连接层设置 dim=0,而卷积层设置 dim=(0,2,3)

  全连接层是二维的,即(batch_size, feature) ,计算全连接层时,计算的是特征维的均值和方差,而每个行代表一个样本,每列代表一个特征。

下图重量/甜度/颜色评分为苹果的特征维,我们来计算特征维的均值:

苹果编号重量(克)甜度(°Bx)颜色评分(1 - 10)
苹果 1200127
苹果 2180106
苹果 3220148

dim=0代表行,dim=1代表列,既然我们要求特征维的均值,那么需要让 dim=0 ,也就是把行“拍扁”。上图把行“拍扁”后得到的特征维的均值如下:

重量(克)甜度(°Bx)颜色评分(1 - 10)
(200+180+220)/3=7(12+10+14)/3=12(7+6+8)/3=7

那么卷积层 (batch_size, channels, height, width)设dim=(0,2,3)也很好理解了,我们需要得到通道维的均值,那么就得把其它几个方向都“拍扁”。



  • 为什么全连接层无需设置keepdim=True 而卷积层需设置keepdim=True

  由于pytorch的广播机制,只会从左边补1,换个说法即只会补齐最外层的维度,因此前者无需设置而后者需设置keepdim=True来保证广播机制的正常启动。

有点抽象,举例子说明:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (3, 4, 5, 6) 的四维张量
B = torch.randn(3, 4, 5, 6)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([3, 4, 5, 6])
可以成功输出

因为广播机制会让B的维度补齐成(1,3,4,5,6),也就是最左边补“1”,于是就可以执行 A+B操作了。

而如下情况,即仅仅稍微改变一下B的形状:

# 构造一个形状为 (2, 3, 4, 5, 6) 的五维张量
A = torch.randn(2, 3, 4, 5, 6)

# 打印张量 A 的形状
print("张量 A 的形状:", A.shape)

# 构造一个形状为 (2, 3, 4, 5) 的四维张量
B = torch.randn(2, 3, 4, 5)
print("张量 B 的形状:", B.shape)

try:
    # 尝试执行 A + B
    A + B
    print("可以成功输出")
except Exception as e:
    # 如果发生异常,打印失败信息
    print("失败输出:", e)

输出结果为:

张量 A 的形状: torch.Size([2, 3, 4, 5, 6])
张量 B 的形状: torch.Size([2, 3, 4, 5])
失败输出: The size of tensor a (6) must match the size of tensor b (5) at non-singleton dimension 4

因为广播机制只会往最左边补“1”,而这里B补“1”后形状变成(1,2,3,4,5),依旧和张量A的形状不一致,所以不能做相加操作。

回到 Batch-Normalization 的代码:

 if len(X.shape) == 2:
            # 使用全连接层的情况,计算特征维上的均值和方差
            mean = X.mean(dim=0)
            var = ((X - mean) ** 2).mean(dim=0)
        else:
            # 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
            # 这里我们需要保持X的形状以便后面可以做广播运算
            mean = X.mean(dim=(0, 2, 3), keepdim=True)
            var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)

我们知道, dim等于哪个维,就是将那个维进行“拍扁”。

对于全连接层(batch_size, feature),设置 dim=0时,相当于将第 0 维“拍扁”,拍扁了相当于那个维直接“消失”了,此时meanvar的形状为(feature)。于是直接可以通过广播机制,在最左边补“1”,变成(1, feature),便可以和变量 X 一起计算了【X的形状(batch_size, feature)】。

而卷积层 (batch_size, channels, height, width)dim=(0,2,3)时,相当于将第 0,2,3 维“拍扁”,此时meanvar的形状为(channels),而 X 的形状是 (batch_size, channels, height, width),你得将meanvar的形状扩展到和 X 一致才可以进行计算,而广播机制只能往最左边补“1”,因此(channels)无法扩展成和X一致的形状,顶多扩展成(1, channels),所以无法和 X 进行计算,程序报错。

因此需要对卷积层使用 keepdim=True这个参数,这样meanvar的形状就可以扩展成 (1, channels, 1, 1),与X一致,才能进行接下来的计算。



  • if not torch.is_grad_enabled() 为什么可以判断是训练还是预测模式?

  反向传播时会涉及梯度的计算,而只有训练时才会进行反向传播,因此可以通过是否进行梯度的计算来判断训练模式还是预测模式。



四、补充

  原论文中提出Batch-Normalization的优点是减少了内部协变量转移(internal covariate shift,简单来说就是变量值的分布在训练过程中会发生变化,但是这种解释在后续论文被证实比较不严谨,发现它并没有减少内部协变量的转移 [Santurkar et al.,2018]。



五、参考文献


  1. Ioffe S. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015. ↩︎

  2. 王琦, 杨毅远, 江季, 深度学习详解, 北京:人民邮电出版社, 2024 ↩︎

  3. (日)斋藤康毅著, 陆宇杰译, 深度学习入门基于Python的理论与实现, 北京:人民邮电出版社, 2018.07 ↩︎

  4. Santurkar S, Tsipras D, Ilyas A, et al. How does batch normalization help optimization?[J]. Advances in neural information processing systems, 2018, 31. ↩︎

  5. 阿斯顿·张(Aston Zhang), 李沐(Mu Li), [美] 扎卡里·C. 立顿(Zachary C. Lipton), 等. 动手学深度学习(PyTorch版)[M]. 第二版. 人民邮电出版社, 2023-2. ↩︎

  6. 【Batch Normalization】 https://www.bilibili.com/video/BV11s4y1c7pg/?share_source=copy_web&vd_source=199a3f4e3a9db6061e1523e94505165a ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/958850.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

JavaSec系列 | 动态加载字节码

视频教程在我主页简介或专栏里 目录: 动态加载字节码 字节码 加载远程/本地文件 利用defineClass()直接加载字节码 利用TemplatesImpl加载字节码 动态加载字节码 字节码 Java字节码指的是JVM执行使用的一类指令,通常被存储在.class文件中。 加载远程…

第十四讲 JDBC数据库

1. 什么是JDBC JDBC(Java Database Connectivity,Java数据库连接),它是一套用于执行SQL语句的Java API。应用程序可通过这套API连接到关系型数据库,并使用SQL语句来完成对数据库中数据的查询、新增、更新和删除等操作…

JVM面试题解,垃圾回收之“分代回收理论”剖析

一、什么是分代回收 我们会把堆内存中的对象间隔一段时间做一次GC(即垃圾回收),但是堆内存很大一块,内存布局分为新生代和老年代、其对象的特点不一样,所以回收的策略也应该各不相同 对于“刚出生”的新对象&#xf…

电脑如何访问手机文件?

手机和电脑已经深深融入了我们的日常生活,无时无刻不在为我们提供服务。除了电脑远程操控电脑外,我们还可以在电脑上轻松地访问Android或iPhone手机上的文件。那么,如何使用电脑远程访问手机上的文件呢? 如何使用电脑访问手机文件…

ThinkPHP 8模型与数据的插入、更新、删除

【图书介绍】《ThinkPHP 8高效构建Web应用》-CSDN博客 《2025新书 ThinkPHP 8高效构建Web应用 编程与应用开发丛书 夏磊 清华大学出版社教材书籍 9787302678236 ThinkPHP 8高效构建Web应用》【摘要 书评 试读】- 京东图书 使用VS Code开发ThinkPHP项目-CSDN博客 编程与应用开…

【MySQL】数据库基础知识

欢迎拜访:雾里看山-CSDN博客 本篇主题:【MySQL】数据库基础知识 发布时间:2025.1.21 隶属专栏:MySQL 目录 什么是数据库为什么要有数据库数据库的概念 主流数据库mysql的安装mysql登录使用一下mysql显示数据库内容创建一个数据库创…

【线性代数】基础版本的高斯消元法

[精确算法] 高斯消元法求线性方程组 线性方程组 考虑线性方程组, 已知 A ∈ R n , n , b ∈ R n A\in \mathbb{R}^{n,n},b\in \mathbb{R}^n A∈Rn,n,b∈Rn, 求未知 x ∈ R n x\in \mathbb{R}^n x∈Rn A 1 , 1 x 1 A 1 , 2 x 2 ⋯ A 1 , n x n b 1…

高等数学学习笔记 ☞ 微分方程

1. 微分方程的基本概念 1. 微分方程的基本概念: (1)微分方程:含有未知函数及其导数或微分的方程。 举例说明微分方程:;。 (2)微分方程的阶:指微分方程中未知函数的导数…

HarmonyOS基于ArkTS卡片服务

卡片服务 前言 Form Kit(卡片开发框架)提供了一种在桌面、锁屏等系统入口嵌入显示应用信息的开发框架和API,可以将应用内用户关注的重要信息或常用操作抽取到服务卡片(以下简称“卡片”)上,通过将卡片添加…

Java复习第四天

一、代码题 1.相同的树 (1)题目 给你两棵二叉树的根节点p和q,编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同,并且节点具有相同的值,则认为它们是相同的。 示例 1: 输入:p[1,2,3],q[1,2,3] 输出:true示例 2: 输…

全面了解 Web3 AIGC 和 AI Agent 的创新先锋 MelodAI

不管是在传统领域还是 Crypto,AI 都是公认的最有前景的赛道。随着数字内容需求的爆炸式增长和技术的快速迭代,Web3 AIGC(AI生成内容)和 AI Agent(人工智能代理)正成为两大关键赛道。 AIGC 通过 AI 技术生成…

新能源汽车充电桩选型以及安装应用

摘要:随着当前经济的不断发展,国家的科技也有了飞速的进步,传统的燃油汽车已经不能适应当前社会的发展,不仅对能源造成巨大的消耗,还对环境造成了污染,当前一种新型的交通运输工具正在占领汽车市场。在环境问题和能源问题愈发严重的当今社会,节能减排已经成为全世界的共同课题,…

一个vue项目npm install失败的问题解决方案

vue的项目一直是史上最难的最烦的问题,今天给别人做毕设单子想在gitee上拉项目二开的时候,由于很久没写过vue项目已经生疏了,在拿到项目之后我还是例行完成最常见的步骤: 1、npm init -y 初始化 2、npm install 用npm把这个项目…

计算机网络 (55)流失存储音频/视频

一、定义与特点 定义:流式存储音频/视频是指经过压缩并存储在服务器上的多媒体文件,客户端可以通过互联网边下载边播放这些文件,也称为音频/视频点播。 特点: 边下载边播放:用户无需等待整个文件下载完成即可开始播放…

UE求职Demo开发日志#6 测试用强化页面UI搭建

1 反向实现思路设计 先看最终效果: 先做了一个大致的分区,右侧的上半部分用来显示数据,下半部分用来强化和显示需要的材料,至于这个背景设定上强化应该叫什么,。。。。,还没定,反正应该不叫强…

python学opencv|读取图像(四十一 )使用cv2.add()函数实现各个像素点BGR叠加

【1】引言 前序已经学习了直接在画布上使用掩模,会获得彩色图像的多种叠加效果,相关文章链接为: python学opencv|读取图像(四十)掩模:三通道图像的局部覆盖-CSDN博客 这时候如果更进一步,直接…

SpringCloudAlibaba 服务保护 Sentinel 项目集成实践

目录 一、简介1.1、服务保护的基本概念1.1.1、服务限流/熔断1.1.2、服务降级1.1.3、服务的雪崩效应1.1.4、服务的隔离的机制 1.2、Sentinel的主要特性1.3、Sentinel整体架构1.4、Sentinel 与 Hystrix 对比 二、Sentinel控制台部署3.1、版本选择和适配3.2、本文使用各组件版本3.…

窥探QCC518x-308x系列与手机之间的蓝牙HCI记录与分析 - 耳机篇

上一篇是介绍如何窥探手机端Bluetooth的HCI log, 本次介绍是如何窥探Bluetooth的HCI log-耳机篇. 这次跟QCC518x/QCC308x测试的手机是Samsung S23 Ultra. QCC518x/QCC308x透过HCI界面取得Log教学. 步骤1: 开启QMDE -> 选择ADK r1102 QCC3083 Headset workspace.步骤2: 点…

C++ list 容器用法

C list 容器用法 C 标准库提供了丰富的功能&#xff0c;其中 <list> 是一个非常重要的容器类&#xff0c;用于存储元素集合&#xff0c;支持双向迭代器。<list> 是 C 标准模板库&#xff08;STL&#xff09;中的一个序列容器&#xff0c;它允许在容器的任意位置快速…

SpringSecurity实现自定义用户认证方案

Spring Security 实现自定义用户认证方案可以根据具体需求和业务场景进行设计和实施&#xff0c;满足不同的安全需求和业务需求。这种灵活性使得认证机制能够更好地适应各种复杂的环境和变化‌。通过自定义认证方案&#xff0c;可以更好地控制和管理用户的访问权限&#xff0c;…