深度学习-07-反向传播的自动化

深度学习-07-反向传播的自动化


本文是《深度学习入门2-自製框架》 的学习笔记,记录自己学习心得,以及对重点知识的理解。如果内容对你有帮助,请支持正版,去购买正版书籍,支持正版书籍不仅是尊重作者的辛勤劳动,也是鼓励更多优秀作品问世。

当前笔记内容主要为:步骤7 反向传播的自动化 章节的相关理解。

书籍总共分为5个阶段,每个阶段分很多步骤,最终是一步一步实现一个深度学习框架。例如前两个阶段为:

第 1 阶段共包括 10 个步骤 。 在这个阶段,将创建自动微分的机制
第 2 阶段,从步骤11-24,该阶段的主要目标是扩展当前的 DeZero ,使它能够执行更复杂的计算 ,使它能 够处理接收多个输入的函数和返回多个输出的函数


1.为反向传播的自动化创造条件

之前我们在实现反向传播的时候,我们是手动编写进行反向传播计算的代码,这意味着我们每次都要编写这些代码。例如下面的:

    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)

    y.grad = np.array(1.0)
    b.grad = C.backward(y.grad)
    a.grad = B.backward(b.grad)
    x.grad = A.backward(a.grad)
    print(x.grad)

如果计算图不一样,那么我们每次就需要白那些不一样的代码来计算反向传播。

图:不同计算图的例子。

那我们就开始思考如果自动化这块内容?书中讲解了一种机制:无论普通的计算流程(正向传播)中是什么样的计算,反向传播都能 自动进行 。这里引入一个概念:Define- by-Run。
Define-by-Run;是在深度学习中进行计算时 ,在计算之间建立"连接"的机 制 。 这种机制也称为动态计算图


计算图都是流水线式的计算 。 因此,只要以列表的形式记录函数的顺序,就可以通过反向回溯自动进行反向传播
    
    
在实现反向传播的自动化之前,我们先思考一下变量和函数之间的关系。解决方案主要从这里入手。我们需要考察变量和函数的关系并且用代码表示出来。

函数的变量包括"输入变量 "(input) 和"输出变量"(output)。函数是变量的 "父母" 如果没有父母,说明这个变量是用户的输入变量。

修改Variable 定义,函数和变量之间的"连接,让这个"连接"在执行普通计算(正向传播)的那一刻创建。

class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self,func):
        self.creator = func

上面代码,定义了熟悉 creator 标识是那个函数。并且定义了set 方法,来进行变量与函数之间的关联。

修改 Function 定义:

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        output.set_creator(self)  # 输出者保存创造者对象
        self.input = input
        self.output = output  # 保存输出者。我是创造者的信息,这是动态建立 "连接"这 一 机制的核心
        return output

    def forward(self, x):
        raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现

    def backward(self, gy):
        raise NotImplementedError()


在函数执行时,就设置输出变量的 creator 为”自己“。并且将输出保存起来。

经过上面的操作后,变量和函数之间建立的联系,并且这两种联系的建立时代码执行过程中创建的。

有了连接,我们手动检查下,计算下反向遍历计算图  

    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)

    assert y.creator == C
    assert y.creator.input == b
    assert y.creator.input.creator == B
    assert y.creator.input.creator.input == a
    assert y.creator.input.creator.input.creator == A
    assert y.creator.input.creator.input.creator.input == x

没有抛出异常,这意味着 assert 语句的所有条件都得到了满足。


2.尝试反向传播

利用上面的变量与函数之间的关系。进行反向传播计算。

先 y ---> b 

    y.grad = np.array(1.0)

    C = y.creator  # 获取函数
    b = C.input     # 获取函数的输入
    b.grad = C.backward(y.grad)    # 调用函数的backward 方法

然后  b ---> a
    
 

    B = b.creator
    b = B.input
    a.grad = B.backward(b.grad)

最后  a----> x 

    A = a.creator
    x = A.input
    x.grad = A.backward(a.grad)
    print(x.grad)


    
其实每一步的规律是:
1 获取函数 
2 获取函数的输入 
3 调用函数的 backward 方法


最后执行的结果为  
3.29744

3.增加backward 方法

前面这些反向传播的代码可以看山,它们有着相同的处理流程准确 来说,是从一个变量到前一个变量的反向传播逻辑相同。

修改Variable 类,增加一个新的方法-backward 


class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self,func):
        self.creator = func

    def backward(self):
        f = self.creator  # 获取函数
        if f is not None:
            x = f.input # 获取函数的输入
            x.grad =f.backward(self.grad) # 调用函数的backward() 方法
            x.backward()  # 调用自己前面的那个变量的 backward()方法
4.项目代码
'''
step07.py
自动实现反向传播

'''

import numpy as np

