【深度学习】Pytorch教程(十三):PyTorch数据结构:5、张量的梯度计算:变量(Variable)、自动微分、计算图及其可视化

文章目录

  • 一、前言
  • 二、实验环境
  • 三、PyTorch数据结构
    • 1、Tensor(张量)
      • 1. 维度(Dimensions)
      • 2. 数据类型(Data Types)
      • 3. GPU加速(GPU Acceleration)
    • 2、张量的数学运算
      • 1. 向量运算
      • 2. 矩阵运算
      • 3. 向量范数、矩阵范数、与谱半径详解
      • 4. 一维卷积运算
      • 5. 二维卷积运算
      • 6. 高维张量
    • 3、张量的统计计算
    • 4、张量操作
      • 1. 张量变形
      • 2. 索引
      • 3. 切片
      • 4. 张量修改
    • 5、张量的梯度计算
      • 0. 变量(Variable)
      • 1. 自动微分
        • a. 单参数
        • b. 多参数
      • 2. 计算图
      • 3. 神经网络模型
      • 4. 可视化计算图结构

一、前言

  本文将介绍张量的梯度计算,包括变量(Variable)、自动微分、计算图及其可视化等

二、实验环境

  本系列实验使用如下环境

conda create -n DL python==3.11
conda activate DL
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
conda install pydot

三、PyTorch数据结构

1、Tensor(张量)

  Tensor(张量)是PyTorch中用于表示多维数据的主要数据结构,类似于多维数组,可以存储和操作数字数据。

1. 维度(Dimensions)

  Tensor(张量)的维度(Dimensions)是指张量的轴数或阶数。在PyTorch中,可以使用size()方法获取张量的维度信息,使用dim()方法获取张量的轴数。

在这里插入图片描述

2. 数据类型(Data Types)

  PyTorch中的张量可以具有不同的数据类型:

  • torch.float32或torch.float:32位浮点数张量。
  • torch.float64或torch.double:64位浮点数张量。
  • torch.float16或torch.half:16位浮点数张量。
  • torch.int8:8位整数张量。
  • torch.int16或torch.short:16位整数张量。
  • torch.int32或torch.int:32位整数张量。
  • torch.int64或torch.long:64位整数张量。
  • torch.bool:布尔张量,存储True或False。

【深度学习】Pytorch 系列教程(一):PyTorch数据结构:1、Tensor(张量)及其维度(Dimensions)、数据类型(Data Types)

3. GPU加速(GPU Acceleration)

【深度学习】Pytorch 系列教程(二):PyTorch数据结构:1、Tensor(张量): GPU加速(GPU Acceleration)

2、张量的数学运算

  PyTorch提供了丰富的操作函数,用于对Tensor进行各种操作,如数学运算、统计计算、张量变形、索引和切片等。这些操作函数能够高效地利用GPU进行并行计算,加速模型训练过程。

1. 向量运算

【深度学习】Pytorch 系列教程(三):PyTorch数据结构:2、张量的数学运算(1):向量运算(加减乘除、数乘、内积、外积、范数、广播机制)

2. 矩阵运算

【深度学习】Pytorch 系列教程(四):PyTorch数据结构:2、张量的数学运算(2):矩阵运算及其数学原理(基础运算、转置、行列式、迹、伴随矩阵、逆、特征值和特征向量)

3. 向量范数、矩阵范数、与谱半径详解

【深度学习】Pytorch 系列教程(五):PyTorch数据结构:2、张量的数学运算(3):向量范数(0、1、2、p、无穷)、矩阵范数(弗罗贝尼乌斯、列和、行和、谱范数、核范数)与谱半径详解

4. 一维卷积运算

【深度学习】Pytorch 系列教程(六):PyTorch数据结构:2、张量的数学运算(4):一维卷积及其数学原理(步长stride、零填充pad;宽卷积、窄卷积、等宽卷积;卷积运算与互相关运算)

5. 二维卷积运算

【深度学习】Pytorch 系列教程(七):PyTorch数据结构:2、张量的数学运算(5):二维卷积及其数学原理

6. 高维张量

3、张量的统计计算

【深度学习】Pytorch教程(九):PyTorch数据结构:3、张量的统计计算详解

4、张量操作

1. 张量变形

