【深度学习笔记】7_5 AdaGrad算法

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

7.5 AdaGrad算法

在之前介绍过的优化算法中,目标函数自变量的每一个元素在相同时间步都使用同一个学习率来自我迭代。举个例子,假设目标函数为 f f f,自变量为一个二维向量 [ x 1 , x 2 ] ⊤ [x_1, x_2]^\top [x1,x2],该向量中每一个元素在迭代时都使用相同的学习率。例如,在学习率为 η \eta η的梯度下降中,元素 x 1 x_1 x1 x 2 x_2 x2都使用相同的学习率 η \eta η来自我迭代:

x 1 ← x 1 − η ∂ f ∂ x 1 , x 2 ← x 2 − η ∂ f ∂ x 2 . x_1 \leftarrow x_1 - \eta \frac{\partial{f}}{\partial{x_1}}, \quad x_2 \leftarrow x_2 - \eta \frac{\partial{f}}{\partial{x_2}}. x1x1ηx1f,x2x2ηx2f.

在7.4节(动量法)里我们看到当 x 1 x_1 x1 x 2 x_2 x2的梯度值有较大差别时,需要选择足够小的学习率使得自变量在梯度值较大的维度上不发散。但这样会导致自变量在梯度值较小的维度上迭代过慢。动量法依赖指数加权移动平均使得自变量的更新方向更加一致,从而降低发散的可能。本节我们介绍AdaGrad算法,它根据自变量在每个维度的梯度值的大小来调整各个维度上的学习率,从而避免统一的学习率难以适应所有维度的问题[1]。

7.5.1 算法

AdaGrad算法会使用一个小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方的累加变量 s t \boldsymbol{s}_t st。在时间步0,AdaGrad将 s 0 \boldsymbol{s}_0 s0中每个元素初始化为0。在时间步 t t t,首先将小批量随机梯度 g t \boldsymbol{g}_t gt按元素平方后累加到变量 s t \boldsymbol{s}_t st

s t ← s t − 1 + g t ⊙ g t , \boldsymbol{s}_t \leftarrow \boldsymbol{s}_{t-1} + \boldsymbol{g}_t \odot \boldsymbol{g}_t, stst1+gtgt,

其中 ⊙ \odot 是按元素相乘。接着,我们将目标函数自变量中每个元素的学习率通过按元素运算重新调整一下:

x t ← x t − 1 − η s t + ϵ ⊙ g t , \boldsymbol{x}_t \leftarrow \boldsymbol{x}_{t-1} - \frac{\eta}{\sqrt{\boldsymbol{s}_t + \epsilon}} \odot \boldsymbol{g}_t, xtxt1st+ϵ ηgt,

其中 η \eta η是学习率, ϵ \epsilon ϵ是为了维持数值稳定性而添加的常数,如 1 0 − 6 10^{-6} 106。这里开方、除法和乘法的运算都是按元素运算的。这些按元素运算使得目标函数自变量中每个元素都分别拥有自己的学习率。

7.5.2 特点

需要强调的是,小批量随机梯度按元素平方的累加变量 s t \boldsymbol{s}_t st出现在学习率的分母项中。因此,如果目标函数有关自变量中某个元素的偏导数一直都较大,那么该元素的学习率将下降较快;反之,如果目标函数有关自变量中某个元素的偏导数一直都较小,那么该元素的学习率将下降较慢。然而,由于 s t \boldsymbol{s}_t st一直在累加按元素平方的梯度,自变量中每个元素的学习率在迭代过程中一直在降低(或不变)。所以,当学习率在迭代早期降得较快且当前解依然不佳时,AdaGrad算法在迭代后期由于学习率过小,可能较难找到一个有用的解

下面我们仍然以目标函数 f ( x ) = 0.1 x 1 2 + 2 x 2 2 f(\boldsymbol{x})=0.1x_1^2+2x_2^2 f(x)=0.1x12+2x22为例观察AdaGrad算法对自变量的迭代轨迹。我们实现AdaGrad算法并使用和上一节实验中相同的学习率0.4。可以看到,自变量的迭代轨迹较平滑。但由于 s t \boldsymbol{s}_t st的累加效果使学习率不断衰减,自变量在迭代后期的移动幅度较小。

%matplotlib inline
import math
import torch
import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l

def adagrad_2d(x1, x2, s1, s2):
    g1, g2, eps = 0.2 * x1, 4 * x2, 1e-6  # 前两项为自变量梯度
    s1 += g1 ** 2
    s2 += g2 ** 2
    x1 -= eta / math.sqrt(s1 + eps) * g1
    x2 -= eta / math.sqrt(s2 + eps) * g2
    return x1, x2, s1, s2

def f_2d(x1, x2):
    return 0.1 * x1 ** 2 + 2 * x2 ** 2

eta = 0.4
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))

输出:

epoch 20, x1 -2.382563, x2 -0.158591

在这里插入图片描述

下面将学习率增大到2。可以看到自变量更为迅速地逼近了最优解。

eta = 2
d2l.show_trace_2d(f_2d, d2l.train_2d(adagrad_2d))

