大数据机器学习-梯度下降:从技术到实战的全面指南

大数据机器学习-梯度下降:从技术到实战的全面指南

文章目录

  • 大数据机器学习-梯度下降:从技术到实战的全面指南
  • 一、简介
    • 什么是梯度下降?
    • 为什么梯度下降重要?
  • 二、梯度下降的数学原理
    • 代价函数(Cost Function)
    • 梯度(Gradient)
    • 更新规则
      • 代码示例:基础的梯度下降更新规则
  • 三、批量梯度下降(Batch Gradient Descent)
    • 基础算法
    • 代码示例
  • 四、随机梯度下降(Stochastic Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点
  • 五、小批量梯度下降(Mini-batch Gradient Descent)
    • 基础算法
    • 代码示例
    • 优缺点

本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。

在这里插入图片描述

一、简介

梯度下降(Gradient Descent)是一种在机器学习和深度学习中广泛应用的优化算法。该算法的核心思想非常直观:找到一个函数的局部最小值(或最大值)通过不断地沿着该函数的梯度(gradient)方向更新参数。

什么是梯度下降?

简单地说,梯度下降是一个用于找到函数最小值的迭代算法。在机器学习中,这个“函数”通常是损失函数(Loss Function),该函数衡量模型预测与实际标签之间的误差。通过最小化这个损失函数,模型可以“学习”到从输入数据到输出标签之间的映射关系。

为什么梯度下降重要?

  1. 广泛应用:从简单的线性回归到复杂的深度神经网络,梯度下降都发挥着至关重要的作用。
  2. 解决不可解析问题:对于很多复杂的问题,我们往往无法找到解析解(analytical solution),而梯度下降提供了一种有效的数值方法。
  3. 扩展性:梯度下降算法可以很好地适应大规模数据集和高维参数空间。
  4. 灵活性与多样性:梯度下降有多种变体,如批量梯度下降(Batch Gradient Descent)、随机梯度下降(Stochastic Gradient Descent)和小批量梯度下降(Mini-batch Gradient Descent),各自有其优点和适用场景。

二、梯度下降的数学原理

在这里插入图片描述

在深入研究梯度下降的各种实现之前,了解其数学背景是非常有用的。这有助于更全面地理解算法的工作原理和如何选择合适的算法变体。

代价函数(Cost Function)

在机器学习中,代价函数(也称为损失函数,Loss Function)是一个用于衡量模型预测与实际标签(或目标)之间差异的函数。通常用 ( J(\theta) ) 来表示,其中 ( \theta ) 是模型的参数。

在这里插入图片描述

梯度(Gradient)

在这里插入图片描述

更新规则

在这里插入图片描述

代码示例:基础的梯度下降更新规则

import numpy as np

def gradient_descent_update(theta, grad, alpha):
    """
    Perform a single gradient descent update.
    
    Parameters:
    theta (ndarray): Current parameter values.
    grad (ndarray): Gradient of the cost function at current parameters.
    alpha (float): Learning rate.
    
    Returns:
    ndarray: Updated parameter values.
    """
    return theta - alpha * grad

# Initialize parameters
theta = np.array([1.0, 2.0])
# Hypothetical gradient (for demonstration)
grad = np.array([0.5, 1.0])
# Learning rate
alpha = 0.01

# Perform a single update
theta_new = gradient_descent_update(theta, grad, alpha)
print("Updated theta:", theta_new)

输出:

Updated theta: [0.995 1.99 ]

在接下来的部分,我们将探讨梯度下降的几种不同变体,包括批量梯度下降、随机梯度下降和小批量梯度下降,以及一些高级的优化技巧。通过这些内容,你将能更全面地理解梯度下降的应用和局限性。


三、批量梯度下降(Batch Gradient Descent)

在这里插入图片描述

批量梯度下降(Batch Gradient Descent)是梯度下降算法的一种基础形式。在这种方法中,我们使用整个数据集来计算梯度,并更新模型参数。

基础算法

批量梯度下降的基础算法可以概括为以下几个步骤:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了批量梯度下降的基础实现。

import torch

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate
alpha = 0.01

# Number of iterations
n_iter = 1000

# Cost function: Mean Squared Error
def cost_function(X, y, theta):
    m = len(y)
    predictions = X @ theta
    return (1 / (2 * m)) * torch.sum((predictions - y) ** 2)

# Gradient Descent
for i in range(n_iter):
    J = cost_function(X, y, theta)
    J.backward()
    with torch.no_grad():
        theta -= alpha * theta.grad
    theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5780],
        [0.7721]], requires_grad=True)

