Python 梯度下降法(六):Nadam Optimize

文章目录

  • Python 梯度下降法(六):Nadam Optimize
    • 一、数学原理
      • 1.1 介绍
      • 1.2 符号定义
      • 1.3 实现流程
    • 二、代码实现
      • 2.1 函数代码
      • 2.2 总代码
    • 三、优缺点
      • 3.1 优点
      • 3.2 缺点
    • 四、相关链接

Python 梯度下降法(六):Nadam Optimize

一、数学原理

1.1 介绍

Nadam(Nesterov-accelerated Adaptive Moment Estimation)优化算法是 Adam 优化算法的改进版本,结合了 Nesterov 动量(Nesterov Momentum)和 Adam 算法的优点。

Nadam 在 Adam 算法的基础上引入了 Nesterov 动量的思想。Adam 算法通过计算梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差)来自适应地调整每个参数的学习率。而 Nesterov 动量则是在计算梯度时,考虑了参数在动量作用下未来可能到达的位置的梯度,从而让优化过程更具前瞻性。

1.2 符号定义

设置一下超参数:

参数说明
η \eta η学习率,控制参数更新的步长
m m m一阶矩估计,梯度均值
β 1 \beta_{1} β1一阶矩指数衰减率,通常取 0.9 0.9 0.9
v v v二阶矩估计,梯度未中心化方差
β 2 \beta_{2} β2二阶矩指数衰减率,通常取 0.999 0.999 0.999
ϵ \epsilon ϵ无穷小量,用于避免分母为零, 1 0 − 8 10^{-8} 108
g t g_{t} gt t t t时刻位置的梯度
θ \theta θ需要进行拟合的参数

1.3 实现流程

  1. 初始化参数: θ n × 1 \theta_{n\times 1} θn×1 m 0 ⃗ n × 1 = 0 \vec{m_{0}}_{n\times 1}=0 m0 n×1=0 v 0 ⃗ n × 1 = 0 \vec{v_{0}}_{n\times 1}=0 v0 n×1=0
  2. 更新一阶矩估计 m t m_{t} mt m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt1+(1β1)gt
  3. 更新二阶矩估计 v t v_{t} vt v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt1+(1β2)gt2
  4. 偏差修正:由于 m 0 , v 0 = 0 m_{0},v_{0}=0 m0,v0=0,在训练初期会存在偏差,需要进行修正: m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_{t}=\frac{m_{t}}{1-\beta_{1}^{t}},\hat{v}_{t}=\frac{v_{t}}{1-\beta_{2}^{t}} m^t=1β1tmt,v^t=1β2tvt
  5. 计算预估一阶矩: m ~ t = β 1 m ^ t + ( 1 − β 1 ) g t 1 − β 1 t \widetilde{m}_{t}=\beta_{1}\hat{m}_{t}+\frac{(1-\beta_{1})g_{t}}{1-\beta_{1}^{t}} m t=β1m^t+1β1t(1β1)gt
  6. 更新模型参数 θ t \theta_{t} θt θ t = θ t − 1 − η v t ^ + ϵ ⊙ m ~ t \theta_{t}=\theta_{t-1}-\frac{\eta}{\sqrt{ \hat{v_{t}} }+\epsilon}\odot\widetilde{m}_{t} θt=θt1vt^ +ϵηm t

二、代码实现

2.1 函数代码

# 定义 Nadam 函数
def nadam_optimizer(X, y, eta, num_iter=1000, beta1=0.9, beta2=0.999, epsilon=1e-8, threshold=1e-8):
    """
    X: 数据 x  mxn,可以在传入数据之前进行数据的归一化
    y: 数据 y  mx1
    eta: 学习率
    num_iter: 迭代次数
    beta: 衰减率
    epsilon: 无穷小
    threshold: 阈值
    """
    m, n = X.shape
    theta, mt, vt = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1))  # 初始化数据
    loss_ = []
    
    for t in range(1, num_iter + 1):
        
        # 计算梯度
        h = X.dot(theta)
        err = h - y
        loss_.append(np.mean(err ** 2) / 2)
        g = (1 / m) * X.T.dot(err)
                
        # 一阶矩估计
        mt = beta1 * mt + (1 - beta1) * g
        # 二阶矩估计
        vt = beta2 * vt + (1 - beta2) * g ** 2

        # 先计算偏差修正,后面需要使用到,并且去除负数
        m_hat, v_hat = mt / (1 - pow(beta1, t)), np.maximum(vt / (1 - pow(beta2, t)), 0)

        # 计算预估一阶矩
        m_pre = beta1 * m_hat + (1 - beta1) * g / (1 - pow(beta1, t))
        
        # 更新参数
        theta = theta - np.multiply((eta / (np.sqrt(v_hat) + epsilon)), m_pre)

        # 检查是否收敛
        if t > 1 and abs(loss_[-1] - loss_[-2]) < threshold:
            print(f"Converged at iteration {t}")
            break

    return theta.flatten(), loss_

