26. 深度学习进阶 - 深度学习的优化方法

在这里插入图片描述

Hi, 你好。我是茶桁。

上一节课中我们预告了,本节课是一个难点,同时也是一个重点,大家要理解清楚。

我们在做机器学习的时候,会用不同的优化方法。

SGD

Alt text

上图中左边就是Batch Gradient Descent,中间是Mini-Batch Gradient Descent, 最右边则是Stochastic Gradient Descent。

我们还是直接上代码写一下来看。首先我们先定义两个随机值,一个x,一个ytrue:

import numpy as np
x = np.random.random(size=(100, 8))
ytrue = torch.from_numpy(np.random.uniform(0, 5, size=(100, 1)))

x是一个1008的随机值,ytrue是1001的随机值,在0到5之间,这100个x对应着这100个ytrue的输入。

然后我们来定义一个Sequential, 在里面按顺序放一个线性函数,一个Sigmoid激活函数,然后再来一个线性函数,别忘了咱们上节课所讲的,要注意x的维度大小。

linear = torch.nn.Linear(in_features=8, out_features=1)
sigmoid = torch.nn.Sigmoid()
linear2 = torch.nn.Linear(in_features=1, out_features=1)

train_x = torch.from_numpy(x)

model = torch.nn.Sequential(linear, sigmoid, linear2).double()

我们先来看一下训练x和ytrue值的大小:

print(model(train_x).shape)
print(ytrue.shape)

---
torch.Size([100, 1])
torch.Size([100, 1])

然后我们就可以来求loss了,先拿到预测值,然后将预测值和真实值一起放进去求值。

loss_fn = torch.nn.MSELoss()
yhat = model(train_x)
loss = loss_fn(yhat, ytrue)
print(loss)

---
36.4703

我们现在可以定义一个optimer, 来尝试进行优化,我们来将之前的所做的循环个100次,在其中我们加上反向传播:

optimer = torch.optim.SGD(model.parameters(), lr=1e-3)

for e in range(100):
    yhat = model(train_x)
    loss = loss_fn(yhat, ytrue)
    loss.backward()
    print(loss)
    optimer.step()

---
tensor(194.9302, dtype=torch.float64, grad_fn=<MseLossBackward0>)
...
tensor(1.9384, dtype=torch.float64, grad_fn=<MseLossBackward0>)

可以看到,loss会一直降低。从194一直降低到了2左右。

在求解loss的时候,我们用到了所有的train_x,那这种方式就叫做Batch gradient Descent,批量梯度下降。

它会对整个数据集计算损失函数相对于模型参数的梯度。梯度是一个矢量,包含了每个参数相对与损失函数的变化率。

这个方法会使用计算得到的梯度来更新模型的参数。更新规则通常是按照一下方式进行:

w t + 1 = w t − η ▽ w t \begin{align*} w_{t+1} = w_t - \eta \triangledown w_t \end{align*} wt+1=wtηwt

w t + 1 w_{t+1} wt+1是模型参数, η \eta η是学习率, ▽ w t \triangledown w_t wt是损失函数相对于参数的梯度。

但是在实际的情况下这个方法可能会有一个问题,比如说,我们在随机x的时候参数不是100,而是10^8,维度还是8维。假如它的维度很大,那么会出现的情况就是把x给加载到模型里面运算的时候,消耗的内存就会非常非常大,所需要的运算空间就非常大。

这也就是这个方法的一个缺点,计算成本非常高,由于需要计算整个训练数据集的梯度,因此在大规模数据集上的计算成本较高。而且可能会卡在局部最小值,难以逃离。说实话,我上面演示的数据也是尝试了几次之后拿到一次满意的,也遇到了在底部震荡的情况。

在这里可以有一个很简单的方法,我们规定每次就取20个:

