梯度下降的数学原理:用泰勒公式剖析梯度下降

梯度下降(Gradient Descent)是机器学习中非常核心的优化算法,通过不断调整模型参数,让损失函数(Loss Function)逐渐变小,从而提高模型的性能。损失函数是一个用来衡量预测值与真实值差距的函数,因此最小化损失函数是训练机器学习模型的目标。这篇文章用一个更数学化的视角来剖析梯度下降的原理,特别是通过泰勒公式来解释为什么沿着梯度反方向调整参数能够有效减少损失函数的值。

1. 损失函数与目标

1.1 什么是损失函数?

损失函数 L ( θ ) L(\theta) L(θ) 是模型参数 θ \theta θ 的函数,用于量化模型预测值 y ^ \hat{y} y^ 和真实值 y y y 的差距。以下是两种常见的损失函数:

  • 均方误差(MSE):
    L ( θ ) = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 L(\theta) = \frac{1}{n} \sum_{i=1}^n (y_i - \hat{y}_i)^2 L(θ)=n1i=1n(yiy^i)2 这里, y i y_i yi 是真实值, y ^ i \hat{y}_i y^i 是由模型参数 θ \theta θ 计算得到的预测值。
  • 交叉熵损失(Cross-Entropy Loss):
    L ( θ ) = − 1 n ∑ i = 1 n [ y i log ⁡ y ^ i + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] L(\theta) = -\frac{1}{n} \sum_{i=1}^n \left[ y_i \log \hat{y}_i + (1 - y_i) \log (1 - \hat{y}_i) \right] L(θ)=n1i=1n[yilogy^i+(1yi)log(1y^i)]
    这个损失函数常用于分类任务,特别是二分类问题。

1.2 目标

在机器学习的训练过程中,我们希望找到一组最优参数 θ \theta θ,让损失函数 L ( θ ) L(\theta) L(θ) 达到最小值:
min ⁡ θ L ( θ ) \min_\theta L(\theta) θminL(θ)
直观上,这意味着我们希望模型的预测尽可能接近真实值。

2. 用泰勒公式分析损失函数的变化

为了理解梯度下降背后的数学原理,我们需要借助泰勒公式,来研究参数 θ \theta θ 的调整如何影响损失函数 L ( θ ) L(\theta) L(θ) 的值。

2.1 一阶泰勒展开

假设当前模型参数是 θ t \theta_t θt,损失函数在这一点的值为 L ( θ t ) L(\theta_t) L(θt)。如果将参数从 θ t \theta_t θt 调整为 θ t + 1 \theta_{t+1} θt+1,损失函数的值可以用一阶泰勒展开近似:
L ( θ t + 1 ) ≈ L ( θ t ) + ∇ L ( θ t ) T ( θ t + 1 − θ t ) L(\theta_{t+1}) \approx L(\theta_t) + \nabla L(\theta_t)^T (\theta_{t+1} - \theta_t) L(θt+1)L(θt)+L(θt)T(θt+1θt)

  • L ( θ t ) L(\theta_t) L(θt):当前参数点的损失函数值。
  • ∇ L ( θ t ) \nabla L(\theta_t) L(θt):损失函数在 θ t \theta_t θt 处的梯度,表示损失函数在每个参数维度上的变化率。
  • θ t + 1 − θ t \theta_{t+1} - \theta_t θt+1θt:参数调整的变化量。

这个公式告诉我们,损失函数的变化主要由两部分决定:

  1. 当前的损失函数值 L ( θ t ) L(\theta_t) L(θt)
  2. 参数变化方向和大小对损失函数的贡献,即 ∇ L ( θ t ) T ( θ t + 1 − θ t ) \nabla L(\theta_t)^T (\theta_{t+1} - \theta_t) L(θt)T(θt+1θt)

2.2 梯度的含义梯度

∇ L ( θ t ) \nabla L(\theta_t) L(θt) 是一个向量,包含损失函数对每个参数的偏导数:
∇ L ( θ t ) = ( ∂ L ∂ θ 1 , ∂ L ∂ θ 2 , … , ∂ L ∂ θ n ) \nabla L(\theta_t) = \left( \frac{\partial L}{\partial \theta_1}, \frac{\partial L}{\partial \theta_2}, \dots, \frac{\partial L}{\partial \theta_n} \right) L(θt)=(θ1L,θ2L,,θnL)
几何上,梯度指向损失函数 L ( θ ) L(\theta) L(θ) 增加最快的方向。也就是说,如果我们沿着梯度的方向调整参数,损失函数的值会迅速增大。