2.2 总代码

import numpy as np
import matplotlib.pyplot as plt

# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False

# 定义 Nadam 函数
def nadam_optimizer(X, y, eta, num_iter=1000, beta1=0.9, beta2=0.999, epsilon=1e-8, threshold=1e-8):
    """
    X: 数据 x  mxn,可以在传入数据之前进行数据的归一化
    y: 数据 y  mx1
    eta: 学习率
    num_iter: 迭代次数
    beta: 衰减率
    epsilon: 无穷小
    threshold: 阈值
    """
    m, n = X.shape
    theta, mt, vt = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1))  # 初始化数据
    loss_ = []
    
    for t in range(1, num_iter + 1):
        
        # 计算梯度
        h = X.dot(theta)
        err = h - y
        loss_.append(np.mean(err ** 2) / 2)
        g = (1 / m) * X.T.dot(err)
                
        # 一阶矩估计
        mt = beta1 * mt + (1 - beta1) * g
        # 二阶矩估计
        vt = beta2 * vt + (1 - beta2) * g ** 2

        # 先计算偏差修正,后面需要使用到,并且去除负数
        m_hat, v_hat = mt / (1 - pow(beta1, t)), np.maximum(vt / (1 - pow(beta2, t)), 0)

        # 计算预估一阶矩
        m_pre = beta1 * m_hat + (1 - beta1) * g / (1 - pow(beta1, t))
        
        # 更新参数
        theta = theta - np.multiply((eta / (np.sqrt(v_hat) + epsilon)), m_pre)

        # 检查是否收敛
        if t > 1 and abs(loss_[-1] - loss_[-2]) < threshold:
            print(f"Converged at iteration {t}")
            break

    return theta.flatten(), loss_


# 生成一些示例数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]

# 超参数
eta = 0.1

# 运行 Nadam 优化器
theta, loss_ = nadam_optimizer(X_b, y, eta)
print("最优参数 theta:")
print(theta)

# 绘制损失函数图像
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.legend()  # 显示图例
plt.grid(True)  # 显示网格线
plt.show()

1738389513_xunsrs0jxa.png1738389512232.png

三、优缺点

3.1 优点

自适应学习率:NAdam 继承了 Adam 的自适应学习率特性,能够根据梯度的一阶矩(均值)和二阶矩(方差)动态调整每个参数的学习率。这使得 NAdam 在处理不同尺度的参数时更加高效,尤其适合稀疏梯度问题。

Nesterov 动量:NAdam 引入了 Nesterov 动量,能够在更新参数时先根据当前动量预测参数的未来位置,再计算梯度。这种“前瞻性”的更新方式使得 NAdam 能够更准确地调整参数,从而加速收敛。

快速收敛:由于结合了 Adam 的自适应学习率和 Nesterov 动量的前瞻性更新,NAdam 在大多数优化问题中能够比 Adam 和传统梯度下降法更快地收敛。特别是在非凸优化问题中,NAdam 的表现通常优于其他优化算法。

鲁棒性:NAdam 对超参数的选择相对鲁棒,尤其是在学习率和动量参数的选择上。这使得 NAdam 在实际应用中更容易调参。

适合大规模数据:NAdam 能够高效处理大规模数据集和高维参数空间,适合深度学习中的大规模优化问题。

3.2 缺点

计算复杂度较高:由于 NAdam 需要同时维护一阶矩和二阶矩估计,并计算 Nesterov 动量,其计算复杂度略高于传统的梯度下降法。虽然现代深度学习框架(如 PyTorch、TensorFlow)已经对 NAdam 进行了高效实现,但在某些资源受限的场景下,计算开销仍然是一个问题。

对初始学习率敏感:尽管 NAdam 对超参数的选择相对鲁棒,但初始学习率的选择仍然对性能有较大影响。如果初始学习率设置不当,可能会导致收敛速度变慢或无法收敛。

可能陷入局部最优:在某些复杂的非凸优化问题中,NAdam 可能会陷入局部最优解,尤其是在损失函数存在大量鞍点或平坦区域时。

