深入理解批量归一化(BN):原理、缺陷与跨小批量归一化(CBN)

在训练深度神经网络时,批量归一化(Batch Normalization,简称BN)是一种常用且有效的技术,它帮助解决了深度学习中训练过程中的梯度消失、梯度爆炸和训练不稳定等。然而,BN也有一些局限性,特别是在处理小批量数据推理阶段时。因此,跨小批量归一化(Cross-Batch Normalization,CBN)作为一种新的方法被提出,旨在克服BN的一些缺点。

本文将详细介绍BN的原理、其在小批量训练中的缺陷,并介绍如何通过CBN解决这些问题,帮助读者更好地理解这些技术。


目录

一、批量归一化(BN)是什么?

1.1 什么是批量归一化?

1.2 批量归一化在卷积神经网络中的应用

1.3 BN的计算步骤

1.3.1 计算均值和方差

1.3.3缩放和平移

1.4 BN的优点

二、批量归一化(BN)存在的缺陷

2.1 小批量训练时的问题

2.2 推理阶段的问题

2.3 对批量大小的敏感性

三、跨小批量归一化(CBN):解决BN缺陷的创新方法

3.1 BN vs CBN 的关键区别

3.2 CBN 的工作原理

3.3 CBN 的优缺点

优点:

缺点:

4. CBN 的实现(PyTorch 示例)

5. 总结


一、批量归一化(BN)是什么?

1.1 什么是批量归一化?

批量归一化(BN)是一种在神经网络的训练过程中对每一层输入进行标准化的技术。具体来说,BN对每一层的输入数据进行 均值为0、方差为1 的归一化处理,从而消除了数据分布的变化(即内部协变量偏移)。BN的核心目标是加速网络训练过程,并提高网络的稳定性。

简而言之,BN就是将每层的输入数据进行标准化处理,使其具有相同的尺度,这样可以避免某些层的输出值过大或过小,从而加速训练的收敛。

1.2 批量归一化在卷积神经网络中的应用

在卷积神经网络(CNN)中,BN通常应用于每一层卷积操作的输出,即特征图。卷积神经网络中的特征图是卷积层生成的二维或三维数据,BN会对这些数据进行标准化处理。

假设网络输入的是一个张量,形状为 \mathbf{X} \in \mathbb{R}^{N \times C \times H \times W},其中:

  • N 是批量大小(batch size),即一次训练中输入的样本数量,
  • C 是卷积层输出的通道数(channels),通常表示颜色通道(RGB)或者卷积层提取的特征数量,
  • H 和 W 是特征图的高度(height)和宽度(width)。

1.3 BN的计算步骤

BN的计算过程可以分为三个步骤:计算均值、计算方差、进行标准化。

1.3.1 计算均值和方差

对于每个通道(channel),BN会计算该通道下所有像素点的均值和方差。假设输入数据 \mathbf{X} 的形状为 N \times C \times H \times W,其中 N 为批量大小,C 为通道数,H 和 W 为特征图的高度和宽度。那么对每个通道 c,BN计算的是该通道内所有像素点的均值(\mu_c)和方差(\sigma_c^2)。

均值:对每个通道的所有像素计算均值

\mu_c = \frac{1}{N \times H \times W} \sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W} x_{i, c, j, k}

这里,x_{i, c, j, k} 是第 i 个样本在第 c 个通道上,位置 (j, k) 的像素值。

方差:对每个通道的所有像素计算方差(方差反映了像素值的离散程度)

\sigma_c^2 = \frac{1}{N \times H \times W} \sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W} (x_{i, c, j, k} - \mu_c)^2\\ =\frac{1}{N \times H \times W} \sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W}x_{i, c, j, k}^2- \mu_c^2

上诉推导由公式:\sigma^2 = Var(X)=E(X^2)-(E(X))^2的公式推导而来

1.3.2 标准化

计算得到均值和方差后,我们将每个像素的值进行标准化处理,使得其符合零均值和单位方差:

\hat{x}_{i, c, j, k} = \frac{x_{i, c, j, k} - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}