输出:

epoch 20, x1 -0.002295, x2 -0.000000

在这里插入图片描述

7.5.3 从零开始实现

同动量法一样,AdaGrad算法需要对每个自变量维护同它一样形状的状态变量。我们根据AdaGrad算法中的公式实现该算法。

features, labels = d2l.get_data_ch7()

def init_adagrad_states():
    s_w = torch.zeros((features.shape[1], 1), dtype=torch.float32)
    s_b = torch.zeros(1, dtype=torch.float32)
    return (s_w, s_b)

def adagrad(params, states, hyperparams):
    eps = 1e-6
    for p, s in zip(params, states):
        s.data += (p.grad.data**2)
        p.data -= hyperparams['lr'] * p.grad.data / torch.sqrt(s + eps)

与7.3节(小批量随机梯度下降)中的实验相比,这里使用更大的学习率来训练模型。

d2l.train_ch7(adagrad, init_adagrad_states(), {'lr': 0.1}, features, labels)

输出:

loss: 0.243675, 0.049749 sec per epoch

在这里插入图片描述

7.5.4 简洁实现

通过名称为Adagrad的优化器方法,我们便可使用PyTorch提供的AdaGrad算法来训练模型。

d2l.train_pytorch_ch7(torch.optim.Adagrad, {'lr': 0.1}, features, labels)

输出:

loss: 0.243147, 0.040675 sec per epoch

在这里插入图片描述

小结

  • AdaGrad算法在迭代过程中不断调整学习率,并让目标函数自变量中每个元素都分别拥有自己的学习率。
  • 使用AdaGrad算法时,自变量中每个元素的学习率在迭代过程中一直在降低(或不变)。

参考文献

[1] Duchi, J., Hazan, E., & Singer, Y. (2011). Adaptive subgradient methods for online learning and stochastic optimization. Journal of Machine Learning Research, 12(Jul), 2121-2159.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/450502.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

tomcat架构

俗话说,站在巨人的肩膀上看世界,一般学习的时候也是先总览一下整体,然后逐个部分个个击破,最后形成思路,了解具体细节,Tomcat的结构很复杂,但是 Tomcat 非常的模块化,找到了 Tomcat最…

【Unity】时间戳与DateTime

介绍 在开发游戏的时候,和时间打交道是一个必不可少的部分。而时间戳就是用的最多的一个,尤其是和服务端通信。 那时间戳是什么? 时间戳就是从1970年1月1日0时0分0秒起到现在的总毫秒数,为什么时1970/1/1/00:00:00,因…

天梯赛的赛场安排(Python)

作者 陈越 单位 浙江大学 天梯赛使用 OMS 监考系统,需要将参赛队员安排到系统中的虚拟赛场里,并为每个赛场分配一位监考老师。每位监考老师需要联系自己赛场内队员对应的教练们,以便发放比赛账号。为了尽可能减少教练和监考的沟通负担&#…

8块硬盘故障的存储异常恢复案例一则

关键词 华为存储、硬盘域、LUN热备冗余、重构、预拷贝 oracle rac、多路径 There are many things that can not be broken! 如果觉得本文对你有帮助,欢迎点赞、收藏、评论! 一、问题现象 近期遇到的一个案例,现象是一套oracl…

Linux下Nginx配置多域名及SSL证书

接上一篇 《Linux 安装Nginx (Nginx-1.25.4)》 本文描述如何配置Nginx多域名及SSL证书。 假设Nginx安装在/usr/local/nginx目录下。Nginx的配置文件为:/usr/local/nginx/conf/nginx.conf,要实现配置域名和SSL证书,都是修改此配置文件。 1.…

docker部署多功能网络工具箱

功能 查看自己的IP:从多个 IPv4 和 IPv6 来源检测显示本机的IP 查看IP信息:显示所有 IP 的相关信息 可用性检测:检测一些网站的可用性 WebRTC 检测:查看使用 WebRTC 连接时使用的 IP DNS 泄露检测:查看 DNS 出口信息 …

前端Vue中自定义Popup弹框、按钮及内容的设计与实践

标题:前端Vue中自定义Popup弹框、按钮及内容的设计与实践 一、引言 在Web前端开发中,弹框(Popup)是一种常见的用户界面元素,用于向用户显示额外的信息或提供额外的功能。然而,标准的弹框往往不能满足所有…

分布式系统超详解析

目录 常见概念 基本概念 应用/系统 模块/组件 分布式 集群 主/从 中间件 评价指标 可用性 响应时长 吞吐量/并发量 架构演进 单机架构 应用数据分离架构 引入更多的应用服务器结点 读写分离架构 引入缓存--冷热分离的结构 垂直分库 业务拆分--微服务 为了更…

网页脚本 bilibili006:视频下载脚本修改+油猴脚本发布

视频下载脚本修改 原始脚本的下载的视频名称总是错的&#xff0c;调用的代码为 document.querySelector(.tag-txt).textContent &#xff0c;发现这是标签的名称 查找视频名称所在的类名称 <h1 title"任天堂告yuzu模拟器&#xff0c;龙神模拟器会被殃及池鱼吗"…

