神经网络中BN层简介及位置分析

1. 简介

Batch Normalization是深度学习中常用的技巧,Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift (Ioffe and Szegedy, 2015) 第一次介绍了这个方法。

这个方法的命名,明明是Standardization, 非要叫Normalization, 把本来就混用、意义不明的两个词更加搅得一团糟。那standardization 和 Normalization有什么区别呢?

一般是下面这样(X是输入数据集):

  • normalization(也叫 min-max scaling),一般译做 “归一化”:

  • standardization,一般译做 “标准化”:

Batch-Norm 是一个网络层,对中间结果作上面说的 standardization 操作。实际上 standardization 也可以叫做 Z-score normalization。所以可以这样理解,standardization 是一种特殊的 normalization。normalization 作为一个 scaling 的大类,包括 min-max scaling,standardization 等。

2. BatchNorm

对输入进行标准化的时候,计算每个特征在样本集合中的均值、方差;然后将每个样本的每个特征减去该特征的均值,并除以它的方差。用数学公式表示,即:

而所谓的BatchNorm, 就是神经网络中间,在一小撮batch样本中进行标准化。具体如下(B: batch size)

注意,BatchNorm作为神经网络的一层,是有两个参数\left ( \gamma, \beta \right )要训练的,分别称为拉伸和偏移参数。可能你会有疑问,既然已经对 u_{b} 作了标准化得到了\hat{u_{b}}  ,为什么还要用 \gamma, \beta将它“还原”呢?

实际上,设置这两个参数是为了给神经网络足够的自由度。如果经过训练\gamma \approx \hat{\sigma} _{batch},\beta \approx \hat{\mu} _{batch}, 说明神经网络认为,不需要进行批标准化即可使loss function最小化,我们也充分“尊重”它的选择。

3. BN 的特点

  • 使用 BatchNorm,我们可以尝试更大的学习率,从而加速收敛,但一般不会改变模型的精度;
  • BatchNorm 的效果依赖于 Batch size;一般需要较大的 Batch size(>16)才能有好的效果
  • 和 Dropout 一样,BatchNorm 在训练和推理时有不同的行为:训练时,它基于每个 batch 计算均值和方差,因此 batch size 必须足够大才能较好反映统计性质;推理时,BatchNorm 则直接用训练集整体的均值和方差进行标准化

训练集整体的均值和方差如何得到?——在每个batch的均值和方差计算中,通过移动平均估算得到。

4. BN 的位置

BatchNorm 究竟应该放在哪,现在还存在争议。很多人说应该放在激活函数之前,但也有声音说应该放在激活函数之后。思考一下,两种说法都有道理。举个简单的例子。

前一种说法是要对 \omega x+b作BatchNorm,这样可以保证 \omega x+b 在0附近, {\sigma}'(\omega x+b) 不至于太小;后一种说法 BatchNorm 的作用对象则直接是x ,这样可以控制梯度 \frac{\partial y}{\partial w} 在合理的范围内,不会因为 x 的极端取值而波动过大。

但现在看来,前一种声音是占上风的:将 BatchNorm 作用在全连接层和卷积层的输出上,激活函数之前。在全连接网络中,顺序是:线性组合+BatchNorm+Activation

对于全连接层,BatchNorm 作用在特征维上。假设输入矩阵大小是 m×n —— m 等于 batch size,即这个小批量中的样本数, n 表示特征数。我们要在每个特征上计算 m 个样本的均值和方差,也就是对每一列做计算。

在卷积神经网络中,顺序是:卷积层+BatchNorm+Activation+池化+全连接。要注意一点是,如果卷积层有K个卷积核(即K个通道),要对每个通道的输出分别做批标准化,且每个通道都拥有独立的拉伸和偏移参数。

对于卷积层,BatchNorm 作用在通道维上。我们先考虑一个 1×1 的卷积层,通道数为 k 。它其实就等价于神经元个数为 k 的全连接层。图片中每个像素点都由一个 k 维的向量表示,可以看作是像素点的 k 个特征。同一批量各个图片的各个像素点,就是不同的样本,共有 m×p×q 个样本, m,p,q 分别为 batch size、高、宽。

