pytorch中的归一化:BatchNorm、LayerNorm 和 GroupNorm

1 归一化概述

训练深度神经网络是一项具有挑战性的任务。 多年来,研究人员提出了不同的方法来加速和稳定学习过程。 归一化是一种被证明在这方面非常有效的技术。

1.1 为什么要归一化

数据的归一化操作是数据处理的一项基础性工作,在一些实际问题中,我们得到的样本数据都是多个维度的,即一个样本是用多个特征来表示的,数据样本的不同特征可能会有不同的尺度,这样的情况会影响到数据分析的结果。为了解决这个问题,需要进行数据归一化处理。原始数据经过数据归一化后,各特征处于同一数量级,适合进行综合对比评价。

例如,我们现在用两个特征构建一个简单的神经网络模型。 这两个特征一个是年龄:范围在 0 到 65 之间,另一个是工资:范围从 0 到 10 000。我们将这些特征提供给模型并计算梯度。

不同规模的输入导致不同的权重更新和优化器的步骤向最小值的方向不均衡。这也使损失函数的形状不成比例。在这种情况下,就需要使用较低的学习速率来避免过冲,这就意味着较慢的学习过程。

所以我们的解决方案是输入进行归一化,通过减去平均值(定心)并除以标准偏差来缩小特征。

此过程也称为“漂白”,处理后所有的值具有 0 均值和单位方差,这样可以提供更快的收敛和更稳定的训练。

1.2 归一化的作用

在深度学习中,数据归一化是一项关键的预处理步骤,用于优化神经网络模型的训练过程和性能。归一化技术有助于解决梯度消失和梯度爆炸问题,加快模型的收敛速度,并提高模型的鲁棒性和泛化能力。详细介绍如下:

  • 梯度消失和梯度爆炸问题:在深度神经网络中,梯度消失和梯度爆炸是常见的问题。数据归一化可以缓解这些问题,使得梯度在合理的范围内进行传播,有助于提高模型的训练效果。

  • 特征尺度不一致:深度学习模型对特征的尺度非常敏感。如果不同特征具有不同的尺度范围,某些特征可能会主导模型的训练过程,而其他特征的影响可能被忽略。通过数据归一化,可以将不同特征的尺度统一到相同的范围,使得模型能够平衡地对待所有特征,避免尺度不一致带来的偏差。

  • 模型收敛速度:数据归一化可以加快模型的收敛速度。当数据被归一化到一个较小的范围时,模型可以更快地找到合适的参数值,并减少训练过程中的震荡和不稳定性。这样可以节省训练时间,提高模型的效率。

  • 鲁棒性和泛化能力:通过数据归一化,模型可以更好地适应不同的数据分布和噪声情况。归一化可以增加模型的鲁棒性,使得模型对输入数据的变化和扰动具有更好的容忍度。同时,归一化还有助于提高模型的泛化能力,使得模型在未见过的数据上表现更好。

1.3 归一化的步骤

归一化通过对数据的特定维度上进行归一化操作来调整输入数据的分布,使其具有零均值和单位方差。一般通过以下步骤对输入进行归一化:

  • 对于给定的输入数据,在给定的维度上计算其均值和方差。

  • 使用计算得到的均值和方差对输入数据进行标准化,将其零均值化并使其具有单位方差。

  • 对标准化后的数据进行缩放和平移操作,通过可学习的参数进行调整,以恢复模型对数据的表达能力。

进一步地,在归一化中,通过缩放和平移操作,引入了可学习的参数,即缩放参数(scale)和平移参数(shift)。这些参数用于在标准化后的数据上进行线性变换,以恢复模型的表达能力。

具体而言,在每个特征维度上,假设归一化后的数据为\hat{x},则通过以下公式计算最终的输出:y=\gamma \hat{x}+\beta 。其中, y是最终的输出, \gamma是缩放参数(scale), \beta是平移参数(shift)。这两个参数是可学习的,它们可以通过反向传播和优化算法(如随机梯度下降)来进行更新。在训练过程中,模型会通过梯度下降的方式来调整这些参数,使得模型能够自适应地对不同的数据分布进行缩放和平移。这样,模型可以根据实际情况自由地调整每个特征的重要性和偏置,从而更好地适应不同的数据分布。

