25. 深度学习进阶 - 权重初始化,梯度消失和梯度爆炸

文章目录

    • 权重初始化
    • 梯度消失与梯度爆炸

在这里插入图片描述

Hi,你好。我是茶桁。

咱们这节课会讲到权重初始化、梯度消失和梯度爆炸。咱们先来看看权重初始化的内容。

权重初始化

机器学习在我们使用的过程中的初始值非常的重要。就比如最简单的wx+b,现在要拟合成一个yhat,w如果初始的过大或者初始的过小其实都会比较有影响。

假设举个极端情况,就是w拟合的时候刚刚就拟合到了离x很近的地方,我们想象一下,这个时候是不是学习起来就会很快?所以对于深度学习模型权重的初始化是一个非常重要的事情,甚至有人就说把初始化做好了,其实绝大部分事情就已经解决了。

那么我们怎么样获得一个比较好的初始化的值?首先有这么几个原则

  • 我们的权重值不能设置为0。
  • 尽量将权重变成一个随机化的正态分布。而且有更大的X输入,那我们的权重就应该更小。

l o s s = ∑ ( y ^ − y i ) 2 = ∑ ( ∑ w i x i − y i ) 2 \begin{align*} loss & = \sum(\hat y - y_i)^2 \\ & = \sum(\sum w_ix_i - y_i)^2 \end{align*} loss=(y^yi)2=(wixiyi)2

我们看上面的式子,yhat就是w_i*x_i, 这个时候x_i可能是几百万,也可能是几百。我们w_i取值在(-n, n)之间,那当x_i维度特别大的时候,那yhat值算出来的也就会特别大。所以,x_i的维度特别大的时候,我们期望w_i值稍微小一些,否则加出来的yhat可能就会特别大,那最后求出来的loss也会特别大。

如果loss值特别大,可能就会得到一个非常的梯度。那我们知道,学习的梯度特别大的话,就会发生比较大的震荡。

所以有一个原则,就是当x的dimension很大的时候, 我们期望的它的权重越小。

那后来就有人提出来了一个比较重要的初始化方法,Xavier初始化。这个方法特别适用于sigmoid激活函数或反正切tanh激活函数,它会根据前一层和当前层的神经元数量来选择初始化的范围,以确保权重不会过大或过小。
均值为 0 和标准差的正态分布 : σ = 2 n i n p u t s + n o u t p u t s − r 和 + r 之间的均匀分布: r = 6 n i n p u t s + n o u t p u t s \begin{align*} 均值为0和标准差的正态分布: \sigma & = \sqrt{\frac{2}{n_{inputs}+n_{outputs}}} \\ -r和+r之间的均匀分布:r & = \sqrt{\frac{6}{n_{inputs}+n_{outputs}}} \end{align*} 均值为0和标准差的正态分布:σr+r之间的均匀分布:r=ninputs+noutputs2 =ninputs+noutputs6

然后W的均匀分布就会是这样:
W ∼ U ∣ − 6 n j + n j + 1 , 6 n j + n j + 1 W \sim U \Bigg \vert -\frac{\sqrt 6}{\sqrt{n_j + n_{j+1}}}, \frac{\sqrt 6}{\sqrt{n_j + n_{j+1}}} WU nj+nj+1 6 ,nj+nj+1 6

这个是一个比较有名的初始化方法,如果要做函数的初始化的话,PyTorch在init里面有一个方法:

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

比如,我们看这样例子:

w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.calculate_gain('relu'))

注意: init方法里还有其他的一些方法,大家可以查阅PyTorch的相关文档:https://pytorch.org/docs/stable/nn.init.html

梯度消失与梯度爆炸

当我们的模型层数特别多的时候

Alt text

就比如我们上节课用到的Sequential,我们可以在里面写如非常多的一个函数:

model = nn.Sequential(
    nn.Linear(in_features=10, out_features=5).double(),
    nn.Sigmoid(),
    nn.Linear(in_features=5, out_features=8).double(),
    nn.Sigmoid(),
    nn.Linear(in_features=8, out_features=8).double(),
    nn.Sigmoid(),
    ...
    nn.Linear(in_features=8, out_features=8).double(),
    nn.Softmax(),
)

Alt text