其中,\epsilon 是一个非常小的常数,防止除零错误。

1.3.3缩放和平移

为了让标准化后的输出数据保持其原本的分布,BN引入了可学习的参数 \gamma_c(缩放因子)和 \beta_c(平移因子):

y_{i, c, j, k} = \gamma_c \hat{x}_{i, c, j, k} + \beta_c

这里,\gamma_c\beta_c 是每个通道的可学习参数,用来恢复输出的表达能力。

1.4 BN的优点

  • 加速训练:通过减少内部协变量偏移,BN让网络训练更加平稳,加快了收敛速度。
  • 提高稳定性:BN通过规范化每一层的输入数据,使得梯度更新更加平滑,从而减少了梯度爆炸和梯度消失的风险。
  • 具有正则化效果:由于每一层的输入数据被归一化,BN本身也具有一定的正则化效果,有时能够减少过拟合。

二、批量归一化(BN)存在的缺陷

虽然BN在训练过程中提供了很多好处,但它也有一些限制,特别是在以下两个方面:

2.1 小批量训练时的问题

BN的性能依赖于小批量中的统计数据(均值和方差)。如果批量大小非常小(例如,批量大小为1或几),那么计算得到的均值和方差可能并不稳定,这会导致训练的不稳定性。在这种情况下,BN的效果往往不如预期,甚至会影响训练的收敛速度。

2.2 推理阶段的问题

在推理阶段,我们通常使用 训练阶段 得到的均值和方差来归一化数据,因为推理时无法获取多个样本的小批量。然而,这种方法存在问题:训练和推理阶段使用的均值和方差可能不一致,尤其当推理数据与训练数据的分布有所不同时。这会导致网络性能在推理阶段下降。

2.3 对批量大小的敏感性

BN对批量大小非常敏感。较小的批量会导致统计不准确,较大的批量则增加计算开销。因此,BN在面对不同批量大小时并不总是最优的解决方案。

三、跨小批量归一化(CBN):解决BN缺陷的创新方法

为了解决BN在小批量训练和推理阶段的缺陷,跨小批量归一化(CBN)应运而生。CBN的目标是通过 跨多个小批量 计算全局的均值和方差,从而避免BN在小批量训练时统计不稳定的问题。

为了理解CBN是如何实现这一点的,我们需要明确以下几个关键概念和步骤:

3.1 BN vs CBN 的关键区别

在标准的 批量归一化(BN) 中,我们通常对每个小批量(batch)内部的均值和方差进行计算,并在每个批次(即每个小批量)上进行归一化处理。这样,每个批次的均值和方差都可能不同。问题是,当批次较小时,计算得到的均值和方差会存在较大误差,导致模型训练不稳定。

而在 跨小批量归一化(CBN) 中,目标是跨多个小批量数据来计算全局的均值和方差,避免每个小批量独立计算统计量带来的波动。具体来说,CBN可以跨多个批次计算全局均值和方差,从而确保训练过程中的统计量更加稳定。

3.2 CBN 的工作原理

在训练过程中,CBN通过以下方式获取跨小批量的统计值。

3.2.1跨多个小批量的数据积累

在标准的BN中,每个小批量都有自己的均值和方差。CBN则会跨多个小批量(或者多个批次)对均值和方差进行积累和计算,逐渐形成一个全局的均值和方差。

具体而言,CBN会通过以下步骤积累统计值:

  • 全局均值计算:每次处理一个小批量时,CBN会将该小批量的均值加入全局均值的计算。
  • 全局方差计算:类似地,CBN会将每个小批量的方差也加入到全局方差的计算中。

3.3.2更新统计值的方式

CBN的统计量(均值和方差)通常使用滑动平均或累积的方式进行更新。具体来说,CBN会通过更新公式来平滑计算全局的均值和方差,避免每个批次计算出的统计量波动过大。

例如,对于均值和方差的更新,CBN可以使用如下公式:

  • 均值更新公式

\mu_{\text{global}} = \frac{1}{t}\sum_{i=1}^{t}\mu_{i}