类比全连接层 BatchNorm 作用在特征维上,要在每个通道(即每个特征)上计算 m×p×q 个样本的均值和方差

设小批量中有m个样本。在单个通道上,假设卷积计算输出的高和宽分别为p和q。我们需要对该通道中m×p×q个元素同时标准化:对这些元素做标准化计算时,我们使用相同的均值和方差,即该通道中m×p×q个元素的均值和方差。——卷积神经网络之Batch Normalization(一)

5. BN的理解与延伸

BN 效果好是因为 BN 的存在会引入 mini-batch 内其他样本的信息,就会导致预测一个独立样本时,其他样本信息相当于正则项,使得 loss 曲面变得更加平滑,更容易找到最优解。相当于一次独立样本预测可以看多个样本,学到的特征泛化性更强,更加 general

Conv+BN+Relu 是卷积网络的一个常见组合。在模型推理时,BN 层的参数已经固定下来,本质就是一个线性变换。我们可以把 Conv+BN+Relu 进行算子融合,以加速模型推理

除了BN层,还有GN(Group Normalization)、LN(Layer Normalization、IN(Instance Normalization)这些个标准化方法,每个标注化方法都适用于不同的任务。

这个图很好地说明了BatchNorm、LayerNorm、InstanceNorm、GroupNorm的区别。N代表batch size;C代表卷积核个数(通道个数);H,W代表卷积结果的高和宽。

BatchNorm: 计算均值和方差时,考虑N * H * W 个元素;对每个通道分别做标准化

LayerNorm:计算均值和方差时,考虑C * H * W 个元素;对batch中的每个instance分别做标准化

InstanceNorm:计算均值和方差时,考虑H * W 个元素;对每个通道、batch中的每个instance分别做标准化

GroupNorm:介于LayerNorm和InstanceNorm二者之间,将C个通道分组,然后进行标准化。

直觉上来讲,GroupNorm把提取到类似特征的不同卷积核分到同一个group中。对这些卷积核进行标准化,确实make sense. 而且GroupNorm摆脱了对batch size的依赖。

GN在训练集上表现最好,在测试集上稍逊于BN(引自 Group Normalization (Yuxin & Kaiming, 2018))

6. BN vs LN

Transformer模型中用到了LayerNorm,着重对比一下LayerNorm和BatchNorm。

对于一个输入序列 (x1,x2,...,xn) ,每一个 xi 都是 d 维的向量。譬如输入序列是一个句子,每个单词 xi 都用一个 d 维的向量表示。

X轴是序列长度(n),Y轴是特征个数(d),Z轴是Batch size

此时BatchNorm是对图中蓝色框作标准化处理,就像我们上面说的——对每个特征分别做标准化;而LayerNorm针对每一个输入序列,对图中黄色框作标准化处理。总结来说,BatchNorm盯住每一个特征;而LayerNorm盯住的是每一个样本。

那么为什么Transformer模型要用LayerNorm而不是BatchNorm呢?

实际上,序列模型的背景下,BatchNorm有一个天然的硬伤,这使得它在所有序列模型中都不吃香:输入序列的长度(n)可能不一致。一般来说,我们会规定一个最长的序列长度,长度不够的序列用0填充。譬如下图这样,Batch中的序列长短不一。

如果用BatchNorm,以一个feature为例,它的标准化有效范围是蓝色的图,其余用0填充;如果是LayerNorm,对于4个序列,它们的标准化有效范围是黄色的图。

直觉上来说,对于BatchNorm的计算方法,当Batch中序列长度差距过大时,均值和方差的波动也会很大

但这个问题对于LayerNorm来说并不存在,因为它是在每一个序列内部计算均值和方差的。

这样,我们可以直观地理解,为什么BatchNorm对于序列模型并不好用;为什么Transformer要采用LayerNorm

7. BN代码实现

我们翻一翻常见的backbone的结构。可以看到在官方Pytorch的resnet.pyclass BasicBlock中,forward时的基本结构是Conv+BN+Relu:

# 省略了一些地方
class BasicBlock(nn.Module):
    def __init__(self,...) -> None:
        ...
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        # 常见的Conv+BN+Relu
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        # 又是Conv+BN+relu
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        out = self.relu(out)

        return out

resnet作为我们常见的万年青backbone不是没有理由的,效果好速度快方便部署。当然还有很多其他优秀的backbone,这些backbone的内部结构也多为Conv+BN+Relu或者Conv+BN的结构。

参考资料:BatchNorm and its variants - 知乎normalization 和 standardization 到底什么区别?_为什么batch normalization使用standardization而不是normaliz-CSDN博客不论是训练还是部署都会让你踩坑的Batch Normalization - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/172434.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

springcloud医院挂号预约系统源码

开发技术: jdk1.8,mysql5.7,nodejs,idea,vscode springcloud springboot mybatis vue elementui 功能介绍: 用户端: 登录注册 首页显示医生列表,医院简介,点击医生…

Stable Diffusion专场公开课

从SD原理、本地部署到其二次开发 分享时间:11月25日14:00-17:00 分享大纲 从扩散模型DDPM起步理解SD背后原理 SD的本地部署:在自己电脑上快速搭建、快速出图如何基于SD快速做二次开发(以七月的AIGC模特生成系统为例) 分享人简介 July&#…

手把手设计C语言版循环队列(力扣622:设计循环队列)

文章目录 前言描述分析力扣AC代码 力扣: 622.设计循环队列 前言 队列会出现“假溢出”现象,即队列的空间有限,队列是在头和尾进行操作的,当元素个数已经达到最大个数时,队尾已经在空间的最后面了,但是对头…

北邮22级信通院数电:Verilog-FPGA(0)怎么使用modelsim进行仿真?modelsim仿真教程一份请签收~

北邮22信通一枚~ 跟随课程进度更新北邮信通院数字系统设计的笔记、代码和文章 持续关注作者 迎接数电实验学习~ 获取更多文章,请访问专栏: 北邮22级信通院数电实验_青山如墨雨如画的博客-CSDN博客 最近很多uu问我怎么用quartus连接的modelsim软件进…

C#使用MaxMind.GeoIP2数据库查询当前ip地址

GeoLite2-City.mmdb下载 因为比较简单,直接上代码,代码展示获取ip地址的国家和城市信息 using MaxMind.GeoIP2; using MaxMind.GeoIP2.Model; using System; using System.Collections; using System.Collections.Generic; using System.Linq; using Sy…

事关Django的静态资源目录设置与静态资源文件引用(Django的setting.py中的三句静态资源(static)目录设置语句分别是什么作用?)

在Django的setting.py中常见的三句静态资源(static)目录设置语句如下: STATICFILES_DIRS [os.path.join(BASE_DIR, static_list)] # 注意这是一个列表,即可以有多个目录的路径 STATIC_ROOT os.path.join(BASE_DIR, static_root) STATIC_URL /static-url/本文介…

解决开着代理情况下pip或魔搭下载失败

解决开着代理情况下pip或魔搭下载失败 一、前言 最近由于经常配环境导致,老是要来回切clash关掉代理,非常的不方便 如下面的,魔搭模型下载失败 ValueError: invalid model repo path HTTPSConnectionPool(host‘www.modelscope.cn’, port4…

Ubuntu22.04 交叉编译GCC13.2.0 for Rv1126

一、安装Ubuntu22.04 sudo apt install vim net-tools openssh-server 二、安装必要项 sudo apt update sudo apt upgrade sudo apt install build-essential gawk git texinfo bison flex 三、下载必备软件包 1.glibc https://ftp.gnu.org/gnu/glibc/glibc-2.38.tar.gz…

【github】初学者使用指南

作者:20岁爱吃必胜客(坤制作人),近十年开发经验, 跨域学习者,目前于新西兰奥克兰大学攻读IT硕士学位。荣誉:阿里云博客专家认证、腾讯开发者社区优质创作者,在CTF省赛校赛多次取得好成绩。跨领域…

关于2023年11月25日PMI认证考试准考信下载及考场规定等事项通知

各位考生:为保证参加2023年11月25日PMI项目管理资格认证考试的每位考生都能顺利进入考场参加考试,请完整阅读本通知内容。 一、关于准考信下载为确保您顺利进入考场参加11月份考试,请及时登录本网站个人系统下载并打印准考信,准考…

CKD TransBTS:用于脑肿瘤分割的具有模态相关交叉注意的临床知识驱动混合转换器

CKD-TransBTS: Clinical Knowledge-Driven Hybrid Transformer With Modality-Correlated Cross-Attention for Brain Tumor Segmentation CKD TransBTS:用于脑肿瘤分割的具有模态相关交叉注意的临床知识驱动混合转换器背景贡献实验方法how radiologists diagnose b…

风丘电动汽车热管理方案 为您的汽车研发保驾护航

热管理技术作为汽车节能、提高经济性和保障安全性的重要措施,在汽车研发过程中具有重要作用。传统燃油汽车的热管理系统主要包括发动机、变速器散热系统和汽车空调,而电动汽车的热管理系统在燃油汽车热管理架构的基础之上,又增加了电机电控热…

策略模式实践

目录 前言 五个部分 名词解释 代码 controller层 HelloService接口 实现类 自定义注解 上下文 策略工厂 Java SPI配置 验证 前言 五个部分 接口、实现类、自定义注解、上下文、策略工厂 名词解释 自定义注解(方便后期增加实现类后灵活控制策略) 上下文(初始化…

深入了解Java 8 新特性:Stream流的实践应用(二)

阅读建议 嗨,伙计!刷到这篇文章咱们就是有缘人,在阅读这篇文章前我有一些建议: 本篇文章大概8000多字,预计阅读时间长需要10分钟(不要害怕字数过多,其中有一大部分是示例代码,读起…

怎么批量提取文件名字到Excel中?

怎么批量提取文件名字到Excel中?Excel是由微软公司开发的一种电子表格软件,它是Microsoft Office办公套件的一部分。Excel提供了强大的数据处理和分析功能,用户可以使用Excel创建、编辑和管理电子表格,进行各种计算、数据分析、图…

2024测试工程师必学的Jmeter:利用jmeter插件收集性能测试结果汇总报告和聚合报告

利用jmeter插件收集性能测试结果 汇总报告(Summary Report ) 用来收集性能测试过程中的请求以及事务各项指标。通过监听器--汇总报告 可以添加该元件。界面如下图所示 汇总报告界面介绍: 所有数据写入一个文件:保存测试结果到本地…

通过AppLink把拼多多热门榜单商品同步至小红书

上篇说到AppLink当中定时调度方式如何配置,这次来演示一下,如何把热门榜单信息同步至小红书 1.拉取一个定时器作为触发动作,通过配置定时器调度时间将定时策略配置为每天执行一次 2.触发动作完成后通过好单库获取拼多多每日热门榜单&#xf…

steamui.dll找不到指定模块,要怎么修复steamui.dll文件

当我们使用Steam进行游戏时,有时可能会面对一些令人无奈的技术问题。一种常见的问题是“找不到指定模块steamui.dll”,这可能是由于缺少文件、文件损坏或软件冲突等原因导致。但别担心,这篇文章将提供几种解决此问题的方法,并针对…

“KeyarchOS:国产Linux新星的崛起与创新之路“

简介 KeyarchOS是一款由浪潮信息自主研发的服务器操作系统。它因为几个特点而受到我的青睐和一些用户的关注。 首先,KeyarchOS注重安全性和稳定性。它有一些防护和隔离功能,来帮助系统稳定运行,而且是中文语言更接地气。 其次,Ke…

基于Android校园交流uniAPP+vue 微信小程序v7e1

本系统结合现今XX校园交流APP的功能模块以及设计方式进行分析,使用Android平台和Ssm框架进行开发设计,具体研究内容如下: (1) 系统管理员主要对用户管理、类型管理、娱乐天地管理、投诉举报管理、学习平台、我的收藏管理、系统管理等功能进…