BatchNormalization和Layer Normalization解析

Batch Normalization

是google团队2015年提出的,能够加速网络的收敛并提升准确率

1.Batch Normalization原理

图像预处理过程中通常会对图像进行标准化处理,能够加速网络的收敛,如下图所示,对于Conv1来说输入的就是满足某一分布的特征矩阵,但对于Conv2而言输入的feature map就不一定满足某一分布规律了(注意这里所说满足某一分布规律并不是指某一个feature map的数据要满足分布规律,理论上是指整个训练样本集所对应的feature map的数据要满足分布规律)。而我们BN的目的就是使feature map满足均值为0,方差为1的分布规律。

对于一个拥有d维的输入x,我们将对它的每一个维度进行标准化处理。假设我们输入的x是RGB三通道的彩色图像,那么这里的d就是输入图像的channels即d=3,其中x^1就代表我们的R通道所对应的特征矩阵,依次类推。标准化处理也就是分别对R通道,G通道,B通道进行处理。

让feature map满足某一分布规律,理论上是指整个训练样本集所对应feature map的数据要满足分布规律,也就是说要计算出整个训练集的feature map然后再进行标准化处理,对于一个大型的数据集明显是不可能的,所以论文中说的BN,也就是计算一个Batch数据的feature map然后进行标准化(batch越大越接近整个数据集的分布,效果越好)。

上图展示了一个batch size为2(两张图片)的Batch Normalization的计算过程,假设feature1、feature2分别是由image1、image2经过一系列卷积池化后得到的特征矩阵,feature的channel为2,那么x^1代表batch的所有的feature的channel1的数据。然后分别计算x^1和x^2的均值和方差。然后再根据标准差计算公式分别计算每个channel 的值(\varepsilon是很小的常量,放置分母为0的情况)。在训练过程中要去不断地计算每个batch的均值和方差,并使用移动平均(moving average)的方法记录统计的均值和方差,在训练完后我们可以近似认为所统计的均值和方差就等于整个训练集的均值和方差。然后再我们的验证以及预测过程中,就使用统计得到的均值和方差进行标准化处理。

\gamma是用来调整数值分布的方差大小,默认为1,\beta是用来调节数值均值的位置,默认值为0。这两个参数实在反向传播过程中学习到的。

2.使用Pytorch进行实验

在训练过程中,均值和方差是同通过计算当前批次数据得到的记录为\mu _{now},\delta_{now} ^{2},而我们的验证以及预测过程中使用的均值方差是一个统计量为\mu _{statistic},\delta _{statistic}^{2}。具体更新策略如下,其中momentum默认取0.1:

\mu _{statistic+1} = 0.9*\mu _{statistic}+0.1*\mu _{now}\\ \delta _{statistic+1}^{2} = 0.9*\delta _{statistic}^{2}+0.1*\delta _{now}^{2}

(1)bn_process函数是自定义的bn处理方法验证是否和使用官方bn处理方法结果一致。在bn_process中计算输入batch数据的每个维度(这里的维度是channel维度)的均值和标准差(标准差等于方差开平方),然后通过计算得到的均值和总体标准差对feature每个维度进行标准化,然后使用均值和样本标准差更新统计均值和标准差。

(2)初始化统计均值是一个元素为0的向量,元素个数等于channel深度;初始化统计方差是一个元素为1的向量,元素个数等于channel深度,初始化\beta=0,\gamma=1。

import numpy as np
import torch.nn as nn
import torch

def bn_process(feature, mean, var):
    feature_shape = feature.shape
    for i in range(feature_shape[1]):
        # [batch,channel, height, weight]
        feature_t = feature[:, i, :, :]
        mean_t = feature_t.mean()
        #总体标准差
        std_t1 = feature_t.std()
        #样本标准差
        std_t2 = feature_t.std(ddof = 1)

        #bn process
        #这里记得加上eps和pytorch保持一致
        feature[:, i, :, :] = (feature[:, i, :, :] - mean_t) / np.sqrt(std_t1 ** 2+ 1e-5)
        #更新计算均值
        mean[i]  = mean[i]*0.9 + mean_t * 0.1
        var[i] = var[i] * 0.9 + (std_t2 ** 2) * 0.1
    print(feature)

