计算图 Compute Graph 和自动求导 Autograd | PyTorch 深度学习实战

前一篇文章,Tensor 基本操作5 device 管理,使用 GPU 设备 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

PyTorch 计算图和 Autograd

  • 微积分之于机器学习
  • Computational Graphs 计算图
  • Autograd 自动求导
  • 一个训练过程及 no_grad 的使用
    • 示例代码
    • 执行结果
      • 生成数据
      • 第一轮后
      • 第二轮后
      • 第十轮后
  • 更多计算图的知识
    • 更为复杂点的计算图的样子
    • 自动求导有关的参数
  • Links

微积分之于机器学习

机器学习的主要工作原理,就是万事万物存在规律,而我们使用机器来完成参数评估。参数评估的过程是随机梯度下降,也就是任意选择起点,然后使用微积分技术指导我们调优,找到一组最优参数值。

这就像我们爬山,面对众多的山峰,我们从不同的出发点出发,不断的朝着山顶前进,最终,我们即便起点不同,都可以达到山顶 - 通向山顶的路有多条。另外一方面,我们可能来到了不同的山顶。

在我们爬山的过程中,如何选择下一步呢?这时,就是微积分大显身手的时候了。

在机器学习中,对参数优化的过程,使用了大量微积分的运算,PyTorch 能成为通用性的机器学习框架,就在于不同的机器学习任务底层的数学原理是一致的,而 PyTorch 内置了这些标准化的数学运算,在 PyTorch 中,除了 Tensor 外,还有两个关键的概念:

  • 计算图
  • 自动求导

Computational Graphs 计算图

神经网络是由很多神经元组成的网络,最简单的神经网络就是只包含一个线性神经元的神经网络,理解这个最简单的神经网络,有助于理解任何复杂的神经网络。

z = x ∗ w + b z = x * w + b z=xw+b

注意:这里没有添加激活函数,这个神经元是一个简单的线性神经元。
在这里插入图片描述
计算过程:

  1. 加权输出 z 与理想输出 y 之间,使用交叉熵(CE)计算出损失(loss)
  2. 然后基于 loss 计算梯度 grad
  3. 基于梯度更新 w 和 b

这个计算过程,可以用一张图表达,一个图就是由节点以及边组成,边上定义操作符。同时,这个计算过程会在训练中发生多次,因为梯度下降算法是 SGD 迭代运算。

PyTorch 为了让每次运算可以更灵活,比如使用 Dropout 随机丢弃一些神经元,PyTorch 实现了每次运算动态的生成这张图 - 动态计算图1。也就是说,对于每次运算,PyTorch 会生成一个计算图并附着计算状态。

Autograd 自动求导

附着状态,最主要的目的就是实现自动求导。因为每个节点都是一个变量,变量和变量之间通过操作符相互依赖,而操作符和变量构成的函数式,就可以实现求导,根据链式法则,实现计算图中,每个变量的导数的计算。

在上图,只有一个线性神经元的情况下,PyTorch 的自动求导是如何工作的呢?参考下面的代码。

import torch

# 定义输入和理想输出
x = torch.ones(5)   # input tensor
y = torch.zeros(3)  # expected output

# 定义参数
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)

# 定义模型,并进行一次运算
z = torch.matmul(x, w)+b

# 定义损失函数,并得到单次的损失
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

# 进行反向传播,并得到梯度
loss.backward()
print(w.grad)
print(b.grad)

如此一来,参数更新将变得非常简单。计算图允许每次迭代传入不同的操作符等,实现训练过程更灵活的配置。计算图保留了运算过程中的 Tensor、操作符、操作符对应的导函数。当 loss.backward() 调用时,顺序的调用自动求导变量的导函数,得到 .grad 梯度值。

一个训练过程及 no_grad 的使用

现在我们看一个例子,通过一个简单的模型,了解训练中,自动求导机制是如何工作的。

示例代码


'''
autograd
'''
import plotly.graph_objects as go
import plotly.express as px
from torch import nn
import numpy as np
import torch
import math

# 输入变量 x,理想输出 yt(生成 y 的函数就是要拟合的模型) 
X  = torch.tensor(np.linspace(-10, 10, 1000))
y  = 1.5 * torch.sin(X) + 1.2 * torch.cos(X/4) # 真实的模型
yt = y + np.random.normal(0, 1, 1000)

# vis
def plotter(X, y, yhat=None, title=None):
    with torch.no_grad():
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=X, y=y, mode='lines',    name='y'))
        fig.add_trace(go.Scatter(x=X, y=yt, mode='markers', marker=dict(size=4), name='yt'))
        if yhat is not None: fig.add_trace(go.Scatter(x=X, y=yhat, mode='lines', name='yhat'))
        fig.update_layout(template='none', title=title)
        fig.show()

plotter(X, y, title='Data Generating Process')

# 计算模型的实际输出,这里前提是假设知道变量 X 和函数 sin|cos, 而不知道参数 theta
def fit_model(theta:torch.tensor=torch.rand(3, requires_grad=True)):
    return theta[0] * X + theta[1] * torch.sin(X) + theta[2] * torch.cos(X/4)