3. 梯度下降:让损失函数减小

我们关心的是让损失函数逐步减小,因此需要反其道而行之,沿着梯度的反方向调整参数。

3.1 参数更新公式在梯度下降中,参数的更新公式是:

θ t + 1 = θ t − α ∇ L ( θ t ) \theta_{t+1} = \theta_t - \alpha \nabla L(\theta_t) θt+1=θtαL(θt)

  • α \alpha α 是学习率(Learning Rate),控制参数调整的步长。
  • − ∇ L ( θ t ) -\nabla L(\theta_t) L(θt) 是梯度的反方向,表示损失函数下降最快的方向。这个公式的核心思想很简单:在每一步迭代中,我们根据当前的梯度信息,沿着损失函数下降最快的方向移动一小步,从而逐步逼近损失函数的最小值。

3.2 损失函数的变化将参数更新公式代入泰勒展开式,我们可以进一步研究损失函数的变化:

L ( θ t + 1 ) ≈ L ( θ t ) + ∇ L ( θ t ) T ( − α ∇ L ( θ t ) ) L(\theta_{t+1}) \approx L(\theta_t) + \nabla L(\theta_t)^T (-\alpha \nabla L(\theta_t)) L(θt+1)L(θt)+L(θt)T(αL(θt))
化简后得到:
L ( θ t + 1 ) ≈ L ( θ t ) − α ∇ L ( θ t ) 2 L(\theta_{t+1}) \approx L(\theta_t) - \alpha \nabla L(\theta_t)^2 L(θt+1)L(θt)αL(θt)2

  • ∇ L ( θ t ) 2 \nabla L(\theta_t)^2 L(θt)2 是梯度的范数平方,表示梯度的大小。
  • 因为 ∇ L ( θ t ) 2 > 0 \nabla L(\theta_t)^2 > 0 L(θt)2>0 α > 0 \alpha > 0 α>0,可以确定 L ( θ t + 1 ) < L ( θ t ) L(\theta_{t+1}) < L(\theta_t) L(θt+1)<L(θt)
    这说明,每次更新参数后,损失函数的值都会减小。

4. 梯度下降的执行流程

基于上面的分析,梯度下降法的执行流程可以总结为以下几个步骤:

  1. 初始化参数: 随机初始化模型参数 θ \theta θ(或根据经验设置初值)。
  2. 计算梯度: 在当前参数点 θ t \theta_t θt,计算损失函数的梯度 ∇ L ( θ t ) \nabla L(\theta_t) L(θt)
  3. 更新参数: 根据梯度下降公式调整参数: θ t + 1 = θ t − α ∇ L ( θ t ) \theta_{t+1} = \theta_t - \alpha \nabla L(\theta_t) θt+1=θtαL(θt)
  4. 判断收敛条件: 如果梯度的范数 ∇ L ( θ t ) \nabla L(\theta_t) L(θt) 足够小,或者损失函数的变化量小于设定阈值,则停止迭代;否则回到第 2 步。

5. 梯度下降的核心思想

通过这个分析,我们可以直观理解梯度下降法的核心思想:

  1. 梯度的方向性: 梯度 ∇ L ( θ ) \nabla L(\theta) L(θ) 指向损失函数 L ( θ ) L(\theta) L(θ) 增加最快的方向,沿着反方向调整参数可以快速减少损失。
  2. 学习率的重要性: 学习率 α \alpha α 控制了每次调整的步长,步长过大可能导致不稳定,过小则会收敛缓慢。
  3. 迭代收敛: 通过逐步调整参数,梯度下降法让损失函数值逐步减小,最终逼近最优解。
  4. 为什么必须减去梯度的值? 可以理解为是泰勒公式本身的要求,因为展开的第二项就带有一个一阶导数

在最后会详细解释


梯度下降法作为机器学习中最基础的优化算法之一,其数学原理非常直观,但效果却极为强大。通过梯度下降,我们可以高效地找到损失函数的最优解,从而训练出性能优秀的机器学习模型。


详细解释:为什么要让 θ \theta θ 减去 α \alpha α 倍梯度?

从泰勒展开到梯度下降:为什么要让 θ \theta θ 减去 α \alpha α 倍梯度?

在机器学习的优化过程中,我们常常使用梯度下降法(Gradient Descent)来最小化损失函数 L ( θ ) L(\theta) L(θ)。下面通过泰勒展开的思路,解释为什么选择
θ ← θ − α ∇ L ( θ ) \theta\leftarrow \theta - \alpha \nabla L(\theta) θθαL(θ)
能够保证损失函数持续下降。

