Pytorch反向传播算法(Back Propagation)

一:revise

我们在最开始提出一个线性模型。

x为我们的输入,w为权重。相乘的结果是我们对y的预测值。

那我们在训练时就是对这个权重w进行更新,就需要用到上一章提到的梯度下降算法,不断更新w。但是此时注意不是用y的预测值对w进行求导,应该是使用loss损失值对w权重进行求导,因为我们需要得到最小的loss。

对于简单的模型我们可以使用解析式去解决,但是对于复杂的模型的w会很难算。

最左边的5个⚪代表的是5个输入,右边的5个⚪代表的是5个输出,中间的每个⚪都是隐藏的值设为H。中间的4列我们如果用向量表示,分别都是一个六维的向量,而我们想用输入的五维向量得到六维向量,就需要使用输入的五维向量乘上6x5的矩阵才能得到这个六维的向量,这就意味着我们需要30个不同的w,其实也就对应着我们图片上的线,每条线都代表需要一个w。

所以此时如果要是写解析式就是一件非常复杂的事情,因此我们希望做一种算法把我们的网络看成一个图,在图上进行传播,根据链式法则把梯度求出来。这个就是我们想要完成的bp(back propagation)

二:forward

先来一个简单的两层神经网络:

我们现在一层一层分析,其实可以看出两层的操作都是一样的。首先第一层计算的是w1*x+b1,假如说我们的输入x是一个n维的列向量,结果是一个m维的列向量,MM是矩阵相乘,那我们需要的w1是一个m*n的矩阵,相乘得到的结果是一个m维的列向量,需要b1也是一个m维的列向量,ADD表示相加,得到的结果可以看成这个层的输出,但其实这个值还需要放入到下一层进行第二层的运算,而两个的运算过程都差不多,大家可以自己看一下。

ok,现在知道每一层的运算了,但是有一个问题出现了。

大家看,在一个线性的运算中,其中不管有多少层,w1,w2都是可以通过计算放在一起的,那最后得到的结果也可以看出来,又是一个新的线性运算。这样就意味着,无论我们经过多少层的运算,最后得到的还是一个线性的运算。

为什么说这样不行,因为我们不希望化简,这样会导致我们的那些增加的权重没有意义,所以我们需要对每一层最终的输出加上一个非线性的变化函数。如下图所示:

三:BP

3.1 链式法则

链式求导第一步就是需要创建计算图。

接下来就是一个前馈forword,其实就是先有x,w通过f函数计算出z,最后得到loss的值。

现在我们如果想知道loss对于x或者w进行求导数,就是需要我们的链式法则,这个过程也就是bp(back propagation)。过程就是如下图

 ok,现在举一个具体的例子1:设x=2,w=3,f(x,w)=x*w。求z的值和求z对w和x求导的结果。大家可以自己计算一下,结果看文末。

3.2整体流程

现在大家目光向下:整体的过一遍流程,先前馈forward,后backward。

这个例子中给出的y_head的计算公式,就类似于我们上面提到的f(x,w)函数,和loss的计算公式。给出了w=1,x=1,y=2,其中r为y_head 减去y。首先计算出y_head为1,随后计算出r为-1,最后算出loss为1,以上为forward过程。接下来就是backforward,通过链式法则的知识,先通过loss和r的函数关系,用loss对r进行求导,接着r对y_head求导,最后y_head对w求导,几个结果相乘最终得到的就是loss对w求导的结果。

上面的计算大家也学会了,现在加上一个偏置量,大家计算一下loss值,loss对b和w的导数。此为例2,结果在文末。

3.3 tensor

在pytorch里最重要的数据成员就是tensor,存我们上面提到的一些数值,数据可以是标量,矩阵或者高阶的tensor,其中有两个比较重要的成员,一个是data(用于存放w本身的值),一个是grad(用于存放loss对于w的梯度值)。在链式法则部分我们提到,链式求导第一步就是需要创建计算图,这个就是使用tensor创建的。

第一部分代码,输入的相关参数:

import torch

#为举例子,自己设置的值
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

w = torch.tensor([1.0])
w.requires_grad = True  #默认是不进行梯度计算的,我们让他为true就是进行梯度计算

