神经网络的工程基础(二)——随机梯度下降法|文末送书

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
本文涉及到的代码链接如下:regression2chatgpt/ch06_optimizer/stochastic_gradient_descent.ipynb

本文将讨论利用PyTorch实现随机梯度下降法的细节。

关于大语言模型的内容,推荐参考这个专栏。

内容大纲

  • 相关说明
  • 一、随机梯度下降法:更优化的算法
  • 二、算法细节
  • 三、代码实现
  • 四、粉丝福利

一、随机梯度下降法:更优化的算法

梯度下降法虽然在理论上很美好,但在实际应用中常常会碰到瓶颈。为了说明这个问题,令 L i L_i Li表示模型在 i i i点的损失,即 L i = ( y i − a x i − b ) 2 L_i = (y_i - ax_i - b)^2 Li=(yiaxib)2,对所有数据点的损失求和后,可以得到整体损失函数: L = 1 ⁄ n ∑ i L i L = 1⁄n ∑_iL_i L=1⁄niLi 。即模型的损失函数实际上是各个数据点损失的平均值,这一观点适用于大多数模型 1

计算整体损失函数 L L L的梯度可得, ∇ L = 1 ⁄ n ∑ i L i ∇L = 1⁄n \sum_i L_i L=1⁄niLi。也就是说,损失函数的梯度等于所有数据点处梯度的平均值。但是在实际应用中,通常会使用大型数据集计算所有数据点的梯度和,这需要相当长的时间。为了加速这个计算过程,可以考虑使用随机梯度下降法(Stochastic Gradient Descent,SGD)。

二、算法细节

随机梯度下降法的核心思想是:每次迭代时只随机选择小批量的数据点来计算梯度,然后用这个小批量数据点的梯度平均值来代替整体损失函数的梯度2

为了使算法的细节更加准确,引入一个超参数,称为批量大小(Batch Size),记作m。每次随机选取m个数据,记为 I 1 , I 2 , ⋯ , I m I_1,I_2,⋯,I_m I1,I2,,Im。使用这些数据点的梯度平均值来近似代替整体损失函数的梯度: ∇ L = 1 ⁄ n ∑ i ∇ L i ≈ 1 ⁄ m ∑ j = 1 m ∇ L I j ∇L = 1⁄n ∑_i∇L_i ≈ 1⁄m ∑_{j = 1}^m∇L_{I_j } L=1⁄niLi1⁄mj=1mLIj 。由此得到新的参数迭代公式:
a k + 1 = a k − η / m ∑ j = 1 m ∂ L I j / ∂ a b k + 1 = b k − η / m ∑ j = 1 m ∂ L I j / ∂ b (1) a_{k + 1} = a_k - η/m ∑_{j = 1}^m∂L_{I_j }/∂a \\ b_{k + 1} = b_k -η/m ∑_{j = 1}^m∂L_{I_j }/∂b \tag{1} ak+1=akη/mj=1mLIj/abk+1=bkη/mj=1mLIj/b(1)

在随机梯度下降法中,所有数据点都使用了一遍,称为模型训练了一轮。由此在实际应用中常使用另一个超参数——训练轮次(Epoch),表示所有数据将被用几遍,用于控制随机梯度下降法的循环次数。换句话说,就是公式(1)被迭代运算多少次。

在一些机器学习书籍和学术文献中,还对随机梯度下降法(当m=1时)和小批量梯度下降法(当m>1时)进行了进一步的区分。然而,这两种方法之间的区别并不大,其核心思想都是基于随机采样来近似计算梯度,从而高效地更新参数、优化模型。在实际应用中,会根据问题的性质和数据规模选择合适的批次大小,以获得最佳的训练效果。因此,本书将统一使用随机梯度下降法来代表这一类方法,以保持概念清晰和简洁。

与梯度下降法相比,随机梯度下降法更高效,这是因为小批量梯度计算比整体梯度计算快得多。尽管在随机梯度下降法中,采用小批量数据估计梯度可能会引入一些噪声,但实践证明这些噪声对整个优化过程有好处,有助于模型克服局部最优的“陷阱”,逐步逼近全局最优参数。

三、代码实现

随机梯度下降法的实现与梯度下降法类似,不同之处在于,每次计算梯度时需要“随机”选取一部分数据,具体的实现步骤可以参考程序清单1(完整代码)。

  1. 在程序清单1的第2行,引入一个名为batch_size的超参数,用于控制每个批次中的数据量大小。选择合适的batch_size对算法的运行效率和稳定性至关重要。如果参数设置过大,可能会导致算法运行效率下降;而过小的参数可能使算法变得过于随机,影响收敛的稳定性。选择合适的参数需要结合具体的模型和应用场景,结合相关领域的经验进行决策。
  2. 在程序清单1的第11—13行,展示了一种随机选取批次数据的实现方式。这也是随机梯度下降法与普通梯度下降法的主要区别之一。实现随机性的方式有很多种,比如引入随机数等。这里仅呈现一种经典方法:将数据按顺序划分成批次。
