深入探讨梯度下降:优化机器学习的关键步骤(一)

文章目录

  • 🍀引言
  • 🍀什么是梯度下降?
  • 🍀损失函数
  • 🍀梯度(gradient)
  • 🍀梯度下降的工作原理
  • 🍀梯度下降的变种
    • 🍀随机梯度下降(SGD)
    • 🍀批量梯度下降(BGD)
    • 🍀小批量梯度下降(Mini-Batch GD)
  • 🍀如何选择学习率?
  • 🍀梯度下降的相关数学公式
  • 🍀梯度下降的实现(代码)
  • 🍀总结

🍀引言

在机器学习领域,梯度下降是一种核心的优化算法,它被广泛应用于训练神经网络、线性回归和其他机器学习模型中。本文将深入探讨梯度下降的工作原理,并且进行简单的代码实现


🍀什么是梯度下降?

梯度下降是一种迭代优化算法,旨在寻找函数的局部最小值(或最大值)以最小化(或最大化)一个损失函数。在机器学习中,我们通常使用梯度下降来最小化模型的损失函数,以便训练模型的参数。
这里顺便提一嘴,与梯度下降齐名的梯度上升算法目的是使效用函数最大。


🍀损失函数

在使用梯度下降之前,我们首先需要定义一个损失函数。损失函数是一个用于衡量模型预测值与实际观测值之间差异的函数。通常,我们使用均方误差(MSE)作为回归问题的损失函数,使用交叉熵作为分类问题的损失函数。


🍀梯度(gradient)

梯度是损失函数相对于模型参数的偏导数。它告诉我们如果稍微调整模型参数,损失函数会如何变化。梯度下降算法利用梯度的信息来不断调整参数,以减小损失函数的值。

🍀梯度下降的工作原理

梯度下降的核心思想是沿着损失函数的负梯度方向调整参数,直到达到损失函数的局部最小值。具体来说,梯度下降的步骤如下:

  • 初始化模型参数:首先,随机初始化模型参数或使用某种启发式方法。

  • 计算损失和梯度:使用当前模型参数计算损失函数的值,并计算损失函数相对于参数的梯度。

  • 参数更新:根据梯度的方向和学习率(learning rate)本文我称其为eta,更新模型参数。学习率是一个控制步长大小的超参数,它决定了每次迭代中参数更新的大小。

  • 重复迭代:重复步骤2和3,直到损失函数的值收敛到一个稳定的值,或达到预定的迭代次数。

🍀梯度下降的变种

在梯度下降的基础上,发展出了多种变种算法,以应对不同的问题和挑战。其中一些常见的包括

🍀随机梯度下降(SGD)

随机梯度下降每次只使用一个随机样本来估计梯度,从而加速收敛速度。它特别适用于大规模数据集和在线学习。

🍀批量梯度下降(BGD)

批量梯度下降在每次迭代中使用整个训练数据集来计算梯度。尽管计算开销较大,但通常能够更稳定地收敛到全局最小值。

🍀小批量梯度下降(Mini-Batch GD)

小批量梯度下降综合了SGD和BGD的优点,它使用一个小批量样本来估计梯度,平衡了计算效率和收敛性能。

🍀如何选择学习率?

学习率是梯度下降的关键超参数之一。选择合适的学习率可以加速收敛,但过大的学习率可能导致不稳定的训练过程。通常,我们可以采用以下方法选择学习率:

  • 网格搜索:尝试不同的学习率值,通过验证集的性能来选择最佳值。

  • 学习率衰减:开始时使用较大的学习率,随着训练的进行逐渐减小学习率。

  • 自适应学习率:使用自适应学习率算法,如Adam、Adagrad或RMSprop,它们可以自动调整学习率以适应梯度的变化。

🍀梯度下降的相关数学公式

本人数学不好,这里有说的不清楚的地方还请见谅,谢谢佬~
首先我们通过图像认识一下损失函数
在这里插入图片描述
这里的步长指的是,可能有些人会好奇为啥有一个负号呢?因为对称轴左侧的导数都是负值,这里加一个负号不就正了嘛
在这里插入图片描述

具体推导过程请查看相关佬的文章(哭~)

🍀梯度下降的实现(代码)

首先我们导入我们需要的库

import numpy as np
import matplotlib.pyplot as plt

之后我们需要举一个例子,这里我们采用numpy里面的一个分割函数linspace,同时我们举一个函数的例子

plt_x = np.linspace(-1,6,141)
plt_y = (plt_x-2.5)**2-1

之后我们使用show进行展示一下图像

plt.plot(plt_x,ply_y)
plt.show()

运行结果如下
在这里插入图片描述

上图看起来就是一个普通的曲线,方便我们进行理解

接下来我们需要两个函数,一个为了返回导数,一个为了返回对应的y值

def dj(thera):
    return 2*(thera-2.5) # 求导
