BatchNorm LayerNorm

0. Abstract

很早以前就遇到了 BatchNormLayerNorm, 当时只是粗略地知道它们是对数据进行了标准化: x = x − μ σ \bm{x} = \frac{\bm{x} - \bm{\mu}}{\bm{\sigma}} x=σxμ 这当然很简单, 但实际的数据是比较复杂的. 对于 CV 任务的数据 image 而言, 一个 batch 的数据通常是这样的: [B, C, H, W]; 对于 NLP 任务的数据 sentence 而言, 一个 batch 的数据通常是这样的: [B, T, D]. 那 XXNorm 是怎么算的? 肯定不是一下算整个 batch 的所有数的 μ \bm{\mu} μ σ \bm{\sigma} σ, 且它们是有区别的.

于是查阅了很多资料, 真是应了导师的话: 别去看那些乱七八糟的博客, 他们水平还不如你呢! 还是老外厉害, 仅仅是在评论区回答问题就把问题讲得清清楚楚: 《Why do transformers use layer norm instead of batch norm?》

1. 如何计算?

本节从理论和代码两个角度详细说明 BatchNormLayerNorm 的计算方式.

1.1 先从理论上说一下, 询问通义千问

>>>>> 通义千问 >>>>>
: layer norm 和 batch norm 的区别? 给出公式解释.

: Layer Normalization (LayerNorm)Batch Normalization (BatchNorm) 都是用于加速神经网络训练的技术, 通过规范化层的输入或激活值, 使网络的学习更加稳定和高效. 但是, 它们在实施方式适用场景上有所不同.

Batch Normalization (BatchNorm)
BatchNorm 主要在训练期间对一个 mini-batch 的数据进行规范化. 它的主要目标减少内部协变量移位(internal covariate shift), 即在网络的前向传播过程中, 中间层的输入分布随训练迭代而变化的现象. BatchNorm 通过计算当前 mini-batch 的均值和方差来规范化输入, 公式如下:
x \bm{x} x 是一个 mini-batch 内的输入向量, μ B \bm{\mu}_B μB σ B 2 \bm{\sigma}^2_B σB2 分别是该 mini-batch 的均值和方差, 则 BatchNorm 的输出 y \bm{y} y 可以表示为: x ^ = x − μ B μ B 2 + ϵ y ^ = γ x ^ + β \begin{aligned} \hat{\bm{x}} &= \frac{\bm{x} - \bm{\mu}_B}{\sqrt{\bm{\mu}_B^2 + \epsilon}} \\ \hat{\bm{y}} &= \gamma\hat{\bm{x}} + \beta \end{aligned} x^y^=μB2+ϵ xμB=γx^+β 其中, γ \gamma γ β \beta β 是可学习的参数, ϵ \epsilon ϵ 是一个很小的常数, 用于数值稳定性.

Layer Normalization (LayerNorm)
LayerNorm 目的是在每个样本级别上进行规范化, 而不依赖于 mini-batch 的统计信息. LayerNorm 更适用于序列模型(如 LSTMTransformer)等, 其中样本间可能存在依赖关系mini-batch 大小可能不是固定的. LayerNorm 计算每个样本特征的均值和方差, 公式如下:
x \bm{x} x 是单个样本的特征向量, μ L \bm{\mu}_L μL σ L 2 \bm{\sigma}^2_L σL2 分别是该样本特征的均值和方差, 则 LayerNorm 的输出 y \bm{y} y 可以表示为: x ^ = x − μ L μ L 2 + ϵ y ^ = γ x ^ + β \begin{aligned} \hat{\bm{x}} &= \frac{\bm{x} - \bm{\mu}_L}{\sqrt{\bm{\mu}_L^2 + \epsilon}} \\ \hat{\bm{y}} &= \gamma\hat{\bm{x}} + \beta \end{aligned} x^y^=μL2+ϵ xμL=γx^+β 这里同样, γ \gamma γ β \beta β 是可学习的参数, ϵ \epsilon ϵ 是一个很小的常数.