批量梯度下降的主要优点是它的稳定性和准确性,但缺点是当数据集非常大时,计算整体梯度可能非常耗时。接下来的章节中,我们将探索一些用于解决这一问题的变体和优化方法。


四、随机梯度下降(Stochastic Gradient Descent)

在这里插入图片描述

随机梯度下降(Stochastic Gradient Descent,简称SGD)是梯度下降的一种变体,主要用于解决批量梯度下降在大数据集上的计算瓶颈问题。与批量梯度下降使用整个数据集计算梯度不同,SGD每次只使用一个随机选择的样本来进行梯度计算和参数更新。

基础算法

随机梯度下降的基本步骤如下:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了SGD的基础实现。

import torch
import random

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate
alpha = 0.01

# Number of iterations
n_iter = 1000

# Stochastic Gradient Descent
for i in range(n_iter):
    # Randomly sample a data point
    idx = random.randint(0, len(y) - 1)
    x_i = X[idx]
    y_i = y[idx]

    # Compute cost for the sampled point
    J = (1 / 2) * torch.sum((x_i @ theta - y_i) ** 2)
    
    # Compute gradient
    J.backward()

    # Update parameters
    with torch.no_grad():
        theta -= alpha * theta.grad

    # Reset gradients
    theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.5931],
        [0.7819]], requires_grad=True)

优缺点

SGD虽然解决了批量梯度下降在大数据集上的计算问题,但因为每次只使用一个样本来更新模型,所以其路径通常比较“嘈杂”或“不稳定”。这既是优点也是缺点:不稳定性可能帮助算法跳出局部最优解,但也可能使得收敛速度减慢。

在接下来的部分,我们将介绍一种折衷方案——小批量梯度下降,它试图结合批量梯度下降和随机梯度下降的优点。


五、小批量梯度下降(Mini-batch Gradient Descent)

在这里插入图片描述

小批量梯度下降(Mini-batch Gradient Descent)是批量梯度下降和随机梯度下降(SGD)之间的一种折衷方法。在这种方法中,我们不是使用整个数据集,也不是使用单个样本,而是使用一个小批量(mini-batch)的样本来进行梯度的计算和参数更新。

基础算法

小批量梯度下降的基本算法步骤如下:

在这里插入图片描述

代码示例

下面的Python代码使用PyTorch库演示了小批量梯度下降的基础实现。

import torch
from torch.utils.data import DataLoader, TensorDataset

# Hypothetical data (features and labels)
X = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]], requires_grad=True)
y = torch.tensor([[1.0], [2.0], [3.0], [4.0]])

# Initialize parameters
theta = torch.tensor([[0.0], [0.0]], requires_grad=True)

# Learning rate and batch size
alpha = 0.01
batch_size = 2

# Prepare DataLoader
dataset = TensorDataset(X, y)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Mini-batch Gradient Descent
for epoch in range(100):
    for X_batch, y_batch in data_loader:
        J = (1 / (2 * batch_size)) * torch.sum((X_batch @ theta - y_batch) ** 2)
        J.backward()
        with torch.no_grad():
            theta -= alpha * theta.grad
        theta.grad.zero_()

print("Optimized theta:", theta)

输出:

Optimized theta: tensor([[0.6101],
        [0.7929]], requires_grad=True)

优缺点

小批量梯度下降结合了批量梯度下降和SGD的优点:它比SGD更稳定,同时比批量梯度下降更快。这种方法广泛应用于深度学习和其他机器学习算法中。