def j(thera)
    return (thera-2.5)**2-1  # 求对应的值

接下来是梯度下降的关键位置了,这里我们需要初始化两个参数以及一个范围参数,同时设置一个while循环,将前一个thera保存在last_thera中,后一个thera是前一个thera和步长的差值,这里的步长就是梯度个参数eta的乘积,最后使用if函数来终结循环,最终我们将最小值点的值、导数、以及自变量打印出来

eta = 0.1
theta =0.0
epsilon = 1e-8
while True:
    gradient = dj(theta)
    last_theta = theta
    theta = theta-gradient*eta 
    if np.abs(j(theta)-j(last_theta))<epsilon:
        break
        
print(theta)
print(dj(theta))
print(j(theta))

运行结果如下
在这里插入图片描述
这里我们也可以使用列表来看看到底进行了多少次thera的循环

eta = 0.1
theta =0.0
epsilon = 1e-8
theta_history = [theta]
while True:
    gradient = dj(theta)
    last_theta = theta
    theta = theta-gradient*eta 
    theta_history.append(theta)
    if np.abs(j(theta)-j(last_theta))<epsilon:
        break
        
print(theta)
print(dj(theta))
print(j(theta))

len(theta_history)

运行结果如下

在这里插入图片描述
还可以绘制图像进行直观查看

plt.plot(plt_x,plt_y)
plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='*')
plt.show()

运行结果如下
在这里插入图片描述
这样的话就很直观了吧~

🍀总结

本节只介绍梯度下降的简单实现,下节继续学习此法中eta参数的调节

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/101458.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java“牵手”京东商品列表数据,关键词搜索京东商品数据接口,京东API申请指南

京东商城是一个网上购物平台&#xff0c;售卖各类商品&#xff0c;包括服装、鞋类、家居用品、美妆产品、电子产品等。要获取京东商品列表和商品详情页面数据&#xff0c;您可以通过开放平台的接口或者直接访问京东商城的网页来获取商品详情信息。以下是两种常用方法的介绍&…

Linux 学习笔记(1)——系统基本配置与开关机命令

目录 0、起步 0-1&#xff09;命令使用指引 0-2&#xff09;查看历史的命令记录 0-3&#xff09;清空窗口内容 0-4&#xff09;获取本机的内网 IP 地址 0-5&#xff09;获取本机的公网ip地址 0-6&#xff09;在window的命令行窗口中远程连接linux 0-7&#xff09;修改系…

docker安装jenkins

运行jenkins docker run -d \--name jenkins \ --hostname jenkins \-u root \-p 29090:8080 \--restart always \-v D:\springcloud\学习\jekins\jenkins\jks_home:/var/jenkins_home \ jenkins/jenkins获取root登录密码 密码在jekins_home/secrets/initalAdminPassword文件…

设计模式—原型模式(Prototype)

目录 一、什么是原型模式&#xff1f; 二、原型模式具有什么优缺点吗&#xff1f; 三、有什么缺点&#xff1f; 四、什么时候用原型模式&#xff1f; 五、代码展示 ①、简历代码初步实现 ②、原型模式 ③、简历的原型实现 ④、深复制 ⑤、浅复制 一、什么是原型模式&…

Ubuntu学习---跟着绍发学linux课程记录(第二部分)

文章目录 7 文件权限7.1 文件的权限7.2 修改文件权限7.3 修改文件的属主 8、可执行脚本8.2Shell脚本8.3python脚本的创建 9Shell9.1Shell中的变量9.2 环境变量9.3用户环境变量 学习链接: Ubuntu 21.04乌班图 Linux使用教程_60集Linux课程 所有资料在 http://afanihao.cn/java …

单调递增的数字【贪心算法】

单调递增的数字 当且仅当每个相邻位数上的数字 x 和 y 满足 x < y 时&#xff0c;我们称这个整数是单调递增的。 给定一个整数 n &#xff0c;返回 小于或等于 n 的最大数字&#xff0c;且数字呈 单调递增 。 public class Solution {public int monotoneIncreasingDigits…

stm32---用外部中断实现红外接收器

一、红外遥控的原理 红外遥控是一种无线、非接触控制技术&#xff0c;具有抗干扰能力强&#xff0c;信息传 输可靠&#xff0c;功耗低&#xff0c;成本低&#xff0c;易实现等显著优点&#xff0c;被诸多电子设备特别是 家用电器广泛采用&#xff0c;并越来越多的应用到计算机系…

【USRP】调制解调系列6:16APSK、32APSK 、基于labview的实现

APSK APSK是&#xff0c;与传统方型星座QAM&#xff08;如16QAM、64QAM&#xff09;相比&#xff0c;其分布呈中心向外沿半径发散&#xff0c;所以又名星型QAM。与QAM相比&#xff0c;APSK便于实现变速率调制&#xff0c;因而很适合目前根据信道及业务需要分级传输的情况。当然…