2 pytorch中的归一化

BatchNorm、LayerNorm 和 GroupNorm 都是深度学习中常用的归一化方式。它们通过将输入归一化到均值为 0 和方差为 1 的分布中,来防止梯度消失和爆炸,并提高模型的泛化能力。

2.1 BatchNorm

一般 CNN 中,卷积层后面会跟一个 BatchNorm 层,减少梯度消失和爆炸,提高模型的稳定性。

在PyTorch中,可以使用torch.nn.BatchNorm1dtorch.nn.BatchNorm2dtorch.nn.BatchNorm3d等批归一化层来实现批归一化。

实例代码:

import torch
import torch.nn as nn
import numpy as np

feature_array = np.array([[[[1, 0],  [0, 2]],
                           [[3, 4],  [1, 2]],
                           [[-2, 9], [7, 5]],
                           [[2, 3],  [4, 2]]],

                          [[[1, 2],  [-1, 0]],
                            [[1, 2], [3, 5]],
                            [[4, 7], [-6, 4]],
                            [[1, 4], [1, 5]]]], dtype=np.float32)

feature_tensor = torch.tensor(feature_array.copy(), dtype=torch.float32)
bn_out = nn.BatchNorm2d(num_features=4, eps=1e-5)(feature_tensor)
print(bn_out)

for i in range(feature_array.shape[1]):
    channel = feature_array[:, i, :, :]
    mean = feature_array[:, i, :, :].mean()
    var = feature_array[:, i, :, :].var()
    print(mean)
    print(var)

    feature_array[:, i, :, :] = (feature_array[:, i, :, :] - mean) / np.sqrt(var + 1e-5)
print(feature_array)

运行结果显示:

tensor([[[[ 0.3780, -0.6299],
          [-0.6299,  1.3859]],

         [[ 0.2847,  1.0441],
          [-1.2339, -0.4746]],

         [[-1.1660,  1.1660],
          [ 0.7420,  0.3180]],

         [[-0.5388,  0.1796],
          [ 0.8980, -0.5388]]],


        [[[ 0.3780,  1.3859],
          [-1.6378, -0.6299]],

         [[-1.2339, -0.4746],
          [ 0.2847,  1.8034]],

         [[ 0.1060,  0.7420],
          [-2.0140,  0.1060]],

         [[-1.2572,  0.8980],
          [-1.2572,  1.6164]]]], grad_fn=<NativeBatchNormBackward0>)
0.625
0.984375
2.625
1.734375
3.5
22.25
2.75
1.9375
[[[[ 0.37796253 -0.6299376 ]
   [-0.6299376   1.3858627 ]]

  [[ 0.28474656  1.0440707 ]
   [-1.2339017  -0.4745776 ]]

  [[-1.1659975   1.1659975 ]
   [ 0.7419984   0.3179993 ]]

  [[-0.53881454  0.17960484]
   [ 0.8980242  -0.53881454]]]


 [[[ 0.37796253  1.3858627 ]
   [-1.6378376  -0.6299376 ]]

  [[-1.2339017  -0.4745776 ]
   [ 0.28474656  1.8033949 ]]

  [[ 0.10599977  0.7419984 ]
   [-2.0139956   0.10599977]]

  [[-1.2572339   0.8980242 ]
   [-1.2572339   1.6164436 ]]]]

2.2 LayerNorm

Transformer block 中会使用到 LayerNorm , 一般输入尺寸形为 :(batch_size, token_num, dim),会在最后一个维度做 归一化: nn.LayerNorm(dim)