for e in range(100):
    for b in range(100 // 20):
        batch_index = np.random.choice(range(len(train_x)), size=20)

        yhat = model(train_x[batch_index])
        loss = loss_fn(yhat, ytrue[batch_index])
        loss.backward()
        print(loss)
        optimer.step()

这样做loss也是可以下降的,那这种方法就叫做Mini Batch。

还有一种方法很极端,就是Stochhastic Gradient Descent,就是每次只取一个个数字:

for e in range(100):
    for b in range(100 // 1):
        ...

这种方法很极端,但是可以每次都可以运行下去。那大家就知道,有这三种不同的优化方式。

Alt text

这样的话,我们来看一下,上图中的蓝色,绿色和紫色,分别对应哪种训练方式?

紫色的是Stochastic Gradient Descent,因为它每次只取一个点,所以它的loss变化会很大,随机性会很强。换句话说,这一次取得数据好,可能loss会下降,如果数据取得不好,它的这个抖动会很大。

绿色就是Mini-Batch, 我们刚才20个、20个的输入进去,是有的时候涨,有的时候下降。

最后蓝色的就是Batch Gradient Descent, 因为它x最多,所以下降的最稳定。

但是因为每次x特别多内存,那有可能就满了。内存如果满了,机器就没有时间去运行程序,就会变得特别的慢。

MOMENTUM

我们上面讲到的了这个式子:

w t + 1 = w t − η ▽ w t \begin{align*} w_{t+1} = w_t - \eta \triangledown w_t \end{align*} wt+1=wtηwt

这个是最原始的Grady descent, 我们会发现一个问题,就是本来在等高线上进行梯度下降的时候,它找到的不是最快的下降的那条线,在实际情况中,数据量会很多,数量会很大。比方说做图片什么的,动辄几兆几十兆,如果要再加载几百个这个进去,那就会很慢。这个梯度往往可能会变的抖动会很大。

那有人就想了一个办法去减少抖动。就是我们每一次在计算梯度下降方向的时候,连带着上一次的方向一起考虑,然后取一个比例改变了原本的方向。那这样的话,整个梯度下降的线就会平缓了,抖动也就没有那么大,这个就叫做Momentum, 动量。

v t = γ ⋅ v t − 1 + η ▽ w t w t + 1 = w t − v t \begin{align*} v_t & = \gamma \cdot v_{t-1} + \eta \triangledown w_t \\ w_{t+1} & = w_t - v_t \end{align*} vtwt+1=γvt1+ηwt=wtvt

动量在物理学中就是物体沿某个方向运动的能量。

之前我们每次的wt是直接去减去学习率乘以梯度,现在还考虑了v{t-1}的值,乘上一个gamma,这个值就是我们刚才说的取了一个比例。

Alt text

就像这个图一样,原来是红色,加了动量之后就变成蓝色,可以看到更平稳一些。

RMS-PROP

除了动量法之外呢,还有一个RMS-PROP方法,全称为Root mean square prop。

Alt text

S ∂ l o s s ∂ w = β S ∂ l o s s ∂ w + ( 1 − β ) ∣ ∣ ∂ l o s s ∂ w ∣ ∣ 2 S ∂ l o s s ∂ b = β S ∂ l o s s ∂ b + ( 1 − β ) ∣ ∣ ∂ l o s s ∂ b ∣ ∣ 2 w = w − α ∂ l o s s ∂ w S ∂ l o s s ∂ w b = b − α ∂ l o s s ∂ b S ∂ l o s s ∂ b \begin{align*} S_{\frac{\partial loss}{\partial w}} & = \beta S_{\frac{\partial loss}{\partial w}} + (1 - \beta)||\frac{\partial loss}{\partial w} ||^2 \\ S_{\frac{\partial loss}{\partial b}} & = \beta S_{\frac{\partial loss}{\partial b}} + (1 - \beta)||\frac{\partial loss}{\partial b} ||^2 \\ w & = w - \alpha \frac{\frac{\partial loss}{\partial w}}{\sqrt{S_{\frac{\partial loss}{\partial w}}}} \\ b & = b - \alpha \frac{\frac{\partial loss}{\partial b}}{\sqrt{S_{\frac{\partial loss}{\partial b}}}} \end{align*} SwlossSblosswb=βSwloss+(1β)∣∣wloss2=βSbloss+(1β)∣∣bloss2=wαSwloss wloss=bαSbloss bloss

这个方法看似复杂,其实也是非常简单。这些方法在PyTorch里其实都有包含,我们可以直接调用。我们在这里还是要理解一下它的原理,之后做事的时候也并不需要真的取从头写这些玩意。

在讲它之前,我们再回头来说一下刚刚求解的动量法,动量法其实已经做的比较好了,但是还是有一个问题,它每次的rate是人工定义的。也就是我们上述公式中的 γ \gamma γ, 这个比例是人工定义的,那在RMS-PROP中就写了一个动态的调整方法。

这个动态的调整方法就是我们每一次在进行调整w或者b的时候,都会除以一个根号下的 S ∂ l o s s ∂ w S_{\frac{\partial loss}{\partial w}} Swloss,我们往上看,如果 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss比较大的话,那么 S ∂ l o s s ∂ w S_{\frac{\partial loss}{\partial w}} Swloss也就将会比较大,那放在下面的式子中,根号下,也就是 S ∂ l o s s ∂ w \sqrt{S_{\frac{\partial loss}{\partial w}}} Swloss 在分母上,那么w就会更小,反之则会更大。

所以说,当这一次的梯度很大的时候,这样一个方法就让 ∂ l o s s ∂ w \frac{\partial loss}{\partial w} wloss其实变小了,对b来说也是一样的情况。

也就说,如果上一次的方向变化的很严重,那么这一次就会稍微的收敛一点,就会动态的有个缩放。那么如果上一次变化的很小,那为了加速它,这个值反而就会变大一些。

所以说他是实现了一个动态的学习率的变化,当然它前面还有一个初始值,这个 γ \gamma γ需要人为设置,但是在这个 γ \gamma γ基础上它实现了动态的学习速率的变化。

动态的学习速率考察两个值,一个是前一时刻的变化的快慢,另一个就是它此时此刻变化的快慢。这个就叫做RMS。

ADAM

那我们在这里,其实还有一个方法:ADAM。
V d w = β 1 V d w + ( 1 − β 1 ) d w V d b = β 1 V d b + ( 1 − β 1 ) d b S d w = β 2 S d w + ( 1 − β 2 ) ∣ ∣ d w ∣ ∣ 2 S d b = β 2 S d b + ( 1 − β 2 ) ∣ ∣ d b ∣ ∣ 2 V d w c o r r e c t e d = V d w 1 − β 1 t V d b c o r r e c t e d = V d b 1 − β 1 t S d w c o r r e c t e d = S d w 1 − β 2 t S d b c o r r e c t e d = S d b 1 − β 2 t w = w − α V d b c o r r e c t e d S d w c o r r e c t e d + ε b = b − α V d b c o r r e c t e d S d b c o r r e c t e d + ε \begin{align*} V_{dw} & = \beta_1V_{dw} + (1-\beta_1)dw \\ V_{db} & = \beta_1V_{db} + (1-\beta_1)db \\ S_{dw} & = \beta_2S_{dw} + (1-\beta_2)||dw||^2 \\ S_{db} & = \beta_2S_{db} + (1-\beta_2)||db||^2 \\ & V_{dw}^{corrected} = \frac{V_{dw}}{1-\beta_1^t} \\ & V_{db}^{corrected} = \frac{V_{db}}{1-\beta_1^t} \\ & S_{dw}^{corrected} = \frac{S_{dw}}{1-\beta_2^t} \\ & S_{db}^{corrected} = \frac{S_{db}}{1-\beta_2^t} \\ w & = w - \alpha\frac{V_{db}^{corrected}}{\sqrt{S_{dw}^{corrected}}+\varepsilon} \\ b & = b - \alpha\frac{V_{db}^{corrected}}{\sqrt{S_{db}^{corrected}}+\varepsilon} \\ \end{align*} VdwVdbSdwSdbwb=β1Vdw+(1β1)dw=β1Vdb+(1β1)db=β2Sdw+(1β2)∣∣dw2=β2Sdb+(1β2)∣∣db2Vdwcorrected=1β1tVdwVdbcorrected=1β1tVdbSdwcorrected=1β2tSdwSdbcorrected=1β2tSdb=wαSdwcorrected +εVdbcorrected=bαSdbcorrected +εVdbcorrected

刚刚讲过的RMS特点其实是动态的调整了我们的学习率,之前讲Momentum其实还保持了上一时刻的方向,RMS就没有解决这个问题,RMS把上一时刻的方向给弄没了。

RMS,它的定义其实就没有考虑上次的方向,它只考虑上次变化的大小。而现在提出来这个ADAM,这个ADAM的意思就是Adaptive Momentum, 还记不记得咱们讲随机森林和Adaboost那一节,我们讲过Adaboost就是Adaptive Boosting,这里的Adaptive其实就是一个意思,就是自适应动量,也叫动态变化动量。

ADAM就结合了RMS和动量的两个优点。第一个是他在分母上也加了一个根号下的数,也就做了RMS做的事,然后在分子上还有一个数,这个数就保留了上一时刻的数,比如 V d w c o r r e c t e d V_{dw}^{corrected} Vdwcorrected, 就保留了上一时刻的V,就保留了上一时刻的方向。

所以ADAM既是动态的调整了学习率,又保留了上一时刻的方向。

那除此之外,其实还有一个AdaGrad和L-BFGS方法,不过常用的方法也就是上面详细讲的这几种。

到此为止,我们进阶神经网络的基础知识就都差不多具备了,接下来我们就该来讲解下卷机和序列,比如说LSTM和RNN、CNN的东西。在这些结束之后,我们还会有Attention机制,Transformer机制,YOLO机制,Segmentation机制,还有强化深度学习其实都是基于这些东西。

那我们下节课,就先从RNN来说开去。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/211740.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】Ubuntu添加root用户

在Ubuntu中&#xff0c;默认情况下是禁用了root用户的登录。如果仍然想要启用root用户&#xff0c;并设置root用户的密码&#xff0c;应按照以下步骤进行操作&#xff1a; 一、输入sudo passwd root设置root用户密码 二、切换root用户 sudo -i su root 这两条命令均可却换至…

python+pytest接口自动化(6)-请求参数格式的确定

我们在做接口测试之前&#xff0c;先需要根据接口文档或抓包接口数据&#xff0c;搞清楚被测接口的详细内容&#xff0c;其中就包含请求参数的编码格式&#xff0c;从而使用对应的参数格式发送请求。例如某个接口规定的请求主体的编码方式为 application/json&#xff0c;那么在…

在linux服上部署vue+springboot+nginx项目

一、环境准备 1、安装winscp便于可视化操作linux&#xff1a;winscp安装及关联putty使用_putty.exe没有找到_cherishSpring的博客-CSDN博客 2、安装jdk&#xff1a;linux系统安装jdk-CSDN博客 3、安装mysql&#xff1a;Linux7安装mysql数据库以及navicat远程连接mysql-CSDN博…

K8S客户端二 使用Rancher部署服务

Rancher容器云管理平台 本博客中使用了四台服务器&#xff0c;如下 rancher服务器k8s-masterk8s-worker01k8s-worker02 一、主机硬件说明 序号硬件操作及内核1CPU 4 Memory 4G Disk 100GCentOS72CPU 4 Memory 4G Disk 100GCentOS73CPU 4 Memory 4G Disk 100GCentOS74CPU 4 …

剑指 Offer(第2版)面试题 15:二进制中1的个数

剑指 Offer&#xff08;第2版&#xff09;面试题 15&#xff1a;二进制中1的个数 剑指 Offer&#xff08;第2版&#xff09;面试题 15&#xff1a;二进制中1的个数解法1&#xff1a;位运算解法2&#xff1a;n & (n - 1)相关题目 剑指 Offer&#xff08;第2版&#xff09;面…

Java数据结构之《希尔排序》题目

一、前言&#xff1a; 这是怀化学院的&#xff1a;Java数据结构中的一道难度中等的一道编程题(此方法为博主自己研究&#xff0c;问题基本解决&#xff0c;若有bug欢迎下方评论提出意见&#xff0c;我会第一时间改进代码&#xff0c;谢谢&#xff01;) 后面其他编程题只要我写完…

最小生成树算法

文章目录 最小生成树概述 P r i m Prim Prim 算法 - 稠密图 - O ( n 2 ) O(n^2) O(n2)思路概述时间复杂度分析AcWing 858. Prim算法求最小生成树CODE K r u s k a l Kruskal Kruskal 算法 - 稀疏图 - O ( m l o g m ) O(mlogm) O(mlogm)思路解析时间复杂度分析AcWing 859. Kr…

Raft 算法

Raft 算法 1 背景 当今的数据中心和应用程序在高度动态的环境中运行&#xff0c;为了应对高度动态的环境&#xff0c;它们通过额外的服务器进行横向扩展&#xff0c;并且根据需求进行扩展和收缩。同时&#xff0c;服务器和网络故障也很常见。 因此&#xff0c;系统必须在正常…

Flask使用线程异步执行耗时任务

1 问题说明 1.1 任务简述 在开发Flask应用中一定会遇到执行耗时任务&#xff0c;但是Flask是轻量级的同步框架&#xff0c;即在单个请求时服务会阻被塞&#xff0c;直到任务完成&#xff08;注意&#xff1a;当前请求被阻塞不会影响到其他请求&#xff09;。 解决异步问题有…

已解决AttributeError: module ‘gradio‘ has no attribute ‘outputs‘

问题描述 Traceback (most recent call last): File "/media/visionx/monica/project/ResShift/app.py", line 118, in <module> gr.outputs.File(label"Download the output")AttributeError: module gradio has no attribute outputs 解决办…

vscode 调试jlink

文章目录 软件使用说明1、启动GDB Server2、下载gdb3、vscode配置4、调试 软件 vscodejlink - (JLinkGDBServer.exe)gcc-arm-none-eabi-10-2020-q4-major (arm-none-eabi-gdb.exe) 使用说明 vscode通过TCP端口调用JLinkGDBServer通过jlink连接和操作设备&#xff0c;vscode不…

创建腾讯云存储桶---上传图片--使用cos-sdk完成上传

创建腾讯云存储桶—上传图片 注册腾讯云账号https://cloud.tencent.com/login 登录成功&#xff0c;选择右边的控制台 点击云产品&#xff0c;选择对象存储 创建存储桶 填写名称&#xff0c;选择公有读&#xff0c;私有写一直下一步&#xff0c;到创建 选择安全管理&#…

无人机助力电力设备螺母缺销智能检测识别,python基于YOLOv7开发构建电力设备螺母缺销小目标检测识别系统

传统作业场景下电力设备的运维和维护都是人工来完成的&#xff0c;随着现代技术科技手段的不断发展&#xff0c;基于无人机航拍飞行的自动智能化电力设备问题检测成为了一种可行的手段&#xff0c;本文的核心内容就是基于YOLOv7来开发构建电力设备螺母缺销检测识别系统&#xf…

智能优化算法应用:基于黄金正弦算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于黄金正弦算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于黄金正弦算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.黄金正弦算法4.实验参数设定5.算法结果6.参考…

熬夜会秃头——beta冲刺Day3

这个作业属于哪个课程2301-计算机学院-软件工程社区-CSDN社区云这个作业要求在哪里团队作业—beta冲刺事后诸葛亮-CSDN社区这个作业的目标记录beta冲刺Day3团队名称熬夜会秃头团队置顶集合随笔链接熬夜会秃头——Beta冲刺置顶随笔-CSDN社区 目录 一、团队成员会议总结 1、成员…

【UE】UEC++获取屏幕颜色GetPixelFromCursorPosition()

目录 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 二、函数的调用 三、运行结果 【UE】UE C 获取屏幕颜色GetPixelFromCursorPosition() 一、函数声明与定义 创建一个蓝图方法库方法 GetPixelFromCursorPosition()&#xff0c;并给他指定UF…

使用 STM32 微控制器读取光电传感器数据的实现方法

本文介绍了如何使用 STM32 微控制器读取光电传感器数据的实现方法。通过配置和使用STM32的GPIO和ADC功能&#xff0c;可以实时读取光电传感器的模拟信号并进行数字化处理。本文将介绍硬件连接和配置&#xff0c;以及示例代码&#xff0c;帮助开发者完成光电传感器数据的读取。 …

算法工程师面试八股(搜广推方向)

文章目录 机器学习线性和逻辑回归模型逻辑回归二分类和多分类的损失函数二分类为什么用交叉熵损失而不用MSE损失&#xff1f;偏差与方差Layer Normalization 和 Batch NormalizationSVM数据不均衡特征选择排序模型树模型进行特征工程的原因GBDTLR和GBDTRF和GBDTXGBoost二阶泰勒…

MATLAB R2022b 安装

文章用于学习记录 文章目录 前言下载解压安装包总结 前言 下载解压安装包 MATLAB R2022b —— A9z3 装载(Mount) MATLAB_R2022b_Win64.iso 打开装载好的 DVD 驱动器并找到 setup&#xff0c;单击鼠标右键以管理员身份运行&#xff1a; 点击窗口右上角的 高级选项下拉框&#…

Docker 镜像及其命令

文章目录 镜像Docker 镜像加载原理联合文件系统bootfs和rootfs镜像分层 镜像分层的优势容器层常用命令 镜像 镜像是一种轻量级、可执行的独立软件包&#xff0c;它包含运行某个软件所需的所有内容&#xff0c;我们把应用程序和配置依赖打包好形成一个可交付的运行环境&#xff…