1.4.1机器学习——梯度下降+α学习率大小判定

1.4.1梯度下降

4.1、梯度下降的概念

※【总结一句话】:系统通过自动的调节参数w和b的值,得到最小的损失函数值J。

如下:是梯度下降的概念图。

  • 我们有一个损失函数 J(w,b),包含两个参数w和b(你可以想象成J(w,b) = w*x + b) ,我们想要找到最合适的w和b,尝试最小化损失函数 J(w,b)的值

  • ”梯度下降“:梯度下降(gradient descent)在机器学习中应用十分的广泛,不论是在线性回归还是Logistic回归中,它的主要目的是通过迭代找到目标函数的最小值,或者收敛到最小值

  • 它还适用于有多个参数的(如图:w1~ wn,b) :梯度下降的任务是调节w1 ~ wn 和 参数 b 的值,最小化损失函数的值

  • 梯度下降法是用来计算损失函数最小值的。它的思路很简单,想象在山顶放了一个球,一松手它就会顺着山坡最陡峭的地方滚落到谷底:

=

4.2、梯度下降概述

在这里插入图片描述

  • J(w,b) = wx + b : 我们开始的时候w和b可以设置为任意值。(这里我们设置 w=0 , b = 0)
  • “梯度下降”要做的就是:通过迭代不断的调整 w和b 的值,去尽量的降低损失函数J(w,b)
  • 直到我们的损失函数J(w,b) 达到/接近 “谷底”/最小值
  • 【※注意】 :损失函数可能不仅仅是一个如右图的抛物线,也可能是“高尔夫球场图”,这样的话minimum最小值的个数就不仅仅是一个了

在这里插入图片描述
“高尔夫球场图”

  • 图中XYZ轴分别代表:W,b , 损失函数 J(w,b)的值

  • 他不是一个 “平方误差损失函数”(线性回归使用 平方误差损失函数)图,因为 “平方误差损失函数”常常以抛物线 / 吊床 的形状结束

  • 这个小人站在哪的起始点取决于你选择的参数w 和 b的起始值

  • 两个谷底最低点都是局部最优解

1.4.2 理解梯度下降+梯度下降偏导的意义+α学习率

  • 图解定义:

在这里插入图片描述

  • 对梯度下降公式算法的解读:
    • 其中参数 w和b 是一个同步进行 修改/微调 的。(同步进行,就是计算w和b的更新的两个公式是同步计算 的)
    • 重复以上两个公式,直到达到结果收敛为止
    • 在收敛重复这个公式的过程中,在计算tmp_w 和tmp_b中 的w是上一次迭代的w的值,只有tmp_w 和tmp_b都计算完成才会进行更新
  • 你可能存在疑问:为什么不要用中间值 tmp_w 和wmp_b,直接就是 w = w - α…;b= b - α…不行吗???
    • 答:你提问的很好。但是在这个梯度下降w,b迭代执行过程中,w和b同时参与了w,b的计算。
    • 简单的说如下图,以W为例,加入在第i次迭代的情况下:
      • 在第i次迭代的情况下执行完第一步w = w -0.2 后(假设这时候 J(w,b) = 0.2),
      • 执行第二步 b = b - J(w,b) 的时候,你想放 执行w = w -0.2 更新之前的值 or 更新之后的值???
      • 每一次迭代,在这第i次中,w 和 b对应的是第i -1次的值,在第i次迭代结束时才一起更新
      • 如果不借用中间值 tmp_w 和 tmp_b,那么w 和 b 的最终值 就会受到影响
      • 在这里插入图片描述