这样,在做偏导的时候我们其中几个值特别小,那两个一乘就会乘出来一个特别特别小的数字。最后可能会导致一个结果, ∂ l o s s ∂ w i \frac{\partial loss}{\partial wi} wiloss的值就会极小,它的更新就会特别的慢。我们把这种东西就叫做梯度消失,也有人叫梯度弥散。

以Sigmoid函数为例,其导数为

σ ′ ( x ) = σ ( x ) ( 1 − σ ( x ) ) \begin{align*} \sigma '(x) = \sigma(x)(1-\sigma(x)) \end{align*} σ(x)=σ(x)(1σ(x))

在x趋近正无穷或者负无穷时,导数接近0。当这种小梯度在多层网络中相乘的时候,梯度会迅速减小,导致梯度消失。

除此之外还有一种情况叫梯度爆炸,剃度爆炸类似,当模型的层很多的时候,如果其中某两个值很大,例如两个102,当这两个乘起来就会变成104。乘下来整个loss很大,又会产生一个结果,我们来看这样一个场景:

Alt text

假如说对于上图中这个函数来说,横轴为x, 竖轴为loss,对于这个xi来说,这个地方 ∂ l o s s ∂ x i \frac{\partial loss}{\partial xi} xiloss已经是一个特别大的数字了。

假设咱们举个极端的情况(忽略图中竖轴上的数字),我们现在loss等于x^4: l o s s = x 4 loss=x^4 loss=x4,然后现在 ∂ l o s s ∂ x 4 \frac{\partial loss}{\partial x^4} x4loss就等于 4 x 3 4x^3 4x3,我们假设x在A点,当x=10的时候,那 4 × x 3 = 4000 4\times x^3 = 4000 4×x3=4000, 那我们计算新的xi,就是 x i = x i − α ⋅ ∂ l o s s ∂ x i x_i = x_i - \alpha \cdot \frac{\partial loss}{\partial x_i} xi=xiαxiloss,现在给alpha一个比较小的数,我们假设是0.1,那式子就变成 10 − 0.1 × 4000 10 - 0.1 \times 4000 100.1×4000,结果就是-390。

我们把它变到-390之后,本来我们本来做梯度下降更新完,xi期望的是loss要下降,但是我们结合图像来看,xi=-390的时候,loss就变得极其的巨大了,然后我们在继续,(-390)^4, 这个loss就已经爆炸了。

再继续的时候,会发现会在极值上跳来跳去,loss就无法进行收敛了。所以我们也要拒绝这种情况的发生。

那梯度消失和梯度爆炸这两个问题该如何解决呢?我们来看第一种解决方法: Batch normalization,批量归一化。

那这个方法的核心思想是对神经网络的每一层的输入进行归一化,使其具有零均值和单位方差。

那么首先,对于每个mini-batch中的输入数据,计算均值和方差。 B = { x 1 . . . m } B = \{x_1...m\} B={x1...m}; 要学习的参数: γ , β \gamma,\beta γ,β

μ B = 1 m ∑ i = 1 m x i σ B 2 = 1 m ∑ i = 1 m ( x i − μ B ) 2 μ 为均值 m e a n , σ 为方差 \begin{align*} \mu_B & = \frac{1}{m}\sum^m_{i=1}x_i \\ \sigma ^2_B & = \frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2 \\ & \mu 为均值mean, \sigma为方差 \end{align*} μBσB2=m1i=1mxi=m1i=1m(xiμB)2μ为均值meanσ为方差

这里和咱们之前讲x做normalization的时候其实是特别相似,基本上就是一件事。

然后我们使用均值和方差对输入进行归一化,使得其零均值和单位方差,即将输入标准化为xhat。

x ^ i = x i − μ B σ B 2 + ε \begin{align*} \hat x_i = \frac{x_i - \mu_B}{\sqrt{\sigma ^2_B + \varepsilon}} \end{align*} x^i=σB2+ε xiμB

接着我们对归一化后的输入应用缩放和平移操作,以允许网络学习最佳的变换。

y i = γ x ^ i + β ≡ B N γ , β ( x i ) \begin{align*} y_i = \gamma \hat x_i + \beta \equiv BN_{\gamma,\beta}(x_i) \end{align*} yi=γx^i+βBNγ,β(xi)

输出为 { y i = B N γ , β ( x i ) } \{y_i = BN_{\gamma,\beta}(x_i)\} {yi=BNγ,β(xi)}