torch.cuda.is_available()=False

问题&#xff1a; 显示torch.cuda.is_available()False 解决办法&#xff1a;说明这个虚拟环境不可用&#xff0c;删除虚拟环境&#xff0c;重建一个新的虚拟环境 1、删除原来的虚拟环境&#xff0c;假如原虚拟环境为pytorch-old&#xff0c;输入以下命令&#xff0c;先退出当…

如何制作一个包含图文视频信息的二维码如何生成?办公多功能利器!

一个包含图片、文字、视频、PDF文件等多种内容的二维码——二维彩虹H5编辑二维码正在各行各业发挥着重要作用。 和普通的二维码不同&#xff0c;H5编辑二维码可以展示更多种类&#xff08;图文视频等&#xff09;、和数量的内容&#xff0c;被广泛应用在多种办公场景。你可以将…

2024年春招助学活动:一批FPGA高端项目让你轻松拿到大厂offer

这里写目录标题 1、前言2、FPGA行业现状3、简历怎么写4、FPGA高端项目4.1 图像类&#xff1a;FPGA图像缩放多路视频拼接4.2 通信类&#xff1a;千兆网UDP协议栈4.3 通信类&#xff1a;万兆网UDP协议栈4.4 图像通信综合&#xff1a;FPGA图像缩放UDP网络视频传输4.5 图像高速接口…

Java零基础入门到精通_Day 2

08-HelloWorld系例常见问题 4.1 BUG的解决 1:具备识别BUG的能力 多看 2:具备分析BUG的能力 多思考&#xff0c;多查阅资料 3:具备解决BUG的能力 多尝试&#xff0c;多总结 09-Notepad软件的安装和使用 略 10-注释 1.1 注释分类 单行注释 格式://注释信息 多行注释 格式:/*…

基于YOLOv8/YOLOv7/YOLOv6/YOLOv5的铁轨缺陷检测系统(Python+PySide6界面+训练代码)

摘要&#xff1a;开发铁轨缺陷检测系统对于物流行业、制造业具有重要作用。本篇博客详细介绍了如何运用深度学习构建一个铁轨缺陷检测系统&#xff0c;并提供了完整的实现代码。该系统基于强大的YOLOv8算法&#xff0c;并对比了YOLOv7、YOLOv6、YOLOv5&#xff0c;展示了不同模…

unity学习(51)——服务器三次注册限制以及数据库化角色信息6--完结

同一账号只写第一次&#xff0c;不同账号第一次爆炸 &#xff0c;就因为下面部分得到逻辑有问题 修改后的代码如下&#xff1a;1.成功完成角色注册信息的数据库化记录。2.每个账号上限3个角色。3.角色是可以重名的&#xff0c;但是角色的id不会重名。 internal class UserCach…

2024 年广东省职业院校技能大赛(高职组) “云计算应用”赛项样题⑤

2024 年广东省职业院校技能大赛&#xff08;高职组&#xff09; “云计算应用”赛项样题⑤ 模块一 私有云&#xff08;50 分&#xff09;任务 1 私有云服务搭建&#xff08;10 分&#xff09;任务 2 私有云服务运维&#xff08;25 分&#xff09;任务 3 私有云运维开发&#xf…

DC/DC高压模块直流升压可调稳压输出升压变换器5V12V24V48V转50V110V150V130V200V250V300V450V500V600V800V

特点 效率高达 80%以上1*2英寸标准封装单电压输出价格低稳压输出工作温度: -40℃~85℃阻燃封装&#xff0c;满足UL94-V0 要求温度特性好可直接焊在PCB 上 应用 HRB W2~40W 系列模块电源是一种DC-DC升压变换器。该模块电源的输入电压分为&#xff1a;4.5~9V、9~18V、及18~36V、…

bugreport中查看开发者选项动画时长缩放日志

首先打开开发者选项&#xff0c;抓取一份bugreport解压后找到bugreport-机型-时间点.zip文件&#xff0c;然后再解压此文件 解压后进入该文件&#xff0c;找到bugreport-机型-时间点.txt文件 打开此文件&#xff0c;搜索“animator_duration_scale”关键字&#xff0c;找到图片…

虚拟机(KVM)克隆

当需要批量部署虚拟机时&#xff0c;可以使用克隆虚拟机的方式来进行。 使用图形界面来克隆虚拟机。 [rootzhoujunru_node1 zhou]# virsh list --allId Name State ------------------------------ vm01 shut off- vm01-clone shut off克隆完成。

如何在Linux本地搭建Tale网站并实现无公网ip远程访问

文章目录 前言1. Tale网站搭建1.1 检查本地环境1.2 部署Tale个人博客系统1.3 启动Tale服务1.4 访问博客地址 2. Linux安装Cpolar内网穿透3. 创建Tale博客公网地址4. 使用公网地址访问Tale 前言 今天给大家带来一款基于 Java 语言的轻量级博客开源项目——Tale&#xff0c;Tale…