【深度学习笔记】深度学习训练技巧

深度学习训练技巧

1 优化器

  1. 随机梯度下降及动量

    • 随机梯度下降算法对每批数据 ( X ( i ) , t ( i ) ) (X^{(i)},t^{(i)}) (X(i),t(i)) 进行优化
      g = ∇ θ J ( θ ; x ( i ) , t ( i ) ) θ = θ − η g g=\nabla_\theta J(\theta;x^{(i)},t^{(i)})\\ \theta = \theta -\eta g g=θJ(θ;x(i),t(i))θ=θηg
      随机梯度下降算法的基本思想是,在每次迭代中,随机选择一个样本 i i i,计算该样本的梯度 g = ∇ θ J ( θ ; x ( i ) , t ( i ) ) g=\nabla_\theta J(\theta;x^{(i)},t^{(i)}) g=θJ(θ;x(i),t(i)),然后按照梯度的反方向更新参数 θ \theta θ,即 θ = θ − η g \theta = \theta -\eta g θ=θηg,其中 η \eta η 是学习率,控制更新的步长。

    • 基于动量的更新过程

      为了改善随机梯度下降算法的收敛性,可以引入动量(momentum)的概念,即在更新参数时,考虑之前的更新方向和幅度,使得参数沿着一个平滑的轨迹移动。
      v = γ v − η g θ = θ + v v=\gamma v-\eta g\\ \theta =\theta+v v=γvηgθ=θ+v
      其中 v v v 是动量变量,初始为零向量, γ \gamma γ 是动量系数,控制之前更新的影响程度,一般取 0.9 0.9 0.9 左右的值。

  2. Adagrad

    Adagrad 是一种自适应学习率的梯度下降算法,它可以根据不同的参数调整不同的学习率,使得目标函数更快地收敛。

    • 梯度记为 g = ∇ θ J ( θ ; x ( i ) , t ( i ) ) g=\nabla_\theta J(\theta;x^{(i)},t^{(i)}) g=θJ(θ;x(i),t(i))

    • 更新过程

      对于每个参数 θ i \theta_i θi,维护一个累积变量 c i c_i ci,初始为 0,然后每次将该参数的梯度平方 g i 2 g_i^2 gi2 累加到 c i c_i ci 上,即 c i = c i + g i 2 c_i=c_i+g_i^2 ci=ci+gi2
      c i = c i + g i 2 θ i = θ i − η c i + ϵ g i c_i=c_i+g_i^2\\ \theta_i=\theta_i-\frac{\eta}{\sqrt{c_i+\epsilon}}g_i ci=ci+gi2θi=θici+ϵ ηgi
      在更新该参数时,使用一个自适应的学习率 η c i + ϵ \frac{\eta}{\sqrt{c_i+\epsilon}} ci+ϵ η,其中 η \eta η 是全局学习率, ϵ \epsilon ϵ 是一个小常数,用于防止除零错误。

    • 工作原理

      如果一个参数的梯度一直很大,那么它的 c i c_i ci 也会很大,从而降低它的学习率,防止过度更新;

      如果一个参数的梯度一直很小,那么它的 c i c_i ci 也会很小,从而增加它的学习率,加快更新速度。这就实现了自适应的学习率调整,有利于加速收敛和避免震荡。

    • 优点

      计算简单,不需要手动调整学习率,适合处理稀疏数据和特征。

    • 缺点

      累积变量 c i c_i ci 会随着迭代次数增加而不断增大,导致学习率过小,甚至接近于零,使得后期训练缓慢或停滞。

  3. RMSProp

    RMSProp算法在Adagrad的基础上提出改进,以解决学习率单调下降的问题

    • 基本思想

      引入一个遗忘因子 γ \gamma γ 。对于每个参数 θ \theta θ,维护一个累积变量 c c c,初始为 0,然后每次将该参数的梯度平方 g 2 g^2 g2 乘以一个衰减系数 ( 1 − γ ) (1-\gamma) (1γ),再加到 c c c 上,即 c = γ c + ( 1 − γ ) g 2 c=\gamma c+(1-\gamma)g^2 c=γc+(1γ)g2
      c = γ c + ( 1 − γ ) g 2 θ = θ − η c + ϵ g c=\gamma c+(1-\gamma)g^2\\ \theta=\theta-\frac{\eta}{\sqrt{c+\epsilon}}g c=γc+(1γ)g2θ=θc+ϵ ηg
      γ \gamma γ 是一个介于 0 和 1 之间的常数,通常为0.9、0.99、0.999,用于控制历史信息的影响程度。

    • 与Adagrad对比

      c c c 不会随着迭代次数增加而无限增大,而是保持在一个合理的范围内,从而使得学习率下降得更加平稳。

  4. Adam

    Adam 算法是一种自适应学习率的梯度下降算法,它结合了 Momentum 和 RMSProp 的优点

    • 简单形式
      KaTeX parse error: Expected 'EOF', got '&' at position 26: …m+(1-\beta_1)g &̲ \text {(积攒历史梯度…
      其中 β 1 \beta_1 β1 β 2 \beta_2 β2 是两个介于 0 和 1 之间的常数,用于控制历史信息的影响程度。一般取 β 1 = 0.9 \beta_1=0.9 β1=0.9 β 2 = 0.999 \beta_2=0.999 β2=0.999 ϵ = 1 0 − 8 \epsilon=10^{-8} ϵ=108

    • 完整形式
      m = β 1 m + ( 1 − β 1 ) g   , m t = m 1 − β 1 t c = β 2 c + ( 1 − β 2 ) g 2   , c t = c 1 − β 2 t θ = θ − η c t + ϵ m t m=\beta_1m+(1-\beta_1)g\ ,m_t=\frac{m}{1-\beta_1^t} \\ c=\beta_2c+(1-\beta_2)g^2\ ,c_t=\frac{c}{1-\beta_2^t} \\ \theta=\theta-\frac{\eta}{\sqrt{c_t+\epsilon}}m_t m=β1m+(1β1)g ,mt=1β1tmc=β2c+(1β2)g2 ,ct=1β2tcθ=θct+ϵ ηmt
      为了消除偏差,还需要对 m m m c c c 进行偏差修正,即除以 1 − β 1 t 1-\beta_1^t 1β1t 1 − β 2 t 1-\beta_2^t 1β2t,其中 t t t 是迭代次数。

    Adam 算法可以利用一阶矩和二阶矩的信息,实现自适应的学习率调整,使得参数在梯度方向上加速,而在垂直梯度方向上减速,从而避免参数在最优值附近的震荡,加快收敛速度。

关于优化器的更多细节:http://cs231n.github.io/neural-networks-3

2 处理过拟合

  1. 过拟合

    • 定义:对训练集拟合得很好,但在验证集表现较差

      神经网络 通常含有大量参数 (数百万甚至数十亿), 容易过拟合

    • 处理策略:参数正则化、早停、随机失活、数据增强

  2. 早停

    image-20240218152006374

    当发现训练损失逐渐下降,但验证集损失逐渐上升时,及时停止优化

  3. 随机失活

    • 训练过程

      在训练迭代过程中,以 p p p(通常为0.5)的概率随机舍弃掉每个隐含层神经元(输出置零

      image-20240218152407979

      这些被置零的输出,将用于在反向传播中计算梯度

    • 优点:

      • 一个隐含层神经元不能依赖于其它存在的神经元,因此可以防止神经元出现复杂的相互协同(co-adaptations)

      • 相当于在合理的时间内训练了大量不同的网络,并将其结果平均

    • 测试过程

      使用"平均网络(mean network)”,包含所有隐含层神经元

      image-20240218152916977

      需要调整神经元输出的权重,用来弥补训练中只有一部分被激活的现象

      • p=0.5时,将权重减半
      • p=0.1,在权重上乘1-p,即0.9

      实践中,p在低层设得较小,例如0.2,但在高层设得更大,例如0.5

      这样得到的结果与在大量网络上做平均得到的结果类似

  4. 数据增强

    数据增强(Data Augmentation)是一种用于优化深度学习模型的方法,它可以通过从现有数据生成新的训练数据来扩展原数据集,从而提高模型的泛化能力和防止过拟合。

    数据增强的工具可以对数据进行各种操作和转换,如旋转、缩放、裁剪、翻转、调整亮度、对比度、颜色等,以生成新的、多样的、有代表性的样本。

    • 随机翻转

      通常只用左右翻转

      image-20240218153939733

    • 随机缩放和裁剪

      image-20240218154027438

      • 将测试图像缩放到模型要求的输入大小
      • 剪裁一个区域(通常是中心区域)并输入到模型
      • 剪裁多个并输入到模型,对输出作平均
    • 随机擦除

      image-20240218154118949

      测试时可以输入整张图像

3 批归一化

  1. 内部协变量偏移(Internal covariate shift)

    当使用SGD时,不同迭代次数时输入到神经网络的数据不同,可能导致某些层输出的分布在不同迭代次数时不同。

    ICS:训练中,深度神经网络中间节点分布的变化。可能增加优化难度

  2. 通过归一化来减少ICS

    对每个标量形式的特征单独进行归一化,使其均值为0,方差为1。

    对于d维激活 x = ( x 1 , … , x d ) x=(x_1,…,x_d) x=(x1,,xd)​,作如下归一化
    x ^ i = x i − E [ x i ] Var [ x i ] \hat x_i=\frac{x_i-E[x_i]}{\sqrt{\text{Var}[x_i]}} x^i=Var[xi] xiE[xi]
    保持该层的表达能力
    y i = γ i x ^ i + β i y_i=\gamma_i\hat x_i+\beta_i yi=γix^i+βi
    γ i = Var [ x i ] \gamma_i=\sqrt{\text{Var}[x_i]} γi=Var[xi] β i = E [ x i ] \beta_i=E[x_i] βi=E[xi] ,恢复到原来的激活值。

  3. 批归一化BN(Batch Normalization)

    批归一化(Batch Normalization,BatchNorm)是一种用于优化深度神经网络的方法,它可以通过对每一层的输入数据进行标准化处理,使其均值为0,方差为1,从而减少每一层输入数据分布的变化,加快网络的收敛速度,提高网络的泛化能力和鲁棒性。

    • 基本思想:

      在每一层的输入数据上进行如下的变换:
      x ~ i = x i − μ B σ B 2 + ϵ (归一化) y i = γ x ~ i + β (尺度变换和偏移) \tilde{x}_i=\frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\epsilon}} \quad \text{(归一化)}\\ y_i=\gamma\tilde{x}_i+\beta \quad \text{(尺度变换和偏移)} x~i=σB2+ϵ xiμB(归一化)yi=γx~i+β(尺度变换和偏移)
      其中, x i x_i xi 是第 i i i 个神经元的输入, μ B \mu_B μB σ B 2 \sigma_B^2 σB2 是该层输入数据的均值和方差, ϵ \epsilon ϵ 是一个小常数,用于防止除零错误, x ~ i \tilde{x}_i x~i 是归一化后的输入, γ \gamma γ β \beta β 是可学习的参数,用于调整数据的尺度和偏移, y i y_i yi 是最终的输出。

    • 优点

      • 可以选择较大的初始学习率,加快网络的收敛。
      • 可以减少正则化参数的选择问题,如 Dropout、L2 正则项等。
      • 可以把训练数据彻底打乱,防止每批训练的时候,某一个样本经常被挑选到。
      • 可以缓解梯度消失或梯度爆炸的问题,使得网络可以使用更深的结构和更多的非线性激活函数。
    • 缺点

      • 增加了网络的计算量和内存消耗。
      • 对于小批量的数据,可能会导致不稳定的结果。
      • 对于某些任务,可能会降低网络的表达能力或性能。
  4. 其他归一化技巧

    image-20240218161027678

    • 批归一化(Batch norm)用于CNN
    • 层归一化(Layer norm)用于RNN
    • 实例归一化(Instance norm)用于图像风格化
    • 群归一化(Group norm)用于CNN处理批较小的情况

4 超参数选取

  1. 超参数

    超参数: 控制算法行为,且不会被算法本身所更新,通常决定了一个模型的能力

    对于一个深度学习模型, 超参数包括

    • 层数,每层的神经元数目
    • 正则化系数
    • 学习率
    • 参数衰减率(Weight decay rate)
    • 动量项(Momentum rate)
  2. 如何选择深度学习模型的架构

    • 熟悉数据集
    • 与之前见过的数据/任务比较
      • 样本数目
      • 图像大小, 视频长度, 输入复杂度…
    • 最好从在类似数据集或任务上表现良好的模型开始
  3. 如何选择其它超参数

    由Fei-Fei Li & Justin Johnson & Serena Yeung(CS231n 2019,Stanford University)给出的建议

    • 第1步:观察初始损失

      • 确保损失的计算是正确的
      • 将权重衰减设为零
    • 第2步:在一组小样本上过拟合

      • 在一组少量样本的训练集上训练,尝试达到100%训练准确率
      • 如果训练损失没有下降,说明是不好的初始化,学习率太小,模型太小
      • 如果训练损失变成Inf或NaN,说明是不好的初始化,学习率太大
    • 第3步:找到使损失下降的学习率

      • 在全部数据上训练模型,并找到使损失值能够快速下降的学习率

        当损失值下降较慢时,将学习率缩小10倍

        使用较小的参数衰减

    • 第4步:粗粒度改变学习率,训练1-5轮

      • 在上一步的基础上,尝试一些比较接近的学习率和衰减率
      • 常用的参数衰减率:1e-4,1e-5,0
    • 第5步:细粒度改变学习率,训练更长时间

      • 使用上一步找到的最好的学习率,并训练更长时间 (10-20 轮),期间不改变学习率
    • 第6步:观察损失曲线

      • 训练损失通常用滑动平均绘制,否则会有很多点聚集在一起

        image-20240218161857392

        有问题的损失曲线:

        image-20240218162126133

延伸阅读:

  • 各种深度学习模型的优化方法:http://cs231n.github.io/neural-networks-3/
  • 关于不同正则化方法的讨论:https://www.cnblogs.com/LXP-
  • Never/p/11566064.html Li,Chen,Hu,Yang,2019
    Understanding the Disharmony Between Dropout and Batch Normalization by Variance Shift CVPR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/411134.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

纯国产轻量化数字孪生:智慧城市、智慧工厂、智慧校园、智慧社区。。。

AMRT 3D数字孪生引擎介绍 AMRT3D引擎是一款融合了眸瑞科技的AMRT格式与轻量化处理技术为基础,以降本增效为目标,支持多端发布的一站式纯国产自研的CS架构项目开发引擎。 引擎包括场景搭建、UI拼搭、零代码交互事件、光影特效组件、GIS/BIM组件、实时数据…

【JavaEE】_前端使用GET请求的queryString向后端传参

目录 1. GET请求的query string 2. 关于query string的urlencode 1. GET请求的query string 1. 在HttpServletRequest请求中,getParameter方法用于在服务器这边获取到请求中的参数,主要在query string中; query string中的键值对都是程序…

【黑马程序员】STL之vector常用操作

文章目录 vectorvector 基本概念功能vector与普通数组区别vector动态扩展 在这里插入图片描述函数原型代码示例 vector容器的赋值操作函数原型代码示例 vector容量和大小函数原型代码示例 vector插入删除函数原型代码示例 vector容器数据存取函数原型代码示例 swap使用代码示例…

当面试问你接口测试时,不要再说不会了!

很多人会谈论接口测试。到底什么是接口测试?如何进行接口测试?这篇文章会帮到你。 01 前端和后端 在谈论接口测试之前,让我们先明确前端和后端这两个概念。 前端是我们在网页或移动应用程序中看到的页面,它由 HTML 和 CSS 编写…

【Golang数组】

数组 数组的引入内存分析数组遍历数组的初始化方式注意事项二维数组二维数组的遍历 数组的引入 【1】练习引入: package main import "fmt" func main(){//实现的功能:给出五个学生的成绩,求出成绩的总和,平均数&…

邮件系统国产化,U-Mail助推企业数字化建设

在当今数字化时代,企业管理和办公效率的提升已成为企业发展的关键。随着信息技术的迅速发展,邮件系统成为许多企业提高办公效率和管理水平的重要工具。然而,长期以来,国内企业在邮件系统方面主要依赖于国外产品,这不仅…

Another Redis Desktop Manager工具连接集群

背景:使用Another Redis Desktop Manager连接redsi集群 win10安装 使用 下载 某盘: 链接:https://pan.baidu.com/s/1dg9kPm9Av8-bbpDfDg9DsA 提取码:t1sm 使用

【前端素材】推荐优质后台管理系统Skydash平台模板(附源码)

一、需求分析 后台管理系统(或称作管理后台、管理系统、后台管理平台)是一种专门用于管理网站、应用程序或系统后台运营的软件系统。它通常由一系列功能模块组成,为管理员提供了管理、监控和控制网站或应用程序的各个方面的工具和界面。以下…

防御保护----内容安全

八.内容安全--------------------------。 IAE引擎: IAE引擎里面的技术:DFI和DPI技术--- 深度检测技术 DPI --- 深度包检测技术--- 主要针对完整的数据包(数据包分片,分段需要重组),之后对 数据包的内容进行…

C/C++有序数组中插入元素

一、不利用指针 代码&#xff1a; int i; void insert(int ,int , int ); int main() {int a[100];int n, m;cout<<"输入数组元素个数\n";cin >> n;cout << "输入数组元素\n";for (i 0; i < n; i) {cin >> a[i];}cout <&…

Less预处理器教程

学习源码可以看我的个人前端学习笔记 (github.com):qdxzw/frontlearningNotes 觉得有帮助的同学&#xff0c;可以点心心支持一下哈 一、Less介绍 less官方文档 lesscss.org/ less中文文档 less.bootcss.com/ less是一种css预处理器&#xff0c;它扩展了css语言&#xff0c…

【Java程序设计】【C00267】基于Springboot的在线考试系统(有论文)

基于Springboot的在线考试系统&#xff08;有论文&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 本系统是基于Springboot的在线考试系统&#xff1b;本系统主要分为管理员、教师和学生三种角色&#xff1b; 管理员登录系统后&#xff0c;可以对首页&#x…

广联达协同办公系统GetAllData接口存在敏感信息泄露漏洞 附POC软件

@[toc] 广联达协同办公系统GetAllData接口存在敏感信息泄露漏洞 附POC软件 免责声明:请勿利用文章内的相关技术从事非法测试,由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后果及损失,均由使用者本人负责,所产生的一切不良后果与文章作者无关。该文…

LeetCode69. x 的平方根(C++)

LeetCode69. x 的平方根 题目链接代码 题目链接 https://leetcode.cn/problems/sqrtx/description/ 代码 class Solution { public:int mySqrt(int x) {int right x, left 0, ans -1;while(left < right){long long mid left (right - left) / 2;if(mid * mid <…

Python算法题集_全排列

Python算法题集_全排列 题46&#xff1a;全排列1. 示例说明2. 题目解析- 题意分解- 优化思路- 测量工具 3. 代码展开1) 标准求解【标记数组递归】2) 改进版一【指针递归】3) 改进版二【高效迭代模块】4) 改进版三【高效迭代模块极简代码】 4. 最优算法5. 相关资源 本文为Python…

猫头虎分享已解决Bug || Vue中的TypeError: Cannot read property ‘name‘ of undefined 错误

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

<Elon Musk>里面的思考

从Elon Musk身上学到的三个重要人生课程 Introduction 在这篇博文中&#xff0c;我将分享我从Elon Musk身上学到的三个重要人生课程。结合了一些个人的经历&#xff0c;以及视频内容中与三位朋友的交流&#xff0c;希望能给大家带来一些启发和帮助。 第一课&#xff1a;如何…

基于springboot+vue的二手图书交易平台(源码+论文)

文章目录 目录 文章目录 前言 一、功能设计 二、功能实现 前台系统功能模块分为 后台系统功能模块分为 三、库表设计 四、论文 前言 在互联网上所有产品的分类信息中&#xff0c;电子类的产品信息无疑是最丰富的&#xff0c;一大批电子资讯类网站从中国互联网诞生初期就开始为…

算法--贪心

这里写目录标题 区间问题区间选点引入算法思想例题代码 最大不相交区间的数量算法思想例题代码 一级目录二级目录二级目录二级目录 一级目录二级目录二级目录二级目录 区间问题 区间选点 引入 区间问题会给定几个区间&#xff0c;之后要求我们在数轴上选取尽量少的点&#xf…

【MySQL】表的约束 -- 详解

表中一定要有各种约束&#xff0c;通过约束让我们在未来插入数据库表中的数据是符合预期的。约束本质是通过技术手段倒逼程序员插入正确的数据&#xff0c;反过来站在 MySQL 的角度&#xff0c;凡是插入进来的数据都是符合数据约束的。约束的最终目标&#xff1a;保证数据的完整…