最后将缩放和平移后的数据传递给激活函数进行非线性变换。

它会输入一个小批量的x值,

经过反复的梯度下降,会得到一个gamma和beta,能够知道在这一步x要怎么样进行缩放,在缩放之前会经历刚开始的时候那个normalization一样,把把过小值会变大,把过大值会变小。

我们在之前的课程中演示过,没看过和忘掉的同学可以往前翻看一下。

然后在经过这两个可学习的参数进行一个变化,这样它可以做到在每一层x变化不会极度的增大或者极度的缩小,可以让我们的权值保持的比较稳定。

那除了Batch normalization之外,还有一个方法叫Gradient clipping, 它是可以直接将过大的梯度值变小。

Alt text

它其实很简单,也叫做梯度减脂。

如果我们求解出来 ∂ l o s s ∂ w i \frac{\partial loss}{\partial w_i} wiloss很大,假设原来等于400,我们定义了一个100,那超过100的部分,就全部设置成100。

train_loss.backward()
pt.nn.units.clip_grad_value_(model.parameters(), 100)
optimizer.step()

简单粗暴。那其实梯度爆炸还是比较容易解决的,比较复杂的其实是梯度消失的问题。

梯度爆炸为什么比较容易解决?梯度爆炸起码是有导数的,只要把这个导数给它放的特别小就行了,有导数起码保证wi可以更新。

假设alpha,我们的learning_rate等于0.01,乘上一个100,可以保证每次可以有个变化。但是每次这个梯度特别小,假如都快接近于0了,那么1e-10, 就算乘上100倍,最后还是一个特别小的数字。所以相较而言,梯度爆炸就更好解决一些,方法更粗暴一些。

补充一个知识点,这个虽然现在已经用不到了,但是对我们的理解还是有帮助的。方法比较古老。

就是当我们发现梯度有问题的时候, 大概在10年前,那个时候神经网络的模块也不太丰富,很多新出的model,做神经网络的人,一些导数,传播什么的都需要自己写,就我们前几节课写那个神经网络框架的时候做的事。

有的时候导数写错了,就有一种方法叫做gradient checking,梯度检查。

这个使用场景非常的少,当你自己发明了一个新的模块,加到这个模型里面的时候会遇到。

其实很简单,就是把最终的 ∂ l o s s ∂ w i \frac{\partial loss}{\partial w_i} wiloss,求解出来的偏导总是不收敛,可能是这个偏导有问题,那么有可能求导的函数写错了。

那在这个时候就可以做个简单的变化:

∂ l o s s ( θ + ε ) − ∂ l o s s ( θ − ε ) 2 ε \begin{align*} \frac{\partial loss(\theta+\varepsilon)-\partial loss(\theta - \varepsilon)}{2\varepsilon} \end{align*} 2εloss(θ+ε)loss(θε)

这其中 ∂ l o s s ( θ + ε ) \partial loss(\theta + \varepsilon) loss(θ+ε) ∂ l o s s ( θ − ε ) \partial loss(\theta - \varepsilon) loss(θε)是在参数 θ \theta θ, 其实也就是我们的wi上添加和减去微小扰动theta后的损失函数值。

然后我们计算数值梯度和反向传播计算得到的梯度之间的差异。通常这是通过计算它们之间的差异来完成,然后将其与一个小的阈值,比如1e-7进行比较。如果差异非常小(小于阈值),则可以认为梯度计算是正确的,否则可能就需要从新写一下偏导函数了。

这个比较难,但不是一个重点,当且仅当自己要发明一个模型的时候。

那接下来我们来看一下关于Learning_rate和Early Stopping的问题。

理论上,如果深度学习效果不好,那么我们可以将learning rate调小,可以让所有模型效果变得更好,它可以让所有的loss下降。

Alt text

但是如果你的learning rate变得特别小,假如说是1e-9,那这样的结果就是w的变化会非常的慢,训练时间就变得很长。为了解决这个问题,就有一些比较简单的方法。

第一个,我们可以把learning rate和loss设置成一个相关的函数,例如说loss越小的时候,Learning rate越小,或者随着epoch的增大,loss越小。这个就叫learning rate的decay。

将learning rate或者训练次数和loss设置成一个相关的函数,那么越到后面效果越好的时候,learning rate就会越小。

还有,我们可能会发现loss连续k次不下降,那我们就可以提前结束训练过程,这个就是Early Stopping。