小批量梯度下降不是没有缺点的。选择合适的批量大小可能是一个挑战,而且有时需要通过实验来确定。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/256894.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

VueStu01-Vue是什么

1.概念 Vue 是一个 用于构建用户界面 的 渐进式 框架 。 2.构建用户界面 基于数据渲染出用户看到的页面。 3.渐进式 Vue的学习是循序渐进的,可以学一点用一点,不必全部学完才能用。哪怕你只学了 声明式渲染 这一个小部分的内容,你也可以完成…

商业加速终极指南:创业家的10个关键步骤

多年来,前沿的初创企业一直在塑造整个行业,引入新技术、更精简的业务模式以及深受客户喜爱的开箱即用的产品。这不仅仅是一种短暂的趋势,而是我们的商业格局发生了根本性的转变,在这种情况下,率先抓住新的增长机遇可以…

ubuntu无 root 权限安装 screen

网上的方法主要是如下图的方法,源码安装,但是我一直 make install失败显示没有权限 然后选择放弃,然后随便试了一下方法 2,成功 方法 1 方法 2 pip3 install screen结果:

【软件工程】软件工程复习题库2023

🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 SpringCloud MybatisPlus JVM 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 软件工程复习题库 一、选择题二、填空题三、判断题四…

AtCoder Beginner Contest 333 A~D(E,F更新中...)

A.Three Threes&#xff08;循环&#xff09; 题意&#xff1a; 给出一个正整数 N N N&#xff0c;要求输出 N N N个 N N N 分析&#xff1a; 按要求输出即可 代码&#xff1a; #include <bits/stdc.h> using namespace std;void solve() {int n;cin >> n;fo…

【C语言】自定义类型:结构体深入解析(一)

&#x1f308;write in front :&#x1f50d;个人主页 &#xff1a; 啊森要自信的主页 ✏️真正相信奇迹的家伙&#xff0c;本身和奇迹一样了不起啊&#xff01; 欢迎大家关注&#x1f50d;点赞&#x1f44d;收藏⭐️留言&#x1f4dd;>希望看完我的文章对你有小小的帮助&am…

springboot 学生信息管理

介绍 一个学生信息管理后台&#xff0c;适用于大作业&#xff0c;课设等 软件架构 springbootmybatisthymeleaf &#xff08;前后端未分离&#xff09; 安装教程 注&#xff1a;mysql数据库要8.0以上&#xff0c;&#xff0c;本地mysql新建一个名为 student 的空数据库&am…

open3d bug:pcd转txt前后位姿发生改变

1、open3d bug&#xff1a;pcd转txt前后位姿发生改变 open3d会对原有结果进行一个微小位姿变换 import open3d as o3d import numpy as np# 读取PCD点云文件 pcd o3d.io.read_point_cloud(/newdisk/darren_pty/zoom_centered_s2.pcd)# 获取点云坐标 points pcd.points# 指定…

4.1 媒资管理模块 - Nacos与Gateway搭建

文章目录 媒资管理模块 - 媒资项目搭建一、需求分析1.1 介绍1.2 数据模型1.3 分析网关 二、 搭建Nacos2.1 服务发现中心2.2.1 Maven2.2.2 配置Nacos 2.2 配置中心2.2.1 介绍2.2.2 Maven 坐标2.2.3 配置 content-api 工程2.2.4 配置 content-service 工程2.2.5 配置 system-api …

如何安装运行Wagtail并结合cpolar内网穿透实现公网访问网站界面

文章目录 前言1. 安装并运行Wagtail1.1 创建并激活虚拟环境 2. 安装cpolar内网穿透工具3. 实现Wagtail公网访问4. 固定的Wagtail公网地址 前言 Wagtail是一个用Python编写的开源CMS&#xff0c;建立在Django Web框架上。Wagtail 是一个基于 Django 的开源内容管理系统&#xf…

grafana基本使用

