梯度下降(批量梯度下降、随机梯度下降、小批量梯度下降)

在上一篇中我们推导了损失函数 J ( θ ) = 1 2 m ∑ i = 1 m ( y i − h θ ( x i ) ) 2 J(\theta) = \frac{1}{2m} \sum_{i=1}^{m} (y^{i} - h_{\theta}(x^{i}))^2 J(θ)=2m1i=1m(yihθ(xi))2的由来,结尾讲到最小化这个损失函数来找到最优的参数 θ \theta θ,通常是使用梯度下降实现的。

梯度下降广泛用于机器学习和统计建模中的参数估计,特别是在训练线性回归模型时。它的目标是最小化一个损失函数(目标函数),这个函数量化了模型预测和真实数据之间的误差。梯度下降通过迭代地调整模型的参数来减少成本函数的值。

梯度下降的过程类似于一个“盲人下山”的过程:

起始点:就像盲人随机地在山上的某个位置开始,梯度下降算法通常从一个随机的参数值开始。这个起始点可能远离最小值,也就是最低点。

目标值(误差最低值):算法的目标是找到成本函数的最小值,这就好比盲人想要下山到达山谷底部,那里是海拔最低的地方。

获取移动方向:盲人通过脚下的坡度来判断下山的方向,而梯度下降算法通过计算成本函数的梯度来确定下降的方向。梯度告诉我们成本函数上升最快的方向,我们需要往相反方向移动来降低成本函数的值。

控制移动距离(步长):盲人下山时的每一步都不会太大,以免跌倒;同理,梯度下降算法中的学习率决定了每一步下降的距离。学习率太大可能会越过最小值,太小则下降过程会非常缓慢。

递归移动(迭代更新):盲人会一步一步地移动,每走一步都基于当前位置的坡度来决定下一步的方向。梯度下降算法通过迭代地更新参数,每次迭代都基于当前参数的梯度来更新,直到找到最小值或者达到其他停止条件。
梯度下降
在这个下山(梯度下降)概念中,又细分出几种算法,其中如上述的普通梯度下降被命名为:批量梯度下降,除了批量梯度下降外还有随机梯度下降和小批量梯度下降。

批量梯度下降:

批量梯度下降(Batch Gradient Descent)和上述的方法一样,小步多次逐步找到最终的目标值,在每次迭代中使用全部的训练数据来计算损失函数的梯度。因为要用到全部训练数据,所以内存占用高、性能差、速度慢、准确度高。
批量梯度下降的详细步骤:

  1. 初始化参数:
    在开始迭代之前,首先随机选择一组参数 θ \theta θ或者从一个零向量开始。
  2. 计算梯度:
    在每次迭代中,先计算损失函数对于每个参数 θ j \theta_j θj的梯度。这涉及到对整个训练集的计算,如下所示: ∇ θ J ( θ ) = − α 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x i \nabla_{\theta} J(\theta) = -\alpha \frac{1}{m} \sum_{i=1}^{m} (y^{i}-h_{\theta}(x^{i}) )x^{i} θJ(θ)=αm1i=1m(yihθ(xi))xi

其中:

  • ∇ θ J ( θ ) \nabla_{\theta} J(\theta) θJ(θ)表示损失函数 J ( θ ) J(\theta) J(θ)关于参数向量 θ \theta θ的梯度
  • m m m是训练样本的数量
  • x i x^{i} xi是第 i i i个训练样本的特征
  • y i y^{i} yi是对应的目标值
  • y i − h θ ( x i ) y^{i} - h_{\theta}(x^{i}) yihθ(xi)是预测误差,也就是模型对第 i i i个训练样本的预测值 h θ ∗ ( x i ) h_{\theta}* (x^{i}) hθ(xi)与实际值 y i y^{i} yi之间的差异
  • x i x^{i} xi(第 i i i的特征向量)与预测误差相乘,表示梯度是如何随着特征 x i x^{i} xi的变化而变化
  • α \alpha α表示步长
  1. 更新参数:
    计算出梯度后,更新参数 θ : = θ − α ∇ θ J ( θ ) \theta := \theta - \alpha \nabla_{\theta} J(\theta) θ:=θαθJ(θ)
    α 是学习率,决定了在参数空间中移动的步长。
    迭代直至收敛:

重复步骤2和步骤3直到损失函数的值不再显著变化,或者达到一定的迭代次数。
所以相对于 θ j \theta_j θj的下一个位置 θ j ′ \theta_j^{\prime} θj就可以表示为 θ j \theta_j θj减去 ( − 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x i ) (-\frac{1}{m} \sum_{i=1}^{m} (y^{i}-h_{\theta}(x^{i}) )x^{i}) (m1i=1m(yihθ(xi))xi),即:

