【刘二大人】pytorch深度学习实践(三):如何实现线性模型的反向传播+代码实现详解(Tensor、backward函数)

目录

  • 参考资料
  • 一、反向传播流程
    • 1.1 问题
    • 1.2 方法
    • 1.3 步骤
    • 1.4 例题
  • 二、Pytorch中前向传播和反馈的计算
    • 2.1 tensor数据类型
    • 2.2 定义线性模型并且计算损失
      • 2.2.1 torch.tensor.item()
      • 2.2.2 代码
    • 2.3 反向传播
      • 2.3.1 torch.tensor.backward()
      • 2.3.2 tensor.zero_( )
      • 2.3.3 代码实现
  • 三、代码实现

参考资料

  • 学习视频: 反向传播-刘二大人
  • 上一节学习笔记: 【刘二大人】pytorch深度学习实践(二):梯度下降算法详解和代码实现(梯度下降、随机梯度下降、小批量梯度下降的对比)
  • pytorch官方文档: pytorch官方文档

一、反向传播流程

1.1 问题

求loss函数对于w和x的偏导。

1.2 方法

基于导数的链式法则,依次求导。
在这里插入图片描述

1.3 步骤

  1. 首先根据前向传播,可以得到 x = 2 , w = 3 , z = f ( x , w ) = x ∗ w = 6 , x =2 ,w = 3, z= f(x,w) = x*w = 6, x=2,w=3,z=f(x,w)=xw=6, 那么就可以求得z关于w和z的导数: ∂ z ∂ w = x = 2 , ∂ z ∂ x = w = 3 \frac{\partial z}{\partial w} =x=2,\frac{\partial z}{\partial x} =w=3 wz=x=2,xz=w=3
  2. 继续前向传播, l o s s = ( y − y p r e d ) 2 loss = (y-y_{pred})^2 loss=(yypred)2,直到计算出loss函数
  3. 根据反向传播,程序可以计算出 ∂ L o s s ∂ z = 5 \frac{\partial Loss}{\partial z} =5 zLoss=5
  4. 根据链式法则,我们知道 ∂ L o s s ∂ w = ∂ L o s s ∂ z ∗ ∂ z ∂ w \frac{\partial Loss}{\partial w}=\frac{\partial Loss}{\partial z}*\frac{\partial z}{\partial w} wLoss=zLosswz, 而我们已经计算出 ∂ z ∂ w = x = 2 \frac{\partial z}{\partial w} =x=2 wz=x=2,所以 ∂ L o s s ∂ w = 2 ∗ 5 = 10 \frac{\partial Loss}{\partial w}=2*5=10 wLoss=25=10,同理可以计算出 ∂ L o s s ∂ x = 3 ∗ 5 = 15 \frac{\partial Loss}{\partial x}=3*5=15 xLoss=35=15
  5. 由此我们便完成了 ∂ z ∂ w , ∂ z ∂ x \frac{\partial z}{\partial w},\frac{\partial z}{\partial x} wz,xz的计算。
    在这里插入图片描述

1.4 例题

求loss函数关于w的偏导数
在这里插入图片描述

前向传播求出局部梯度,再反向传播求得最终梯度

二、Pytorch中前向传播和反馈的计算

2.1 tensor数据类型

pytorch官方文档 - tensor

Tensor中有两个重要的数据变量

  • data:该节点的数据值,为Tensor类
  • grad:该节点的梯度值,为Tensor类
    在这里插入图片描述

对w使用Tensor数据类型进行定义:设置requires_grad = True表明在计算过程中需要保留该值的梯度;

import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = torch.tensor([1.0])
w.requires_grad = True

2.2 定义线性模型并且计算损失

y ′ = w ∗ x y' =w*x y=wx