#随机生成一个batch为2,channel为2,height=width=2的特征向量
#[batch, channel, height, width]
feature1 = torch.randn(2, 2, 2, 2)
#初始化统计均值和方差
calculate_mean = [0.0, 0.0]
calculate_var = [1.0, 1.0]
#print(feature1.numpy())

#注意要使用copy()深拷贝
bn_process(feature1.numpy().copy(), calculate_mean, calculate_var)

bn = nn.BatchNorm2d(2, eps =  1e-5)
output = bn(feature1)
print(output)

 

3.使用BN时需要注意的问题

(1)训练时要将training采纳数设置为True,在验证时将training参数设置为False。在Pytorch中了可以通过创建模型的model.train()和model.eval()方法控制。

(2)batch size尽可能设置大点,设置小后表现很糟糕,设置的越大求的均值和方差越接近整个训练集的均值和方差。

(3)建议将bn层放在卷积层和激活层之间,且卷积层不要使用偏置bias,因为没有用,参考下图推理,及时使用了偏置bias求出的结果也是一样的。

 


Layer Normalization

Layer Normalization针对自然语言处理提出的,为什么不用BN呢,因为在RNN这类时序网络中,时序的长度并不是一个定值(网络深度不一定相同),比如每句话的长短都不一定相同,所以很难去使用BN,所以作者提出了Layer Normalization(图像处理领域BN比LN更有效),但现在很多人将自然语言领域的模型用来处理图像,比如Vision Transformer,此时会涉及到LN。

直接看Pytorch 官方给出的关于LayerNorm 的介绍。不同的是,BN是对一个batch数据的每个channel进行Norm处理,一个for循环,但LN是对单个数据的制定维度进行Norm处理与batch无关而且BN中训练时是需要累计moving_mean和moving_var两个变量的(所以BN中有4个参数moving_mean,moving_var,\beta ,\gamma),但LN不需要累计只有\beta ,\gamma两个参数。

在Pytorch的LayerNorm类中有个normalized_shape参数,可以指定要Norm的维度(注意,函数说明中the last certain number of dimensions,指定的维度必须是从最后一维开始)。比如我们的数据shape是[4,2,3],那么normalized_shape可以是[3](最后一维进行Norm处理),也可以是[2,3](Norm最后两个维度),也可以是整个维度[4,2,3],但不能是[2]或者[4,2],否则会报错。

y = \frac{x-E[X]}{\sqrt{Var[x]+\varepsilon}}*\gamma +\beta

import torch
import torch.nn as nn

def layer_norm_process(feature:torch.Tensor, beta=0.,gamma = 1.,eps=1e-5):
    var_mean = torch.var_mean(feature, dim = -1, unbiased = False)
    #均值
    mean = var_mean[1]
    #方差
    var = var_mean[0]

    #layer norm process
    feature  = (feature - mean[..., None]) / torch.sqrt(var[..., None] + eps)
    feature = feature*gamma+beta

    return feature

def main():
    t = torch.randn(4, 2, 3)
    print(t)
    #仅在最后一个维度上做norm处理
    norm = nn.LayerNorm(normalized_shape= t.shape[-1], eps = 1e-5)
    #官方layer norm处理
    t1 = norm(t)
    #自己实现的layer norm处理
    t2 = layer_norm_process(t, eps = 1e-5)
    print("t1:\n",t1)
    print("t2:\n",t2)

if __name__ == '__main__':
    main()
tensor([[[ 0.8512,  0.4201, -0.3457],
         [ 0.4701, -0.0647,  0.0733]],

        [[-0.9950, -0.4634,  0.0540],
         [ 0.4096,  0.4037, -0.0914]],

        [[-2.3165,  1.3059,  0.3183],
         [-0.9716,  0.4956,  0.4524]],

        [[-0.6209, -0.5958,  0.3212],
         [-0.8762,  0.3176, -0.5427]]])
t1:
 tensor([[[ 1.0963,  0.2254, -1.3218],
         [ 1.3697, -0.9893, -0.3804]],

        [[-1.2302,  0.0110,  1.2192],
         [ 0.7198,  0.6942, -1.4140]],

        [[-1.3642,  1.0050,  0.3591],
         [-1.4137,  0.7385,  0.6752]],

        [[-0.7355, -0.6783,  1.4138],
         [-1.0123,  1.3614, -0.3490]]], grad_fn=<NativeLayerNormBackward0>)