与批归一化不同,层归一化的计算是针对单个样本的特征维度进行的,其目的是将每个样本在单个层内的特征维度进行归一化,以增强特征之间的独立性,并提供更稳定的特征表示。具有以下优点:

  • 适用于处理单个样本:相比于批归一化,层归一化的计算是基于单个样本在单个层内的特征维度进行的,而不依赖于小批量样本的统计信息。这使得层归一化适用于处理单个样本的情况,例如在循环神经网络(RNN)中,每个时间步的输入可以被看作是一个单独的样本。

  • 适用于动态计算图和序列数据:由于层归一化的计算不依赖于小批量样本的统计信息,它更适合在动态计算图和序列数据上进行计算。在处理变长序列数据或使用动态计算图的场景下,层归一化可以提供更好的性能和效果。

另外,在Transformer模型中使用层归一化(Layer Normalization)主要是因为其独立的特征维度归一化的性质:Transformer模型的核心是自注意力机制,它在每个位置对输入序列中的所有位置进行关注。由于每个位置的特征维度可以看作是独立的,对每个位置进行层归一化可以提供更稳定的特征表示,减少特征之间的耦合,这有助于模型更好地学习位置之间的依赖关系,提高模型的表示能力。并且,减少了特征之间的内部协变量转移,这有助于缓解深度神经网络中常见的梯度消失和梯度爆炸问题,提高模型的训练效果和收敛速度。

层归一化(Layer Normalization)一般来说在激活函数之前应用:在层归一化之后应用激活函数可以使得激活函数的输入保持在归一化后的范围内,避免激活函数的输入过大或过小。这种方式与批标准化的应用顺序相似。

例如:在传统的RNN中,通常是将输入序列经过线性变换后再应用激活函数(如tanh或ReLU)进行非线性变换。然而,这样的操作可能会导致梯度消失或梯度爆炸问题,并且不同时间步的输入之间可能存在较大的变化。通过在激活函数之前应用层归一化,可以解决上述问题。

示例代码:

import torch
import torch.nn as nn
import numpy as np

feature_array = np.array([[[[1, 0],  [0, 2]],
                           [[3, 4],  [1, 2]],
                           [[2, 3],  [4, 2]]],

                          [[[1, 2],  [-1, 0]],
                            [[1, 2], [3, 5]],
                            [[1, 4], [1, 5]]]], dtype=np.float32)


feature_array = feature_array.reshape((2, 3, -1)).transpose(0, 2, 1)
feature_tensor = torch.tensor(feature_array.copy(), dtype=torch.float32)

ln_out = nn.LayerNorm(normalized_shape=3)(feature_tensor)
print(ln_out)

b, token_num, dim = feature_array.shape
feature_array = feature_array.reshape((-1, dim))
for i in range(b*token_num):
    mean = feature_array[i, :].mean()
    var = feature_array[i, :].var()
    print(mean)
    print(var)

    feature_array[i, :] = (feature_array[i, :] - mean) / np.sqrt(var + 1e-5)
print(feature_array.reshape(b, token_num, dim))

代码运行显示:

tensor([[[-1.2247,  1.2247,  0.0000],
         [-1.3728,  0.9806,  0.3922],
         [-0.9806, -0.3922,  1.3728],
         [ 0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000],
         [-0.7071, -0.7071,  1.4142],
         [-1.2247,  1.2247,  0.0000],
         [-1.4142,  0.7071,  0.7071]]], grad_fn=<NativeLayerNormBackward0>)
2.0
0.6666667
2.3333333
2.888889
1.6666666
2.888889
2.0
0.0
1.0
0.0
2.6666667
0.88888884
1.0
2.6666667
3.3333333
5.555556
[[[-1.2247357   1.2247357   0.        ]
  [-1.3728105   0.980579    0.3922316 ]
  [-0.98057896 -0.39223155  1.3728106 ]
  [ 0.          0.          0.        ]]

 [[ 0.          0.          0.        ]
  [-0.70710295 -0.70710295  1.4142056 ]
  [-1.2247427   1.2247427   0.        ]
  [-1.4142123   0.7071062   0.7071062 ]]]