也就是当你发现loss连续k次不下降,或者甚至于在上升,那么这个时候,就可以将最优的这个值给它记录下来。

咱们可能会经常出现的情况就是值在那里震荡,本来呢已经快接近于最优点了,可是震荡了几次之后,还可能震荡出去了,loss变大了。或者就一直在这个震荡里边出不去,这个时候多学习也没有用,所以就可以早点停止,这个就是Early Stopping,中文有人称呼它为早停方法。

好,下节课,咱们要讲一个重点,也是一个难点。就是咱们做机器学习的时候,不同的优化方法。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/202236.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TZOJ 1373 求多项式的和

答案&#xff1a; #include <stdio.h> int main() {int m 0;scanf("%d", &m); // 读取测试实例的个数 while (m--) //循环m次{int n 0, i 0;scanf("%d", &n); // 读取求和项数n double sum 0.0;for (i 1; i < n; i) //分…

JenKins快速安装与使用

一、JenKins 0.准备&#xff0c;配置好环境 1&#xff09;Git&#xff08;yum安装&#xff09; 2&#xff09;JDK&#xff08;自行下载&#xff09; 3&#xff09;Jenkins&#xff08;自行下载&#xff09; 1.下载安装包 进官网&#xff0c;点Download下方即可下载。要下…

linux之下安装 nacos

1 下载地址 也可使用在线下载wget https://github.com/alibaba/nacos/releases/download/1.4.6/nacos-server-1.4.6.tar.gzTags alibaba/nacos GitHuban easy-to-use dynamic service discovery, configuration and service management platform for building cloud nativ…

一次Apollo Client升级导致的生产404 Not Found问题排查记录

概述 本文记录一次升级Apollo Client组件到1.7.0后遇到的重大生产事故。只想看结论的&#xff0c;可直接快进到文末。实际上&#xff0c;第一句话就是一个结论。 另&#xff0c;本文行文思路事后看起来可行略显思路清晰&#xff0c;实际上排查生产问题时如无头苍蝇&#xff0…

使用STM32微控制器实现烟雾传感器的接口和数据处理

烟雾传感器是常见的安全检测装置&#xff0c;通过检测空气中的烟雾浓度来提醒用户有潜在的火灾风险。本文将介绍如何使用STM32微控制器来实现烟雾传感器的接口和数据处理。包括硬件连接、采集模拟信号、数字信号处理和报警策略等方面。同时&#xff0c;给出相应的代码示例。 一…

【Android知识笔记】架构专题(一)

什么是 MVC 其实我们日常开发中的Activity,Fragment和XML界面就相当于是一个MVC的架构模式,但往往Activity中需要处理绑定UI,用户交互,以及数据处理。 这种开发方式的缺点就是业务量复杂的时候一个Activity过于臃肿。但是页面结构不复杂的情况下使用这种方式就会显得很简…

基于Java SSM框架实现KTV点歌系统项目【项目源码+论文说明】

基于java的SSM框架实现KTV点歌系统演示 摘要 本论文主要论述了如何使用JAVA语言开发一个KTV点歌系统&#xff0c;本系统将严格按照软件开发流程进行各个阶段的工作&#xff0c;采用B/S架构&#xff0c;面向对象编程思想进行项目开发。在引言中&#xff0c;作者将论述KTV点歌系…

Linux基础项目开发1:量产工具——输入系统(三)

前言&#xff1a; 前面我们已经实现了显示系统&#xff0c;现在我们来实现输入系统&#xff0c;与显示系统类似&#xff0c;下面让我们一起来对输入系统进行学习搭建吧 目录 一、数据结构抽象 1. 数据本身 2. 设备本身&#xff1a; 3. input_manager.h 二、触摸屏编程 to…

云时空社会化商业 ERP 系统 gpy 文件上传漏洞复现

0x01 产品简介 时空云社会化商业ERP&#xff08;简称时空云ERP&#xff09; &#xff0c;该产品采用JAVA语言和Oracle数据库&#xff0c; 融合用友软件的先进管理理念&#xff0c;汇集各医药企业特色管理需求&#xff0c;通过规范各个流通环节从而提高企业竞争力、降低人员成本…

VSCODE 在新窗口中打开

