深度学习-06-手动进行反向传播

深度学习-06-手动进行反向传播


本文是《深度学习入门2-自製框架》 的学习笔记,记录自己学习心得,以及对重点知识的理解。如果内容对你有帮助,请支持正版,去购买正版书籍,支持正版书籍不仅是尊重作者的辛勤劳动,也是鼓励更多优秀作品问世。

当前笔记内容主要为:步骤6 手动进行反向传播 章节的相关理解。

书籍总共分为5个阶段,每个阶段分很多步骤,最终是一步一步实现一个深度学习框架。例如前两个阶段为:

第 1 阶段共包括 10 个步骤 。 在这个阶段,将创建自动微分的机制
第 2 阶段,从步骤11-24,该阶段的主要目标是扩展当前的 DeZero ,使它能够执行更复杂的计算 ,使它能 够处理接收多个输入的函数和返回多个输出的函数


1.Variable 类的功能

Variable 类在原有的基础上,除了data普通值外,应该增加一个grad 的值,用于记录导数值。grad 也是numpy ndarray 类型。为啥是ndarray ,我们上一节讲的是对x 一维的,但是我们后面可能会设计到多维度,每个维度都需要进行求导,所以这里定义的是一个ndarray ,这样就可以表示向量或者矩阵。

class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None   #存储梯度,导数值

2.Function 类的功能


Function 类我们在原有的定义前向计算,表示复合函数的基础上定义一个反向传播方法,并且函数定义一个数据域input, 用以记录每次执行是的输入变量。有了input变量后,每次调用backward 方法后,向函数输入的 Variable 实例就可以作为 self.input使用。

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        self.input = input  # input 保存每次输入的变量
        return output

    def forward(self, x):
        raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现

    def backward(self, gy):
        raise NotImplementedError()
3.Square 类和Exp 类的功能扩展


按照最新的Function类,为Square 和Exp 类增加反向传播功能

class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y

    def backward(self, gy):
        x= self.input.data
        gx = 2 * x * gy     #方法的参数 gy 是 一个 ndarray 实例 , 它是从输出传播而来的导数 。
        return gx


class Exp(Function):
    def forward(self, x):
        y = np.exp(x)
        return y

    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx
4.反向传播的实现

尝试对函数进行推演,正向计算,图:    
    

    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)


    
    
反向计算导数图:


    

   

    y.grad = np.array(1.0)
    b.grad = C.backward(y.grad)
    a.grad = B.backward(b.grad)
    x.grad = A.backward(a.grad)
    print(x.grad)
5.项目代码


项目全部代码如下:

'''
step06.py
手动实现反向传播

'''

import numpy as np

class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        self.input = input
        return output

    def forward(self, x):
        raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现

    def backward(self, gy):
        raise NotImplementedError()

class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y

    def backward(self, gy):
        x= self.input.data
        gx = 2 * x * gy     #方法的参数 gy 是 一个 ndarray 实例 , 它是从输出传播而来的导数 。
        return gx


class Exp(Function):
    def forward(self, x):
        y = np.exp(x)
        return y

    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx

if __name__ == '__main__':
    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)

    y.grad = np.array(1.0)
    b.grad = C.backward(y.grad)
    a.grad = B.backward(b.grad)
    x.grad = A.backward(a.grad)
    print(x.grad)

执行结果:3.297442541400256


问题分析:代码里面的 为啥计算梯度的时候,乘以gy?

如果没有理解,可以看下链式法则,这里只有乘以 gy 才把其他的连乘结果带到下一个微分计算里面。
深入理解下链式法则,就懂了这里为啥每个backward 函数乘以一个入参gy 


6.总结

本节实现了反向传播,计算微分的计算,当前是手动实现的,后面改造成自动化计算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/676279.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

O2O : Finetuning Offline World Models in the Real World

CoRL 2023 Oral paper code Intro 算法基于TD-MPC,利用离线数据训练世界模型,然后在线融合基于集成Q的不确定性估计实现Planning。得到的在线数据将联合离线数据共同训练目标策略。 Method TD-MPC TD-MPC由五部分构成: 状态特征提取 z h θ ( s ) …

Amazon Q Developer 实战:从新代码生成到遗留代码优化(下)

简述 本文是使用 Amazon Q Developer 探索如何在 Visual Studio Code 集成编程环境(IDE),从新代码生成到遗留代码优化的续集。在上一篇博客《Amazon Q Developer 实战:从新代码生成到遗留代码优化(上)》中…

java基础篇(1)

JDK是什么?有哪些内容组成?JDK是Java开发工具包 JVM虚拟机: Java程序运行的地方 核心类库: Java已经写好的东西,我们可以直接用开发工具: javac、java、jdb、jhat.. JRE是什么?有哪些内容组成? JRE是Java运行环境 JVM、核心类库、运行工具 JDK,JRE&…

Linux网络编程:传输层协议|UDP|TCP

知识引入: 端口号: 当应用层获得一个传输过来的报文时,这时数据包需要知道,自己应该送往哪一个应用层的服务,这时就引入了“端口号”,通过区分同一台主机不同应用程序的端口号,来保证数据传输…

Java1.8基于BS版 vue+ uniapp+ springboot专业团队自主研发的一套上门家政APP系统成品源码,支持商用(后台端介绍)

Java1.8基于BS版 vue uniapp springboot专业团队自主研发的一套上门家政APP系统成品源码,支持商用(后台端介绍) 家政服务后台端 家政服务后台端是一个专为家政服务行业设计的管理系统,用于处理业务运营、用户端管理、师傅端调度、…

Spring boot 随笔 1 DatasourceInitializer