# 随机初始化参数,开启自动求导
theta = torch.randn(3, requires_grad=True)

# 损失函数和优化器
loss_fn  = nn.MSELoss()                         # MSE loss
optimizer = torch.optim.SGD([theta], lr=0.01)   # build optimizer 

# 迭代训练
epochs = 500
for i in range(epochs):
    yhat = fit_model(theta)  # 计算实际输出
    loss = loss_fn(y, yhat)  # 将实际输出和理想输出传入损失函数,得到损失 loss
    loss.backward()          # 反向传播,完成 .grad 梯度的计算
    optimizer.step()         # 基于梯度完成参数更新 
    optimizer.zero_grad()    # 本轮计算完成,将梯度值归零,否则下次计算损失并调用 backward 导致梯度累计 
    if i % (epochs/10) == 0: # 验证及输出调试信息 
        msg = f"loss: {loss.item():>7f} theta: {theta.detach().numpy()}"
        yhat = fit_model(theta)
        plotter(X, y, yhat.detach(), title=f"loss: {loss.item():>7f} theta: {theta.detach().numpy().round(3)}")

执行结果

生成数据

创建了一个假数据:

  • 分布在象限中的点就是 x,y
  • 象限中的曲线,就是符合设想的模型,我们看最终的机器学习的模型,能否拟合这条曲线
    在这里插入图片描述

第一轮后

初始化后,实际模型和理想模型差距很大。注意,此时 theta 和目标参数差距很大。
在这里插入图片描述

第二轮后

经过两次迭代,差距在缩小。

在这里插入图片描述

第十轮后

又经过了几轮训练,此时,我们发现图中已经分辨不出来,但是从 theta 的值,我们还可以看到一点差距,这已经证明,机器学习拟合上了目标空间。
在这里插入图片描述

更多计算图的知识

更为复杂点的计算图的样子

在训练中,生成的 DAG 类似如下。
在这里插入图片描述

自动求导有关的参数

# 做一个计算图
x = torch.rand(1)
b = torch.rand(1, requires_grad=True)
w = torch.rand(1, requires_grad=True)
y = w * x  # y 是一个新的 tensor

# 检查 y 是否是叶子节点,这里 y 是输出,也就是 root 节点而不是 leaf 节点
print(y.is_leaf)

# 反向传播 
y.backward(retain_graph=True)  # retain_graph=True,保留计算图中的状态,https://discuss.pytorch.org/t/use-of-retain-graph-true/179658
print(w.grad) # 查看梯度

Links

  • How Computational Graphs are Constructed in PyTorch
  • How Computational Graphs are Executed in PyTorch
  • PyTorch’s Dynamic Graphs (Autograd)
  • Automatic Differentiation with torch.autograd
  • Autograd mechanics

  1. PyTorch 使用 DAG 有向无环图这种格式存储计算图,其中输入的 Tensor 称为叶子节点(leaves),输出的 Tensor 称为根节点(roots)。 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/964432.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C++11详解(一) -- 列表初始化,右值引用和移动语义

文章目录 1.列表初始化1.1 C98传统的{}1.2 C11中的{}1.3 C11中的std::initializer_list 2.右值引用和移动语义2.1左值和右值2.2左值引用和右值引用2.3 引用延长生命周期2.4左值和右值的参数匹配问题2.5右值引用和移动语义的使用场景2.5.1左值引用主要使用场景2.5.2移动构造和移…

Spring Boot常用注解深度解析:从入门到精通

今天,这篇文章带你将深入理解Spring Boot中30常用注解,通过代码示例和关系图,帮助你彻底掌握Spring核心注解的使用场景和内在联系。 一、启动类与核心注解 1.1 SpringBootApplication 组合注解: SpringBootApplication Confi…

生成式AI安全最佳实践 - 抵御OWASP Top 10攻击 (下)

今天小李哥将开启全新的技术分享系列,为大家介绍生成式AI的安全解决方案设计方法和最佳实践。近年来生成式 AI 安全市场正迅速发展。据IDC预测,到2025年全球 AI 安全解决方案市场规模将突破200亿美元,年复合增长率超过30%,而Gartn…

git:恢复纯版本库

初级代码游戏的专栏介绍与文章目录-CSDN博客 我的github:codetoys,所有代码都将会位于ctfc库中。已经放入库中我会指出在库中的位置。 这些代码大部分以Linux为目标但部分代码是纯C的,可以在任何平台上使用。 源码指引:github源…

蓝桥杯python基础算法(2-1)——排序

目录 一、排序 二、例题 P3225——宝藏排序Ⅰ 三、各种排序比较 四、例题 P3226——宝藏排序Ⅱ 一、排序 (一)冒泡排序 基本思想:比较相邻的元素,如果顺序错误就把它们交换过来。 (二)选择排序 基本思想…

python学opencv|读取图像(五十四)使用cv2.blur()函数实现图像像素均值处理

