简易机器学习笔记(六)不同优化算法器

前言

我们之前不是说了有关梯度下降公式的事嘛,就是那个
在这里插入图片描述

这样梯度下降公式涉及两个问题,一是梯度下降的策略,二是涉及到参数的选择,如果我们选择固定步长的时候,就会发现我们求的值一直在最小值左右震荡,很难选择到我们期望的值。

假设上图中,x0为我们期望的极小值,yB = xA - yA’xA的时候,xB实际上蹿到了极小值的右侧去了,当我们yC = xB - yB’xB的时候,我们求的yC又跑到极小值的左边去了,反正就是一直会在极小值x0附近震荡。

这时候我们有几种常见的优化算法,这也是我们在之前的代码中看到的SGD(随机梯度下降)优化器的由来。

来自Paddlepaddle的官方文档,图示如下:

  1. SGD(随机梯度下降)
    在这里插入图片描述
  2. Momentum
    引入物理“动量”的概念,累积速度,减少震荡,使参数更新的方向更稳定。

在这里插入图片描述

每个批次的数据含有抽样误差,导致梯度更新的方向波动较大。如果我们引入物理动量的概念,给梯度下降的过程加入一定的“惯性”累积,就可以减少更新路径上的震荡,即每次更新的梯度由“历史多次梯度的累积方向”和“当次梯度”加权相加得到。历史多次梯度的累积方向往往是从全局视角更正确的方向,这与“惯性”的物理概念很像,也是为何其起名为“Momentum”的原因。类似不同品牌和材质的篮球有一定的重量差别,街头篮球队中的投手(擅长中远距离投篮)喜欢稍重篮球的比例较高。一个很重要的原因是,重的篮球惯性大,更不容易受到手势的小幅变形或风吹的影响。

  1. AdaGrad
    根据不同参数距离最优解的远近,动态调整学习率。学习率逐渐下降,依据各参数变化大小调整学习率。

在这里插入图片描述
通过调整学习率的实验可以发现:当某个参数的现值距离最优解较远时(表现为梯度的绝对值较大),我们期望参数更新的步长大一些,以便更快收敛到最优解。当某个参数的现值距离最优解较近时(表现为梯度的绝对值较小),我们期望参数的更新步长小一些,以便更精细的逼近最优解。类似于打高尔夫球,专业运动员第一杆开球时,通常会大力打一个远球,让球尽量落在洞口附近。当第二杆面对离洞口较近的球时,他会更轻柔而细致的推杆,避免将球打飞。与此类似,参数更新的步长应该随着优化过程逐渐减少,减少的程度与当前梯度的大小有关。根据这个思想编写的优化算法称为“AdaGrad”,Ada是Adaptive的缩写,表示“适应环境而变化”的意思。RMSProp是在AdaGrad基础上的改进,学习率随着梯度变化而适应,解决AdaGrad学习率急剧下降的问题。

  1. Adam
    在这里插入图片描述

由于动量和自适应学习率两个优化思路是正交的,因此可以将两个思路结合起来,这就是当前广泛应用的算法。

很显然,除开我们常用的随机梯度下降方法,每种优化的方式都有自己的优势。这主要看目标场景和数据集类型和数值的好坏。

实际代码

实际上在开发中,在Paddlepaddle库只需要修改优化器的选型即可,库本身已经做好了这些方案的分装

#仅优化算法的设置有所差别
def train(model):
    model.train()
    #调用加载数据的函数
    train_loader = load_data('train')
    
    #四种优化算法的设置方案,可以逐一尝试效果
    opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
    # opt = paddle.optimizer.Momentum(learning_rate=0.01, momentum=0.9, parameters=model.parameters())
    # opt = paddle.optimizer.Adagrad(learning_rate=0.01, parameters=model.parameters())
    # opt = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())
    
    EPOCH_NUM = 3
    for epoch_id in range(EPOCH_NUM):
        for batch_id, data in enumerate(train_loader()):
            #准备数据
            images, labels = data
            images = paddle.to_tensor(images)
            labels = paddle.to_tensor(labels)
            
            #前向计算的过程
            predicts = model(images)
            
            #计算损失,取一个批次样本损失的平均值
            loss = F.cross_entropy(predicts, labels)
            avg_loss = paddle.mean(loss)
            
            #每训练了100批次的数据,打印下当前Loss的情况
            if batch_id % 200 == 0:
                print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
            
            #后向传播,更新参数的过程
            avg_loss.backward()
            # 最小化loss,更新参数
            opt.step()
            # 清除梯度
            opt.clear_grad()

    #保存模型参数
    paddle.save(model.state_dict(), 'mnist.pdparams')
    