内存占用较高:NAdam 需要存储一阶矩和二阶矩估计,这会增加内存占用。对于非常大的模型(如 GPT-3 等),内存占用可能成为一个瓶颈。

理论分析较少:相比于 Adam 和传统的梯度下降法,NAdam 的理论分析相对较少。虽然实验结果表明 NAdam 在大多数任务中表现优异,但其理论性质仍需进一步研究。

四、相关链接

Python 梯度下降法合集:

  • Python 梯度下降法(一):Gradient Descent-CSDN博客
  • Python 梯度下降法(二):RMSProp Optimize-CSDN博客
  • Python 梯度下降法(三):Adagrad Optimize-CSDN博客
  • Python 梯度下降法(四):Adadelta Optimize-CSDN博客
  • Python 梯度下降法(五):Adam Optimize-CSDN博客
  • Python 梯度下降法(六):Nadam Optimize-CSDN博客
  • Python 梯度下降法(七):Summary-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965190.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

《Kettle保姆级教学-界面介绍》

目录 一、Kettle介绍二、界面介绍1.界面构成2、菜单栏详细介绍2.1 【文件F】2.2 【编辑】2.3 【视图】2.4 【执行】2.5 【工具】2.6 【帮助】 3、转换界面介绍4、作业界面介绍5、执行结果 一、Kettle介绍 Kettle 是一个开源的 ETL&#xff08;Extract, Transform, Load&#x…

Spring Boot篇

为什么要用Spring Boot Spring Boot 优点非常多&#xff0c;如&#xff1a; 独立运行 Spring Boot 而且内嵌了各种 servlet 容器&#xff0c;Tomcat、Jetty 等&#xff0c;现在不再需要打成 war 包部署到 容器 中&#xff0c;Spring Boot 只要打成一个可执行的 jar 包就能独…

C# 中记录(Record)详解

从C#9.0开始&#xff0c;我们有了一个有趣的语法糖&#xff1a;记录(record)   为什么提供记录&#xff1f; 开发过程中&#xff0c;我们往往会创建一些简单的实体&#xff0c;它们仅仅拥有一些简单的属性&#xff0c;可能还有几个简单的方法&#xff0c;比如DTO等等&#xf…

Page Assist - 本地Deepseek模型 Web UI 的安装和使用

Page Assist Page Assist是一个开源的Chrome扩展程序&#xff0c;为本地AI模型提供一个直观的交互界面。通过它可以在任何网页上打开侧边栏或Web UI&#xff0c;与自己的AI模型进行对话&#xff0c;获取智能辅助。这种设计不仅方便了用户随时调用AI的能力&#xff0c;还保护了…

UE求职Demo开发日志#21 背包-仓库-装备栏移动物品

1 创建一个枚举记录来源位置 UENUM(BlueprintType) enum class EMyItemLocation : uint8 {None0,Bag UMETA(DisplayName "Bag"),Armed UMETA(DisplayName "Armed"),WareHouse UMETA(DisplayName "WareHouse"), }; 2 创建一个BagPad和WarePa…

Django框架丨从零开始的Django入门学习

Django 是一个用于构建 Web 应用程序的高级 Python Web 框架&#xff0c;Django是一个高度模块化的框架&#xff0c;使用 Django&#xff0c;只要很少的代码&#xff0c;Python 的程序开发人员就可以轻松地完成一个正式网站所需要的大部分内容&#xff0c;并进一步开发出全功能…

企业四要素如何用PHP进行调用

一、什么是企业四要素&#xff1f; 企业四要素接口是在企业三要素&#xff08;企业名称、统一社会信用代码、法定代表人姓名&#xff09;的基础上&#xff0c;增加了一个关键要素&#xff0c;通常是企业注册号或企业银行账户信息。这种接口主要用于更全面的企业信息验证&#x…

JVM监控和管理工具

基础故障处理工具 jps jps(JVM Process Status Tool)&#xff1a;Java虚拟机进程状态工具 功能 1&#xff1a;列出正在运行的虚拟机进程 2&#xff1a;显示虚拟机执行主类(main()方法所在的类) 3&#xff1a;显示进程ID(PID&#xff0c;Process Identifier) 命令格式 jps […

Java 大视界 -- Java 大数据在智慧文旅中的应用与体验优化(74)

&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎来到 青云交的博客&#xff01;能与诸位在此相逢&#xff0c;我倍感荣幸。在这飞速更迭的时代&#xff0c;我们都渴望一方心灵净土&#xff0c;而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识&#xff0c;也…