区别

  • 依赖于数据: BatchNorm 依赖于 mini-batch 的统计信息, 而 LayerNorm 依赖于单个样本的特征统计信息;
  • 应用场景: BatchNorm 更适用于计算机视觉任务, 而 LayerNorm 更适用于自然语言处理任务, 尤其是在处理序列数据时;
  • 计算效率: BatchNorm 在并行计算上可能更高效, 因为它可以利用 mini-batch 内的并行性, 而 LayerNorm 在每个样本上的计算独立, 更适合处理不同长度的序列数据.

这两种方法各有优劣, 选择哪种方法通常取决于具体的应用场景和数据特性.
<<<<< 通义千问 <<<<<

清楚了吗? 迷迷糊糊, 什么叫做 “当前 mini-batch 的均值和方差”? 什么叫做 “每个样本特征的均值和方差”? 还是老外回答的好:

先搞清基本概念:

  • instance: 一个样本, image for CV, sentence for NLP;
  • element: 样本的组件, pixel for image, word for sentence;
  • feature: element 的特征, RGB 值 for pexel, embedding for word.

体现到数据上就是:

  • [B, C, H, W] 是一个 batch, [C, H, W] 是一个 instance, [C] 是一个 pixel, 包含了 C 个 feature.
  • [B, T, L] 是一个 batch, [T, L] 是一个 instance, [L] 是一个 word, 包含了 L 个 feature.

如下图:

从 Batch Dimension 那一侧看, 每个小方格往背后延申代表一个 element, 如左图中的紫色长条, 一个 pixel 的 RGB 特征, 或者一个词向量. LayerNorm每个 element 计算均值和方差, 可得 BxL 个均值和方差(或BxHxW个). 然后各 element 独立地进行标准化. 右图中的紫色片是一个 feature, 批次中所有 word 的第一个 feature. 每一个这样的片是一个特征, BatchNorm每个 feature 计算均值和方差, 可得 L 个均值和方差(或C个). 然后各 feature 独立地进行标准化.


需要注意的是, Transformer 中并不是按上面所说的 LayerNorm 计算的, 而是给每个 instance 计算均值和方差, 可得 B 个均值和方差, 然后各 instance 独立地进行标准化. 确切说是下图的样子:

1.2 PyTorch 中的 BatchNormLayerNorm
1.2.1 BatchNorm

在 PyTorch 中, BatchNormnn.BatchNorm1d, nn.BatchNorm2dnn.BatchNorm3d, 分别针对不同维度的数据:

  • nn.BatchNorm1d: (N, C) or (N, C, L)
  • nn.BatchNorm2d: (N, C, H, W)
  • nn.BatchNorm3d: (N, C, D, H, W)

查看源码:

class BatchNorm1d(_BatchNorm):
	r"""
	Args:
		num_features: number of features or channels `C` of the input

	Shape:
		- Input: `(N, C)` or `(N, C, L)`, where `N` is the batch size,
		  `C` is the number of features or channels, and `L` is the sequence length
		- Output: `(N, C)` or `(N, C, L)` (same shape as input)
	"""
	def _check_input_dim(self, input):
		if input.dim() != 2 and input.dim() != 3:
			raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")

Examples:

>>> m = nn.BatchNorm1d(100)  # C=100	   # With Learnable Parameters
>>> m = nn.BatchNorm1d(100, affine=False)  # Without Learnable Parameters
>>> input = torch.randn(20, 100)  # (N, C)
>>> output = m(input)
>>> # 或者
>>> input = torch.randn(20, 100, 30)  # (N, C, L)
>>> output = m(input)

γ , β \bm{\gamma}, \bm{\beta} γ,β 是可学习的参数, 且 shape=(C,), 参数名是 .weight.bias:

>>> m = nn.BatchNorm1d(100)
>>> m.weight
Parameter containing:
tensor([1., 1., ..., 1.], requires_grad=True) 
>>> m.weight.shape
torch.Size([100])
>>> m.bias
Parameter containing:
tensor([0., 0., ..., 0.], requires_grad=True)

BatchNorm2dBatchNorm3d 是一样的, 不同之处在于 _check_input_dim(input):