我们先从泰勒公式出发,通过分析损失函数在当前参数附近的局部线性变化,发现如果让参数在梯度反方向移动,就能够在每一步迭代里最大程度地让损失函数减小。然后,才自然地得到 θ \theta θ的更新规则是「减去 α \alpha α倍的梯度」。


1. 从泰勒公式出发

在参数 θ \theta θ 附近,对损失函数 L ( θ ) L(\theta) L(θ) 做一阶泰勒展开:
L ( θ + Δ θ ) ≈ L ( θ ) + ∇ L ( θ ) T Δ θ . L(\theta + \Delta\theta) \approx L(\theta) + \nabla L(\theta)^T \Delta\theta. L(θ+Δθ)L(θ)+L(θ)TΔθ.

  • ∇ L ( θ ) \nabla L(\theta) L(θ) 是损失函数在 θ \theta θ 处的梯度向量。
  • Δ θ \Delta\theta Δθ 是参数改变的量。我们的目标是让 L ( θ + Δ θ ) < L ( θ ) L(\theta + \Delta\theta) < L(\theta) L(θ+Δθ)<L(θ),因此需要
    ∇ L ( θ ) T   Δ θ < 0. \nabla L(\theta)^T \,\Delta\theta < 0. L(θ)TΔθ<0.

2. 梯度反方向:最快下降的方向

  1. 梯度方向: ∇ L ( θ ) \nabla L(\theta) L(θ) 指向 L ( θ ) L(\theta) L(θ) 增长(上升)最快的方向。
  2. 反梯度方向: − ∇ L ( θ ) -\nabla L(\theta) L(θ) 则是 L ( θ ) L(\theta) L(θ) 减少(下降)最快的方向。如果令
    Δ θ = − α   ∇ L ( θ ) , \Delta\theta = -\alpha \,\nabla L(\theta), Δθ=αL(θ),
    那么
    ∇ L ( θ ) T   Δ θ = ∇ L ( θ ) T ( − α ∇ L ( θ ) ) = −   α   ∇ L ( θ ) 2 , \nabla L(\theta)^T \,\Delta\theta = \nabla L(\theta)^T \bigl(-\alpha \nabla L(\theta)\bigr)= -\,\alpha\,\nabla L(\theta)^2, L(θ)TΔθ=L(θ)T(αL(θ))=αL(θ)2,
    由于 ∇ L ( θ ) 2 > 0 \nabla L(\theta)^2 > 0 L(θ)2>0 α > 0 \alpha > 0 α>0,所以该结果为负值,即
    L ( θ + Δ θ ) ≈ L ( θ ) − α   ∇ L ( θ ) 2 < L ( θ ) . L(\theta + \Delta\theta) \approx L(\theta) - \alpha \,\nabla L(\theta)^2 < L(\theta). L(θ+Δθ)L(θ)αL(θ)2<L(θ).
    这表示我们在当前点处让损失函数确实向下移动了一步。

3. 推导更新公式

为了让损失函数一步步地下降,我们可以把上述“在反梯度方向移动”写成一个迭代式子:
Δ θ = − α   ∇ L ( θ ) , \Delta\theta = -\alpha \,\nabla L(\theta), Δθ=αL(θ),
Δ θ = θ t + 1 − θ t \Delta\theta = \theta_{t+1} - \theta_t Δθ=θt+1θt
从而有:
θ t + 1 = θ t − α   ∇ L ( θ t ) . \theta_{t+1} = \theta_t - \alpha \,\nabla L(\theta_t). θt+1=θtαL(θt).

  • θ t \theta_t θt 表示在第 t t t 步的参数值;
  • α \alpha α 是学习率(Learning Rate),用于控制移动的步长。在每一步迭代中,我们根据当前参数 θ t \theta_t θt 的梯度 ∇ L ( θ t ) \nabla L(\theta_t) L(θt),沿着反梯度方向走一小步,使 L ( θ ) L(\theta) L(θ) 得到下降。

