归一化技术比较研究:Batch Norm, Layer Norm, Group Norm

归一化层是深度神经网络体系结构中的关键,在训练过程中确保各层的输入分布一致,这对于高效和稳定的学习至关重要。归一化技术的选择(Batch, Layer, GroupNormalization)会显著影响训练动态和最终的模型性能。每种技术的相对优势并不总是明确的,随着网络体系结构、批处理大小和特定任务的不同而变化。

本文将使用合成数据集对三种归一化技术进行比较,并在每种配置下分别训练模型。记录训练损失,并比较模型的性能。

神经网络中的归一化层是用于标准化网络中某一层的输入的技术。这有助于加速训练过程并获得更好的表现。有几种类型的规范化层,其中 Batch Normalization, Layer Normalization, Group Normalization是最常见的。

常见的归一化技术

BatchNorm

BN应用于一批数据中的单个特征,通过计算批处理上特征的均值和方差来独立地归一化每个特征。它允许更高的学习率,并降低对网络初始化的敏感性。

这种规范化发生在每个特征通道上,并应用于整个批处理维度,它在大型批处理中最有效,因为统计数据是在批处理中计算的。

LayerNorm

与BN不同,LN计算用于归一化单个数据样本中所有特征的均值和方差。它应用于每一层的输出,独立地规范化每个样本的输入,因此不依赖于批大小。

LN有利于循环神经网络(rnn)以及批处理规模较小或动态的情况。

GroupNorm

GN将信道分成若干组,并计算每组内归一化的均值和方差。这对于通道数量可能很大的卷积神经网络很有用,将它们分成组有助于稳定训练。

GN不依赖于批大小,因此适用于小批大小的任务或批大小可以变化的任务。

每种规范化方法都有其优点,并且根据网络体系结构、批处理大小和训练过程的特定需求适合不同的场景:

BN对于具有稳定和大批大小的网络非常有效,LN对于序列模型和小批大小是首选,而GN提供了对批大小变化不太敏感的中间选项。

代码示例

我们演示了使用PyTorch在神经网络中使用三种规范化技术的代码,并且绘制运行的结果图。

首先是生成数据

 importtorch
 importtorch.nnasnn
 importtorch.optimasoptim
 importnumpyasnp
 importmatplotlib.pyplotasplt
 fromtorch.utils.dataimportDataLoader, TensorDataset
 
 # Generate a synthetic dataset
 np.random.seed(42)
 X=np.random.rand(1000, 10)
 y= (X.sum(axis=1) >5).astype(int)  # simple threshold sum function
 X_train, y_train=torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)
 
 # Create a DataLoader
 dataset=TensorDataset(X_train, y_train)
 loader=DataLoader(dataset, batch_size=64, shuffle=True)

然后是创建模型,这里将三种方法写在一个模型中,初始化时只要传递不同的参数就可以使用不同的归一化方法

 # Define a model with Batch Normalization, Layer Normalization, and Group Normalization
 classNormalizationModel(nn.Module):
     def__init__(self, norm_type="batch"):
         super(NormalizationModel, self).__init__()
         self.fc1=nn.Linear(10, 50)
         
         ifnorm_type=="batch":
             self.norm=nn.BatchNorm1d(50)
         elifnorm_type=="layer":
             self.norm=nn.LayerNorm(50)
         elifnorm_type=="group":
             self.norm=nn.GroupNorm(5, 50)  # 5 groups
             
         self.fc2=nn.Linear(50, 2)
 
     defforward(self, x):
         x=self.fc1(x)
         x=self.norm(x)
         x=nn.ReLU()(x)
         x=self.fc2(x)
         returnx

然后是训练的代码,我们也简单的封装下,方便后面调用

 # Training function
 deftrain_model(norm_type):
     model=NormalizationModel(norm_type=norm_type)
     criterion=nn.CrossEntropyLoss()
     optimizer=optim.Adam(model.parameters(), lr=0.001)
     num_epochs=50
     losses= []
 
     forepochinrange(num_epochs):
         forinputs, targetsinloader:
             optimizer.zero_grad()
             outputs=model(inputs)
             loss=criterion(outputs, targets)
             loss.backward()
             optimizer.step()
             losses.append(loss.item())
     
     returnlosses