程序清单1 随机梯度下降法
 1 |  # 定义每批次用到的数据量
 2 |  batch_size = 20
 3 |  # 定义模型
 4 |  model = Linear()
 5 |  # 确定最优化算法
 6 |  learning_rate = 0.1
 7 |  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
 8 |  
 9 |  for t in range(20):
10 |      # 选取当前批次的数据,用于训练模型
11 |      ix = (t * batch_size) % len(x)
12 |      xx = x[ix: ix + batch_size]
13 |      yy = y[ix: ix + batch_size]
14 |      yy_pred = model(xx)
15 |      # 计算当前批次数据的损失
16 |      loss = (yy - yy_pred).pow(2).mean()
17 |      # 将上一次的梯度清零
18 |      optimizer.zero_grad()
19 |      # 计算损失函数的梯度
20 |      loss.backward()
21 |      # 迭代更新模型参数的估计值
22 |      optimizer.step()
23 |      # 注意!loss记录的是模型在当前批次数据上的损失,该数值的波动较大
24 |      print(f'Step {t + 1}, Loss: {loss: .2f}; Result: {model.string()}')

在随机梯度下降法的执行过程中,通常使用模型的整体损失作为指标来监测算法的运行情况。但要注意的是,程序清单1中第16行定义的loss表示模型在小批量数据上的损失,这个值仅依赖于少量数据,迭代过程中会表现出极大的不稳定性,因此并不适合作为评估算法运行情况的主要标志。

如果希望更准确地监测算法的运行情况,需要在更大的数据集上估计模型的整体损失,例如在全部训练数据上计算损失,如图1所示。这种评估方式更稳定,能够更全面地反映模型的训练进展。

图1

图1

四、粉丝福利

参与方式:关注博主、点赞、收藏、评论区评论“解构大语言模型”(切记要点赞+收藏,否则抽奖无效,每个人最多评论三次!)
本次送书数量不少于3本,【阅读量越多,送得越多】
活动结束后,会私信中奖粉丝,请各位注意查看私信哦~

活动截止时间:2024-05-24 24:00:00


  1. 对于解决回归问题的模型,这个结论显然成立。对于解决分类问题的模型(比如逻辑回归模型),只需对模型的似然函数做简单的数学变换(先求对数,再求相反数),就可以得到同样的结论。 ↩︎

  2. 这在数学上是完全合理的。从统计的角度来看,用所有数据点求平均值,并不比随机抽样的方法高明很多。与线性回归参数估计值类似,两个结果都是随机变量:它们都以真实梯度为期望,只是前者的置信区间更小。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/669128.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux下CPU1000%记一次挖矿病毒清理流程

今天top后发现一个进程CPU高1795%,判断是病毒 查找进程ps -elf|grep 进程idpid和ppid查找到sleep进程 ps -ef|grep 4277 查看具体进程内容,ll /proc/进程idpid ll /proc/4277 ls -l /proc/{pid号} ls -l /proc/{pid号}/exe kill掉病毒进程 排查病毒…

Android开机动画的结束过程BootAnimation(基于Android10.0.0-r41)

文章目录 Android 开机动画的结束过程BootAnimation(基于Android10.0.0-r41) Android 开机动画的结束过程BootAnimation(基于Android10.0.0-r41) 路径frameworks/base/cmds/bootanimation/bootanimation_main.cpp init进程把我们的BootAnimation的二进制文件拉起来了&#xf…

如何备份RDK X3(旭日X3派)的 SD卡镜像

该方法可以在Ubuntu的开发机上生成一个img镜像,后续可以直接使用rufus软件烧录备份好的镜像。 Step 1:在Ubuntu的开发机上安装gparted软件 如果安装失败则需要为您的Ubuntu开发机换源,这里推荐阿里源:https://developer.aliyun.…

用户购物性别模型标签(USG)之决策树模型

一、USG模型引入: 首先了解一下,如何通过大数据来确定用户的真实性别, 经常谈论的用户精细化运营,到底是什么? 简单来讲,就是将网站的每个用户标签化,制作一个属于用户自己的网络身份证。然后,运营人员 通…

【算法】贪心算法简介