θ j ′ = θ j + α 1 m ∑ i = 1 m ( y i − h θ ( x i ) ) x j i \theta_j^{\prime} = \theta_j + \alpha \frac{1}{m} \sum_{i=1}^{m} \left(y^{i} - h_{\theta}(x^{i}) \right) x_j^{i} θj=θj+αm1i=1m(yihθ(xi))xji

随机梯度下降:

随机梯度下降是批量梯度下降的一个优化版本,每次只找一个样本,迭代速度快,但不一定每次都朝着收敛方向。对于每个参数 θ j ′ \theta_j^{\prime} θj,更新规则如下:

θ j ′ = θ j + ( y i − h θ ( x i ) ) x j i \theta_j^{\prime}= \theta_j + \left({y^{i} - h_ \theta}(x^{i}) \right) x_j^{i} θj=θj+(yihθ(xi))xji

小批量梯度下降:

小批量梯度下降在每次迭代中使用一个小批量(batch)的样本来计算梯度和更新参数。相当于取少量的数据牺牲一些每次移动的准确性,从而极大提高运算速度,因此也是最常用的梯度下降方法。对于每个参数 θ j ′ \theta_j^{\prime} θj,更新规则如下:
θ j ′ = θ j − α 1 B ∑ i = k k + B − 1 ( h θ ( x i ) − y i ) x j i \theta_j^{\prime}= \theta_j - \alpha \frac{1}{B} \sum_{i=k}^{k+B-1} \left( h_{\theta}(x^{i}) - y^{i} \right) x_j^{i} θj=θjαB1i=kk+B1(hθ(xi)yi)xji

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/224036.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】简单工厂模式

2023年12月6日,周三下午 今天又学习了一次简单工厂模式 每多学习一次,都会加深对设计模式的理解 目录 什么是简单工厂模式简单工厂模式的优缺点举例说明 什么是简单工厂模式 简单工厂模式(Simple Factory Pattern)是一种创建型…

uni-app 微信小程序之加载行政区图

文章目录 1. 实现效果2. 实现步骤 1. 实现效果 2. 实现步骤 使用三方组件 ucharts echarts 高性能跨全端图表组件页面导入引入的三方组件 组件demo代码 <template><view class"qiun-columns"><view class"cu-bar bg-white margin-top-xs"…

管理和监控CentOS上的HTTP服务

CentOS作为一款稳定的开源服务器操作系统&#xff0c;为各种网络服务提供了优秀的支持。其中&#xff0c;HTTP服务是互联网上最常用的服务之一&#xff0c;它为人们提供了便捷的信息访问和交互方式。在CentOS上管理和监控HTTP服务是一项重要的任务&#xff0c;下面我们将介绍一…

华为配置风暴控制示例

组网需求 如下图所示&#xff0c;SwitchA作为二层网络到三层路由器的衔接点&#xff0c;需要防止二层网络转发的广播、未知组播或未知单播报文产生广播风 配置思路 用如下的思路配置风暴控制。 通过在GE0/0/1接口视图下配置风暴控制功能&#xff0c;实现防止二层网络转发的…

轻快小miniconda3在linux下的安装配置-centos9stream-Miniconda3 Linux 64-bit

miniconda与anaconda的区别&#xff1a; Miniconda 和 Anaconda 是用于管理环境和安装软件包的 Python 发行版。它们之间的主要区别在于以下几点&#xff1a; 1. 安装内容和大小&#xff1a; Anaconda&#xff1a; Anaconda 是一个完整的 Python 数据科学平台&#xff0c;包含…

10-tornado项目部署

1. python3的安装和配置 1.1 安装系统依赖包 sudo dnf install wget yum-utils make gcc openssl-devel bzip2-devel libffi-devel zlib-devel -y1.2 下载Python wget https://www.python.org/ftp/python/3.9.5/Python-3.9.5.tgz1.3 解压 tar xzf Python-3.9.5.tgz 1.4 安装…

Java实现TCP一对一通信,UDP协议实现群聊

实现服务端对话框&#xff1a; 其中可自行更改对话框大小样式等配置。 package com.ex.controller;import javax.servlet.http.HttpServletRequest; import javax.swing.*; import java.awt.*; import java.awt.event.*; import java.io.*; import java.net.*; public class …

Centos服务器上根据端口号查询jar包,根据jar包查端口号

在开发springboot服务器时&#xff0c;经常会遇到其他人部署的java服务&#xff0c;需要自己维护&#xff0c;留下的信息又非常少。经常面临找不到jar包位置&#xff0c;或者不知道占用端口&#xff0c;不知道启动命令的问题。这里记录一下常用的centos服务器上的命令&#xff…

【桑基图】绘制桑基图

绘制桑基图 一、绘制桑基图&#xff08;1&#xff09;方法一&#xff1a;去在线网站直接绘制&#xff08;2&#xff09;方法二&#xff1a;写html之后在vscode上运行 二、遇到的问题&#xff08;1&#xff09;当导入一些excel的时候&#xff0c;无法绘制出桑基图 一、绘制桑基图…