最后就是训练,经过上面的封装,我们直接循环调用即可

 # Train and plot results for each normalization
 norm_types= ["batch", "layer", "group"]
 results= {}
 
 fornorm_typeinnorm_types:
     losses=train_model(norm_type)
     results[norm_type] =losses
     plt.plot(losses, label=f"{norm_type} norm")
 
 plt.xlabel("Iteration")
 plt.ylabel("Loss")
 plt.title("Normalization Techniques Comparison")
 plt.legend()
 plt.show()

生成的图表将显示每种归一化技术如何影响有关减少损失的训练过程。我们可以解释哪种归一化技术对这个特定的合成数据集和训练设置更有效。我们的评判标准是通过适当的归一化实现更平滑和更快的收敛。

BN(蓝色)、LN(橙色)和GN(绿色)。

所有三种归一化方法都以相对较高的损失开始,并迅速减小。

可以看到BN的初始收敛速度非常的快,但是到了最后,损失出现了大幅度的波动,这可能是因为学习率、数据集或小批量选择的随机性质决定的,或者是模型遇到具有不同曲率的参数空间区域。我们的batch_size=64,如果加大这个参数,应该会减少波动。

LN和GN的下降平稳,并且收敛速度和表现都很类似,通过观察能够看到LN的方差更大一些,表明在这种情况下可能不太稳定

最后所有归一化技术都显著减少了损失,但是因为我们使用的是生成的数据,所以不确定否都完全收敛了。不过虽然该图表明,最终的损失值很接近,但是GN的表现可能更好一些。

总结

在这些规范化技术的实际应用中,必须考虑任务的具体要求和约束。BatchNorm在大规模批处理可行且需要稳定性时更可取。LayerNorm在rnn和具有动态或小批量大小的任务的背景下可以发挥作用。GroupNorm提供了一个中间选项,在不同的批处理大小上提供一致的性能,在cnn中特别有用。

归一化层是现代神经网络设计的基石,通过了解BatchNorm、LayerNorm和GroupNorm的操作特征和实际含义,根据任务需求选择特定的技术,可以在深度学习中实现最佳性能。

https://avoid.overfit.cn/post/e8ec905659e5446e84fb9617feb86e95

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/521905.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CSS - 你实现过宽高自适应的正方形吗

难度 难度级别:中高级及以上 提问概率:80% 宽高自适应的需求并不少见,尤其是在当今流行的大屏系统开发中更是随处可见,很显然已经超越了我们日常将div写死100px这样的范畴,那么如何实现一个宽高自适应的正方形呢?这里提出两种实现方案。…

【Linux】进程初步理解

个人主页 : zxctscl 如有转载请先通知 文章目录 1. 冯诺依曼体系结构1.1 认识冯诺依曼体系结构1.2 存储金字塔 2. 操作系统2.1 概念2.2 结构2.3 操作系统的管理 3. 进程3.1 进程描述3.2 Linux下的PCB 4. task_struct本身内部属性4.1 启动4.2 进程的创建方式4.2.1 父…

JAVA:探索Apache POI 处理利器

请关注微信公众号:拾荒的小海螺 1、简述 Apache POI是Apache软件基金会的顶级项目之一,它允许Java开发人员读取和写入Microsoft Office格式的文档,包括Excel、Word和PowerPoint文件。通过POI,开发人员可以创建、修改和读取Excel…

面试(04)————JavaWeb

1、网络通讯部分 1.1、 TCP 与 UDP 区别? 1.2、什么是 HTTP 协议? 1.3、TCP 的三次握手,为什么? 1.4、HTTP 中重定向和请求转发的区别? 1.5、 Get 和 Post 的区别? 2、cookie 和 session 的区别&am…

加入酷开会员 酷开系统带你一起开启看电视的美好时光!

看电视对孩子和大人来说,都是有好处的。英国的《星期日泰晤士报》曾刊登报道:“看电视可以让小孩增长见闻,学习各种良好的社交和学习技巧,从而为他们今后的学习打下良好的基础。”而对于成年人来说,看电视也是一种娱乐…

linux 安装 pptp 协议

注意:目前iOS已不支持该协议 yum -y install ppp wget https://download-ib01.fedoraproject.org/pub/epel/7/x86_64/Packages/p/pptpd-1.4.0-2.el7.x86_64.rpm yum -y install pptpd-1.4.0-2.el7.x86_64.rpm vi /etc/pptpd.conf 去除 localip 和 remoteip的注释 …