class BatchNorm2d(_BatchNorm):
	r"""
	Args:
		num_features: `C` from an expected input of size `(N, C, H, W)`
	Shape:
		- Input: :math:`(N, C, H, W)`
		- Output: :math:`(N, C, H, W)` (same shape as input)
	"""
	def _check_input_dim(self, input):
		if input.dim() != 4:
			raise ValueError(f"expected 4D input (got {input.dim()}D input)")

Examples:

>>> m = nn.BatchNorm2d(100)
>>> input = torch.randn(20, 100, 35, 45)
>>> output = m(input)
class BatchNorm3d(_BatchNorm):
	r"""
	Args:
		num_features: `C` from an expected input of size `(N, C, D, H, W)`
	Shape:
		- Input: :math:`(N, C, D, H, W)`
		- Output: :math:`(N, C, D, H, W)` (same shape as input)
	"""
	def _check_input_dim(self, input):
		if input.dim() != 5:
			raise ValueError(f"expected 5D input (got {input.dim()}D input)")

Examples:

>>> m = nn.BatchNorm3d(100)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)
1.2.2 LayerNorm

不同于 BatchNorm(num_features), LayerNorm(normalized_shape) 的参数是 input.shape 的后 xdim, 如 [B, T, L] 的后两维 [T, L], 则每个句子会被独立地标准化; 若 L[L], 则每个词向量被独立地标准化.

NLP Example

>>> batch, sentence_length, embedding_dim = 20, 5, 10
>>> embedding = torch.randn(batch, sentence_length, embedding_dim)
>>> layer_norm = nn.LayerNorm(embedding_dim)
>>> layer_norm(embedding)  # Activate module

Image Example

>>> N, C, H, W = 20, 5, 10, 10
>>> input = torch.randn(N, C, H, W)
>>> # Normalize over the last three dimensions (i.e. the channel and spatial dimensions)
>>> layer_norm = nn.LayerNorm([C, H, W])
>>> output = layer_norm(input)

也就是说, 它不仅仅包含了上面理论计算所说的 “各 element 独立地进行标准化” 和 “各 instance 独立地进行标准化”, 而且可以计算任何的 normalize over the last x dimensions.

1.3 计算过程考察
import torch
from torch import nn

# >>> 手动计算 BatchNorm2d >>>
weight = torch.ones([1, 3, 1, 1])
bias = torch.zeros([1, 3, 1, 1])

x = 10 * torch.randn(2, 3, 4, 4) + 100
mean = x.mean(dim=[0, 2, 3], keepdim=True)
std = x.std(dim=[0, 2, 3], keepdim=True, unbiased=False)
print(x)
print(mean)
print(std)

y = (x - mean) / std
y = y * weight + bias
print(y)
# <<< 手动计算 BatchNorm2d <<<

# >>> nn.BatchNorm2d >>>
bnm2 = nn.BatchNorm2d(3)
z = bnm2(x)
print(z)
# <<< nn.BatchNorm2d <<<
print(torch.norm(z - y, p=1))

会发现手动计算和 nn.BatchNorm 计算的几乎完全一致, 可能是有一些 ϵ \epsilon ϵ 的影响吧. 注意, 这里的 unbiased=False 有讲究, 官方文档有说明:

"""
At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
equivalent to `torch.var(input, unbiased=False)`.
However, the value stored in the moving average of the standard-deviation is calculated via
the unbiased  estimator, equivalent to `torch.var(input, unbiased=True)`.

Also by default, during training this layer keeps running estimates of its computed mean and
variance, which are then used for normalization during evaluation.
"""

这里只是想验证计算过程, 不重点关注 unbiased. 简单地提一下:

  • training 阶段计算的是方差的有偏估计, 而存在方差的 moving average 中的方差是无偏估计;
  • during training, 会保存 mean 和 var 的滑动平均, 然后用于 testing 阶段.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/793566.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

linux系统操作/基本命令/vim/权限修改/用户建立