【深度学习】Pytorch教程(十):PyTorch数据结构:4、张量操作(1):张量变形操作

2. 索引

3. 切片

【深度学习】Pytorch 教程(十一):PyTorch数据结构:4、张量操作(2):索引和切片操作

4. 张量修改

【深度学习】Pytorch 教程(十二):PyTorch数据结构:4、张量操作(3):张量修改操作(拆分、拓展、修改)

5、张量的梯度计算

0. 变量(Variable)

  Variable(变量)是早期版本中的一种概念,用于自动求导(autograd)。从PyTorch 0.4.0版本开始,Variable已经被弃用,自动求导功能直接集成在张量(Tensor)中,因此不再需要显式地使用Variable。
  在早期版本的PyTorch中,Variable是一种包装张量的方式,它包含了张量的数据、梯度和其他与自动求导相关的信息。可以对Variable进行各种操作,就像操作张量一样,而且它会自动记录梯度信息。然后,通过调用.backward()方法,可以对Variable进行反向传播,计算梯度,并将梯度传播到相关的变量。

import torch
from torch.autograd import Variable
 
# 创建一个Variable
x = Variable(torch.tensor([2.0]), requires_grad=True)
 
# 定义一个计算图
y = x ** 2 + 3 * x + 1
 
# 进行反向传播
y.backward()
 
# 获取梯度
gradient = x.grad
print("梯度:", gradient)  # 输出: tensor([7.])

1. 自动微分

  PyTorch 使用自动微分机制来计算梯度,当定义一个 Tensor 对象时,可以通过设置 requires_grad=True 来告诉 PyTorch 跟踪相关的计算,并使用 backward() 方法来计算梯度:

a. 单参数
import torch


x = torch.tensor([2.0], requires_grad=True)


def f(x):
    return 5 * x ** 2 + 2 * x - 1


# 计算函数的输出
y = f(x)

# 使用backward()方法计算梯度
y.backward()

# 获取梯度值
gradient = x.grad
print(gradient)

输出:

tensor([22.])

y ′ = 10 x + 2 22 = 10 ∗ 2 + 2 y'=10x+2\\22=10*2+2 y=10x+222=102+2

b. 多参数
import torch


def f(x, w, b):
    return 1 / (torch.exp(-(w * x + b)) + 1)


# 设置输入参数
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 计算函数f(x;w,b)的值
result = f(x, w, b)

# 使用自动微分计算梯度
result.backward()
print(x.grad)  # 求关于x的梯度
print(w.grad)  # 求关于w的梯度
print(b.grad)  # 求关于b的梯度

在这里插入图片描述

2. 计算图

  计算图是一种用来表示数学运算过程的图形化结构,它将数学计算表达为节点和边的关系,提供了一种直观的方式来理解和推导复杂的数学运算过程。在深度学习中,计算图帮助我们理解模型的训练过程,直观地把握损失函数对模型参数的影响,同时为反向传播算法提供了理论基础。现代深度学习框架如 PyTorch 和 TensorFlow 都是基于计算图的理论基础构建出来的。
f ( x ; w , b ) = 1 e − ( w x + b ) + 1 f(x;w,b)=\frac{1}{e^{-(wx+b)}+1} f(x;w,b)=e(wx+b)+11 x = 1 , w = 0 , b = 0 x=1,w=0,b=0 x=1,w=0,b=0
在这里插入图片描述
在这里插入图片描述
  计算图包括两种节点:计算节点(Compute Node)和数据节点(Data Node)。

  • 数据节点:表示输入数据、参数或中间变量,在计算图中通常用圆形结点表示。数据节点始终是叶节点,它们没有任何输入,仅表示数据。

  • 计算节点:表示数学运算过程,它将输入的数据节点进行数学运算后输出结果。在计算图中通常用方形结点表示。计算节点可以有多个输入和一个输出。反向传播算法中的梯度计算正是通过计算节点来实现的。

一个完整的计算图可以分为正向传播和反向传播两个阶段:

  • 正向传播(Forward Propagation):输入数据经过计算节点逐层传播,最终得到输出结果

  • 反向传播(Backward Propagation):首先根据损失函数计算输出结果与真实标签之间的误差,然后利用链式法则,逐个计算每个计算节点对应的输入的梯度,最终得到参数的梯度信息