2.3 GroupNorm

batch size 过大或过小都不适合使用 BN,而是使用 GN。

(1)当 batch size 过大时,BN 会将所有数据归一化到相同的均值和方差。这可能会导致模型在训练时变得非常不稳定,并且很难收敛。

(2)当 batch size 过小时,BN 可能无法有效地学习数据的统计信息。

比如,Deformable DETR 中,就用到了 GroupNorm

示例代码:

import torch
import torch.nn as nn
import numpy as np

feature_array = np.array([[[[1, 0],  [0, 2]],
                           [[3, 4],  [1, 2]],
                           [[-2, 9], [7, 5]],
                           [[2, 3],  [4, 2]]],

                          [[[1, 2],  [-1, 0]],
                            [[1, 2], [3, 5]],
                            [[4, 7], [-6, 4]],
                            [[1, 4], [1, 5]]]], dtype=np.float32)

feature_tensor = torch.tensor(feature_array.copy(), dtype=torch.float32)
gn_out = nn.GroupNorm(num_groups=2, num_channels=4)(feature_tensor)
print(gn_out)

feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((4, 2, 2, 2))

for i in range(feature_array.shape[0]):
    channel = feature_array[i, :, :, :]
    mean = feature_array[i, :, :, :].mean()
    var = feature_array[i, :, :, :].var()
    print(mean)
    print(var)

    feature_array[i, :, :, :] = (feature_array[i, :, :, :] - mean) / np.sqrt(var + 1e-5)
feature_array = feature_array.reshape((2, 2, 2, 2, 2)).reshape((2, 4, 2, 2))
print(feature_array)

运行结果显示:

tensor([[[[-0.4746, -1.2339],
          [-1.2339,  0.2847]],

         [[ 1.0441,  1.8034],
          [-0.4746,  0.2847]],

         [[-1.8240,  1.6654],
          [ 1.0310,  0.3965]],

         [[-0.5551, -0.2379],
          [ 0.0793, -0.5551]]],


        [[[-0.3618,  0.2171],
          [-1.5195, -0.9406]],

         [[-0.3618,  0.2171],
          [ 0.7959,  1.9536]],

         [[ 0.4045,  1.2136],
          [-2.2923,  0.4045]],

         [[-0.4045,  0.4045],
          [-0.4045,  0.6742]]]], grad_fn=<NativeGroupNormBackward0>)
