pytorch(二)梯度下降算法

文章目录

    • 优化问题
    • 梯度下降
    • 随机梯度下降

在线性模型训练的时候,一开始并不知道w的最优值是什么,可以使用一个随机值来作为w的初始值,使用一定的算法来对w进行更新

优化问题

寻找使得目标函数最优的权重组合的问题就是优化问题

梯度下降

通俗的讲,梯度下降就是使得梯度往下降的方向,也就是负方向走。一般来说,梯度往正方向走,表示梯度大于0,,表示函数是往递增方向走,而这里需要的是找最低点,最低点一定是在往下走,所以这里的梯度要取负号。梯度下降更新权重的公式如下(注意是减),α表示学习率:
在这里插入图片描述
梯度下降算法属于贪心算法的一种,它的权重的更新,每一次都是朝着梯度下降最快的方向进行更新,当梯度为0的时候,算法收敛,权重不再更新。梯度下降可能得到的是一个局部最优解(非凸函数)。

在深度学习中,尽管梯度下降算法会陷入局部最优,但是在深度学习中梯度下降算法依旧广泛应用:在之前大家认为,深度学习的目标函数会出现很多的局部最优解,但是实际上,其损失函数并没有很多的局部最优解。但是深度学习的损失函数会存在很多的鞍点(也就某一点上梯度为0,从一个切面上看是最小值,从另一个切面看是最大值的点,如下图 ),导致权重无法继续迭代,可以使用动量法来解决鞍点问题。

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现:

  • 要求:模拟梯度下降算法,计算在x_data、y_data数据集下,y=w*x模型找到合适的w的值。

  • 和第二课不同的是,第一课的w是我们认为设定的,通过一个for循环使得w迭代,这一次需要的是通过模型找到适合的w

import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

w=1.0
# 求预测值
def forward(x):
    return x*w

# 损失函数
def cost(xs,ys):
    costs=0
    # 用zip打包成元祖,并返回元祖组成的列表
    for x,y in zip(xs,ys):
        y_pred=forward(x)
        costs+=(y_pred-y)**2       
    return costs/len(xs)

# 计算梯度
def gradient(xs,ys):
    grad=0
    for x,y in zip(xs,ys):
        grad+=2*x*(x*w-y)
    return grad/len(xs)

cost_list=[]
epoch_list=[]
print('predict before training',forward(4))

for epoch in range(200):
    cost_val=cost(x_data,y_data)
    grad_val=gradient(x_data,y_data)
    w-=0.01*grad_val
    
    epoch_list.append(epoch)
    cost_list.append(cost_val)
    
    print('epoch:',epoch,'w=',w,'loss=',cost_val)

print('predict after training:',forward(4))