3. 神经网络模型

略~详见后文模块(Module) 部分

import torch


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x, w, b):
        return 1 / (torch.exp(-(w * x + b)) + 1)


model = MyModel()

# 设置输入参数
x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# Forward pass
output = model(x, w, b)

# Backward pass
output.backward()

# Access the gradients
print(x.grad)  # Gradient with respect to x
print(w.grad)  # Gradient with respect to w
print(b.grad)  # Gradient with respect to b

4. 可视化计算图结构

  以下内容完全由GPT生成:

import torch
import onnx
from onnx.tools.net_drawer import GetPydotGraph
import pydot


class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x, w, b):
        return 1 / (torch.exp(-(w * x + b)) + 1)


model = MyModel()

x = torch.tensor(1.0, requires_grad=True)
w = torch.tensor(0.0, requires_grad=True)
b = torch.tensor(0.0, requires_grad=True)

# 导出为ONNX格式
torch.onnx.export(model, (x, w, b), "f.onnx", verbose=True, input_names=["x", "w", "b"], output_names=["output"])

# 可视化计算图结构
model = onnx.load("f.onnx")
pydot_graph = GetPydotGraph(model.graph, name=model.graph.name, rankdir="TB"
                            # rankdir="LR"
                            )  # Set rankdir to "LR" for left to right orientation
pydot_graph.write_dot("f_computation_graph.dot")

# 将dot文件转换为PNG格式的图像
(graph,) = pydot.graph_from_dot_file('f_computation_graph.dot')
graph.write_png('f_computation_graph.png')

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/412072.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Netty NIO 非阻塞模式