ASP.NET Core中间件Markdown转换器

目录 需求 文本编码检测 Markdown→HTML 注意 实现 需求 Markdown是一种文本格式&#xff1b;不被浏览器支持&#xff1b;编写一个在服务器端把Markdown转换为HTML的中间件。我们开发的中间件是构建在ASP.NET Core内置的StaticFiles中间件之上&#xff0c;并且在它之前运…

2025游戏行业的趋势预测

一、市场现状 从总产值的角度来看&#xff0c;游戏总产值的增长率已经放缓&#xff0c;由增量市场转化为存量市场&#xff0c;整体的竞争强度将会加大&#xff0c;技术水平不强&#xff08;开发技术弱、产品品质低、开发效率低&#xff09;的公司将会面临更大的生存的困难。 从…

C++的 I/O 流

本文把复杂的基类和派生类的作用和关系捋出来&#xff0c;具体的接口请参考相关文档 C的 I/O 流相关的类&#xff0c;继承关系如下图所示 https://zh.cppreference.com/w/cpp/io I / O 的概念&#xff1a;内存和外设进行数据交互称为 I / O &#xff0c;例如&#xff1a;把数…

在https下引用IC卡读卡器web插件

HTTPS &#xff08;全称&#xff1a;Hypertext Transfer Protocol Secure &#xff09;&#xff0c;是以安全为目标的 HTTP 通道&#xff0c;在HTTP的基础上通过传输加密和身份认证保证了传输过程的安全性 。HTTPS 在HTTP 的基础下加入SSL&#xff0c;HTTPS 的安全基础是 SSL&a…

堆的实现——堆的应用(堆排序)

文章目录 1.堆的实现2.堆的应用--堆排序 大家在学堆的时候&#xff0c;需要有二叉树的基础知识&#xff0c;大家可以看我的二叉树文章&#xff1a;二叉树 1.堆的实现 如果有⼀个关键码的集合 K {k0 , k1 , k2 , …&#xff0c;kn−1 } &#xff0c;把它的所有元素按完全⼆叉树…

基于单片机的智能安全插座(论文+源码)

1 系统整体方案设计 本课题基于单片机的智能安全插座设计&#xff0c;以STM32嵌入式单片机为主体&#xff0c;将计算机技术和检测技术有机结合&#xff0c;设计一款电量参数采集装置&#xff0c;实现电压、电流信号的数据采集任务&#xff0c;电压、电流和功率在上位机的显示任…

【网络】3.HTTP(讲解HTTP协议和写HTTP服务)

目录 1 认识URL1.1 URI的格式 2 HTTP协议2.1 请求报文2.2 响应报文 3 模拟HTTP3.1 Socket.hpp3.2 HttpServer.hpp3.2.1 start()3.2.2 ThreadRun()3.2.3 HandlerHttp&#xff08;&#xff09; 总结 1 认识URL 什么是URI&#xff1f; URI 是 Uniform Resource Identifier的缩写&…

导入了fastjson2的依赖,但却无法使用相关API的解决方案

今天遇到了一个特别奇怪的问题&#xff0c;跟着视频敲代码&#xff0c;视频中用到了一个将JSON字符串转为对象的 API&#xff0c;需要引入alibaba的fastjson2相关依赖&#xff0c;我引入的依赖跟视频一样。 <!--视频中给的相关依赖 --> <dependency><groupId&g…

DeepSeek R1 简单指南:架构、训练、本地部署和硬件要求

DeepSeek 的 LLM 推理新方法 DeepSeek 推出了一种创新方法&#xff0c;通过强化学习 (RL) 来提高大型语言模型 (LLM) 的推理能力&#xff0c;其最新论文 DeepSeek-R1 对此进行了详细介绍。这项研究代表了我们如何通过纯强化学习来增强 LLM 解决复杂问题的能力&#xff0c;而无…

Vue Dom截图插件,截图转Base64 html2canvas

安装插件 npm install html2canvas --save插件使用 <template><div style"padding: 10px;"><div ref"imageTofile" class"box">发生什么事了</div><button click"toImage" style"margin: 10px;&quo…

Flink2支持提交StreamGraph到Flink集群

最近研究Flink源码的时候&#xff0c;发现Flink已经支持提交StreamGraph到集群了&#xff0c;替换掉了原来的提交JobGraph。 新增ExecutionPlan接口&#xff0c;将JobGraph和StreamGraph作为实现。 Flink集群Dispatcher也进行了修改&#xff0c;从JobGraph改成了接口Executio…