class Variable:
    def __init__(self, data):
        self.data = data
        self.grad = None
        self.creator = None

    def set_creator(self,func):
        self.creator = func

    def backward(self):
        f = self.creator  # 获取函数
        if f is not None:
            x = f.input # 获取函数的输入
            x.grad =f.backward(self.grad) # 调用函数的backward() 方法
            x.backward()  # 调用自己前面的那个变量的 backward()方法

class Function:
    def __call__(self, input):
        x = input.data
        y = self.forward(x)
        output = Variable(y)
        output.set_creator(self)  # 输出者保存创造者对象
        self.input = input
        self.output = output  # 保存输出者。我是创造者的信息,这是动态建立 "连接"这 一 机制的核心
        return output

    def forward(self, x):
        raise NotImplementedError()  # 使用Function  这个方法forward 方法的人 , 这个方法应该通过继承采实现

    def backward(self, gy):
        raise NotImplementedError()

class Square(Function):
    def forward(self, x):
        y = x ** 2
        return y

    def backward(self, gy):
        x= self.input.data
        gx = 2 * x * gy     #方法的参数 gy 是 一个 ndarray 实例 , 它是从输出传播而来的导数 。
        return gx


class Exp(Function):
    def forward(self, x):
        y = np.exp(x)
        return y

    def backward(self, gy):
        x = self.input.data
        gx = np.exp(x) * gy
        return gx

if __name__ == '__main__':
    A = Square()
    B = Exp()
    C = Square()

    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)

    assert y.creator == C
    assert y.creator.input == b
    assert y.creator.input.creator == B
    assert y.creator.input.creator.input == a
    assert y.creator.input.creator.input.creator == A
    assert y.creator.input.creator.input.creator.input == x

    # 反向传播
    y.grad = np.array(1.0)

    C = y.creator  # 获取函数
    b = C.input     # 获取函数的输入
    b.grad = C.backward(y.grad)    # 调用函数的backward 方法

    B = b.creator
    b = B.input
    a.grad = B.backward(b.grad)

    A = a.creator
    x = A.input
    x.grad = A.backward(a.grad)
    print(x.grad)

    # 自动反向传播
    x = Variable(np.array(0.5))
    a = A(x)
    b = B(a)
    y = C(b)
    y.grad = np.array(1.0)
    y.back()
    print(x.grad)
5.总结

经过此小节的代码丰富,目前我们可以执行自动反向传播计算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/679527.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

解决 clickhouse jdbc 偶现 failed to respond 问题

背景 Clickhouse集群版本为 Github Clickhouse 22.3.5.5, clickhouse-jdbc 版本为 0.2.4。 问题表现 随着业务需求的扩展,基于Clickhouse 需要支持更多任务在期望的时效内完成,于是将业务系统和Clickhouse交互的部分都提交给可动态调整核心…

Python中如何打开网页

幸好思念无声,可惜思念无声 ——24.6.4 Python打开前端网页 1.导入webbrowser库 用webbrowser.open(传入网址),打开网页 import webbrowser webbrowser.open("Index.html") 2.用flask框架 from wsgiref.simple_server import make_serve…

九、从0开始卷出一个新项目之瑞萨RZN2L生产烧录固件(jflash擦写读外挂flash)

目录 七、生产烧录固件(jflash擦/写/读外挂flash) 7.1 flash母片读写 7.2 jflash擦/写/读外挂flash 九、从0开始卷出一个新项目之瑞萨RZN2L 七、生产烧录固件(jflash擦写读外挂flash) 七、生产烧录固件(jflash擦/写/读外挂flash) 7.1 flash母片读写 略 7.2 jflash擦/写/读…

文件上传漏洞之upload-labs

前提: 本文中的以xshell命名的均为以密码为admin的一句话木马,而shell命名的则是由冰蝎工具生成的木马。 pass-01:js前端验证 测试性的上传一个一句话木马,发现被拦截了,而且根据推测大概率是前端检测,于…

uniapp小程序开发 | 从零实现一款影视类app (横向滚动和下拉刷新的实现)

uniapp小程序开发实战系列,完整介绍从零实现一款影视类小程序。包含小程序前端和后台接口的全部完整实现。系列连载中,喜欢的可以点击收藏。 这里介绍下我的电影小程序的完整实现过程。这个系列将会详细讲解每个步骤,包括接口设计、数据结构优…

720云「3D空间漫游」功能爆发!户型标注、自动导览、切换视图…

一、新增 [开场封面] 支持图片、视频开场 作品第一印象必须惊艳!使用频率极高的功能,终于给3D漫游安排上啦~你可以自定义上传一张图片或一段视频,支持对桌面端、移动端分别进行设置并预览,完美适配不同终端。 二、升级模型交互体验…

Docker安装MySQL8.0报错记录

Linux已知有docker MySQL5.6版本,再安装MySQL8.0,报错信息记录如下 Docker安装MySQL8.0报错记录 Linux已知有docker MySQL5.6版本,再安装MySQL8.0,报错信息记录如下 问题1 :ls: cannot access ‘/docker-entrypoint…