1.概要 1.1 说明 使用非阻塞的模式,就可以用一个现场,处理多个客户端的请求了 1.2 要点 ssc.configureBlocking(false);if(sc!null){ sc.configureBlocking(false); channels.add(sc); }if(len>0){ byteBuffer.flip(); 2.代码 2.1 服务端代码 …

mini-spring|定义标记类型Aware接口,实现感知容器对象

**前言:**如果我们想获得 Spring 框架提供的 BeanFactory、ApplicationContext、BeanClassLoader等这些能力做一些扩展框架的使用时该怎么操作呢。所以我们本章节希望在 Spring 框架中提供一种能感知容器操作的接口,如果谁实现了这样的一个接口&#xff…

智慧城市与数字孪生:共创未来城市新篇章

一、引言 随着科技的飞速发展,智慧城市与数字孪生已成为现代城市建设的核心议题。智慧城市注重利用先进的信息通信技术,提升城市治理水平,改善市民生活品质。而数字孪生则通过建立物理城市与数字模型之间的连接,为城市管理、规划…

docker容器技术(2)

docker容器数据卷 什么是数据卷? 在Docker中,数据卷(Data Volumes)是一种特殊的目录,可以在容器和主机之间共享数据。它允许容器内的文件持久存在,并且可以被多个容器共享和访问。 数据卷的主要作用如下&am…

jdk21本地执行flink出现不兼容问题

环境说明:换电脑尝尝鲜,jdk,flink都是最新的,千辛万苦把之前的项目编译通过,跑一下之前的flink项目发现启动失败,啥都不说了上异常 Exception in thread "main" java.lang.IllegalAccessError: …

贝叶斯核机器回归拓展R包:bkmrhat

1.摘要 bkmrhat包是用于扩展bkmr包的贝叶斯核机器回归(Bayesian Kernel Machine Regression, BKMR)分析工具,支持多链推断和诊断。该包利用future, rstan, 和coda包的功能,提供了在贝叶斯半参数广义线性模型下进行identity链接和 …

【kubernetes】二进制部署k8s集群之cni网络插件flannel和calico工作原理

k8s集群的三种接口 k8s集群有三大接口: CRI:容器进行时接口,连接容器引擎--docker、containerd、cri-o、podman CNI:容器网络接口,用于连接网络插件如:flannel、calico、cilium CSI:容器存储…

SPI技术实现对比Java SPI、Spring SPI、Dubbo SPI

概念 SPI机制,全称为Service Provider Interface,是一种服务提供发现机制。 SPI的核心思想是面向接口编程,它允许程序员定义接口,并由第三方实现这些接口。在运行时,SPI机制能够发现并加载所有可用的实现&#xff0c…

文献速递:深度学习--深度学习方法用于帕金森病的脑电图诊断

文献速递:深度学习–深度学习方法用于帕金森病的脑电图诊断 01 文献速递介绍 人类大脑在出生时含有最多的神经细胞,也称为神经元。这些神经细胞无法像我们身体的其他细胞那样自我修复。随着年龄的增长,神经元逐渐死亡,因此变得…

docker小知识:linux环境安装docker

安装必要软件包,执行如下命令 yum install -y yum-utils device-mapper-persistent-data lvm2目的是确保在安装 Docker 之前,系统已经安装了必要的软件包和服务,以支持 Docker 的正常运行。设置yum源,添加Docker官方的CentOS存储…

Open CASCADE学习|GC_MakeArcOfCircle构造圆弧

目录 1、通过圆及圆的两个参数创建圆弧,参数为弧度角 2、通过圆及圆上的一点、圆的1个参数创建圆弧,参数为弧度角,Sense决定方向 3、通过圆及圆上的两个点创建圆弧,Sense决定方向 4、通过三点创建圆弧,最后一点应安…

Mysql 常用数据类型

数值型(整数)的基本使用 如何定义一个无符号的整数 数值型(bit)的使用 数值型(小数)的基本使用 字符串的基本使用 字符串使用细节 日期类型的基本使用

用html编写的小广告板

用html编写的小广告板 相关代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</tit…

【练习——打印每一位数】

打印一个数的每一位 举个例子&#xff1a;我们现在要求打印出123的每一位数字。我们需要去想123%10等于3&#xff0c;就可以把3单独打印出来了&#xff0c;然后再将123/10可以得到12&#xff0c;将12%10就可以打印出2&#xff0c;而我们最后想打印出1&#xff0c;只需要1%10就…

国内大型语言模型(LLM)的研发及突破性应用

随着人工智能技术的迅猛发展&#xff0c;大型语言模型&#xff08;LLM&#xff09;在国内外科技领域成为了热点话题。这些模型因其在文本生成、理解和处理方面的卓越能力&#xff0c;被广泛应用于各种行业和场景中。 在中国&#xff0c;一批人工智能公司在LLM的研发与应用方面…

科技云报道:黑马Groq单挑英伟达,AI芯片要变天?

科技云报道原创。 近一周来&#xff0c;大模型领域重磅产品接连推出&#xff1a;OpenAI发布“文字生视频”大模型Sora&#xff1b;Meta发布视频预测大模型 V-JEPA&#xff1b;谷歌发布大模型 Gemini 1.5 Pro&#xff0c;更毫无预兆地发布了开源模型Gemma… 难怪网友们感叹&am…

数据结构之栈的链表实现

数据结构之栈的链表实现 代码&#xff1a; #include<stdio.h> #include<stdbool.h> #include<stdlib.h> //链表节点定义 typedef struct node {int value;struct node* next; }Node; //入栈操作 bool push(Node** head, int val) {if (*head NULL){*head …

如何学习Arduino单片机

&#xff08;本文为简单介绍&#xff0c;内容源于网络&#xff09; 学习Arduino相关的网址和开源社区&#xff1a; Arduino官方文档: Arduino - HomeArduino Forum: Arduino ForumArduino Playground: Arduino Playground - HomePageGitHub: GitHub: Let’s build from here …

第十三天-mysql交互

目录 1.安装MySQL connector 方式1&#xff1a;直接安装 方式2&#xff1a;下载 2.创建链接 3.游标Cursor 4.事务控制 5. 数据库连接池 1. 使用 6.循环执行SQL语句 不了解mysql的可以先了解mysql基础 1.安装MySQL connector 1. MySQL connector 是MySQL官方驱动模块…

接口测试 —— Jmeter读取数据库数据作测试参数

1、添加Jdbc Request 2、添加ForEach控制器(右键线程组->逻辑控制器->ForEach控制器) ①输入变量的前缀&#xff1a;mobilephone&#xff1b; 从jdbc request设置的变量得知&#xff0c;我们要取的值为mobilephone_1、mobilephone_2、mobilephone_3......所以这里输入m…