1.625
1.734375
3.75
9.9375
1.625
2.984375
2.5
13.75
[[[[-0.4745776  -1.2339017 ]
   [-1.2339017   0.28474656]]

  [[ 1.0440707   1.8033949 ]
   [-0.4745776   0.28474656]]

  [[-1.8240178   1.6654075 ]
   [ 1.0309665   0.3965256 ]]

  [[-0.55513585 -0.23791535]
   [ 0.07930512 -0.55513585]]]


 [[[-0.3617867   0.21707201]
   [-1.5195041  -0.9406454 ]]

  [[-0.3617867   0.21707201]
   [ 0.79593074  1.9536481 ]]

  [[ 0.40451977  1.2135593 ]
   [-2.2922788   0.40451977]]

  [[-0.40451977  0.40451977]
   [-0.40451977  0.67419964]]]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/243654.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

六.聚合函数

聚合函数 1.什么是聚合函数1.1AVG和SUM函数1.2MIN和MAX函数1.3COUNT函数 2.GROUP BY2.1基本使用2.2使用多个列分组2.3GROUP BY中使用WITH ROLLUP 3.HAVING3.1基本使用3.2WHERE和HAVING的区别 4.SELECT的执行过程4.1查询的结构4.2SELECT执行顺序4.3SQL执行原理 1.什么是聚合函数…

JOSEF约瑟快速跳闸继电器RXMS1RK216063 DC220V

系列型号 RXMS1 RK 216 437快速跳闸继电器&#xff1b;RXMS1 RK 216 237快速跳闸继电器&#xff1b; RXMS1 RK 216 449快速跳闸继电器&#xff1b;RXMS1 RK 216 249快速跳闸继电器&#xff1b; RXMS1 RK 216 450快速跳闸继电器&#xff1b;RXMS1 RK 216 250快速跳闸继电器&a…

HarmonyOS使用Web组件

Web组件的使用 1 概述 相信大家都遇到过这样的场景&#xff0c;有时候我们点击应用的页面&#xff0c;会跳转到一个类似浏览器加载的页面&#xff0c;加载完成后&#xff0c;才显示这个页面的具体内容&#xff0c;这个加载和显示网页的过程通常都是浏览器的任务。 ArkUI为我…

C语言实现贪吃蛇【完整版】

贪吃蛇 文章目录 贪吃蛇使用到的WIN32一些接口简单介绍控制台窗口大小隐藏光标控制光标的位置获取键盘的值的情况字符问题 游戏逻辑开始游戏打印地图初始化贪吃蛇创建食物 运行游戏控制蛇的移动 运行结束 贪吃蛇实现出来的效果如下&#xff1a; 贪吃蛇小游戏录屏 完整代码&…

【Spark精讲】Spark存储原理

目录 类比HDFS的存储架构 Spark的存储架构 存储级别 RDD的持久化机制 RDD缓存的过程 Block淘汰和落盘 类比HDFS的存储架构 HDFS集群有两类节点以管理节点-工作节点模式运行&#xff0c;即一个NameNode(管理节点)和多个DataNode(工作节点)。 Namenode管理文件系统的命名空…

销售技巧培训之如何提升金融销售技巧

销售技巧培训之如何提升金融销售技巧 在金融行业&#xff0c;销售技巧是决定业绩成败的关键因素之一。无论是销售保险、股票、债券&#xff0c;还是提供投资咨询服务&#xff0c;都需要掌握一定的销售技巧。本文将探讨如何提升金融销售技巧&#xff0c;通过案例分析&#xff0…

如何提升网络安全技术【蓝队】?在职学长告诉你

网络安全的防守技术是网络安全工程师必备技能&#xff0c;只有攻防兼备的白帽子&#xff0c;才算是真正的网安精英。 网络安全的攻击技术在前面我已经讲过了&#xff0c;感兴趣的可以去看看&#xff1a; 90%的人都不算会网络安全&#xff0c;这才是真正的白帽子技术【红队】 . …

GitHub帐户管理更改电子邮件

登录到您的 GitHub 帐户&#xff1a; 前往 GitHub 网站并使用您的凭据登录。 访问个人设置&#xff1a; 单击右上角的您的头像&#xff0c;然后选择“Settings”&#xff08;设置&#xff09;。 选择电子邮件选项卡&#xff1a; 在左侧边栏中选择“Emails”&#xff08;电子邮…

单片稳压集成电路78LXX系列——固定的电压输出,适用于需100mA电源供给的应用场合(网络产品,声卡和电脑主板等产品)

78LXX系列是一款单片稳压集成电路&#xff0c;它们有一系列固定的电压输出&#xff0c;适用于需100mA电源供给的应用场合。78LXX系列采用T0-92和SOT-89-3L的封装形式。 主要特点&#xff1a; ● 最大输出电流为100mA ● 输出电压为3.3V. 5V. 6V. 8V、9V、10V、 12V和15V ● 热…

深入理解 Go Channel:解密并发编程中的通信机制

一、Channel管道 1、Channel说明 共享内存交互数据弊端 单纯地将函数并发执行是没有意义的。函数与函数间需要交互数据才能体现编发执行函数的意义虽然可以使用共享内存进行数据交换&#xff0c;但是共享内存在不同的goroutine中容易发送静态问题为了保证数据交换的正确性&am…

HTTP、HTTPS、SSL协议以及相关报文讲解

目录 HTTP/HTTPS介绍 HTTP/HTTPS基本信息 HTTP如何实现有状态 HTTP请求与应答报文 HTTP请求报文 HTTP响应报文 SSL协议 SSL单向认证 SSL双向认证 HTTP连接建立与传输步骤 HTTP访问全过程相关报文&#xff08;以访问www.download.cucdccom为例子&#xff09; DNS报文…

什么是SEO优化

什么是SEO&#xff0c;百度其实就有答案&#xff0c;只是回答的很基础&#xff0c;说的都是基础概念&#xff0c;没有具体的体现在里面&#xff0c;SEO除了基础概念&#xff0c;还要有相应的构架&#xff0c;不然怎么弄都是一场空而已。 关于什么是SEO的文章导读&#xff1f; 1…

vite+vue3+electron搭建项目

编辑器使用vscode&#xff0c;打开一个空文件夹 第一步 初始化vite项目 初始化vite项目&#xff0c;命令 npm init vite 第二步 下载依赖 进入新建的项目&#xff0c;下载依赖&#xff0c;命令 cd vite-projec npm i第三步 使用cnpm下载 electron依赖 新建一个终端&#…

springboot应用,cpu高、内存高问题排查

前几天&#xff0c;排查了2个生产问题。一个cpu高&#xff0c;一个内存高。今天把解决过程整理一下 文章目录 1、cpu高问题排查1.1、获取栈日志1.2、分析栈日志 2、内存高问题排查2.1、dump日志分析2.2、堆内存使用情况2.3、解决方案2.4、arthas trace解决问题2.5、总结 1、cp…

【玩转TableAgent数据智能分析】会话式数据分析,所需即所得!

目录 1 TableAgent介绍 2 TableAgent五大优点 3 体验TableAgent 3.1 登录TableAgent平台 3.2 会话式数据分析 4 总结 【优化改善】 【对比TableAgent与文心一言- E言易图】 1 TableAgent介绍 TableAgent是一款数据集成和分析平台&#xff0c;它可以帮助用户从多个数据源中…

Maven的安装配置流程

步骤一&#xff1a;下载Maven 打开Maven官方网站&#xff0c;进入"Download"页面。我这里有下好的&#xff0c;网盘链接在文末&#xff01;&#xff01; 在"Download"页面中找到最新版本的Maven&#xff0c;选择一个稳定的版本。通常&#xff0c;你会看到…

面相对象开发的原则

1、开闭原则 对修改关闭&#xff0c;对扩展打开。 2、里氏替换原则 子类继承父类的方法时&#xff0c;不可重写父类的方法。 如果重写了父类的方法会导致整个继承体系比较差&#xff0c;特别是运用多态比较平凡时&#xff0c;程序运行出错概率较大。 如果出现了违背“里氏替换…

LIN总线信号串行译码

我们用虹科Pico汽车示波器捕捉了LIN总线信号 &#xff0c;如果想看它对应的报文数据&#xff0c;我们可以应用PicoScope Automotive软件的串行译码功能来对它破译。 使用指导如下&#xff1a; 点击“串行译码”&#xff0c;选择对应的协议&#xff0c;如LIN。 在下面对话框&…

04.HTML其他知识

HTML其他知识 1.HTML实体 介绍 在 HTML 中我们可以用一种特殊的形式的内容&#xff0c;来表示某个符号&#xff0c;这种特殊形式的内容&#xff0c;就是 HTML 实体。比如小于号 < 用于定义 HTML 标签的开始。如果我们希望浏览器正确地显示这些字符&#xff0c;我们必须在…

西南科技大学C++程序设计实验十(函数模板与类模板)

一、实验目的 1. 掌握函数模板与类模板; 2. 掌握数组类、链表类等线性群体数据类型定义与使用; 二、实验任务 1. 分析完善以下程序,理解模板类的使用: (1)补充类模板声明语句。 (2)创建不同类型的类对象,使用时明确其数据类型? _template<typename T>__…