其中,\mu_{\text{global}} 是全局均值的当前值,\mu_{i} 是第 i 批量的均值,t 为当前批量的索引。

但是在实际运用中,我们会给上诉公司做简化处理:

\mu_{\text{global}} =\alpha \mu_{t-1} + (1-\alpha) \mu_{t}

其中α 是一个平滑因子(通常接近1,例如0.9或0.99),用于控制历史信息的影响。

  • 方差更新公式

\sigma_c^2 = \frac{1}{N \times H \times W} \sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W} (x_{i, c, j, k} - \mu_c)^2=\frac{1}{N \times H \times W} \sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W}x_{i, c, j, k}^2- \mu_c^2

上诉推导由公式:\sigma^2 = Var(X)=E(X^2)-(E(X))^2的公式推导而来

\sigma_{\text{global}}^{2} = \overline{E(X^2)}-\overline{(E(X))^2}=\frac{\sum_{i=1}^{N} \sum_{j=1}^{H} \sum_{k=1}^{W}\sum_{t=1}^{T}x_{i, c, j, k,t}^2}{N \times H \times W \times T} -\mu_{\text{global}}^{2}

这里的 \sigma_{\text{global}}^{2} 是全局方差的当前值,\mu_{\text{global}}^{2} 全局均值。

同理,我们在实际应用中简化如下公式:

\sigma_{\text{global}}^2 =\alpha \sigma _{t-1}^2 + (1-\alpha) \sigma_{t}^2'

其中α 是一个平滑因子(通常接近1,例如0.9或0.99),用于控制历史信息的影响。

3.2.3标准化使用全局统计量

训练过程中,每个小批量的输入都会使用 全局均值全局方差 来进行标准化,而不仅仅依赖当前小批量的统计量。具体而言,每次输入数据通过标准化公式:

\hat{x}_{i, c, j, k} = \frac{x_{i, c, j, k} - \mu_{\text{global}}}{\sqrt{\sigma_{\text{global}}^2 + \epsilon}}

其中,\mu_{\text{global}}\sigma_{\text{global}}^2 是跨多个小批量积累的全局均值和方差,\epsilon 是一个小常数,用于防止除零错误。

通过这种方式,CBN确保了所有小批量在训练过程中使用的是稳定的统计量。

3.3 CBN 的优缺点

优点:
  • 减少小批量训练的不稳定性:CBN通过跨多个小批量积累统计量,避免了单个小批量方差和均值的不准确,尤其在批量大小非常小的情况下,效果尤为明显。
  • 保持训练和推理阶段的一致性:CBN在训练阶段和推理阶段使用相同的全局均值和方差,从而避免了在推理时因为统计量差异而导致的性能下降。
缺点:
  • 计算开销增加:CBN需要跨多个小批量计算统计量,因此需要更多的内存和计算资源来保存历史统计值。
  • 需要更多的数据积累:为了准确地计算全局均值和方差,CBN通常需要积累较多的小批量数据,这可能会影响训练效率。

4. CBN 的实现(PyTorch 示例)

下面是一个简单的基于 PyTorch 实现的 CBN 类,它演示了如何跨多个批量计算均值和方差。

import torch
import torch.nn as nn

class CrossBatchNorm(nn.Module):
    def __init__(self, num_features, momentum=0.1):
        super(CrossBatchNorm, self).__init__()
        self.num_features = num_features
        self.momentum = momentum
        # 初始化全局均值和方差
        self.running_mean = torch.zeros(num_features)
        self.running_var = torch.ones(num_features)
    
    def forward(self, x):
        # 计算当前小批量的均值和方差
        mean = x.mean([0, 2, 3])  # 跨批量、行、列计算均值
        var = x.var([0, 2, 3], unbiased=False)  # 跨批量、行、列计算方差
        
        # 更新全局均值和方差
        self.running_mean = self.running_mean * self.momentum + mean * (1 - self.momentum)
        self.running_var = self.running_var * self.momentum + var * (1 - self.momentum)
        
        # 使用全局均值和方差进行标准化
        x_hat = (x - self.running_mean[None, :, None, None]) / torch.sqrt(self.running_var[None, :, None, None] + 1e-5)
        
        # 可学习的缩放和平移
        gamma = self.gamma if hasattr(self, 'gamma') else torch.ones_like(mean)
        beta = self.beta if hasattr(self, 'beta') else torch.zeros_like(mean)
        
        return gamma[None, :, None, None] * x_hat + beta[None, :, None, None]