0. 为啥感觉升级了 win11 之后,电脑像是刚买回来的,很快 这篇加餐完全是一个意外:时隔两年半,再看 Springboot-quartz-starter 集成实现的时候,不知道为啥我的h2 在应用启动的时候,不能自动创建quartz相关…

FL Studio怎么给钢琴加延音 FL Studio怎么用钢琴做伴奏

在使用钢琴音色进行音乐创作的时候,可以对钢琴进行延音处理,这样处理的音色给人的感觉会更加的饱满丰富,同时,给钢琴加了延音之后,钢琴的声音时值也会相应的变长,听起来更加的柔和。今天就和大家讲一讲&…

STM32使用HAL库UART接收不定长数据-1

使用STM32的HAL库实现UART串口不定长数据的接收 使用STM32的UART接收数据的时候,经常会遇到接收长度不固定的数据,比如一帧数据可能是10个字节,也可能是12个字节。这种数据称为不定长数据。 现有的很多通信协议是不定长的,比如mo…

vue3_组件间通信方式

目录 一、父子通信 1.父传子( defineProps) 2.父传子(useAttrs) 3.子传父(ref,defineExpose ) 4.子传父(defineEmits) 5.子传父(v-model) …

数据库 mysql 的彻底卸载

MySQL卸载步骤如下: (1)按 winr 快捷键,在弹出的窗口输入 services.msc,打开服务列表。 (2)在服务列表中, 找到 mysql 开头的所有服务, 右键停止,终止对应的…

【问题随记】tightvnc 连接后灰屏

问题描述 刚刚入手了官方发的 OrangePi AI Pro,想用 tight vnc 来连接开发板,就不用连接屏幕那么麻烦了。结果连接后,没能显示 OrangePi AI Pro 桌面。 问题解决 看一下现有的桌面环境。 apt list --installed | grep desktop从中可以看到…

游戏找不到d3dcompiler43.dll怎么办,分享5种有效的解决方法

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是找不到某个文件。其中,找不到d3dcompiler43.dll是一个常见的问题。这个问题通常出现在运行某些游戏或应用程序时,由于缺少了d3dcompiler43.dll文件,导致程…

【PTA】7-3 拯救007(C++)代码实现 易错点反思

题目见下: 输入样例 14 20 25 -15 -25 28 8 49 29 15 -35 -2 5 28 27 -29 -8 -28 -20 -35 -25 -20 -13 29 -30 15 -35 40 12 12 //输入上述数据后输出“Yes” AC代码如下: #include<bits/stdc++.h> using namespace std; #define sz 100 typedef struct node{int …

基于javacv ffmpeg 使用原生ffmpeg命令

基于javacv ffmpeg 使用原生ffmpeg命令 1. ffmpeg2. ffprobe 相关阅读&#xff1a; javacv ffmpeg使用笔记 测试过程中&#xff0c;发现ffmpeg-6.0-1.5.9-linux-x86_64.jar 存在问题&#xff08;ffmpeg原生命令执行失败&#xff09;&#xff0c;降级到ffmpeg-5.1.2-1.5.8-linux…

Mixly 开启WIFI AP UDP收发数据

一、开发环境 软件&#xff1a;Mixly 2.0在线版 硬件&#xff1a;ESP32-C3&#xff08;立创实战派&#xff09; 固件&#xff1a;ESP32C3 Generic(UART) 测试工工具&#xff1a;NetAssist V5.0.1 二、实现功能 ESP32开启WIFI AP&#xff0c;打印接入点IP地址&#xff0c;允许…

直播预告|手把手教你玩转 Milvus Lite !

Milvus Lite&#xff08;https://milvus.io/docs/milvus_lite.md&#xff09;是一个轻量级向量数据库&#xff0c;支持本地运行&#xff0c;可用于搭建 Python 应用&#xff0c;由 Zilliz 基于全球最受欢迎的开源向量数据库 Milvus&#xff08;https://milvus.io/intro&#xf…

Pandas读取文本文件为多列

要使用Pandas将文本文件读取为多列数据&#xff0c;你可以使用pandas.read_csv()函数&#xff0c;并通过指定适当的分隔符来确保正确解析文件中的数据并将其分隔到多个列中。 假设你有一个以逗号分隔的文本文件&#xff08;CSV格式&#xff09;&#xff0c;每一行包含多个值&a…

二分查找算法介绍(边界值、循环条件、值的变化、二分查找的原理、异常处理)

一、二分查找法原理介绍 二分查找是经典的查找算法之一&#xff0c;其原理也非常简单。 对于已排序的数组&#xff08;假设是整型&#xff0c;如果非整型&#xff0c;如果有排序和大小比较的定义&#xff0c;也可以使用二分查找&#xff09;&#xff0c;我们每次判断中间值与目…

单位企业是如何禁用USB接口的(公司禁止USB接口的三大方法)

在当前信息化时代&#xff0c;单位企业对于数据安全的重视程度日益增加&#xff0c;尤其是防止通过USB接口导致的数据泄露和恶意软件传播。 为了构建一个更加安全的办公环境&#xff0c;许多企业采取措施禁用USB接口。 以下是公司禁止USB接口的三大常用方法&#xff1a; 1. 部…

AI绘画Stable Diffusion最新整合包,开源免费 AI 绘图工具神器,解压即用!

写在前面 众所周知现在的AI绘画可谓是热火朝天&#xff0c;前有国外的Midjourney&#xff0c;后有国内各大平台推出的 各种AI工具等&#xff0c;但是目前的这些线上的AI绘画都会有生成次数、时长等限制&#xff0c;有时候还得排队等待出图&#xff0c;所以免费开源的 Stable D…