Linux的目录结构&#xff1a; 一&#xff1a;在Linux系统中&#xff0c;路径之间的层级关系&#xff0c;使用:/来表示 注意:1、开头的/表示根目录 2、后面的/表示层级关系 二&#xff1a;在windows系统中&#xff0c;路径之间的层级关系&#xff0c;使用:\来表示 注意:1、D:表示…

【技术追踪】HiDiff:医学图像分割的混合扩散框架(TMI-2024)

传统分割方法与扩散分割方法结合&#xff0c;做大做强~ HiDiff&#xff1a;一种用于医学图像分割的新型混合扩散框架&#xff0c;它可以协同现有判别分割模型和新型生成扩散模型的优势&#xff0c;在腹部器官、脑肿瘤、息肉和视网膜血管分割数据集上性能表现 SOTA &#xff01;…

【eNSP模拟实验】三层交换机实现VLAN通信

实验需求 让PC1和PC2能够互相通讯&#xff0c;其中PC1在vlan10中&#xff0c;PC2在vlan20中。 实验操作 首先把PC1和PC2都配置好ip&#xff0c;配置好之后&#xff0c;点击右下角的应用 然后&#xff0c;在S2交换机&#xff08;S3700&#xff09;上做如下配置 #进入系统 <…

Java基础-组件及事件处理(下)

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 面板组件 说明 常见组件 JScrollPane常用构造方法 JScrollPane设置面板滚动策略的方法 JScrollPane滚…

进程调度篇

在操作系统的广阔领域中&#xff0c;进程调度是其中一个至关重要的环节。它如同操作系统的“交通警察”&#xff0c;负责在多个等待CPU执行的进程间进行高效、公平的分配。本文将带您了解进程调度的基本概念、重要性、常用算法…… 1. 进程调度的基本概念 1.1 进程调度的定义 …

HTAP 数据库在国有大行反洗钱场景的应用

导读 在金融领域&#xff0c;随着数字化服务的深入和监管要求的提高&#xff0c;反洗钱工作变得尤为关键。洗钱活动不仅威胁金融安全&#xff0c;也对社会秩序构成挑战。本文深入探讨了国产 HTAP 分布式数据库 TiDB 在某国有大行反洗钱系统中的应用实践。 依托 TiDB 构建的新…

c++初阶知识——类和对象(1)

目录 1.类和对象 1.1 类的定义 1.2 访问限定符 1.3 类域 2.实例化 2.1 实例化概念 2.2 对象大小 内存对齐规则 3.this指针 1.类和对象 1.1 类的定义 &#xff08;1&#xff09;class为定义类的关键字&#xff0c;Stack为类的名字&#xff0c;{}中为类的主体&#xf…

python怎么调用cmd命令

关于python调用cmd命令&#xff1a; 1、python的OS模块 OS模块调用CMD命令有两种方式&#xff1a;os.popen()、os.system()都是用当前进程来调用。 OS.system是无法获取返回值的。当运行结束后接着往下面执行程序。用法如&#xff1a;OS.system("ipconfig"). OS.…

前台线程和后台线程(了解篇)

在多线程编程中&#xff0c;理解线程的不同类型对于编写高效、稳定的程序至关重要。特别地&#xff0c;前台线程&#xff08;Foreground Threads&#xff09;与后台线程&#xff08;Background Threads&#xff09;在行为上有着根本的区别&#xff0c;这些区别直接影响到程序的…

【Linux 线程】线程的基本概念、LWP的理解

文章目录 一、ps -L 指令&#x1f34e;二、线程控制 一、ps -L 指令&#x1f34e; &#x1f427; 使用 ps -L 命令查看轻量级进程信息&#xff1b;&#x1f427; pthread_self() 用于获取用户态线程的 tid&#xff0c;而并非轻量级进程ID&#xff1b;&#x1f427; getpid() 用…

(CVPR-2024)SwiftBrush:具有变分分数蒸馏的单步文本到图像扩散模型