贪心算法概述 目录 1.贪心算法概念2.贪心算法特点3.贪心算法学习 1.贪心算法概念 贪心算法是一种 “思想” ,即解决问题时从 “局部最优” 从而达到 “全局最优” 的效果。 ①把解决问题的过程分为若干步②解决每一步时候,都选择当前最优解(不关注全局…

基于SSM前后端分离版本的论坛系统-自动化测试

目录 前言 一、测试环境 二、环境部署 三、测试用例 四、执行测试 4.1、公共类设计 创建浏览器驱动对象 测试套件 释放驱动类 4.2、功能测试 注册页面 登录页面 版块 帖子 用户个人中心页 站内信 4.3、界面测试 注册页面 登录页面 版块 帖子 用户个人中心页…

数据整理操作及众所周知【数据分析】

各位大佬好 ,这里是阿川的博客,祝您变得更强 个人主页:在线OJ的阿川 大佬的支持和鼓励,将是我成长路上最大的动力 阿川水平有限,如有错误,欢迎大佬指正 Python 初阶 Python–语言基础与由来介绍 Python–…

进程与线程(二)

进程与线程(二) exec函数族守护进程守护进程的概念Linux守护进程的编写步骤创建子进程,父进程退出在子进程中创建新会话改变当前目录为根目录重设文件权限掩码关闭文件描述符守护进程案例编写 线程线程的概念Linux下进程和线程的区别 Linux下…

C++基础编程100题-002 OpenJudge-1.1-04 输出保留3位小数的浮点数

更多资源请关注纽扣编程微信公众号 002 OpenJudge-1.1-04 输出保留3位小数的浮点数 http://noi.openjudge.cn/ch0101/04/ 描述 读入一个单精度浮点数,保留3位小数输出这个浮点数。 输入 只有一行,一个单精度浮点数。 输出 也只有一行,…

java —— 集合

一、集合的概念 集合可以看做是一个存储对象的容器,与数组不同的是集合可以存储不同类型的对象,但开发中一般不这样做。集合不能存储基本类型的对象,如果存储则需要将其转化为对应的包装类。 二、集合的分类 集合分为 Collection 和 Map 两…

2024抖音流量认知课:掌握流量底层逻辑,明白应该选择什么赛道 (43节课)

课程下载:https://download.csdn.net/download/m0_66047725/89360865 更多资源下载:关注我。 课程目录 01序言:拍前请看.mp4 02抖音建模逻辑1.mp4 03抖音标签逻辑2.mp4 04抖音推流逻辑3.mp4 05抖音起号逻辑4.mp4 06养号的意义.mp4 0…

算法解析——单身狗问题

欢迎来到博主的专栏:算法解析 博主ID代码小豪 文章目录 什么是单身狗问题leetcode_136——只出现一次的数字I使用位运算解决单身狗问题。 leetcode_137——只出现一次的数字II统计二进制数解决单身狗问题leetcode_260 只出现一次数字III分区域按位异或解决问题。 总…

iperf3带宽压测工具使用

iperf3带宽压测工具使用 安装下载地址:[下载入口](https://iperf.fr/iperf-download.php)测试结果:时长测试(压测使用):并行测试反向测试UDP 带宽测试 iPerf3 是用于主动测试 IP 网络上最大可用带宽的工具 安装 下载地址&#x…

SAP 生产订单批量报工(代码分享)

最近公司一直在对成本这块的业务进行梳理,影响比较大的就是生产这块的报工,经常会要求要批量的冲销报工,然后在继续报工,来调整生产订单的实际工时,前面的博客中已经给大家分享了批量冲销生产订单的代码, 下面给大家分享一下生产订单批量报工的代码 首先流程制造和离散制…

如何解决研发数据传输层面安全可控、可追溯的共性需求?

研发数据在企业内部跨网文件交换,是相对较为普遍而频繁的文件流转需求,基于国家法律法规要求及自身安全管理需要,许多企业进行内部网络隔离。不同企业隔离方案各不相同,比如银行内部将网络隔离为生产网、办公网、DMZ区&#xff0c…

如何用ChatGPT上热门:完整使用教程与写作技巧

1. ChatGPT概述修订 ChatGPT是一款基于深度神经网络的语言生成技术,能够协助用户创造出各类高品质的文字材料,适宜广泛的应用场景,如编撰文章、文学创作及社交媒体内容生成。 2. 利用ChatGPT生成热门内容的基本步骤 为了有效利用ChatGPT创作…

代码随想录算法训练营第三十五 | ● 860.柠檬水找零 ● 406.根据身高重建队列 ● 452. 用最少数量的箭引爆气球

860.柠檬水找零 讲解链接:https://programmercarl.com/0860.%E6%9F%A0%E6%AA%AC%E6%B0%B4%E6%89%BE%E9%9B%B6.html 本题只有5元10元20元,只需要考虑收到5、10、20这三种情况; 收到5元,五块的个数; 收到10,找…

Appium安装及配置(Windows环境)

在做app相关自动化测试,需要使用appium来做中转操作,下面来介绍一下appium的环境安装配置 appium官方文档:欢迎 - Appium Documentation 一、下载appium 下载地址:https://github.com/appium/appium-desktop/releases?page3 通…

Java多线程--volatile关键字

并发编程的三大特性 可见性有序性原子性 可见性 为什么会有可见性问题? 多核CPU 为了提升CPU效率,设计了L1,L2,L3三级缓存,如图。 如果俩个核几乎同时操作同一块内存,cpu1修改完,当下是对c…

TDR原理的介绍

目录 简介 简单定义 TDR测试原理 简介 时域和频域就像孪生兄弟一样,经常在测试测量领域同时出现,可谓是工程师们分析问题和解决问题的两大法宝。所以,在某些测试场景中,如果有时域信息的护法,咱们就能从时频域两个维…