4. 结论

  1. 先有泰勒展开的分析:
    通过局部线性化,明确看到 − ∇ L ( θ ) -\nabla L(\theta) L(θ) 能带来最大的负变化,让损失快速下降。
  2. 再有梯度下降的更新公式:
    使用 θ ← θ − α ∇ L ( θ ) \theta \leftarrow \theta - \alpha \nabla L(\theta) θθαL(θ),逐步向使损失函数减小的方向移动,直到达到最优或近似最优解。因此,正是因为在泰勒公式的分析中我们发现“减去 α \alpha α 倍的梯度”能够让损失在每一步减少最多,才推导出参数更新规则中要让 θ \theta θ 减去 α \alpha α 倍的梯度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/940525.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java性能调优 - JVM性能监测及调优

JVM 内存模型概述 堆 堆是JVM内存中最大的一块内存空间&#xff0c;该内存被所有线程共享&#xff0c;几乎所有对象和数组都被分配到了堆内存中。堆被划分为新生代和老年代&#xff0c;新生代又被进一步划分为Eden和Survivor区&#xff0c;最后Survivor由From Survivor和To Su…

RK3588 , mpp硬编码yuv, 保存MP4视频文件.

RK3588 , mpp硬编码yuv, 保存MP4视频文件. ⚡️ 传送 ➡️ Ubuntu x64 架构, 交叉编译aarch64 FFmpeg mppRK3588, FFmpeg 拉流 RTSP, mpp 硬解码转RGBRk3588 FFmpeg 拉流 RTSP, 硬解码转RGBRK3588 , mpp硬编码yuv, 保存MP4视频文件.

【计算机网络2】计算机网络的性能能指标

目录 一 、计算机网络的性能指标 二、具体介绍 1、速 率 2、带 宽 3、吞 吐 量 4、时 延 5、时延带宽积 6、往 返 时 延 7、信道利用率 一 、计算机网络的性能指标 计算机网络的性能指标就是从不同方面度量计算机网络的性能&#xff0c;有如下7个指标&#xff1a; 速…

OpenAI 12天发布会(12 Days of OpenAI)总结

在OpenAI的“12 Days of OpenAI”活动中&#xff0c;每一天都会发布新的功能或技术&#xff0c;展示公司在AI领域的最新进展。首先展示下全部功能发布完成后&#xff0c;现在ChatGPT的界面&#xff1a; 以下是每一天的简要概述及其意义&#xff1a; 第1天 - 完整版O1模型 今天…

android:sharedUserId 应用进程声明介绍

背景 adb install 安装系统软件报错,原因是签名不一致,进程改变。 代码分析 AndroidManifest.xml 定义的 android:sharedUserId 应用归属进程不同,从phone切换到system。 初始配置 <manifest xmlns:android="http://schemas.android.com/apk/res/android"c…

Spark优化----Spark 性能调优

目录 常规性能调优 常规性能调优一&#xff1a;最优资源配置 常规性能调优二&#xff1a;RDD 优化 RDD 复用 RDD 持久化 RDD 尽可能早的 filter 操作 常规性能调优三&#xff1a;并行度调节 常规性能调优四&#xff1a;广播大变量 常规性能调优五&#xff1a;Kryo 序列化 常规性…

Android Room 数据库使用详解

一、Room介绍 Android Room 是 Google 提供的一个 Android 数据持久化库&#xff0c;是 Android Jetpack 组成部分之一。它提供了一个抽象层&#xff0c;使得 SQLite 数据库的使用更为便捷。通过 Room&#xff0c;开发者可以轻松地操作数据库&#xff0c;不需要直接编写繁琐的…

数据结构十大排序之(冒泡,快排,并归)

接上期&#xff1a; 数据结十大排序之&#xff08;选排&#xff0c;希尔&#xff0c;插排&#xff0c;堆排&#xff09;-CSDN博客 前言&#xff1a; 在计算机科学中&#xff0c;排序算法是最基础且最重要的算法之一。无论是大规模数据处理还是日常的小型程序开发&#xff0c;…

【原生js案例】让你的移动页面实现自定义的上拉加载和下拉刷新

目前很多前端UI都是自带有上拉加载和下拉刷新功能,按照官网配置去实现即可,比如原生小程序,vantUI等UI框架,都替我们实现了内部功能。 那如何自己来实现一个上拉加载和下拉刷新的功能? 实现效果 不用浏览器的css滚动条,自定义实现滚动效果 自定义实现滚动,添加上拉加载…

【D3.js in Action 3 精译_046】DIY 实战:在 Observable 平台利用饼图布局函数实现 D3 多个环形图的绘制

当前内容所在位置&#xff1a; 第五章 饼图布局与堆叠布局 ✔️ 5.1 饼图和环形图的创建 ✔️ 5.1.1 准备阶段&#xff08;一&#xff09;5.1.2 饼图布局生成器&#xff08;二&#xff09;5.1.3 圆弧的绘制&#xff08;三&#xff09;5.1.4 数据标签的添加&#xff08;四&#…

