Python 梯度下降法(五):Adam Optimize

文章目录

  • Python 梯度下降法(五):Adam Optimize
    • 一、数学原理
      • 1.1 介绍
      • 1.2 符号说明
      • 1.3 实现流程
    • 二、代码实现
      • 2.1 函数代码
      • 2.2 总代码
      • 2.3 遇到的问题
      • 2.4 算法优化
    • 三、优缺点
      • 3.1 优点
      • 3.2 缺点

Python 梯度下降法(五):Adam Optimize

相关链接:
Python 梯度下降法(一):Gradient Descent-CSDN博客
Python 梯度下降法(二):RMSProp Optimize-CSDN博客
Python 梯度下降法(三):Adagrad Optimize-CSDN博客
Python 梯度下降法(四):Adadelta Optimize-CSDN博客

一、数学原理

1.1 介绍

Adam 算法结合了 Adagrad 和 RMSProp 算法的优点。Adagrad 算法会根据每个参数的历史梯度信息来调整学习率,对于出现频率较低的参数会给予较大的学习率,而对于出现频率较高的参数则给予较小的学习率。RMSProp 算法则是对 Adagrad 算法的改进,它通过使用移动平均的方式来计算梯度的平方,从而避免了 Adagrad 算法中学习率单调下降的问题。

1.2 符号说明

参数意义
g t = ∇ θ J ( θ t ) g_{t}=\nabla_{\theta}J(\theta_{t}) gt=θJ(θt) t t t时刻的梯度
m t m_{t} mt梯度的一阶矩(均值)
β 1 \beta_{1} β1一阶矩衰减率,一般取0.9
v t v_{t} vt梯度的二阶矩(未中心化的方差)
β 2 \beta_{2} β2二阶矩衰减率,一般取0.99
θ \theta θ线性拟合参数
η \eta η学习率
ϵ \epsilon ϵ无穷小量,一般取 1 0 − 8 10^{-8} 108

1.3 实现流程

  1. 初始化: θ \theta θ η \eta η m 0 ⃗ = 0 \vec{m_{0}}=0 m0 =0 v 0 ⃗ = 0 \vec{v_{0}}=0 v0 =0
  2. 计算梯度: g t = ∇ θ J ( θ t ) = 1 m X T L g_{t}=\nabla_{\theta}J(\theta_{t})=\frac{1}{m}X^{T}L gt=θJ(θt)=m1XTL
  3. 梯度的一阶矩估计(均值): m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_{t}=\beta_{1}m_{t-1}+(1-\beta_{1})g_{t} mt=β1mt1+(1β1)gt
  4. 梯度的二阶矩估计(未中心化的方差): v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_{t}=\beta_{2}v_{t-1}+(1-\beta_{2})g_{t}^{2} vt=β2vt1+(1β2)gt2
  5. 偏差修正: m t ^ = m t 1 − β 1 t 、 v t ^ = v t 1 − β 2 t \hat{m_{t}}=\frac{m_{t}}{1-\beta_{1}^{t}}、\hat{v_{t}}=\frac{v_{t}}{1-\beta_{2}^{t}} mt^=1β1tmtvt^=1β2tvt
  6. 更新参数: θ t = θ t − 1 − η m t ^ v t ^ + ϵ \theta_{t}=\theta_{t-1}-\frac{\eta \hat{m_{t}}}{\sqrt{ \hat{v_{t}} }+\epsilon} θt=θt1vt^ +ϵηmt^

二、代码实现

2.1 函数代码

# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):
    """
    X: 数据 x  mxn,可以在传入数据之前进行数据的归一化
    y: 数据 y  mx1
    eta: 学习率
    num_iter: 迭代次数
    beta: 衰减率
    epsilon: 无穷小
    threshold: 阈值
    """
    m, n = X.shape
    theta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据
    for iter in range(num_iter):
        h = X.dot(theta)
        err = h - y
        loss_.append(np.mean((err ** 2) / 2))
        g = (1 / m ) * X.T.dot(err)
        # 一阶矩估计
        mt = beta1 * mt + (1 - beta1) * g
        # 二阶矩估计
        vt = beta2 * vt + (1 - beta2) * g ** 2
        # 偏差修正
        mt_ = mt / (1 - pow(beta1, (iter + 1)))  # 得 + 1 不然在 iter = 0 时,分母为零
        vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))
        # 更新参数
        theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)
    
        # 检查是否收敛
        if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:
            print(f"Converged at iteration {iter + 1}")
            break

    return theta.flatten(), loss_