SwiftBrush&#xff1a;具有变分分数蒸馏的单步文本到图像扩散模型 Paper Title&#xff1a;SwiftBrush: One-Step Text-to-Image Diffusion Model with Variational Score Distillation Paper 是 VinAI Research 发表在 CVPR 24 的工作 Paper地址 Code:地址 Abstract 尽管文本…

前端工程化(01):10款自动化构建工具初识。

前端工程化自动化构建工具是用于简化前端开发流程、提高开发效率和优化项目质量的工具。市面上的工具多种多样&#xff0c;贝格前端工场先介绍一下什么是前端工程化&#xff0c;为什么要前端工程化&#xff0c;以及常用工具&#xff0c;后面会对各种工具逐一介绍。 一、什么是…

【数据结构】一文了解七大排序算法

文章目录 前言一.直接插入排序插入排序思想插入排序代码实现插入排序总结 二.希尔排序希尔排序思想希尔排序代码实现希尔排序总结 三.选择排序选择排序思想选择排序代码实现选择排序总结 四.堆排序堆排序思想堆排序代码实现堆排序总结 五、冒泡排序冒泡排序思想冒泡排序代码实现…

深化信创存储 ,XEDP 与 飞腾腾云 S5000C 完成兼容性认证

近日&#xff0c;XSKY星辰天合的统一数据平台 XEDP 与飞腾信息技术有限公司的高性能服务器 CPU 飞腾腾云 S5000C 完成兼容性互认证。 经过严格的测试与评估&#xff0c;双方产品在技术上兼容良好&#xff0c;运行稳定且性能优异&#xff0c;融合双方优势构筑的软件定义存储系统…

SpringBoot实战:轻松实现接口数据脱敏

一、接口数据脱敏概述 1.1 接口数据脱敏的定义 接口数据脱敏是Web应用程序中一种保护敏感信息不被泄露的关键措施。在API接口向客户端返回数据时&#xff0c;系统会对包含敏感信息&#xff08;如个人身份信息、财务数据等&#xff09;的字段进行特殊处理。这种处理通过应用特…

Go-知识测试-模糊测试

Go-知识测试-模糊测试 1. 定义2. 例子3. 数据结构4. tesing.F.Add5. 模糊测试的执行6. testing.InternalFuzzTarget7. testing.runFuzzing8. testing.fRunner9. FuzzXyz10. RunFuzzWorker11. CoordinateFuzzing12. 总结 建议先看&#xff1a;https://blog.csdn.net/a1879272183…

智能家居开发新进展:乐鑫 ESP-ZeroCode 与亚马逊 ACK for Matter 实现集成

日前&#xff0c;乐鑫 ESP-ZeroCode 与亚马逊 Alexa Connect Kit (ACK) for Matter 实现了集成。这对智能家居设备制造商来说是一项重大进展。开发人员无需编写固件或开发移动应用程序&#xff0c;即可轻松设计符合 Matter 标准的产品。不仅如此&#xff0c;开发者还可以在短短…

Python(四)---序列

文章目录 前言1.列表1.1.列表简介1.2.列表的创建1.2.1.基本方式[]1.2.2.list()方法1.2.3.range()创建整数列表1.2.4.推导式生成列表 1.3. 列表各种函数的使用1.3.1.增加元素1.3.2.删除元素1.3.3.元素的访问和计数1.3.4.切片1.3.5.列表的排序 1.4.二维列表 2.元组2.1.元组的简介…

内网安全:域内信息探测

1.域内基本信息收集 2.NET命令详解 3.内网主要使用的域收集方法 4.查找域控制器的方法 5.查询域内用户的基本信息 6.定位域管 7.powershell命令和定位敏感信息 1.域内基本信息收集&#xff1a; 四种情况&#xff1a; 1.本地用户&#xff1a;user 2.本地管理员用户&#x…

短链接day4

短链接管理 创建短链接数据库表 URI、URL和URN区别 : URI 指的是一个资源 &#xff1b;URL 用地址定位一个资源&#xff1b; URN 用名称定位一个资源。 举个例子&#xff1a; 去寻找一个具体的人&#xff08;URI&#xff09;&#xff1b;如果用地址&#xff1a;XX省XX市XX区…