5. 总结

跨小批量归一化(CBN) 通过跨多个小批量数据计算全局均值和方差,从而避免了单个小批量的统计量可能存在的误差。这种方法在处理小批量数据时特别有效,能够提供更稳定的训练过程,并保持训练和推理阶段的一致性。虽然这种方法增加了计算和内存开销,但它可以显著提高深度学习模型在特定情况下的表现,特别是在处理小批量数据时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/941787.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

iptables交叉编译(Hisiav300平台)

参考文章:https://blog.csdn.net/Bgm_Nilbb/article/details/135714738 https://bbs.archlinux.org/viewtopic.php?pid1701065 1、libmnl 交叉编译 tar xvf libmnl-1.0.5.tar.bz2 sudo chmod 777 -R libmnl-1.0.5 cd libmnl-1.0.5 mkdir _install //host和CC需要修…

redis数据类型:list

数据结构 源码版本:7.2.2路径:src/adlist.h 关于list的 头文件中涉及到的这三个结构体如下 /* Node, List, and Iterator are the only data structures used currently. */ # 节点 typedef struct listNode {struct listNode *prev; # 前元素的指针s…

达梦8数据库备份与还原

通过命令找到达梦数据库进程所在位置 ps -ef | grep dm 得到达梦相关进程 pwd 进程ID得到进程目录 [rootdmdb01 bin]# pwd /data/dmdbms/bin [rootdmdb01 bin]# ps -ef | grep dm root 1183 2 0 Nov04 ? 00:00:33 [kworker/8:1H-xfs-log/dm-0] root …

电气设计 | 低压接地系统:TN-C 、TN-S、TN-C-S、TT适用哪些场所?

电气设计 | 低压接地系统:TN-C 、TN-S、TN-C-S、TT适用哪些场所? 1、低压配电系统简介2、各种低压配电系统介绍2.1、TN-C系统2.2、TN-S系统2.3、TN-C-S 系统2.4、TT 系统2.5、IT 系统 1、低压配电系统简介 低压配电系统有TN-C、TN-S、TN-C-S、TT和IT五种…

重温设计模式--组合模式

文章目录 1 、组合模式(Composite Pattern)概述2. 组合模式的结构3. C 代码示例4. C示例代码25 .应用场景 1 、组合模式(Composite Pattern)概述 定义:组合模式是一种结构型设计模式,它允许你将对象组合成…

漏洞检测工具:Swagger UI敏感信息泄露