2.2 总代码

import numpy as np
import matplotlib.pyplot as plt

# 设置 matplotlib 支持中文
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.rcParams['axes.unicode_minus'] = False

# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):
    """
    X: 数据 x  mxn,可以在传入数据之前进行数据的归一化
    y: 数据 y  mx1
    eta: 学习率
    num_iter: 迭代次数
    beta: 衰减率
    epsilon: 无穷小
    threshold: 阈值
    """
    m, n = X.shape
    theta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据
    for iter in range(num_iter):
        h = X.dot(theta)
        err = h - y
        loss_.append(np.mean((err ** 2) / 2))
        g = (1 / m ) * X.T.dot(err)
        # 一阶矩估计
        mt = beta1 * mt + (1 - beta1) * g
        # 二阶矩估计
        vt = beta2 * vt + (1 - beta2) * g ** 2
        # 偏差修正
        mt_ = mt / (1 - pow(beta1, (iter + 1)))  # 得 + 1 不然在 iter = 0 时,分母为零
        vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))
        # 更新参数
        theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)
    
        # 检查是否收敛
        if iter > 1 and abs(loss_[-1] - loss_[-2]) < threshold:
            print(f"Converged at iteration {iter + 1}")
            break

    return theta.flatten(), loss_

# 生成一些示例数据
np.random.seed(42)
X = 2 * np.random.rand(100, 1)
y = 4 + 3 * X + np.random.randn(100, 1)
# 添加偏置项
X_b = np.c_[np.ones((100, 1)), X]

# 超参数
eta = 0.01

# 运行 Adam 优化器
theta, loss_ = adam_optimizer(X_b, y, eta)

print("最优参数 theta:")
print(theta)

# 绘制损失函数图像
plt.plot(range(len(loss_)), loss_, label="损失函数图像")
plt.title("损失函数图像")
plt.xlabel("迭代次数")
plt.ylabel("损失值")
plt.legend()  # 显示图例
plt.grid(True)  # 显示网格线
plt.show()

1738332803_d7btmbrnt5.png1738332802724.png

2.3 遇到的问题

当偏差修正为以下算法时,出现报错:

        # 偏差修正
        mt_ = mt / (1 - pow(beta1, (iter)))
        vt_ = np.abs(vt / (1 - pow(beta2, (iter))))

1738332890_bekam4jjvm.png1738332889461.png

进行检验时,我们发现:

1738333012_ujmv5g46fw.png1738333011494.png

mt_,vt_ \text{mt\_,vt\_} mt_,vt_为无穷量,因此考虑分母为零的情况,而当 iter = 0 \text{iter}=0 iter=0时, 1 − β iter = 0 1- \beta^{\text{iter}}=0 1βiter=0,故说明索引不能从0开始,而应该从1开始,因此引入 iter + 1 \text{iter}+1 iter+1,防止分母的无穷大引入。

2.4 算法优化

由于算法过程中,如果数据量太多会引起资源的严重浪费,因此我们引入小批量梯度下降法的类似方法,批量截取数据来进行拟合。