#创建模型    
model = MNIST()
#启动训练过程
train(model)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/286768.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

java 开源中文的繁简体转换工具 opencc4j-01-overview

Opencc4j Opencc4j 支持中文繁简体转换,考虑到词组级别。 拓展阅读 pinyin 汉字转拼音 pinyin2hanzi 拼音转汉字 segment 高性能中文分词 opencc4j 中文繁简体转换 nlp-hanzi-similar 汉字相似度 word-checker 拼写检测 sensitive-word 敏感词 Features 特…

视频融合云平台/智慧监控平台EassyCVR告警警告出错是什么原因?该如何解决?

视频集中存储/云存储/视频监控管理平台EasyCVR能在复杂的网络环境中,将分散的各类视频资源进行统一汇聚、整合、集中管理,实现视频资源的鉴权管理、按需调阅、全网分发、智能分析等。AI智能/大数据视频分析EasyCVR平台已经广泛应用在工地、工厂、园区、楼…

快速打通 Vue 3(一):基本介绍与组合式 API

很激动进入了 Vue 3 的学习,作为一个已经上线了三年多的框架,很多项目都开始使用 Vue 3 来编写了 这一组文章主要聚焦于 Vue 3 的新技术和新特性 如果想要学习基础的 Vue 语法可以看我专栏中的其他博客 Vue(一):Vue 入…

数据库攻防学习之Redis

Redis 0x01 redis学习 在渗透测试面试或者网络安全面试中可能会常问redis未授权等一些知识,那么什么是redis?redis就是个数据库,常见端口为6379,常见漏洞为未授权访问。 0x02 环境搭建 这里可以自己搭建一个redis环境&#xf…

VS2022 Android NativeActivity 开发指南

几年前最初使用VS时,记得是有Android NativeActivity的,今天更新到了2022最新版,发现找不到这个创建选项。 然后确保安装了C 跨平台开发工具后,开始排查原因。 Visual Studio 2022 中没有“本机活动应用程序” - android - SO中…

AndroidStudio导入程序、项目(教程)

目录 1. 首先解压压缩包,转为文件夹 2. 打开解压好的项目文件夹,删除.gradle和.idea这两个文件 3. 修改bulid.gradle文件,将gradle的版本型号改成自己的 (1) 传统结构 (2) 简洁结构 4. 打开android stdio软件,导入已经修改好…

STM32 学习(二)GPIO

目录 一、GPIO 简介 1.1 GPIO 基本结构 1.2 GPIO 位结构 1.3 GPIO 工作模式 二、GPIO 输出 三、GPIO 输入 1.1 传感器模块 1.2 开关 一、GPIO 简介 GPIO(General Purpose Input Output)即通用输入输出口。 1.1 GPIO 基本结构 如下图&#xff0…

Docker(八)Python+旧版本chrome+selenium+oss2+fastapi镜像制作

目录 一、背景二、能力三、核心流程图四、制作镜像1.资源清单2.Dockerfile3.制作镜像 五、启动测试 一、背景 近几年我们线下的创业团队已从零到一开发过好几个小程序项目,都是和体育相关。其中生成海报分享图片好像都是不可或缺的功能。之前的项目老板给的时间都比…

STM32MP157D-DK1 Qt程序交叉编译与运行测试

上篇文章介绍了STM32MP157D-DK1开发板Qt镜像的构建,通过在Ubuntu中重新编译带有Qt功能的系统来实现。 本篇在上篇的基础上,继续搭建Qt的交叉编译环境,实现Qt程序在Ubuntu中编译,在STM32MP157板子中运行。 1 编译安装SDK 在上篇…