第二部分代码,确定计算的一些步骤:

def forward(x):
    return x * w

def loss(x,y):
    y_pred = forward(x)
    return (y_pred - y)**2

 此时有一个需要注意的点,我们在第一步的时候设置的w是一个tensor值,当它遇到*时间,,此时的*已经被重载了,现在进行的是tensor于tensor的数乘。但是此时x并不是一个tensor类型,会自动转化为tensor。此时就构建出类似于这样的计算图

 并且由于我们最后需要对w计算梯度,所以求出的z也需要计算梯度。

同理定义的loss函数也会建立出一个计算图。

第三步就是计算过程。

print('predict (before training)',4,forward(4).item())

for epoch in range(100):
    for x,y in zip (x_data,y_data):
        l = loss(x,y) #这一步是前馈的过程
        l.backward() #这一步是bp的过程,注意bp完会消除所有的计算图
        print('\tgrad:',x,y,w.grad.item())
        w.data = w.data - 0.01 *w.grad.data #此时注意一定要.data 因为w是一个tensor,而我们需要的是tensor里面的data
        
        w.grad.data.zero_() #在上一步的更新完,导数还存在,所以我们需要将其清零。
    print('progress',epoch,l.item())

print('predict(after training)',4,forward(4).item)

现在大家应该知道整体的流程和代码了,现在大家可以自己尝试去写一下下面这个流程。关于x_data于y_data的值与上面的值相同,大家可以尝试一下。

四:answer

例子1:z的结果为6,z对w和x求导的结果分别为10和15。

例子2:z的结果是1,z对w和x求导结果分别为2和2。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/667011.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端Vue自定义支付密码输入框键盘与设置弹框组件的设计与实现

摘要 随着信息技术的不断发展,前端开发的复杂性日益加剧。传统的开发方式,即将整个系统构建为一个庞大的整体应用,往往会导致开发效率低下和维护成本高昂。任何微小的改动或新功能的增加都可能引发对整个应用逻辑的广泛影响,这种…

Mybatis-plus 更新或新增时设置某些字段值为空

方式一 在实体中设置某个字段为的注解中 TableField(updateStrategy FieldStrategy.IGNORED)private Date xxxxxxTime;通过这种方式会指定更新时该字段的策略,通常情况下updateById这种会根据字段更新,通常都会判断null 以及空值 指定 updateStrategy …

学习Java的日子 Day51 数据库,DDL

Day51 MySQL 1.数据库 数据库(database)就是一个存储数据的仓库。为了方便数据的存储和管理,它将数据按照特定的规律存储在磁盘上。通过数据库管理系统,可以有效地组织和管理存储在数据库中的数据 MySQL就是数据库管理系统&#…

[ubuntu18.04]搭建mptcp测试环境说明

MPTCP介绍 Multipath TCP — Multipath TCP -- documentation 2022 documentation 安装ubuntu18.04,可以使用虚拟机安装 点击安装VMware Tool 桌面会出现如下图标 双击打开VMware Tools,复制如下图所示的文件到Home目录 打开终端,切换到管…

安卓启动 性能提升 20-30% ,基准配置 入门教程

1.先从官方下载demohttps://github.com/android/codelab-android-performance/archive/refs/heads/main.zip 2.先用Android studio打开里面的baseline-profiles项目 3.运行一遍app,这里建议用模拟器,(Pixel 6 API 34)设备运行&a…

[Algorithm][动态规划][子序列问题][最长递增子序列的个数][最长数对链]详细讲解

目录 1.最长递增子序列的个数1.题目链接2.算法原理详解3.代码实现 2.最长数对链1.题目链接2.算法原理详解3.代码实现 1.最长递增子序列的个数 1.题目链接 最长递增子序列的个数 2.算法原理详解 注意:本题思路和思维方式及用到的方法很值得考究,个人感…

GPT4o还没用上?落后一个月!