为了方便理解α(学习率),我们这里把b设置为0,只考虑有w的一维情况

  • 其中α:称为学习率(粉色框内容),通常这个值在 0~1 之间,是控制w和b的调参步长

    • 合适:在这里插入图片描述

    过程解释:

    红框中:是求w的偏导,这里求完偏导知道是>0

    α永远是 大于0 小于1 的数

    然后w = w - α(正数) ====> w变小

    所以就实现了微调

    在这里插入图片描述

    当W的偏导< 0 时:
    红框中:是求w的偏导,这里求完偏导知道是<0

    α永远是 大于0 小于1 的数

    然后w = w - α(负数) ====> w变大

    所以就实现了微调.

    在这里插入图片描述

    • 如果α非常大:那么这个是一个非常激进的梯度下降的过程,步长会非常大,反而会越过谷底,不断上升:+

    • 动图地址
      图片: ![Alt](https://img-blog.csdnimg.cn/img_convert/b8a2549083eae7aa3a33cc80c6cee5dd.webp?x-oss-process=image/format,png)

    吴恩达老师的解释:

    ​ 如果α过大/过大的步长:

    ​ 会导致超过 , 并且从来不会达到最小值。

    ​ 不能直线收敛,甚至是发散

在这里插入图片描述

  • 如果α非常小:比如 α = 0.0001 就过于小了,迭代 20 次后离谷底还很远,实际上 100 次后都无法到达谷底:

    梯度下降会起作用,但是需要很长的时间

  • 动图地址
    在这里插入图片描述

吴恩达的解释:

​ 如果α 步长太小,会实现收敛,但是这个收敛的过程会很慢很慢

  • **总结:**不同的步长α ,随着迭代次数的增加,会导致被优化函数f(x) 的值有不同的变化:

img

关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

1.4.3 用于线性回归的梯度下降

在这里插入图片描述

  • 公式推到来源:

    在这里插入图片描述

  • 那么 线性回归梯度下降就是:

重复对w 和 b 执行更新直到收敛

在这里插入图片描述

  • 当使用线性回归的平方误差损失函数时,全局只有一个最低点

在这里插入图片描述

  • 当使用非平方误差函数时,就是非线性回归梯度下降的时候就会出现 >=1 的局部最优解

在这里插入图片描述

1.4.4 线性回归的梯度下降的应用

  • 等高线最中心那个圈是 损失函数值最小的点,Cost值越小,说明线性回归的拟合越好,直到我们达到全局最小值
  • 比如当 Size in feet 是1250 ,对应的回归预测值是 250 K

在这里插入图片描述

  • “批量梯度下降”:每一次的梯度下降使用的是全部的训练的数据:
  • 当计算 w 和 b 的偏导时,我们从 1 ~ m 所有的数据都计算上,然后相加求平均

在这里插入图片描述

※1.4.5 线性回归的梯度下降函数的代码实现

1、求损失函数

求损失函数代码如下:

def compute_cost(x, y, w, b): 
    """
    Computes the cost function for linear regression.
    
    Args:
      x (ndarray (m,)): Data, m examples 
      y (ndarray (m,)): target values
      w,b (scalar)    : model parameters  
    
    Returns
        total_cost (float): The cost of using w,b as the parameters for linear regression
               to fit the data points in x and y
    """
    # number of training examples
    m = x.shape[0] 
    
    cost_sum = 0 
    for i in range(m): 
        f_wb = w * x[i] + b   
        cost = (f_wb - y[i]) ** 2  
        cost_sum = cost_sum + cost  
    total_cost = (1 / (2 * m)) * cost_sum  

    return total_cost

2、求偏导 / 梯度

※求偏导代码如下:

def compute_gradient(x, y, w, b): 
    """
    Computes the gradient for linear regression 
    Args:
      x (ndarray (m,)): Data, m examples 
      y (ndarray (m,)): target values
      w,b (scalar)    : model parameters  
    Returns
      dj_dw (scalar): The gradient of the cost w.r.t. the parameters w
      dj_db (scalar): The gradient of the cost w.r.t. the parameter b     
     """
    
    # Number of training examples
    m = x.shape[0]    
    dj_dw = 0
    dj_db = 0
    
    for i in range(m):  
        f_wb = w * x[i] + b 
        dj_dw_i = (f_wb - y[i]) * x[i] 
        dj_db_i = f_wb - y[i] 
        dj_db += dj_db_i
        dj_dw += dj_dw_i 
    dj_dw = dj_dw / m 
    dj_db = dj_db / m 
        
    return dj_dw, dj_db

3、梯度下降函数

def gradient_descent(x, y, w_in, b_in, alpha, num_iters, cost_function, gradient_function): 
    """
    Performs gradient descent to fit w,b. Updates w,b by taking 
    num_iters gradient steps with learning rate alpha
    
    Args:
      x (ndarray (m,))  : Data, m examples 
      y (ndarray (m,))  : target values
      w_in,b_in (scalar): initial values of model parameters  
      alpha (float):     Learning rate
      num_iters (int):   number of iterations to run gradient descent
      cost_function:     function to call to produce cost
      gradient_function: function to call to produce gradient
      
    Returns:
      w (scalar): Updated value of parameter after running gradient descent
      b (scalar): Updated value of parameter after running gradient descent
      J_history (List): History of cost values
      p_history (list): History of parameters [w,b] 
      """
    
    w = copy.deepcopy(w_in) # avoid modifying global w_in
    # An array to store cost J and w's at each iteration primarily for graphing later
    J_history = []
    p_history = []
    b = b_in
    w = w_in
    
    for i in range(num_iters):
        # Calculate the gradient and update the parameters using gradient_function
        dj_dw, dj_db = gradient_function(x, y, w , b)     

        # Update Parameters using equation (3) above
        b = b - alpha * dj_db                            
        w = w - alpha * dj_dw                            

        # Save cost J at each iteration
        if i<100000:      # prevent resource exhaustion 
            J_history.append( cost_function(x, y, w , b))
            p_history.append([w,b])
        # Print cost every at intervals 10 times or as many iterations if < 10
        if i% math.ceil(num_iters/10) == 0:
            print(f"Iteration {i:4}: Cost {J_history[-1]:0.2e} ",
                  f"dj_dw: {dj_dw: 0.3e}, dj_db: {dj_db: 0.3e}  ",
                  f"w: {w: 0.3e}, b:{b: 0.5e}")
 
    return w, b, J_history, p_history #return w and J,w history for graphing
※※关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/306985.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

竞赛保研 基于深度学习的动物识别 - 卷积神经网络 机器视觉 图像识别

文章目录 0 前言1 背景2 算法原理2.1 动物识别方法概况2.2 常用的网络模型2.2.1 B-CNN2.2.2 SSD 3 SSD动物目标检测流程4 实现效果5 部分相关代码5.1 数据预处理5.2 构建卷积神经网络5.3 tensorflow计算图可视化5.4 网络模型训练5.5 对猫狗图像进行2分类 6 最后 0 前言 &#…

GCF:在线市场异质治疗效果估计的广义因果森林

英文题目&#xff1a;GCF: Generalized Causal Forest for Heterogeneous Treatment Effects Estimation in Online Marketplace 中文题目&#xff1a;GCF&#xff1a;在线市场异质治疗效果估计的广义因果森林 单位&#xff1a;滴滴&美团 时间&#xff1a;2022 论文链接…

postgresql迁移到mysql

1.工具方法&#xff1a;Navicat Premium16 2. 手工方法&#xff1a; 迁移流程 下面是将 Postgresql 数据库迁移到 MySQL 的步骤流程&#xff1a; 步骤描述1. 创建MySQL表结构在MySQL中创建与Postgresql中的表结构相同的表2. 导出Postgresql数据将Postgresql中的数据导出为SQ…

【学术会议】第三届神经计算青年研讨会 学习笔记

第三届神经计算青年研讨会 学习笔记 会议时间&#xff1a;2024-1-6至2024-1-7 会议地点&#xff1a;电子科技大学 会议介绍&#xff1a; 为提升我国神经计算⻘年研究队伍的学术⽔平和国际影响⼒&#xff0c;研讨会主题涵盖&#xff1a;神经系统建模与模拟、脑机接⼝与类脑智能、…

Rabbitmq 消息可靠性保证

1、简介 消息的可靠性投递就是要保证消息投递过程中每一个环节都要成功&#xff0c;本文详细介绍两个环节的消息可靠性传递方式&#xff1a;1&#xff09;、消息传递到交换机的 confirm 模式&#xff1b;2&#xff09;、消息传递到队列的 Return 模式。 消息从 producer 到 ex…

SD-WAN跨境专线:优化企业网络效率与安全性

企业网络通信的革新已经成为跨境业务发展中的重要一环。新加坡作为国际商业中心&#xff0c;吸引着众多企业在此开设分支机构。然而&#xff0c;传统的跨境网络连接方式常常存在着诸多问题&#xff0c;例如网络延迟高、丢包率大等。在这些挑战面前&#xff0c;SD-WAN&#xff0…

【leetcode】力扣算法之两数相加【中等难度】

题目描述 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外&#xff0c;这两个数都…

助力企业出海,Ogcloud提供一站式网络解决方案

随着全球市场的开放和跨境电商的蓬勃发展&#xff0c;越来越多企业开始在海外拓展业务。但在这过程中&#xff0c;各种各样的网络问题成为企业出海的阻碍。Ogcloud凭借其卓越的技术实力和丰富的经验&#xff0c;为全球业务的公司提供全面的网络解决方案&#xff0c;包括SD-WAN、…

C语言如何提高程序的可读性?

一、问题 可读性是评价程序质量的一个重要标准&#xff0c;直接影响到程序的修改和后期维护&#xff0c;那么如何提高程序的可读性呢? 二、解答 提高程序可读性可以从以下几方面来进行。 &#xff08;1&#xff09;C程序整体由函数构成的。 程序中&#xff0c;main()就是其中…

C#,数值计算,高斯消元法与列主元消元法的源代码及数据动态可视化

高斯消元法&#xff01; 一、高斯消元法 Gaussian Elimination 高斯消元法&#xff08;或译&#xff1a;高斯消去法&#xff09;&#xff0c;是线性代数中的一个常用算法&#xff0c;常用于求解线性方程组和矩阵的逆。 本程序的运行效果&#xff1a; 1、高斯消元法的动画演示…

Linux学习19 在Ubuntu命令行下使用新硬盘

Linux学习19 在Ubuntu命令行下使用新硬盘 一、准备环境二、检测硬盘三、对新硬盘格式化1. 创建分区2. 格式化 三、挂载操作1. 创建挂载点2. 挂载硬盘3. 验证挂载 四、实现永久挂载&#xff08;可选&#xff09;1. 文件结构与内容&#xff1a;2. /etc/fstab 的重要性与作用3. 修…

运算符详解

1、定义 定义&#xff1a;运算符是一种特殊的符号&#xff0c;用于表示数据的运算、赋值和比较等 运算符的分类 1&#xff09;按功能分类&#xff1a; 1&#xff09;算术运算符&#xff08;7个&#xff09;​​​​​​​ 、-、*、/、%、、--​​​​​​​ 2&#xff0…

自动化测试框架pytest系列之基础概念介绍(一)

如果你要打算学习自动化测试 &#xff0c;无论是web自动化、app自动化还是接口自动化 &#xff0c;在学习的道路上&#xff0c;你几乎会遇到pytest这个测试框架&#xff0c;因为自动化编写没有测试框架&#xff0c;根本玩不了 。 如果你已经是一位自动化测试人员 &#xff0c;…

工程监测领域振弦采集仪的数据处理与分析方法探讨

工程监测领域振弦采集仪的数据处理与分析方法探讨 在工程监测领域&#xff0c;振弦采集仪是常用的一种设备&#xff0c;用于测量和记录结构物的振动数据。数据处理和分析是使用振弦采集仪得到的数据的重要环节&#xff0c;可以帮助工程师了解结构物的振动特性&#xff0c;评估…

FRPS配置服务端(腾讯云)、客户端(PC电脑Windows、树莓派Debian)并设置虚拟域名

1.服务端&#xff08;腾讯云&#xff09;&#xff1a;frps.ini [common] bind_port 7000 vhost_http_port8080 vhost_https_port44344 dashboard_port 7500 privilege_token your_password subdomain_host example.com use_encryption true encryption_method tls dashb…

面试宝典进阶之关系型数据库面试题

D1、【初级】你都使用过哪些数据库&#xff1f; &#xff08;1&#xff09;MySQL&#xff1a;开源数据库&#xff0c;被Oracle公司收购 &#xff08;2&#xff09;Oracle&#xff1a;Oracle公司 &#xff08;3&#xff09;SQL Server&#xff1a;微软公司 &#xff08;4&#…

RabbitMQ发布确认

1.单个确认 单个确认发布是一种同步确认发布方式&#xff0c;也就是发布一个消息后只有它被确认发布&#xff0c;后续的消息才能继续发布。 缺点:发布速度特别慢,因为若是没有确认发布的消息会阻塞所有后续消息的发布 package com.hong.rabbitmq5;import com.hong.utils.Rabb…

微信Windows版如何从旧电脑迁移聊天记录到新电脑

我们都知道&#xff0c;换手机的话&#xff0c;如果是同品牌&#xff0c;可以用该品牌的换机助手将微信资料传输给新手机&#xff0c;或者用微信PC端的迁移与备份功能来实现 那么换电脑或者重装系统呢&#xff1f;我们可以通过转移文件夹的方式进行 1、登录PC微信&#xff0c;…

选择最适合您的10个在线PS类型工具

Adobe Photoshop 多年来&#xff0c;Photoshop一直是设计师的首选。PS的功能无疑是非常强大的。设计师可以使用它来制作从简单的网页到复杂的移动应用程序设计。学习PS的基本知识很容易&#xff0c;但学习PS的所有技能都需要大量的时间和精力。当然&#xff0c;您也可以选择体…

哈希表-散列表数据结构

1、什么是哈希表&#xff1f; 哈希表也叫散列表&#xff0c;哈希表是根据关键码值(key value)来直接访问的一种数据结构&#xff0c;也就是将关键码值(key value)通过一种映射关系映射到表中的一个位置来加快查找的速度&#xff0c;这种映射关系称之为哈希函数或者散列函数&…