dyld: Library not loaded: /usr/lib/swift/libswiftCoreGraphics.dylib

更新Xcode14后低版本iPhone调试报错 dyld: Library not loaded: /usr/lib/swift/libswiftCoreGraphics.dylib Referenced from: /var/containers/Bundle/Application/…/….app/… Reason: image not found 这是缺少libswiftCoreGraphics库 直接导入libswiftCoreGraphics库即…

数据库云平台新数科技完成B轮融资,打造全链路智能化数据库云平台

数据库云平台软件厂商「北京新数科技有限公司」(以下简称「新数科技」)已于2023年完成B1轮和B2轮融资,分别由渤海创富和彬复资本投资;义柏资本担任本轮融资独家财务顾问。 新数科技成立于2014年,当前产品矩阵包括数据库…

经典卷积神经网络-VGGNet

经典卷积神经网络-VGGNet 一、背景介绍 VGG是Oxford的Visual Geometry Group的组提出的。该网络是在ILSVRC 2014上的相关工作,主要工作是证明了增加网络的深度能够在一定程度上影响网络最终的性能。VGG有两种结构,分别是VGG16和VGG19,两者并…

@EnableXXX注解+@Import轻松实现SpringBoot的模块装配

文章目录 前言原生手动装配模块装配概述模块装配的四种方式准备工作声明自定义注解 导入普通类导入配置类导入ImportSelector导入ImportBeanDefinitionRegistrar 总结TODO后续--条件装配 前言 最早我们开始学习或接触过 SSH 或者 SSM 的框架整合,大家应该还记得那些…

RabbitMQ基础知识

一.什么是RabbitMQ RabbitMQ是一个开源的、高性能的消息队列系统,用于在应用程序之间实现异步通信。它实现了AMQP(Advanced Message Queuing Protocol)协议,可以在分布式系统中传递和存储消息。 消息队列是一种将消息发送者和接收…

微服务雪崩问题及解决方案

雪崩问题 微服务中,服务间调用关系错综复杂,一个微服务往往依赖于多个其它微服务。 微服务之间相互调用,因为调用链中的一个服务故障,引起整个链路都无法访问的情况。 如果服务提供者A发生了故障,当前的应用的部分业务…

剑指offer

1、排序算法 0、排序算法分类 1、直接插入排序 基本思想 直接插入排序的基本思想是:将数组中的所有元素依次跟前面已经排好的元素相比较,如果选择的元素比已排序的 元素小,则交换,直到全部元素都比较过为止。 算法描述 1、从…

ubuntu系统没有网络图标的解决办法

参考文章:https://blog.csdn.net/qq_56922632/article/details/132309643 1. 执行关闭网络服务的命令,关闭网络服务sudo service NetworkManager stop2. 删除网络的状态文件sudo rm /var/lib/NetworkManager/NetworkManager.state3. 修改网络的配置文件sudo vi /etc…

每日一题(LeetCode)----二叉树--二叉树的层平均值

每日一题(LeetCode)----二叉树–二叉树的层平均值 1.题目(637. 二叉树的层平均值) 给定一个非空二叉树的根节点 root , 以数组的形式返回每一层节点的平均值。与实际答案相差 10-5 以内的答案可以被接受。 示例 1: 输入:root […

【Math】重要性采样 Importance sample推导【附带Python实现】

【Math】重要性采样 Importance sample推导【附带Python实现】 文章目录 【Math】重要性采样 Importance sample推导【附带Python实现】1. Why need importance sample?2. Derivation of Discrete Distribution3. Derivation of Continuous Distribution3. An Example 笔者在学…

k8s的声明式资源管理(yaml文件)

1、声明式管理的特点 (1)适合对资源的修改操作 (2)声明式管理依赖于yaml文件,所有的内容都在yaml文件当中 (3)编辑好的yaml文件,还是要依靠陈述式的命令发布到k8s集群当中 kubect…