【案例80】麒麟操作系统无法使用Uclient访问NC65

问题现象 麒麟操作系统&#xff0c;安装Uclient&#xff0c;添加应用后无法看到登录界面&#xff0c;一直在转圈。 问题分析 进入到Uclient的工作目录 发现在工作目录下&#xff0c;无相关app.log生成。 查看Uclient的main.log发现&#xff0c;有大量的报错与Uclient下的sha…

阿里云OSS批量导出下载地址 OSS批量导出 OSS导出清单

阿里云官方提供的客户端不能批量导出下载地址&#xff0c;阿里云OSS批量导出下载地址 OSS批量导出 OSS导出清单 1、参数配置&#xff1a;填写阿里云的AccessKeyID和AccessKeySecret&#xff0c;在阿里右上角的用户获取 2、选择地域&#xff1a;就是你OSS开的是哪个地方的&…

【CSS in Depth 2 精译_084】第 14 章:CSS 蒙版、形状与剪切概述 + 14.1:CSS 滤镜

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第四部分 视觉增强技术 ✔️【第 14 章 蒙版、形状与剪切】 ✔️ 14.1 滤镜 ✔️ 14.1.1 滤镜的类型 ✔️14.1.2 背景滤镜 ✔️ 14.2 蒙版 文章目录 第 14 章 蒙版、形状与剪切 Masks, shapes, and…

如何高效调试复杂布局?Layout Inspector 的 Toggle Deep Inspect 完全解析

Layout Inspector 是 Android Studio 提供的一个强大工具&#xff0c;用于分析和调试 Android 应用的 UI 布局。前置条件是app是debug调试状态&#xff0c;它允许你在应用运行时实时查看布局层次结构、UI 元素的属性&#xff0c;并帮助你诊断 UI 渲染问题。 Toggle Deep Inspe…

wxpython 基础代码

wxpython 基础代码 import wxclass MyFrame(wx.Frame):def __init__(self):super().__init__(parentNone, title计算器, size(450, 250))panel wx.Panel(self)# panel.SetBackgroundColour(wx.GREEN)self.icon1 wx.Icon(name"test.ico", typewx.BITMAP_TYPE_PNG)se…

2.5 io_uring

io_uring的相关函数接口介绍 io_uring 是 Linux 内核中一种高效的异步 I/O 接口&#xff0c;最早引入于 **Linux 内核 5.1** 版本。它是由 Jens Axboe 开发的&#xff0c;目的是提供更高效的异步 I/O 操作&#xff0c;尤其是相比 epoll 和 aio&#xff0c;io_uring 减少了系统…

服务器数据恢复—V7000存储中多块磁盘出现故障导致业务中断的数据恢复案例

服务器存储数据恢复环境&#xff1a; 一台V7000存储上共12块SAS机械硬盘&#xff08;其中1块是热备盘&#xff09;&#xff0c;组建了2组Mdisk&#xff0c;创建了一个pool。挂载在小型机上作为逻辑盘使用&#xff0c;小型机上安装的AIXSybase。 服务器存储故障&#xff1a; V7…

python 读取win7 win10本机ipv6 地址转发到电邮(备份)

python 版本&#xff1a; 3.8.10 用于外网查询SMB服务器ipv6 地址。服务器定时查询本机ipv6地址&#xff0c;如地址变动则用电邮发送新地址。 import smtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart import ssl import socket…

多个JAVA环境变量安装配置

在做java代码审计时&#xff0c;为了要成功运行目标环境&#xff0c;时长要对于jdk版进行切换&#xff0c;且在装多个jdk时还时长会遇到安装配置后环境变量不生效的情况&#xff0c;下文介绍&#xff1b; 1、为什么安装了新的jdk&#xff0c;有的时候环境变量中的jdk版本确还是…

数字经济下的 AR 眼镜

目录 1. &#x1f4c2; AR 眼镜发展历史 1.1 AR 眼镜相关概念 1.2 市面主流 XR 眼镜 1.3 AR 眼镜大事记 1.4 国内外 XR 眼镜 1.5 国内 AR 眼镜四小龙 2. &#x1f531; 关键技术 2.1 AR 眼镜近眼显示原理 2.2 AR 眼镜关键技术 2.3 AR 眼镜技术难点 3. &#x1f4a…