ROS-ROS通信机制-参数服务器

文章目录 一、基础理论知识二、C实现三、Python实现 一、基础理论知识 参数服务器在ROS中主要用于实现不同节点之间的数据共享。参数服务器相当于是独立于所有节点的一个公共容器&#xff0c;可以将数据存储在该容器中&#xff0c;被不同的节点调用&#xff0c;当然不同的节点…

【LeetCode刷题】-- 79.单词搜索

79.单词搜索 方法&#xff1a;使用回溯 使用dfs函数表示判断以网格的(i.j)位置出发&#xff0c;能否搜索到word(k)&#xff0c;其中word(k)表示字符串word从第k个字符开始的后缀子串&#xff0c;如果能搜索到&#xff0c;返回true,反之返回false 如果board[i][j]≠word[k]&am…

李宏毅gpt个人记录

参考&#xff1a; 李宏毅机器学习--self-supervised&#xff1a;BERT、GPT、Auto-encoder-CSDN博客 用无标注资料的任务训练完模型以后&#xff0c;它本身没有什么用&#xff0c;GPT 1只能够把一句话补完&#xff0c;可以把 Self-Supervised Learning 的 Model做微微的调整&am…

分享一个Python网络爬虫数据采集利器

前言 你是否曾为获取重要数据而感到困扰&#xff1f;是否因为数据封锁而无法获取所需信息&#xff1f;是否因为数据格式混乱而头疼&#xff1f;现在&#xff0c;所有这些问题都可以迎刃而解。让我为大家介绍一款强大的数据收集平台——亮数据Bright Data。 作为世界领先的数据…

【Qt】QLineEdit显示输入十六进制,位数不足时按照规则填充显示及每两个字符以空格填充

问题 在实际开发中&#xff0c;有时候需要对输入进行限制&#xff0c;一是更加合理&#xff0c;二是防止出现误操作。 比如&#xff1a; 使用Qt进行应用程序开发时&#xff0c;对单行编辑框QLineEdit控件&#xff0c;设置只可输入十六进制。 限制输入的方式常用且经典的是使用…

限流算法,基于go的gRPC 实现的

目录 一、单机限流 1、令牌桶算法 3、固定窗口限流算法 4、滑动窗口 二、集群限流 1、分布式固定窗口 &#xff08;基于redis&#xff09; 2、分布式滑动窗口 一、单机限流 1、令牌桶算法 令牌桶算法是当流量进入系统前需要获取令牌&#xff0c;没有令牌那么就要进行限…

Docker+jenkins+gitlab实现持续集成

1.安装环境 服务器ip虚拟机版本192.168.5.132centos7.6192.168.5.152centos7.6 2. 安装docker 安装必要的一些系统工具 yum install -y yum-utils device-mapper-persistent-data lvm2添加软件源信息&#xff0c;要确保centos7能上外网 yum-config-manager --add-repo http:…

什么是数据清洗、特征工程、数据可视化、数据挖掘与建模?

1.1什么是数据清洗、特征工程、数据可视化、数据挖掘与建模&#xff1f; 视频为《Python数据科学应用从入门到精通》张甜 杨维忠 清华大学出版社一书的随书赠送视频讲解1.1节内容。本书已正式出版上市&#xff0c;当当、京东、淘宝等平台热销中&#xff0c;搜索书名即可。内容涵…

智篆商业的电商运营课程:超越传统,引领未来

智篆商业一直以来都是电商运营课程领域的领军者&#xff0c;其成功之处在于超越传统&#xff0c;不断引领未来。探究智篆商业的成功之道&#xff0c;我们将发现其可信度和培训专业性是学员信赖的根本原因。 首先&#xff0c;智篆商业注重构建强大的导师团队。导师们不仅拥有丰富…

持续集成交付CICD:Jenkins使用GitLab共享库实现自动更新前后端项目质量配置

目录 一、实验 1.Jenkins使用GitLab共享库实现自动更新后端项目质量配置 2.Jenkins使用GitLab共享库实现自动更新前端项目质量配置 二、问题 1.Sonarqube如何添加自定义质量阈 一、实验 1.Jenkins使用GitLab共享库实现自动更新后端项目质量配置 (1)修改GitLab的Sonar.gr…

HBase-架构与设计

HBase架构与设计 一、背景二、HBase概述1.设计特点2.适用场景2.1 海量数据2.2 稀疏数据2.3 多版本数据2.4 半结构或者非结构化数据 三、数据模型1.RowKey2.Column Family3.TimeStamp 四、HBase架构图1.Client2.Zookeeper3.HMaster4.HRegionServer5.HRegion6.Store7.StoreFile8.…