音频——I2S DSP 模式(五)

I2S 基本概念飞利浦(I2S)标准模式左(MSB)对齐标准模式右(LSB)对齐标准模式DSP 模式TDM 模式 文章目录 DSP formatDSP A时序图逻辑分析仪抓包 DSP B时序图逻辑分析仪抓包 DSP format DSP/PCMmode 分为 Mode-A 和 Mode-B 共 2 种模式。不同芯⽚有的称为 PCM mode 有的称为 DSP m…

Qt —UDP通信QUdpSocket 简介 +案例

1. UDP通信概述 UDP是无连接、不可靠、面向数据报&#xff08;datagram&#xff09;的协议&#xff0c;可以应用于对可靠性要求不高的场合。与TCP通信不同&#xff0c;UDP通信无需预先建立持久的socket连接&#xff0c;UDP每次发送数据报都需要指定目标地址和端口。 QUdpSocket…

在访问一个网页时弹出的浏览器窗口,如何用selenium 网页自动化解决?

相信大家在使用selenium做网页自动化时&#xff0c;会遇到如下这样的一个场景&#xff1a; 在你使用get访问某一个网址时&#xff0c;会在页面中弹出如上图所示的弹出框。 首先想到是利用Alert类来处理它。 然而&#xff0c;很不幸&#xff0c;Alert类处理的结果就是没有结果…

flink on yarn with kerberos 边缘提交

flink on yarn 带kerberos 远程提交 实现 flink kerberos 配置 先使用ugi进行一次认证正常提交 import com.google.common.io.Files; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.flink.client.cli.CliFrontend; import o…

Matlab(数值微积分)

目录 1.多项式微分与积分 1.1 微分 1.2 多项式微分 1.3 如何正确的使用Matlab? 1.3.1 Matlab表达多项式 1.3.2 polyval() 多项式求值 1.3.3 polyder()多项式微分 1.4 多项式积分 1.4.1 如何正确表达 1.4.2 polyint() 多项式积分 2.数值的微分与积分 2.1 数值微分 2…

django.core.exceptions.AppRegistryNotReady: Apps aren‘t loaded yet.

运行django测试用例报错django.core.exceptions.AppRegistryNotReady: Apps arent loaded yet. 解决&#xff1a;在测试文件上方加上 django.setup() django.setup()是Django框架中的一个函数。它用于在非Django环境下使用Django的各种功能、模型和设置。 在常规的Django应用…

如何中mac上安装多版本python并配置PATH

摘要 mac 默认安装的python是 python3&#xff0c;但是如果我们需要其他python版本时&#xff0c;该怎么办呢&#xff1f; 例如&#xff1a;需要python2 版本&#xff0c;如果使用homebrew安装会提示没有python2。同时使用python --version 会发现commond not found。 所以本…

POI-TL制作word

本文相当于笔记&#xff0c;主要根据官方文档Poi-tl Documentation和poi-tl的使用&#xff08;最全详解&#xff09;_JavaSupeMan的博客-CSDN博客文章进行学习&#xff08;上班够用&#xff09; Data AllArgsConstructor NoArgsConstructor ToString EqualsAndHashCode public …

[杂谈]-2023年实现M2M的技术有哪些?

2023年实现M2M的技术有哪些&#xff1f; 文章目录 2023年实现M2M的技术有哪些&#xff1f;1、寻找连接2、M2M与IoT3、流行的 M2M 协议 在当今的数字世界中&#xff0c;机器对机器 (M2M) 正在迅速成为标准。 M2M 包括使联网设备能够交换数据或信息的任何技术。 它可以是有线或无…

ESLint 中的“ space-before-function-paren ”相关报错及其解决方案

ESLint 中的“ space-before-function-paren ”相关报错及其解决方案 出现的问题及其报错&#xff1a; 在 VScode 中&#xff0c;在使用带有 ESLint 工具的项目中&#xff0c;保存会发现报错&#xff0c;并且修改好代码格式后&#xff0c;保存会发现代码格式依然出现问题&…

HTTP介绍:一文了解什么是HTTP

前言&#xff1a; 在当今数字时代&#xff0c;互联网已经成为人们生活中不可或缺的一部分。无论是浏览网页、发送电子邮件还是在线购物&#xff0c;我们都离不开超文本传输协议&#xff08;HTTP&#xff09;。HTTP作为一种通信协议&#xff0c;扮演着连接客户端和服务器的重要角…

OpenCV(十四):ROI区域截取

在OpenCV中&#xff0c;你可以使用Rect对象或cv::Range来截取图像的感兴趣区域&#xff08;Region of Interest&#xff0c;ROI&#xff09;。 方法一&#xff1a;使用Rect对象截取图像 Rect_(_Tp _x&#xff0c; _Tp _y&#xff0c; _Tp _width,_Tp _height) Tp:数据类型&…