# 画图
plt.plot(epoch_list,cost_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

部分结果截图
在这里插入图片描述

随机梯度下降

在上面的梯度下降中,求损失的时候用的是全部数据的平均损失作为更新的依据。而随机梯度下降是在全部的数据中随机选择一个作为更新的依据。使用随机梯度下降可以有效的避开鞍点

在这里插入图片描述

import matplotlib.pyplot as plt

x_data=[1.0,2.0,3.0]
y_data=[2.0,4.0,6.0]

w=1.0
# 求预测值
def forward(x):
    return x*w

# 损失函数
def cost(x,y):
    y_pred=forward(x)
    return (y_pred-y)**2

# 计算梯度
def gradient(x,y):
    return 2*x*(x*w-y)

cost_list=[]
epoch_list=[]
print('predict before training',forward(4))

for epoch in range(200):
    for x,y in zip(x_data,y_data):
        grad_val=gradient(x,y)
        print('\tgrad:',x,y,w)
        w-=0.01*grad_val
        print('\tgrad:',x,y,w,'\n')
        cost_val=cost(x,y)
    
    epoch_list.append(epoch)
    cost_list.append(cost_val)
    print('epoch:',epoch,'w=',w,'loss=',cost_val)

print('predict after training:',forward(4))

# 画图
plt.plot(epoch_list,cost_list)
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

结果截图:
在这里插入图片描述

在第一个梯度下降中,样本x与样本x+1之间没有时序关系,我们计算的事他们的总的损失,这些运行时可以并行运行的。但是在第二个随机梯度下降中,我们是先计算了x再计算的x+1,数据之间存在先后的关系,有以来关系,不能用并行计算。所以梯度下降可以有效提高运算的效率,而随机梯度下降可以获得一个优异的w。把以上两种方法折中,就产生了小批量随机梯度下降,取了一种性能与时间复杂度之间的折中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/336059.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【极问系列】springBoot集成elasticsearch出现Unable to parse response body for Response

【极问系列】 springBoot集成elasticsearch出现Unable to parse response body for Response 如何解决? 一.问题 #springboot集成elasticsearch组件,进行增删改操作的时候报异常Unable to parse response body for Response{requestLineDELETE /aurora-20240120/…

编译和链接(翻译环境:预编译+编译+汇编+链接​、运行环境)

一、翻译环境和运行环境​ 在ANSI C的任何一种实现中,存在两个不同的环境。​ 第1种是翻译环境,在这个环境中源代码被转换为可执行的机器指令。​ 第2种是执行环境,它用于实际执行代码。​ VS中编译器:cl.exe ;Linux中…

基于无人机的消防灭火系统设计

摘要:人类社会的进步,使火灾变得更加频繁且越来越复杂,随着这些年无人机技术的发展,将无人机技术融入消防灭火逐渐变成必然。消防救援采用无人机主要有以下几点原因:一、对火场及火场周围环境信息十分匮乏,…

LaTeX-OCR安装教程

一. 通用安装步骤 1.前置应用 安装LaTeX-OCR首先需要安装Python。在系统自带的应用商店Microsoft Store搜索Python,点击最新版本Python 3.12下载即可。 2.运行powershell Win11按底部状态栏windows徽标在搜索框内搜索 powershell 或者按快捷键 “win 键 R” &am…

Hack The Box-Sherlocks-Tracer

靶场介绍 A junior SOC analyst on duty has reported multiple alerts indicating the presence of PsExec on a workstation. They verified the alerts and escalated the alerts to tier II. As an Incident responder you triaged the endpoint for artefacts of interest…

毫米波雷达4D点云生成(基于实测数据)

本期文章分享TI毫米波雷达实测4D点云生成的代码,包含距离、速度、水平角度、俯仰角度,可用于日常学习。 处理流程包含:数据读取和解析、MTI、距离估计、速度估计、非相干累积、2D-CFAR、水平角估计、俯仰角估计、点云生成、坐标转换等内容。…

用MATLAB函数在图表中建立模型

本节介绍如何使用Stateflow图表创建模型,该图表调用两个MATLAB函数meanstats和stdevstats。meanstats计算平均值,stdevstats计算vals中值的标准偏差,并将它们分别输出到Stateflow数据平均值和stdev。 请遵循以下步骤: 1.使用以下…

04 SpringBoot整合Druid/MyBatis/事务/AOP+打包项目

整合Druid 项目结构&#xff1a; 引入依赖&#xff1a; <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaL…

Docker 部署考核

Docker安装 安装必要的系统工具 yum install -y yum-utils device-mapper-persistent-data lvm2 添加docker-ce安装源&#xff1a; yum-config-manager --add-repo https://download.docker.com/linux/centos/docker-ce.repo 配置阿里云Docker Yum源: yum-config-manager --ad…

4个值得使用的免费爬虫工具

在信息时代&#xff0c;数据的获取对于各行业都至关重要。而在数据采集的众多工具中&#xff0c;免费的爬虫软件成为许多用户的首选。本文将专心分享四款免费爬虫工具&#xff0c;突出介绍其中之一——147采集软件&#xff0c;为您揭示这些工具的优势和应用&#xff0c;助您在数…

使用 crypto-js 进行 AES 加解密操作

在前端开发中&#xff0c;数据的加密和解密是为了保障用户隐私和数据的安全性而常见的任务。AES&#xff08;Advanced Encryption Standard&#xff09;是一种对称密钥加密算法&#xff0c;被广泛用于保护敏感信息的传输和存储。本文将介绍 AES 加解密的基本原理&#xff0c;并…

超过GPT3.5?Mixtral 8*7B 模型结构分析

Datawhale干货 作者&#xff1a;宋志学&#xff0c;Datawhale成员 前言 2023年12月11日&#xff0c;Mistral AI团队发布了一款高质量的稀疏专家混合模型Mixtral 8x7B。 Mistral AI继续致力于向开发者社区提供最优秀的开放模型。在人工智能领域向前发展&#xff0c;需要采取超越…

Node.js 使用 cors 中间件解决跨域问题

CORS 跨域资源共享 什么是 CORS cors 是 Express 的一个第三方中间件。通过安装和配置 cors 中间件&#xff0c;可以很方便地解决跨域问题。 CORS &#xff08;Cross-Origin Resource Sharing&#xff0c;跨域资源共享&#xff09;由一系列 HTTP 响应头组成&#xff0c;这些…

统计学-R语言-4.6

文章目录 前言列联表条形图及其变种---单式条形图条形图及其变种---帕累托图条形图及其变种---复式条形图条形图及其变种---脊形图条形图及其变种---马赛克图饼图及其变种---饼图饼图及其变种---扇形图直方图茎叶图箱线图小提琴图气泡图总结 前言 本篇文章是对数据可视化的补充…

Vulnhub-TECH_SUPP0RT: 1渗透

文章目录 一、前言1、靶机ip配置2、渗透目标3、渗透概括 开始实战一、信息获取二、使用smb服务获取信息三、密码破解四、获取webshell五、反弹shell六、web配置文件获取信息七、提权 一、前言 由于在做靶机的时候&#xff0c;涉及到的渗透思路是非常的广泛&#xff0c;所以在写…

手把手教你如何快速定位bug,如何编写测试用例,快来观摩......

手把手教你如何快速定位bug,如何编写测试用例,快来观摩......手把手教你如何快速定位bug,如何编写测试用例,快来观摩......作为一名测试人员如果连常见的系统问题都不知道如何分析&#xff0c;频繁将前端人员问题指派给后端人员&#xff0c;后端人员问题指派给前端人员&#xf…

Rust - 可变引用和悬垂引用

可变引用 在上一篇文章中&#xff0c;我们提到了借用的概念&#xff0c;将获取引用作为函数参数称为 借用&#xff08;borrowing&#xff09;&#xff0c;通常情况下&#xff0c;我们无法修改借来的变量&#xff0c;但是可以通过可变引用实现修改借来的变量。代码示例如下&…

超详细的 pytest 钩子函数 —— 之初始钩子和引导钩子来啦!

前几篇文章介绍了 pytest 点的基本使用&#xff0c;学完前面几篇的内容基本上就可以满足工作中编写用例和进行自动化测试的需求。从这篇文章开始会陆续给大家介绍 pytest 中的钩子函数&#xff0c;插件开发等等。 仔细去看过 pytest 文档的小伙伴&#xff0c;应该都有发现 pyt…

爬虫接口获取外汇数据(汇率,外汇储备,贸易顺差,美国CPI,M2,国债利率)

akshare是一个很好用的财经数据api接口&#xff0c;完全免费&#xff01;&#xff01;和Tushare不一样。 除了我标题显示的数据外&#xff0c;他还提供各种股票数据&#xff0c;债券数据&#xff0c;外汇&#xff0c;期货&#xff0c;宏观经济&#xff0c;基金&#xff0c;银行…

木塑地板行业分析:市场正展现出自身良好的发展势头

木塑地板是一种新型环保型木塑复合材料产品&#xff0c;在生产中、高密度纤维板过程中所产生的木酚&#xff0c;加入再生塑料经过造粒设备做成木塑复合材料&#xff0c;然后进行挤出生产组做成木塑地板。 木塑地板能够跃居行业主流&#xff0c;得到众企业和消费者的追捧&#x…