t2:
 tensor([[[ 1.0963,  0.2254, -1.3218],
         [ 1.3697, -0.9893, -0.3804]],

        [[-1.2302,  0.0110,  1.2192],
         [ 0.7198,  0.6942, -1.4140]],

        [[-1.3642,  1.0050,  0.3591],
         [-1.4137,  0.7385,  0.6752]],

        [[-0.7355, -0.6783,  1.4138],
         [-1.0123,  1.3614, -0.3490]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/717187.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Dart 弱引用进阶

前言 村里的老人说&#xff1a;“真正的强者&#xff0c;都是扮猪吃老虎。” 日常开发中经常需要用到弱引用&#xff0c;Dart 语言里也有提供弱引用的接口 WeakReference&#xff0c;我们可以基于它开发更强大的复杂结构。 在前面的文章中&#xff0c;我们用到了一个以弱引用…

无需配置MySQL,Navicat也有在线版了?

前言 随着互联网技术的飞速发展&#xff0c;远程办公和在线协作成为了新的趋势。为了满足这一需求&#xff0c;TitanIDE模板市场近日上线了Navicat模板&#xff0c;使得数据库管理变得更加便捷、高效。现在&#xff0c;用户只需在浏览器打开TitanIDE&#xff0c;即可轻松使用N…

2024年上网行为审计软件排名,推荐这五款上网行为管理软件

上网行为审计软件是企业IT管理中不可或缺的一部分&#xff0c;它们旨在帮助组织监控、管理、审计员工的互联网使用情况&#xff0c;确保网络资源的合理利用&#xff0c;提高工作效率&#xff0c;同时维护企业信息安全。下面将介绍几款市场上知名的上网行为审计软件&#xff0c;…

通用大模型VS垂直大模型 难兄难弟?

在互联网&#x1f30f;背景下的快速发展与人工智能AI的崛起是21世纪科技进步的重要标志&#x1f3c5;&#xff0c; 近年来&#xff0c;随着计算能力的显著提升&#x1f680;、海量数据的积累以及算法创新&#xff0c;尤其是深度学习技术的突破&#xff0c;人工智能领域迎来了…

Windows 与 Java 环境下的 Redis 利用分析

1 前言 在最近的一次攻防演练中&#xff0c;遇到了两个未授权访问的 Redis 实例。起初以为可以直接利用&#xff0c;但后来发现竟然是Windows Java (Tomcat)。因为网上没有看到相关的利用文章&#xff0c;所以在经过摸索&#xff0c;成功解决之后决定简单写一写。 本文介绍了…

树莓派pico入坑笔记,快捷键键盘制作

使用usb_hid功能制作快捷键小键盘&#xff0c;定义了6个键&#xff0c;分别是 ctrlz ctrlv ctrlc ctrla ctrlw ctrln 对应引脚 board.GP4, board.GP8, board.GP13 board.GP28, board.GP20, board.GP17 需要用到的库&#xff0c;记得复制进单片机存储里面 然后是main主程…

【leetcode刷题】面试经典150题 88.合并两个有序数组

leetcode刷题 面试经典150 88. 合并两个有序数组 难度&#xff1a;简单 文章目录 一、题目内容二、自己实现代码2.1 实现思路2.2 实现代码2.3 结果分析 三、 官方解法3.1 直接合并后排序3.1.1 算法实现3.1.2 代码实现3.1.3 代码分析 3.2 双指针3.2.1 算法实现3.2.2 代码实现3.2…

列表(list)(Python)

文章目录 一、定义二、列表常用操作 一、定义 list ["张三", "李四", "王五", "赵六"]二、列表常用操作 分类关键字/函数/方法说明增加列表.append(值)在列表末尾追加值列表.insert(索引&#xff0c; 值)在指定位置插入值&#xff…

从11个视角看全球Rust程序员1/4:深度解读JetBrains最新报告

讲动人的故事,写懂人的代码 五个月前,编程界的大佬JetBrains发布了他们的全球开发者年度报告。 小吾从这份报告中找出了下面11个关于全球程序员如何使用Rust的有趣的趋势,让你学习和使用Rust更轻松。 1 这两年有多少程序员在工作中使用了Rust? 2 全球程序员使用Rust有多…

2024年数字媒体、新闻与管理国际会议(DMJM 2024)

2024年数字媒体、新闻与管理国际会议&#xff08;DMJM 2024&#xff09; 2024 International Conference on Digital Media, Journalism, and Management 【重要信息】 大会地点&#xff1a;长沙 大会官网&#xff1a;http://www.cdmjm.com 投稿邮箱&#xff1a;cdmjmsub-conf…

colab挂载googledrive云盘

参考&#xff1a; Google Colab简易\入门\常规\常用操作和命令_colab快捷键-CSDN博客 首先新建一个或者打开一个笔记本。 等待连接成功。 点击这个图标&#xff0c;变为如下这样: 挂载成功。 这里我是用现有的ipynb文件挂载&#xff1a; 他让我运行代码: 他会提示这个运行这…

相约北京“信通院数据智能大会”

推动企业数智化转型发展&#xff0c;凝聚产业共识&#xff0c;引领行业发展方向&#xff0c;摩斯将参与信通院首届“数据智能大会”&#xff08;6月19-20日&#xff0c;北京&#xff09;。 本次大会设置多个主题论坛&#xff0c;将发布多项研究成果&#xff0c;分享产业最新实…

微信核销通知地址设置返回:请开通回调通知产品权限

1.背景 微信代金券设置核销通知地址时返回: {"code":"REQUEST_BLOCKED","message":"请开通回调通知产品权限\n"} 2.解决方法 登录对应的微信商户号,然后访问如下链接: 微信支付 - 中国领先的第三方支付平台 &#xff5c; 微信支付提…

从11个视角看全球Rust程序员2/4:深度解读JetBrains最新报告

讲动人的故事,写懂人的代码 5 Rust代码最常使用什么协议与其他代码交互? REST API: 2022年:51%2023年:51%看上去REST API的使用比例挺稳定的,没啥变化。语言互操作性(Language Interop): 2022年:53%2023年:43%语言互操作性的比例在2023年下来了一些,掉了10个百分点…

编译器优化入门(基于ESP32)

主要参考资料&#xff1a; kimi: https://kimi.moonshot.cn/ ESP-IDF 支持多种编译器&#xff0c;但默认情况下&#xff0c;它使用的是乐鑫官方提供的 Xtensa 编译器&#xff0c;这是一个针对 ESP32 芯片架构&#xff08;Tensilica Xtensa LX6 微处理器&#xff09;优化的交叉编…

springboot应用启动太慢排查 半天才打印日志

springboot应用启动太慢排查 半天才打印日志 解决办法 hostnamectl 命令查看主机名 vim /etc/hosts 加上主机名配置 127.0.0.1 hostname

【2024最新华为OD-C/D卷试题汇总】[支持在线评测] 火星字符串(100分) - 三语言AC题解(Python/Java/Cpp)

&#x1f36d; 大家好这里是清隆学长 &#xff0c;一枚热爱算法的程序员 ✨ 本系列打算持续跟新华为OD-C/D卷的三语言AC题解 &#x1f4bb; ACM银牌&#x1f948;| 多次AK大厂笔试 &#xff5c; 编程一对一辅导 &#x1f44f; 感谢大家的订阅➕ 和 喜欢&#x1f497; &#x1f…

Elixir学习笔记——Erlang 库

Elixir 提供了与 Erlang 库的出色互操作性。事实上&#xff0c;Elixir 不鼓励简单地包装 Erlang 库&#xff0c;而是直接与 Erlang 代码交互。在本节中&#xff0c;我们将介绍一些 Elixir 中没有的最常见和最有用的 Erlang 功能。 Erlang 模块的命名约定与 Elixir 不同&#x…

电商风控指南 | 直播间里的藏匿的“羊毛党”,普通消费者看不到

目录 直播间里的羊毛党 电商要针对性进行防范 随着618网购节的开启&#xff0c;各大电商平台的直播间再次成为消费者关注的焦点。在5月20日的一场酒水电商直播中&#xff0c;主播仅用43分钟便实现了成交额破亿&#xff0c;售出3万瓶白酒。然而&#xff0c;这些“秒杀”特价商品…

Excel加密怎么设置?这5个方法不容错过!(2024总结)

Excel加密怎么设置&#xff1f;如何不让别人未经允许查看我的excel文件&#xff1f;如果您也有这些疑问&#xff0c;那么千万不要错过本篇文章了。今天小编将向大家分享excel加密的5个简单方法&#xff0c;保证任何人都可以轻松掌握&#xff01;毫无疑问的是&#xff0c;为Exce…