# 定义 Adam 函数
def adam_optimizer(X, y, eta, num_iter=1000, batch_size=32, beta1=0.8, beta2=0.8, epsilon=1e-8, threshold=1e-8):
    """
    X: 数据 x  mxn,可以在传入数据之前进行数据的归一化
    y: 数据 y  mx1
    eta: 学习率
    num_iter: 迭代次数
    batch_size: 小批量分支法的批量数
    beta: 衰减率
    epsilon: 无穷小
    threshold: 阈值
    """
    m, n = X.shape
    theta, mt, vt, loss_ = np.random.randn(n, 1), np.zeros((n, 1)), np.zeros((n, 1)), []  # 初始化数据
    num_batchs = m // batch_size
    for _ in range(num_iter):
        range_shuffle = np.random.permutation(m)
        X_shuffled = X[range_shuffle]
        y_shuffled = y[range_shuffle]
        loss_temp = []
        for iter in range(num_batchs):
            start_index = batch_size * iter
            end_index = start_index + batch_size
            xi = X_shuffled[start_index:end_index]
            yi = y_shuffled[start_index:end_index]
            h = xi.dot(theta)
            err = h - yi
            loss_temp.append(np.mean((err ** 2) / 2))
            g = (1 / m ) * xi.T.dot(err)
            # 一阶矩估计
            mt = beta1 * mt + (1 - beta1) * g
            # 二阶矩估计
            vt = beta2 * vt + (1 - beta2) * g ** 2
            # 偏差修正
            mt_ = mt / (1 - pow(beta1, (iter + 1)))
            vt_ = np.abs(vt / (1 - pow(beta2, (iter + 1))))
            # 更新参数
            theta = theta - (eta * mt_) / (np.sqrt(vt_) + epsilon)

        loss_.append(np.mean(loss_temp))
        # 检查是否收敛
        if _ > 1 and abs(loss_[-1] - loss_[-2]) < threshold:
            print(f"Converged at iteration {iter + 1}")
            break

    return theta.flatten(), loss_

1738333762_rdxih0p4h8.png1738333761148.png

使用小批量进行Adam优化,可以大大节省系统的资源。

三、优缺点

3.1 优点

对不同参数调整学习率:Adam 能够为模型的每个参数自适应地调整学习率。它会根据参数的梯度历史信息,对出现频率较低的参数给予较大的学习率,对出现频率较高的参数给予较小的学习率。这使得模型在训练过程中能够更好地处理不同尺度和变化频率的参数,加速收敛过程。

无需手动精细调整:在很多情况下,Adam 算法提供的默认超参数就能取得不错的效果,不需要像传统优化算法那样进行大量的手动调参,节省了时间和精力。

低内存需求:Adam 只需要存储梯度的一阶矩估计(均值)和二阶矩估计(未中心化的方差),不需要像一些二阶优化方法那样存储复杂的海森矩阵(Hessian matrix),因此内存占用相对较小,适合处理大规模数据集和深度神经网络。

快速收敛:通过结合梯度的一阶矩和二阶矩信息,Adam 能够更准确地估计梯度的方向和大小,从而在大多数情况下比传统的随机梯度下降(SGD)算法更快地收敛到最优解。

利用稀疏信息:在处理稀疏数据(如自然语言处理中的词向量)时,Adam 能够根据数据的稀疏性调整学习率。对于那些很少出现的特征,算法会给予较大的学习率,使得模型能够更有效地学习这些特征,避免因数据稀疏而导致的学习困难

偏差修正机制:Adam 算法引入了偏差修正机制,用于修正一阶矩和二阶矩估计在训练初期的偏差。这使得算法在训练的早期阶段更加稳定,能够避免因初始估计不准确而导致的训练波动或不收敛问题。

3.2 缺点

自适应特性的局限性:虽然 Adam 能够自适应地调整学习率,但在某些情况下,这种自适应特性可能会导致算法陷入局部最优解。由于学习率会随着训练过程自动调整,可能会在接近局部最优解时过早地降低学习率,使得算法难以跳出局部最优区域,从而无法找到全局最优解。

需要一定的调参经验:尽管 Adam 提供了默认的超参数,但在某些复杂的任务或数据集上,这些默认参数可能不是最优的。例如, β \beta β ϵ \epsilon ϵ的取值会影响算法的性能,如果选择不当,可能会导致收敛速度变慢、模型性能下降等问题。因此,在实际应用中,可能仍然需要进行一定的超参数调优。

过度适应训练数据:由于 Adam 算法在训练过程中过于关注梯度的历史信息和自适应调整学习率,可能会导致模型过度适应训练数据,从而降低模型的泛化能力。在某些情况下,使用 Adam 训练的模型在测试集上的表现可能不如使用其他优化算法训练的模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962355.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