l o s s = ( y ′ − y ) 2 = ( x ∗ w − y ) 2 loss = (y'-y)^2 =(x*w-y)^2 loss=(yy)2=(xwy)2

这段代码是在构建如下的计算图,前向传播并且求出loss值
在这里插入图片描述

此处的 l l l是一个张量(因为w是一个张量),所以后续需要 l l l的值时要使用 l . i t e m ( ) l.item() l.item()的方法进行取值

2.2.1 torch.tensor.item()

item()是将一个张量的值以一个python数字形式返回;
在这里插入图片描述
使用item()将Tensor张量转换为数字

在这里插入图片描述

2.2.2 代码

def forward(x):
    return x*w

def loss(x,y):
    return (y-forward(x))**2

2.3 反向传播

2.3.1 torch.tensor.backward()

该函数计算当前张量相对于计算图中所有叶子节点的梯度
在这里插入图片描述

2.3.2 tensor.zero_( )

把Tensor的数值清零。
在这里插入图片描述

2.3.3 代码实现

  1. 使用for循环设置训练10个epoch
  2. 使用loss函数构建计算图,计算损失值
  3. 调用backward函数计算计算图上叶子节点的梯度值
  4. 根据w的梯度值更新w( w − = w ∗ 学习率 w-=w*学习率 w=w学习率
  5. 清空w的梯度,准备下一轮计算
print("predicted(before training)",4,forward(4).item())
# 训练10个epoch
for epoch in range(10):
    for x,y in zip(x_data,y_data):
    # 计算损失值
        l = loss(x,y)
        # 反向传播
        l.backward()
        print("\tgrad:",x,y,w.grad.item())
        # 更新w
        w.data=w.data-0.01*w.grad
# 清空w的梯度
        w.grad.zero_()
    print("progress:",epoch,l.item())
print("predict(after training)",4,forward(4).item())

三、代码实现

import torch
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
w = torch.tensor([1.0])
w.requires_grad = True

def forward(x):
    return x*w

def loss(x,y):
    return (y-forward(x))**2

print("predicted(before training)",4,forward(4).item())

for epoch in range(10):
    for x,y in zip(x_data,y_data):
        l = loss(x,y)
        l.backward()
        print("\tgrad:",x,y,w.grad.item())
        w.data=w.data-0.01*w.grad

        w.grad.zero_()
    print("progress:",epoch,l.item())
print("predict(after training)",4,forward(4).item())


在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/222359.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

企业机密无忧!好用的文件加密系统大揭秘,尽在这里!

由于众多企业内部都存储着大量机密数据,以电子文档形式存在,且传播手段多样,文件泄密问题容易发生。员工通过网络泄密重要文件,或黑客入侵窃取机密数据等情况,都可能导致企业业务和声誉受到严重损害。因此,…

C++断言assert

2023年12月6日&#xff0c;周三上午 在C中&#xff0c;assert 是一个宏定义&#xff0c;用于在程序运行期间检查一些条件是否满足。如果条件不满足&#xff0c;则 assert 会终止程序并输出一条错误消息。 assert 宏定义的语法如下&#xff1a; #include <cassert>asser…

【ESP8266】ESP8266集成开发环境对比

当涉及到ESP8266开发环境的选择时&#xff0c;有几个常见的选择可供开发人员使用。在本篇文章中&#xff0c;我们将对比一些目前最流行的ESP8266集成开发环境&#xff08;IDE&#xff09;&#xff0c;以帮助您选择最适合您的需求的开发环境。 总结&#xff1a;Arduino IDE和Pl…

学校图书管理系统的开发

目 录 摘要 1 Abstract. 1 1 引言 2 1.1 图书管理的现状 2 1.2 现有图书管理系统的概述 3 1.3 选题的目的、意义 3 1.4 图书管理系统的可行性分析 4 1.5 系统开发运行环境 4 2 图书管理系统开发相关技术的介绍 5 2.1 Asp.net的介绍 5 2.1.1 Asp.net的优势介绍 5 2.1.2 Asp.net…

codeforces 题目 Fadi and LCM

目录 题目&#xff1a; 题目描述&#xff1a; 思路&#xff1a; AC代码&#xff1a; 题目&#xff1a; 题目描述&#xff1a; 给你一个长整型 X ①你需要找到一对 a 和 b &#xff0c;使得 LCM&#xff08;a&#xff0c;b&#xff09; X ②你需要保证 max(a&#xff…

智能优化算法应用:基于水基湍流算法无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于水基湍流算法无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于水基湍流算法无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.水基湍流算法4.实验参数设定5.算法结果6.参考…

css 字体添加外轮廓

color: #ffeb3b; -webkit-text-stroke: 10px transparent; background: linear-gradient(90deg,#5d3d02f5,#5d3d02f5,#5d3d02f5,#5d3d02f5,#5d3d02f5,#5d3d02f5,#5d3d02f5) top left / 100% 100%; -webkit-background-clip: text;

「Verilog学习笔记」无占空比要求的奇数分频

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 timescale 1ns/1nsmodule odd_div ( input wire rst ,input wire clk_in,output wire clk_out5 ); //*************code***********//reg [1:0] data ;reg […

class035 数据结构设计高频题【算法】

class035 数据结构设计高频题【算法】 算法讲解035【必备】数据结构设计高频题 code1 设计有setAll功能的哈希表 // setAll功能的哈希表 // 测试链接 : https://www.nowcoder.com/practice/7c4559f138e74ceb9ba57d76fd169967 // 请同学们务必参考如下代码中关于输入、输出…

class050 双指针技巧与相关题目【算法】

class050 双指针技巧与相关题目【算法】 算法讲解050【必备】双指针技巧与相关题目 code1 922. 按奇偶排序数组 II // 按奇偶排序数组II // 给定一个非负整数数组 nums。nums 中一半整数是奇数 &#xff0c;一半整数是偶数 // 对数组进行排序&#xff0c;以便当 nums[i] 为…

Isaac Sim教程04 Isaac Sim的高级使用

Isaac Sim 高级使用 版权信息 Copyright 2023 Herman YeAuromix. All rights reserved.This course and all of its associated content, including but not limited to text, images, videos, and any other materials, are protected by copyright law. The author holds…

EI论文复现:考虑源荷不确定性的含风电-电力系统低碳调度程序代码!

本程序参考论文《考虑源荷不确定性的含风电-电力系统低碳调度》&#xff0c;程序中考虑了源荷的不确定性&#xff0c;引入模糊机会约束规划来求解不确定性模型&#xff0c;对做相关研究方向的小伙伴非常有帮助&#xff0c;程序算例丰富、注释清晰、干货满满&#xff0c;下面对文…

JAVA刷题之数组的总结和思路分享

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好&#xff0c;我是xiaoxie.希望你看完之后,有不足之处请多多谅解&#xff0c;让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …

Amazon Code Whisperer 的正式使用,全新 AI 代码工具等你发现!(内附详细安装步骤图解)

文章作者&#xff1a;稚始稚终 关于 Code Whisperer Code Whisperer&#xff0c;亚马逊推出的实时 AI 编程助手&#xff0c;是一项基于机器学习的服务&#xff0c;它可以分析开发者在集成开发环境&#xff08;IDE&#xff09;中的注释和代码&#xff0c;并根据其内容生成多种代…

【LeetCode:2646. 最小化旅行的价格总和 | DFS + DP】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

【开发问题解决方法记录】04.dian 权限表单优化

权限表单优化方向&#xff1a; 父级权限从晶点权限表获取做成列表下拉选中 权限名称和编码一行两列 页面id从 select * from APEX_APPLICATION_PAGES where APPLICATION_ID304; 中获取 【遇到的问题1】 DG可以获取到页面信息&#xff0c;但是表和应用程序无法获取到 【问…

机器学习-逻辑回归

一、引言 逻辑回归&#xff08;Logistic Regression&#xff09;是一种广泛应用于分类问题的监督学习算法。尽管名字中含有“回归”二字&#xff0c;但这并不意味着它用于解决回归问题。相反&#xff0c;逻辑回归专注于解决二元或多元分类问题&#xff0c;如邮件是垃圾邮件还是…

TSMaster添加注释

当我们在回放报文的时候&#xff0c;会遇到一些需要添加注释&#xff0c;有以下几种办法进行注释 报文运行时手动注释 在图形窗口回放报文&#xff0c;正在抓取报文或者进行报文回放。工具栏选择添加实时注释&#xff0c;这种办法需要手速快&#xff0c;而且时间对的不是很准…

App内存优化

一、内存优化介绍 1.背景介绍 内存是大问题但缺乏关注压实骆驼的最后一个稻草&#xff08;堆栈溢出&#xff09; 2.内存问题 内存抖动&#xff1a;锯齿状、GC导致卡顿内存泄露&#xff1a;可用内存减少、频繁GC内存溢出&#xff1a;OOM&#xff0c;程序异常 二、优化工具选…

jvs智能bi新增:数据集添加sql自定义节点、添加websocket任务进度动态展示等等

智能bi更新功能 新增: 1.数据集添加sql自定义输入节点&#xff0c;支持mysql Oracle数据源&#xff1b; 用户可以从这些数据源中获取数据&#xff0c;并通过SQL语句对数据进行自定义处理和分析。可以帮助用户更加灵活地处理和分析数据&#xff0c;满足各种个性化的需求。 2.…