HarmonyOS鸿蒙应用开发——ArkUI组件封装最佳实践

文章目录 背景与案例描述静态注册属性-封装UI组件动态注册属性-封装UI组件总结 背景与案例描述 在应用开发中,对一些频繁使用的业务UI组件常常会进行一层封装,提取到公共基础库中实现组件的复用,避免类似的逻辑重复编写,减少代码…

Excel中高级筛选多个条件怎么做?

高级筛选关键点就在条件设置,筛选条件可以设置多行多列,同一行之间的条件是“并且”的关系,同一列之间的条件是“或者”的关系。 我们以筛选厂家通用、大众,在北京、上海、成都,1月的数据为例来演示条件设置 一、按字…

React - 实现走马灯组件

一、实现效果 二、源码分析 import {useRef, useState} from "react";export const Carousel () > {const images [{id: 3, url: https://sslstage3.sephorastatic.cn/products/2/4/6/8/1/6/1_n_new03504_100x100.jpg}, {id: 1, url: https://sslstage2.sephor…

一个月飙升 9k star!打破常规的 git 客户端

作为一名程序员,想必大家每天都要使用 git 来管理自己的代码吧。有些大佬喜欢使用命令行来进行 git 的操作,有些新入门的小白程序员则比较喜欢使用各种 git 客户端来可视化的管理代码,而有些程序员则喜欢使用 IDE 中集成的 git 功能来做代码的…

【机器学习】Softmax回归探索

从零开始探索Softmax回归:深度学习的入门之旅 一、Softmax回归的原理与关键步骤二、研究准备:GPU环境下的PyTorch安装与配置三、研究内容:使用PyTorch实现Softmax回归 随着人工智能和机器学习的迅猛发展,深度学习技术逐渐成为了科…

Codeforces Round 950 (Div. 3)(A~D题)

A. Problem Generator 思路:暴力模拟,对于每个字母&#xff0c;如果不足m mm&#xff0c;就加入最终答案. 实现代码: #include<bits/stdc.h> using namespace std; #define N 2000005 #define mod 100003 typedef long long ll; ll n, m, t, cnt, ans, sum1,sum2, maxx…

史上Z快ST

一支股票ST要多久&#xff0c;锦州港告诉你&#xff0c;3个交易日足矣。上周五发出发公告&#xff0c;今天停牌1天&#xff0c;明天复牌就变ST。对锦州港的九万股民来说&#xff0c;好消息是今天不会跌&#xff0c;坏消息是复牌ST以后可以开始每天数地板了。 又是一个浓眉大眼的…

前端 CSS 经典:3D Hover Effect 效果

前言&#xff1a;有趣的 3D Hover Effect 效果&#xff0c;通过 js 监听鼠标移动&#xff0c;动态赋值 rotateX&#xff0c;rotateY 的旋转度来实现。 效果图&#xff1a; 代码实现&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta …

Python脚手架系列-PyQt5

记录PyQt模块使用中的一些常常复用的代码 其他 导入界面 import sysfrom PyQt5.QtGui import QIcon from PyQt5.QtWidgets import QApplication, QMainWindow from UI.MainWindow import Ui_MainWindow # 导入UI界面的类以供继承class MyApp(QMainWindow, Ui_MainWindow):de…

Scala环境的搭建

要搭建Scala&#xff0c;我们必须先下载java&#xff0c;由于我的电脑已经搭建好了环境&#xff0c;因此我这里用截图来教大家搭建环境。 可以从网上搜索安装包对其进行安装 IntelliJ IDEA – 领先的 Java 和 Kotlin IDE 不建议下载最新版的&#xff0c;大家下载的版本可以下…

CyberDAO引领Web3新时代,共创去中心化未来

Web3的新时代 Web3是互联网的下一代版本&#xff0c;基于区块链技术&#xff0c;实现了去中心化、透明和安全的网络体验。与微信、淘宝等传统中心化平台不同&#xff0c;Web3赋予用户更多的控制权和数据所有权。用户行为数据将由用户自己拥有并分布式管理&#xff0c;不再集中…

PySpark特征工程(III)--特征选择

有这么一句话在业界广泛流传&#xff1a;数据和特征决定了机器学习的上限&#xff0c;而模型和算法只是逼近这个上限而已。由此可见&#xff0c;特征工程在机器学习中占有相当重要的地位。在实际应用当中&#xff0c;可以说特征工程是机器学习成功的关键。 特征工程是数据分析…

LeetCode刷题之最大子数组

今天打算多做一题。 1、题目描述 2、逻辑分析 哈哈&#xff0c;这题我前两天在小红书刷到了&#xff0c;博主答不上来&#xff0c;一样的是&#xff0c;我也不知道怎么做。当时只看到评论说什么dp解法&#xff0c;看看题解怎么说。现在才反应过来dp dynamic programming &am…