使用VS习惯了&#xff0c;经常在新窗口中打开查看 但是VSCODE&#xff0c;无法拖动标签到一个新窗口中&#xff0c;一直以为没这个功能 后来发现 使用快捷健 ctlk,o 可以将标签页在新窗口中打开&#xff0c;虽然不如vsstudio方便&#xff0c;不过也可实现在新窗口打开的功能…

【动态规划】LeetCode2552:优化了6版的1324模式

本文涉及的基础知识点 C算法&#xff1a;前缀和、前缀乘积、前缀异或的原理、源码及测试用例 包括课程视频 动态规划 本题其它解法 C前缀和算法的应用&#xff1a;统计上升四元组 类似题解法 包括题目及代码C二分查找算法&#xff1a;132 模式解法一枚举3C二分查找算法&am…

CSAPP bomb_lab:phase_5

phase_5的汇编代码 0x0000000000401062 <0>: push %rbx0x0000000000401063 <1>: sub $0x20,%rsp0x0000000000401067 <5>: mov %rdi,%rbx0x000000000040106a <8>: mov %fs:0x28,%rax0x0000000000401073 <17>: mov …

hql面试题之字符串使用split分割,并选择其中的一部分字段的问题

版本&#xff1a;20231109 1.题目&#xff1a; 有两张表,a表有id和abstringr两个字段&#xff0c;b表也有id和bstr两个字段&#xff0c;具体如下 A表&#xff1a; 1abc,bcd,cdf2123,456,789 B表: 1acddef2123456 在a表的abstring字段中用‘,’分割&#xff0c;并取出前两…

sqli-labs靶场详解(less32-less37)

宽字节注入 原理在下方 目录 less-32 less-33 less-34 less-35 less-36 less-37 less-32 正常页面 ?id1 下面有提示 获取到了Hint: The Query String you input is escaped as : 1\ ?id1 看来是把参数中的非法字符就加上了转义 从而在数据库中只能把单引号当成普通的字…

Go 从编译到执行

一、Go运行编译简介 Go语言&#xff08;也称为Golang&#xff09;自从2009年由Google发布以来&#xff0c;已成为现代软件开发中不可或缺的一部分。设计者Rob Pike, Ken Thompson和Robert Griesemer致力于解决多核处理器、网络系统和大型代码库所引发的现实世界编程问题。我们…

TA-Lib学习研究笔记——Overlap Studies(二)上

TA-Lib学习研究笔记——Overlap Studies&#xff08;二&#xff09; 1. Overlap Studies 指标 [BBANDS, DEMA, EMA, HT_TRENDLINE, KAMA, MA, MAMA, MAVP, MIDPOINT, MIDPRICE, SAR, SAREXT, SMA, T3, TEMA, TRIMA, WMA]2.数据准备 get_data函数参数&#xff08;代码&#x…

msyql迁移到mongodb

关系型数据库迁移到mongodb的理由 高并发需求&#xff0c;关系型数据库不容易扩展 快速迭代 灵活的json模式 大数据量需求 应用迁移难度&#xff1a; 关系型到关系 oracle-》mysql oracle -》 postgresql 关系到文档- oracle -》 mongodb 需要考虑&#xff1a; 总体架构&#…

阿里云Windows server2016 安装Docker

阿里云Windows server2016 安装Docker 1 软件环境介绍2 下载更新2.1 windowsR 输入sconfig2.2 下载最新版的安装包&#xff0c;安装并重启2.3 下载并安装更新2.4 以管理员方式运行powershell2.5 将Tls修改成二级2.6 安装NuGet服务2.7 安装docker模块2.7 安装 docker包 32.8 查看…

Ajax的使用方法

1,什么是Ajax&#xff1f; Ajax&#xff08;异步Javascript和XML&#xff09;&#xff0c;是指一种创建交互式网页应用的网页开发技术。 2&#xff0c;Ajax的作用 Ajax可以使网页实现异步更新----即在不更新整个页面的情况下实现对某一部分进行更新。 简单来说Ajax就是用于连接…

顶级大厂Quora如何优化数据库性能?

Quora 的流量涉及大量阅读而非写入&#xff0c;一直致力于优化读和数据量而非写。 0 数据库负载的主要部分 读取数据量写入 1 优化读取 1.1 不同类型的读需要不同优化 ① 复杂查询&#xff0c;如连接、聚合等 在查询计数已成为问题的情况下&#xff0c;它们在另一个表中构…