一、安装grafana 1.下载 官网下载地址&#xff1a; https://grafana.com/grafana/download官网包的下载地址&#xff1a; yum install -y https://dl.grafana.com/enterprise/release/grafana-enterprise-10.2.2-1.x86_64.rpm官网下载速度非常慢&#xff0c;这里选择清华大…

Zotero 7 安装并彻底解决“无法安装插件。它可能无法与该版本的 Zotero 兼容“。以及解决“此翻译引擎不可用,可能是密钥错误“的问题

Zotero 7 安装并彻底解决"无法安装插件。它可能无法与该版本的 Zotero 兼容"。以及解决"此翻译引擎不可用&#xff0c;可能是密钥错误"的问题 &#xff01;&#xff01;&#xff01;不要直接在Zotero 6上安装翻译插件&#xff0c;将会版本不兼容&#xff0…

【python】飞机大战

import pygame import random# 敌机出现事件 ENEMY_EVENT pygame.USEREVENT # 发射子弹事件 FIRE pygame.USEREVENT 1class GameSprite(pygame.sprite.Sprite):def __init__(self, image_name, speed1):super().__init__()# 定义对象的属性self.image pygame.image.load(im…

使用IDEA创建springboot依赖下载很慢,解决方法

显示一直在resolving dependencies&#xff0c;速度很慢 原因&#xff1a;maven会使用远程仓库来加载依赖&#xff0c;是一个国外的网站&#xff0c;所以会很慢。应该使用阿里云的镜像&#xff0c;这样速度会提升很多。 步骤&#xff1a;1.右击pom.xml&#xff0c;选择"m…

MFC 窗口创建过程与消息处理

目录 钩子简介 代码编写 窗口创建过程分析 消息处理 钩子简介 介绍几个钩子函数&#xff0c;因为它们与窗口创建工程有关 安装钩子函数 HHOOK SetWindowsHookExA([in] int idHook,[in] HOOKPROC lpfn,[in] HINSTANCE hmod,[in] DWORD dwThreadId ); 参数说明…

前端常见面试题之html和css篇

文章目录 一、html1. 如何理解html语义化2. 说说块级元素和内联元素的区别 二、css1. 盒模型的宽度offsetWidth如何计算2. box-sizing:border-box有什么用3. margin的纵向重叠问题4. 谈谈你对BFC的理解和应用5. 清除浮动有哪些方式6. 使用flex布局实现骰子37.position的absolut…

喜报!巨蟹数科荣获国家“高新技术企业”认定!

根据《高新技术企业认定管理办法》&#xff08;国科发火〔2016〕32 号&#xff09;和《高新技术企业认定管理工作指引》&#xff08;国科发火〔 2016〕195号&#xff09;有关规定&#xff0c;经省高新技术企业认定管理机构组织企业申请、专家评审等程序&#xff0c;并经全国高新…

Linux汇编语言编程-机器语言

机器语言是处理器看到的语言。 在获取-执行周期【fetch-execute cycle 】中获取的字节是机器码的字节。汇编语言可以定义为一种使程序员能够控制机器码的语言。汇编语言指定机器码。如果不熟悉机器语言&#xff0c;汇编语言的这一特性是不明显的。本章介绍 x86 机器码的主要特征…

一句话木马是什么?代码实例及绕过方法

一句话木马是指一种短小的、通常只有一行代码的恶意软件&#xff0c;它被用来在目标系统中执行攻击者的命令或代码。这种类型的木马通常通过各种途径被注入到目标系统中&#xff0c;一旦成功运行&#xff0c;攻击者就可以远程控制受感染的系统。一句话木马的目的包括窃取敏感信…

RPC(3):HttpClient实现RPC之GET请求

1HttpClient简介 在JDK中java.net包下提供了用户HTTP访问的基本功能&#xff0c;但是它缺少灵活性或许多应用所需要的功能。 HttpClient起初是Apache Jakarta Common 的子项目。用来提供高效的、最新的、功能丰富的支持 HTTP 协议的客户端编程工具包&#xff0c;并且它支持 H…