labelme_json_to_dataset ValueError: path is on mount ‘D:‘,start on C

这是你的labelme运行时label照片的盘和保存目的地址的盘不同都值得报错 labelme_json_to_dataset ValueError: path is on mount D:,start on C 只需要放一个盘但可以不放一个目录

中间件安全

一.中间件概述 1.中间件定义 介绍&#xff1a;中间件&#xff08;Middleware&#xff09;作为一种软件组件&#xff0c;在不同系统、应用程序或服务间扮演着数据与消息传递的关键角色。它常处于应用程序和操作系统之间&#xff0c;就像一座桥梁&#xff0c;负责不同应用程序间…

玩转大语言模型——配置图数据库Neo4j(含apoc插件)并导入GraphRAG生成的知识图谱

系列文章目录 玩转大语言模型——使用langchain和Ollama本地部署大语言模型 玩转大语言模型——ollama导入huggingface下载的模型 玩转大语言模型——langchain调用ollama视觉多模态语言模型 玩转大语言模型——使用GraphRAGOllama构建知识图谱 玩转大语言模型——完美解决Gra…

sizeof和strlen的对比与一些杂记

1.sizeof和strlen的对比 1.1sizeof &#xff08;1&#xff09;sizeof是一种操作符 &#xff08;2&#xff09;sizeof计算的是类型或变量所占空间的大小&#xff0c;单位是字节 注意事项&#xff1a; &#xff08;1&#xff09;sizeof 返回的值类型是 size_t&#xff0c;这是一…

书生大模型实战营6

文章目录 L1——基础岛玩转书生「多模态对话」与「AI搜索」产品MindSearch 开源的 AI 搜索引擎书生浦语 InternLM 开源模型官方的对话类产品书生万象 InternVL 开源的视觉语言模型官方的对话产品在知乎上的提交 L1——基础岛 玩转书生「多模态对话」与「AI搜索」产品 MindSea…

three.js+WebGL踩坑经验合集(6.1):负缩放,负定矩阵和行列式的关系(2D版本)

春节忙完一轮&#xff0c;总算可以继续来写博客了。希望在春节假期结束之前能多更新几篇。 这一篇会偏理论多一点。笔者本没打算在这一系列里面重点讲理论&#xff0c;所以像相机矩阵推导这种网上已经很多优质文章的内容&#xff0c;笔者就一笔带过。 然而关于负缩放&#xf…

Baklib解析内容中台与人工智能技术带来的价值与机遇

内容概要 在数字化转型的浪潮中&#xff0c;内容中台与人工智能技术的结合为企业提供了前所未有的发展机遇。内容中台作为一种新的内容管理和生产模式&#xff0c;通过统一管理和协调各种内容资源&#xff0c;帮助企业更高效地整合内外部数据。而人工智能技术则以其强大的数据…

Learning Vue 读书笔记 Chapter 4

4.1 Vue中的嵌套组件和数据流 我们将嵌套的组件称为子组件&#xff0c;而包含它们的组件则称为它们的父组件。 父组件可以通过 props向子组件传递数据&#xff0c;而子组件则可以通过自定义事件&#xff08;emits&#xff09;向父组件发送事件。 4.1.1 使用Props向子组件传递…

小程序电商运营内容真实性增强策略及开源链动2+1模式AI智能名片S2B2C商城系统源码的应用探索

摘要&#xff1a;随着互联网技术的不断发展&#xff0c;小程序电商已成为现代商业的重要组成部分。然而&#xff0c;如何在竞争激烈的市场中增强小程序内容的真实性&#xff0c;提高用户信任度&#xff0c;成为电商运营者面临的一大挑战。本文首先探讨了通过图片、视频等方式增…

【游戏设计原理】96 - 成就感

成就感是玩家体验的核心&#xff0c;它来自完成一件让自己满意的任务&#xff0c;而这种任务通常需要一定的努力和挑战。游戏设计师的目标是通过合理设计任务&#xff0c;不断为玩家提供成就感&#xff0c;保持他们的参与热情。 ARCS行为模式&#xff08;注意力、关联性、自信…