【.Net】Polly

文章目录 概述服务熔断、服务降级、服务限流、流量削峰、错峰、服务雪崩Polly的基本使用超时策略悲观策略乐观策略 重试策略请求异常响应异常 降级策略熔断策略与策略包裹(多种策略组合) 参考 概述 Polly是一个被.NET基金会支持认可的框架,同…

SAP-MM 新增公司代码 激活物料分类账

1、OMX1 - 激活物料分类账(配置环境) 2、CKMSTART - 物料分类账的生产开始(生产机运行) 不激活创建物料时会报错:估价范围还没有生产式的物料账簿 执行后结果: 以上~~

creo扫描杯子学习笔记

creo扫描杯子学习笔记 扫描2要素: 轨迹, 截面。 多用于曲线扫描,区别于拉伸命令。 大小自定 旋转扫描 抽壳 草绘把手 扫描把手 复制曲面 实例化切除 成型

Web爬虫

📑前言 本文主要是【Web爬虫】——简单使用的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 🌄每日一句&#…

PHP实现网站微信扫码关注公众号后自动注册登陆实现方法及代码【关注收藏】

在网站注册登陆这环节,增加微信扫码注册登陆,普通的方法需要开通微信开发者平台,生成二维码扫码后才能获取用户的uinonid或openid,实现注册登陆,但这样比较麻烦还要企业认证交费开发者平台,而且没有和公众号…

区域自动气象站讲解

TH-QC10当我们每天查看天气预报,安排出行计划,或是在户外活动时关注天气变化,很少有人会想到这一切背后默默付出的“英雄”——区域自动气象站。这些看似不起眼的气象监测设备,却在我们日常生活中扮演着至关重要的角色。今天&…

【话题】程序员35岁会失业吗?

大家好,我是全栈小5,欢迎阅读小5的系列文章,这是《话题》系列文章 目录 背景招聘分析一、技术更新换代的挑战二、经验与技术的双重优势三、职业发展的多元化选择四、个人成长与职业规划的平衡五、结语文章推荐 背景 35岁被认为是程序员职业生…

【OJ】stack刷题

个人主页 : zxctscl 如有转载请先通知 题目 1. 155. 最小栈1.1 分析1.2 代码 2. JZ31 栈的压入、弹出序列2.1 分析2.2 代码 3. 150. 逆波兰表达式求值3.1 分析3.2 代码 1. 155. 最小栈 1.1 分析 利用两个栈,一个栈a负责入数据和出数据,另一个…

分类预测 | Matlab实现DRN深度残差网络数据分类预测

分类预测 | Matlab实现DRN深度残差网络数据分类预测 目录 分类预测 | Matlab实现DRN深度残差网络数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现DRN深度残差网络数据分类预测(完整源码和数据),运行环境为Matl…

合宙开发板Core_Air780E测试AT指令

一、官方资料 CORE-AIR780E 开发板是合宙通信推出的基于 Air780E 模组所开发的,包含电源,SIM 卡,USB,天线,音频等必要功能的最小硬件系统。以方便用户在设计前期对 Air780E 模块 进行性能评估,功能调试&…

CUDA10的安装

1、因为要用到tensorflow1.15.5的GPU版本,所以想安装cuda10来进行加速,通过nvidia-smi检查本机上的CUDA版本 2、下载的cuda10版本,cuda_10.0.130_411.31_win10.exe 下载的cudnn版本,cudnn-10.0-windows10-x64-v7.6.4.38.zip 然后…

mathtype如何嵌入到word中?mathtype 7永久激活码密钥及2024最新序列号附安装教程

将MathType嵌入到Word中的方法主要有三种,分别是: 通过加载项嵌入MathType。首先,在Word中点击“文件”按钮,选择“选项”,然后选择“加载项”一栏,找到MathType相关的加载项并勾选,点击“确定…

20240404这个数字有什么特点吗?

今天是2024年的清明节,20240404这个数字让我提出了一个疑问,它是否有什么含义或者特点呢? 首先,如果把它拆分为两个整数的平方和,会怎么样呢? 于是,我一顿操作猛如虎,搞出了这么个…

html--烟花3

html <!DOCTYPE html> <html> <head> <meta charset"UTF-8"> <title>Canvas烟花粒子</title> <meta name"keywords" content"canvas烟花"/> <meta name"description" content"can…