【1】引言 前序学习进程中,对图像的操作均基于各个像素点上的BGR值不同而展开。 对于彩色图像,每个像素点上的BGR值为三个整数,因为是三通道图像;对于灰度图像,各个像素上的BGR值是一个整数,因为这是单通…

Slint的学习

Slint是什么 Slint是一个跨平台的UI工具包,支持windows,linux,android,ios,web,可以用它来构建申明式UI,后端代码支持rust,c,python,nodejs等语言。 开源地址:https://github.com/slint-ui/slint 镜像地址:https://kkgithub.com/…

惰性函数【Ⅱ】《事件绑定的自我修养:从青铜到王者的进化之路》

【Ⅱ】《事件绑定的自我修养:从青铜到王者的进化之路》 1. 代码功能大白话(给室友讲明白版) // 青铜写法:每次都要问浏览器"你行不行?" function addEvent青铜版(element, type, handler) {if (window.add…

Unity飞行代码 超仿真 保姆级教程

本文使用Rigidbody控制飞机,基本不会穿模。 效果 飞行效果 这是一条优雅的广告 如果你也在开发飞机大战等类型的飞行游戏,欢迎在主页搜索博文并参考。 搜索词:Unity游戏(Assault空对地打击)开发。 脚本编写 首先是完整代码。 using System.Co…

基于微信小程序的私家车位共享系统设计与实现(LW+源码+讲解)

专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…

C++编程语言:抽象机制:模板(Bjarne Stroustrup)

目录 23.1 引言和概观(Introduction and Overview) 23.2 一个简单的字符串模板(A Simple String Template) 23.2.1 模板的定义(Defining a Template) 23.2.2 模板实例化(Template Instantiation) 23.3 类型检查(Type Checking) 23.3.1 类型等价(Type Equivalence) …

多线程的常用方法

getName和setName方法 注意点 setName方法最好放在线程启动之前 最好在线程启动之前修改名字,因为线程启动之后,如果执行过快的话,那么在调用 setName() 之前线程可能就已经结束了 MyThread t1 new MyThread("haha"); t1.setNa…

C++继承的基本意义

文章目录 一、继承的本质和原理二、重载、隐藏和覆盖三、基类与派生类的转换 一、继承的本质和原理 继承的本质:a. 代码的复用 b. 类和类之间的关系: 组合:a part of… 一部分的关系 继承:a kind of… 一种的关系 总结&#xff…

简单易懂的倒排索引详解

文章目录 简单易懂的倒排索引详解一、引言 简单易懂的倒排索引详解二、倒排索引的基本结构三、倒排索引的构建过程四、使用示例1、Mapper函数2、Reducer函数 五、总结 简单易懂的倒排索引详解 一、引言 倒排索引是一种广泛应用于搜索引擎和大数据处理中的数据结构,…

FinRobot:一个使用大型语言模型的金融应用开源AI代理平台

“FinRobot: An Open-Source AI Agent Platform for Financial Applications using Large Language Models” 论文地址:https://arxiv.org/pdf/2405.14767 Github地址:https://github.com/AI4Finance-Foundation/FinRobot 摘要 在金融领域与AI社区间&a…

Docker使用指南(一)——镜像相关操作详解(实战案例教学,适合小白跟学)

目录 1.镜像名的组成 2.镜像操作相关命令 镜像常用命令总结: 1. docker images 2. docker rmi 3. docker pull 4. docker push 5. docker save 6. docker load 7. docker tag 8. docker build 9. docker history 10. docker inspect 11. docker prune…

Qt跨屏窗口的一个Bug及解决方案

如果我们希望一个窗口覆盖用户的整个桌面,此时就要考虑用户有多个屏幕的场景(此窗口要横跨多个屏幕),由于每个屏幕的分辨率和缩放比例可能是不同的,Qt底层在为此窗口设置缩放比例(DevicePixelRatio&#xf…

Linux 传输层协议 UDP 和 TCP

UDP 协议 UDP 协议端格式 16 位 UDP 长度, 表示整个数据报(UDP 首部UDP 数据)的最大长度如果校验和出错, 就会直接丢弃 UDP 的特点 UDP 传输的过程类似于寄信 . 无连接: 知道对端的 IP 和端口号就直接进行传输, 不需要建立连接不可靠: 没有确认机制, 没有重传机制; 如果因…

安全实验作业

一 拓扑图 二 要求 1、R4为ISP,其上只能配置IP地址;R4与其他所有直连设备间均使用共有IP 2、R3-R5-R6-R7为MGRE环境,R3为中心站点; 3、整个OSPF环境IP基于172.16.0.0/16划分; 4、所有设备均可访问R4的环回&#x…

防御保护:安全策略配置

目录 一、实验拓扑 二、实验要求 ​编辑 三、要求分析 四、实验配置 前置配置 1.配置vlan与access、truck接口 2.进入web界面进行配置 3.安全策略的配置 3.1实现实验需求2(办公区PC在工作日时间(周一至周五,早8晚6)可以正常访问OA Server,其他时间不允许) 新建地址…