Linux系统上安装与配置 MySQL( CentOS 7 )

目录 1. 下载并安装 MySQL 官方 Yum Repository 2. 启动 MySQL 并查看运行状态 3. 找到 root 用户的初始密码 4. 修改 root 用户密码 5. 设置允许远程登录 6. 在云服务器配置 MySQL 端口 7. 关闭防火墙 8. 解决密码错误的问题 前言 在 Linux 服务器上安装并配置 MySQL …

读书笔记-《Redis设计与实现》(一)数据结构与对象(下)

各位朋友新年快乐~ 今天我们来继续学习 Redis 。 01 整数集合 当集合仅包含整数值&#xff0c;并且元素数量不多时&#xff0c;Redis 就会采用整数集合来作为集合键的底层实现。 typedef struct intset {// 编码方式uint32_t encoding;// 元素数量uint32_t length;// 数组in…

IP服务模型

1. IP数据报 IP数据报中除了包含需要传输的数据外&#xff0c;还包括目标终端的IP地址和发送终端的IP地址。 数据报通过网络从一台路由器跳到另一台路由器&#xff0c;一路从IP源地址传递到IP目标地址。每个路由器都包含一个转发表&#xff0c;该表告诉它在匹配到特定目标地址…

上海亚商投顾:沪指冲高回落 大金融板块全天强势 上海亚商投

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 一&#xff0e;市场情绪 市场全天冲高回落&#xff0c;深成指、创业板指午后翻绿。大金融板块全天强势&#xff0c;天茂集团…

数据分析系列--④RapidMiner进行关联分析(案例)

一、核心概念 1.项集&#xff08;Itemset&#xff09; 2.规则&#xff08;Rule&#xff09; 3.支持度&#xff08;Support&#xff09; 3.1 支持度的定义 3.2 支持度的意义 3.3 支持度的应用 3.4 支持度的示例 3.5 支持度的调整 3.6 支持度与其他指标的关系 4.置信度&#xff0…

HTB靶场Adminstrator

文章目录 靶机信息域环境初步信息收集与权限验证FTP 登录尝试SMB 枚举尝试WinRM 登录olivia域用户枚举 获取Michael权限BloodHound 提取域信息GenericAll 获取Benjamin权限ForceChangePasswordftp登录benjamin 获取Emily权限pwsafehashcat 获取Ethan权限获取管理员(Administrat…

C语言指针专题三 -- 指针数组

目录 1. 指针数组的核心原理 2. 指针数组与二维数组的区别 3. 编程实例 4. 常见陷阱与防御 5. 总结 1. 指针数组的核心原理 指针数组是一种特殊数组&#xff0c;其所有元素均为指针类型。每个元素存储一个内存地址&#xff0c;可指向不同类型的数据&#xff08;通常指向同…

Spring Boot - 数据库集成06 - 集成ElasticSearch

Spring boot 集成 ElasticSearch 文章目录 Spring boot 集成 ElasticSearch一&#xff1a;前置工作1&#xff1a;项目搭建和依赖导入2&#xff1a;客户端连接相关构建3&#xff1a;实体类相关注解配置说明 二&#xff1a;客户端client相关操作说明1&#xff1a;检索流程1.1&…

深入理解MySQL 的 索引

索引是一种用来快速检索数据的一种结构, 索引使用的好不好关系到对应的数据库性能方面, 这篇文章我们就来详细的介绍一下数据库的索引。 1. 页面的大小: B 树索引是一种 Key-Value 结构&#xff0c;通过 Key 可以快速查找到对应的 Value。B 树索引由根页面&#xff08;Root&am…

vue之pinia组件的使用

1、搭建pinia环境 cnpm i pinia #安装pinia的组件 cnpm i nanoid #唯一id&#xff0c;相当于uuid cnpm install axios #网络请求组件 2、存储读取数据 存储数据 >> Count.ts文件import {defineStore} from piniaexport const useCountStore defineStore(count,{// a…