Swagger UI敏感信息泄露 漏洞定义 Swagger UI是一个交互式的、可视化的RESTful API文档工具,它允许开发人员快速浏览、测试API接口。Swagger UI通过读取由Swagger(也称为OpenAPI)规范定义的API描述文件(如swagger.json或swagger…

Linux下学【MySQL】表中插入和查询的进阶操作(配实操图和SQL语句通俗易懂)

绪论​ 每日激励:挫折是会让我们变得越来越强大的重点是我们敢于积极的面对它。—Jack叔叔 绪论​: 本章是表操作的进阶篇章(没看过入门的这里是传送门,本章将带你进阶的去学习表的插入insert和查找select,本质也就是…

JavaScript 标准内置对象——Object

1、构造函数 2、静态方法 // 将源对象中所有可枚举的自有属性复制到目标对象,,并返回修改后的目标对象 Object.assign(target, ...sources) Object.create(proto, propertiesObject) // 以一个现有对象作为原型,创建一个新对象Object.defineP…

Robot Framework搭建自动化测试框架

1.配置环境 需要安装jdk8,andrid sdk(安装adb),pycharm编译环境以及软件 安装Robot Framework 首先,你需要安装Robot Framework,可以使用 pip 进行安装: pip install robotframework安装所需的…

fastjson诡异报错

1、环境以及报错描述 1.1 环境 操作系统为中标麒麟、cpu 为国产鲲鹏服务器。 jdk为openjdk version 1.8.0._242 1.2 错误 com.alibaba.fastjson2.JSONException: syntax error : f at com.alibaba.fastjson2.JSONReaderUTF16.readBoolValue(JSONReaderUTF16.java:6424) at c…

Unity3d 基于UGUI和VideoPlayer 实现一个多功能视频播放器功能(含源码)

前言 随着Unity3d引擎在数字沙盘、智慧工厂、数字孪生等场景的广泛应用,视频已成为系统程序中展示时,不可或缺的一部分。在 Unity3d 中,我们可以通过强大的 VideoPlayer 组件和灵活的 UGUI 系统,将视频播放功能无缝集成到用户界面…

蓝牙协议——音乐启停控制

手机播放音乐 手机暂停音乐 耳机播放音乐 耳机暂停音乐

【EthIf-13】EthIfGeneral容器配置-01

1.EthIfGeneral类图结构 下面是EthIfGeneral配置参数的类图,比较重要的参数就是配置: 接收中断是否打开发送确认中断是否打开EthIf轮询周期 1.EthIfGeneral参数的含义

如何看待2024年诺贝尔物理学奖颁给了机器学习与神经网络?

成长路上不孤单😊😊😊😊😊😊 【14后😊///C爱好者😊///持续分享所学😊///如有需要欢迎收藏转发///😊】 今日分享关于2024年诺贝尔物理学奖颁给了机器学习与神…

有没有检测吸烟的软件 ai视频检测分析厂区抽烟报警#Python

在现代厂区管理中,安全与规范是重中之重,而吸烟行为的管控则是其中关键一环。传统的禁烟管理方式往往依赖人工巡逻,效率低且存在监管死角,难以满足当下复杂多变的厂区环境需求。此时,AI视频检测技术应运而生&#xff0…

VSCode搭建Java开发环境 2024保姆级安装教程(Java环境搭建+VSCode安装+运行测试+背景图设置)

名人说:一点浩然气,千里快哉风。—— 苏轼《水调歌头》 创作者:Code_流苏(CSDN) 目录 一、Java开发环境搭建二、VScode下载及安装三、VSCode配置Java环境四、运行测试五、背景图设置 很高兴你打开了这篇博客,更多详细的安装教程&…

二手车交易平台开发:安全与效率的双重挑战

3.1系统体系结构 系统的体系结构非常重要,往往决定了系统的质量和生命周期。针对不同的系统可以采用不同的系统体系结构。本系统为二手车交易平台系统,属于开放式的平台,所以在体系结构中采用B/s。B/s结构抛弃了固定客户端要求,采…

共享无人系统,从出行到生活全面覆盖

共享无人系统已经覆盖到我们生活中的方方面面,出行上,比如共享自行车小程序、共享自行车;生活中,比如说棋牌室、茶室。我们以棋牌室举例。 通过开发使用共享无人系统,可以极大地降低人力成本,共享无人棋牌室…

FPGA学习(基于小梅哥Xilinx FPGA)学习笔记

文章目录 一、整个工程的流程二、基于Vivado的FPGA开发流程实践(二选一多路器)什么是二选一多路器用verilog语言,Vivado软件进行该电路实现1、设计输入:Design Sources中的代码2、分析和综合:分析设计输入中是否有错误…

四相机设计实现全向视觉感知的开源空中机器人无人机

开源空中机器人 基于深度学习的OmniNxt全向视觉算法OAK-4p-New 全景硬件同步相机 机器人的纯视觉避障定位建图一直是个难题: 系统实现复杂 纯视觉稳定性不高 很难选到实用的视觉传感器 为此多数厂家还是采用激光雷达的定位方案。 OAK-4p-New 为了弥合这一差距…