文章目录 一.Share官方网站:以一半的价格享受官网服务1.1 网址1.2 一些介绍和教学实战:1.3 主界面(支持4o):1.4 GPTS(上千个工具箱任你选择):1.5 快速的文件数据分析(以数学建模为例…

CPU/GPU/FPSGO,负载调试/设置命令开关

CPU/GPU/FPSGO,负载调试/设置命令开关 首先,进入: adb shell cat sys/kernel/ged/hal/gpu_utilization 查看GPU的负载情况。输出三个数字,第1个表示使用率,第3个表示空闲率。 echo 0 /sys/kernel/fpsgo/common/force…

Tableau创建数据提取

Tableau创建数据提取通过与原始数据集分离可有效减少总体数据量。以下通过示例-超市数据进行演示: 需求:提取华北及东北地区家具销售利润低于5000的数据 1) 连接到数据并在“数据源”页面上设置数据源后,请在右上角选择“数据提…

Python 机器学习 基础 之 处理文本数据 【处理文本数据/用字符串表示数据类型/将文本数据表示为词袋】的简单说明

Python 机器学习 基础 之 处理文本数据 【处理文本数据/用字符串表示数据类型/将文本数据表示为词袋】的简单说明 目录 Python 机器学习 基础 之 处理文本数据 【处理文本数据/用字符串表示数据类型/将文本数据表示为词袋】的简单说明 一、简单介绍 二、处理文本数据 三、用…

Java中的软引用,你了解吗?

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

关系数据库:关系运算

文章目录 关系运算并(Union)差(Difference)交(Intersection)笛卡尔积(Extended Cartesian Product)投影(projection)选择(Selection)除…

翻译《The Old New Thing》- What a drag: Dragging a virtual file (IStream edition)

What a drag: Dragging a virtual file (IStream edition) - The Old New Thing (microsoft.com)https://devblogs.microsoft.com/oldnewthing/20080319-00/?p23073 Raymond Chen 2008年03月19日 拖拽虚拟文件(IStream 版本) 上一次,我们看…

Scikit-Learn 基础教程

目录 🐋Scikit-Learn 基础教程 🐋Scikit-Learn 简介 🐋 数据预处理 🦈数据集导入 🦈数据清洗 🦈特征选择 🦈特征标准化 🐋 模型选择 🦈分类模型 🦈回…

【 0 基础 Docker 极速入门】镜像、容器、常用命令总结

Docker Images(镜像)生命周期 Docker 是一个用于创建、部署和运行应用容器的平台。为了更好地理解 Docker 的生命周期,以下是相关概念的介绍,并说明它们如何相互关联: Docker: Docker 是一个开源平台&#…

HTML旋转照片盒子

效果图 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><meta http-equiv"X-UA-Compatible" content…

docker私有镜像仓库的搭建及认证

简介&#xff1a; docker私有镜像仓库的搭建及认证 前言 在生产上使用的 Docker 镜像可能包含我们的代码、配置信息等&#xff0c;不想被外部人员获取&#xff0c;只允许内 网的开发人员下载。 Docker 官方提供了一个叫做 registry 的镜像用于搭建本地私有仓库使用。在内部网…

C 基础 - 预处理命令和基本语法详解

#include <stdio.h> //预处理指令int main() //函数 {printf("Hello, World!"); //输出语句return 0; //返回语句 } 目录 一.预处理指令 1.#define #ifdef #ifndef #if #else #elif #endif 2.#inlcude a.新增一个文件 b.#include c.运行结果 d.扩…

Liunx中使用他人身份来执行命令或新建文件

前言 在一些情况下。我们想要借助某个用户的身份来执行命令或者新建文件&#xff0c; 比如某个用户的bash是 nologin 或者 false。 该怎么做呢&#xff1f;&#xff1f; 答&#xff1a;使用 sudo -u 即可。 例如&#xff1a; sudo -u ygz1 touch temp1.txt哈哈哈&#xff0…

【FPGA】Verilog语言从零到精通

接触fpga一段时间&#xff0c;也能写点跑点吧……试试系统地康康呢~这个需要耐心但是回报巨大的工作。正原子&&小梅哥 15_语法篇&#xff1a;Verilog高级知识点_哔哩哔哩_bilibili 1Verilog基础 Verilog程序框架&#xff1a;模块的结构 类比&#xff1a;c语言的基础…