PyTorch 中使用自动求导计算梯度

使用 PyTorch 进行自动求导和梯度计算

在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效地追踪和计算梯度。

1、梯度的定义

在数学中,梯度是一个向量,表示函数在某一点的变化率。在深度学习中,我们通常关心的是损失函数相对于模型参数的梯度。具体来说,假设我们有一个输出 out,我们计算的是损失函数对模型参数(如权重和偏置)的梯度,而不是直接对输出的梯度。

2、 简单例子

在我们接下来的例子中,我们将计算 out 相对于输入变量 x x x y y y的梯度,通常表示为 ( d out d x ) ( \frac{d \text{out}}{dx}) (dxdout) ( d out d y ) ( \frac{d \text{out}}{dy}) (dydout)

import torch

# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True)  # 输入变量 x
y = torch.tensor(3.0, requires_grad=True)  # 输入变量 y

# 2. 定义第一个函数 f(z) = z^2
def f(z):
    return z**2

# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):
    z = x + y  # 中间变量 z
    z_no_grad = z.detach()  # 创建不需要梯度的副本
    return f(z_no_grad) + y**3  # 输出 out = f(z_no_grad) + y^3

# 4. 计算输出
out = g(x, y)  # 计算输出

# 5. 反向传播以计算梯度
out.backward()  # 计算梯度

# 6. 打印梯度
print(f"dz/dx: {x.grad}")  # 输出 x 的梯度
print(f"dz/dy: {y.grad}")  # 输出 y 的梯度
dout/dx: None
dout/dy: 27.0
import torch

# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True)  # 输入变量 x
y = torch.tensor(3.0, requires_grad=True)  # 输入变量 y


# 2. 定义第一个函数 f(z) = z^2
def f(z):
    return z ** 2


# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):
    z = x + y  # 中间变量 z
    return f(z) + y ** 3  # 输出 out = f(z_no_grad) + y^3


# 4. 计算输出
out = g(x, y)  # 计算输出

# 5. 反向传播以计算梯度
out.backward()  # 计算梯度

# 6. 打印梯度
print(f"dout/dx: {x.grad}")  # 输出 x 的梯度
print(f"dout/dy: {y.grad}")  # 输出 y 的梯度
dout/dx: 10.0
dout/dy: 37.0

在这两个代码示例中,dout/dxdout/dy 的值存在显著差异,主要原因在于如何处理中间变量 ( z ) 以及其对最终输出 out 的影响。

结果分析

  1. 第一部分代码

    • g(x, y) 函数中,使用了 z . detach ( ) z.\text{detach}() z.detach() 创建了一个不需要梯度的副本 z no_grad z_{\text{no\_grad}} zno_grad。这意味着在计算 f ( z no_grad ) f(z_{\text{no\_grad}}) f(zno_grad) 时,PyTorch 不会将 z z z 的变化记录进计算图中。

    • 因此, z z z out \text{out} out 的影响被切断,导致
      d out d x = None \frac{d \text{out}}{d x} = \text{None} dxdout=None
      因为 x x x 的变化不会影响到 out \text{out} out 的计算。

    • 对于 y y y,计算得到的梯度为
      d out d y = 27.0 \frac{d \text{out}}{d y} = 27.0 dydout=27.0
      这是通过以下步骤得到的:

    • 输出为
      out = f ( z no_grad ) + y 3 \text{out} = f(z_{\text{no\_grad}}) + y^3 out=f(zno_grad)+y3

    • 使用链式法则:
      d out d y = 0 + 3 y 2 = 3 ( 3 2 ) = 27 \frac{d \text{out}}{d y} = 0 + 3y^2 = 3(3^2) = 27 dydout=0+3y2=3(32)=27

  2. 第二部分代码

    • g(x, y) 函数中,直接使用了 z z z 而没有使用 z . detach ( ) z.\text{detach}() z.detach()。这使得 z z z 的变化会被记录在计算图中。
    • 计算
      d out d x \frac{d \text{out}}{d x} dxdout
      时, z = x + y z = x + y z=x+y 的变化会影响到 out \text{out} out,因此计算得到的梯度为
      d out d x = 10.0 \frac{d \text{out}}{d x} = 10.0 dxdout=10.0
      这是因为:
    • f ( z ) = z 2 f(z) = z^2 f(z)=z2 的导数为
      d f ( z ) d z = 2 z \frac{d f(z)}{d z} = 2z dzdf(z)=2z
      z = 5 z = 5 z=5(当 x = 2 , y = 3 x=2, y=3 x=2,y=3 时),所以
      2 z = 10 2z = 10 2z=10
    • 对于 y y y,计算得到的梯度为
      d out d y = 37.0 \frac{d \text{out}}{d y} = 37.0 dydout=37.0
      这是因为
      d out d y = d ( f ( z ) + y 3 ) d y = 2 z ⋅ d z d y + 3 y 2 = 2 ( 5 ) ( 1 ) + 3 ( 3 2 ) = 10 + 27 = 37 \frac{d \text{out}}{d y} = \frac{d (f(z) + y^3)}{d y} = 2z \cdot \frac{d z}{d y} + 3y^2 = 2(5)(1) + 3(3^2) = 10 + 27 = 37 dydout=dyd(f(z)+y3)=2zdydz+3y2=2(5)(1)+3(32)=10+27=37

3、线性拟合及梯度计算

在深度学习中,线性回归是最基本的模型之一。通过线性回归,我们可以找到输入特征与输出之间的线性关系。在本文中,我们将使用 PyTorch 实现一个简单的线性拟合模型,定义模型为 y = a x + b x + c + d y = ax + bx + c + d y=ax+bx+c+d,并展示如何计算梯度,同时控制某些参数(如 b b b d d d)不更新梯度。
在这个模型中,我们将定义以下参数:

  • a a a:斜率,表示输入 x x x 对输出 y y y 的影响。
  • b b b:另一个斜率,表示输入 x x x 对输出 y y y 的影响,但在训练过程中不更新。
  • c c c:截距,表示当 x = 0 x=0 x=0 时的输出值。
  • d d d:一个常数项,在训练过程中不更新。

3.1、完整代码

下面是实现线性拟合的完整代码:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# 1. 创建数据
# 假设我们有一些样本数据
x_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = torch.tensor([3.0, 5.0, 7.0, 9.0, 11.0])  # 目标值

# 2. 定义线性模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.a = nn.Parameter(torch.tensor(1.0))  # 需要更新的参数
        self.b = nn.Parameter(torch.tensor(0.5), requires_grad=False)  # 不需要更新的参数
        self.c = nn.Parameter(torch.tensor(0.0))  # 需要更新的参数
        self.d = nn.Parameter(torch.tensor(0.5), requires_grad=False)  # 不需要更新的参数

    def forward(self, x):
        return self.a * x + self.b * x + self.c + self.d

# 3. 实例化模型
model = LinearModel()

# 4. 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.005)  # 随机梯度下降优化器

# 5. 训练模型
for epoch in range(5000):
    model.train()  # 设置模型为训练模式

    # 计算模型输出
    y_pred = model(x_data)

    # 计算损失
    loss = criterion(y_pred, y_data)

    # 反向传播
    optimizer.zero_grad()  # 清零梯度
    loss.backward()  # 计算梯度
    optimizer.step()  # 更新参数

    # 每10个epoch打印一次loss和参数值
    if (epoch + 1) % 500 == 0:
        print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}, a: {model.a.item():.4f}, b: {model.b.item():.4f}, c: {model.c.item():.4f}, d: {model.d.item():.4f}')

# 6. 打印最终参数
print(f'Final parameters: a = {model.a.item()}, b = {model.b.item()}, c = {model.c.item()}, d = {model.d.item()}')

# 7. 绘制拟合结果
with torch.no_grad():
    # 生成用于绘图的 x 值
    x_fit = torch.linspace(0, 6, 100)  # 从 0 到 6 生成 100 个点
    y_fit = model(x_fit)  # 计算对应的 y 值

# 绘制真实数据点
plt.scatter(x_data.numpy(), y_data.numpy(), color='red', label='True Data')
# 绘制拟合曲线
plt.plot(x_fit.numpy(), y_fit.numpy(), color='blue', label='Fitted Curve')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Linear Fit Result')
plt.legend()
plt.grid()
plt.show()

3.2、梯度计算过程

在这个例子中,我们使用了 PyTorch 的自动求导功能来计算梯度。以下是对每个参数的梯度计算过程的解释:

  1. 参数定义

    • a a a c c c 是需要更新的参数,因此它们的 requires_grad 属性默认为 True
    • b b b d d d 是不需要更新的参数,设置了 requires_grad=False,因此它们的梯度不会被计算。
  2. 损失计算

    • 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
      loss = 1 n ∑ i = 1 n ( y pred , i − y i ) 2 \text{loss} = \frac{1}{n} \sum_{i=1}^{n} (y_{\text{pred},i} - y_{i})^2 loss=n1i=1n(ypred,iyi)2
  3. 反向传播

    • 调用 loss.backward() 计算所有参数的梯度。由于 b b b d d drequires_grad 被设置为 False,因此它们的梯度不会被计算和更新。
  4. 参数更新

    • 使用优化器 optimizer.step() 更新参数。只有 a a a c c c 会被更新。
Epoch [500/100], Loss: 0.0038, a: 1.5399, b: 0.5000, c: 0.3559, d: 0.5000
Epoch [1000/100], Loss: 0.0007, a: 1.5171, b: 0.5000, c: 0.4382, d: 0.5000
Epoch [1500/100], Loss: 0.0001, a: 1.5073, b: 0.5000, c: 0.4735, d: 0.5000
Epoch [2000/100], Loss: 0.0000, a: 1.5032, b: 0.5000, c: 0.4886, d: 0.5000
Epoch [2500/100], Loss: 0.0000, a: 1.5014, b: 0.5000, c: 0.4951, d: 0.5000
Epoch [3000/100], Loss: 0.0000, a: 1.5006, b: 0.5000, c: 0.4979, d: 0.5000
Epoch [3500/100], Loss: 0.0000, a: 1.5002, b: 0.5000, c: 0.4991, d: 0.5000
Epoch [4000/100], Loss: 0.0000, a: 1.5001, b: 0.5000, c: 0.4996, d: 0.5000
Epoch [4500/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4998, d: 0.5000
Epoch [5000/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4999, d: 0.5000
Final parameters: a = 1.5000202655792236, b = 0.5, c = 0.4999275505542755, d = 0.5

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/918375.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

安装pytest失败ModuleNotFoundError: No module named ‘distutils‘

下载一下即可解决 pip install setuptools 下载完成后,再进行下载 pip install pytest

数据结构树和二叉树知识点和递归序列

二叉树知识点 一.树的概念1.1关于树的名词解释 二.二叉树的概念1. 二叉树性质: 三.满二叉树与完全二叉树递归前序遍历递归中序遍历递归后续遍历 一.树的概念 树是一种非线性数据结构,它是由n个或大于n个的结点来组成具有层次关系的一个集合(…

【汇编语言】数据处理的两个基本问题(二) —— 解密汇编语言:数据长度与寻址方式的综合应用

文章目录 前言1. 指令要处理的数据有多长?1.1 通过寄存器指明数据的尺寸1.1.1 字操作1.1.2 字节操作 1.2 用操作符X ptr指明内存单元的长度1.2.1 访问字单元1.2.2 访问字节单元1.2.3 为什么要用操作符X ptr指明 1.3 其他方法 2. 寻址方式的综合应用2.1 问题背景&…

【ArcGIS微课1000例】0130:图层组详解与使用

文章目录 一、图层组概述二、创建图层组三、在图层组中管理图层四、对话框中图层组的列表一、图层组概述 图层组包含其他图层。图层组有助于对地图中相关类型的图层进行组织,并且可用于定义高级绘制选项。例如,假设在地图上有两个图层分别用于表示铁路和高速公路。您可将这些…

Cyberchef配合Wireshark提取并解析TCP/FTP流量数据包中的文件

前一篇文章中讲述了如何使用cyberchef提取HTTP/TLS数据包中的文件,详见《Cyberchef配合Wireshark提取并解析HTTP/TLS流量数据包中的文件》,链接这里,本文讲述下如何使用cyberchef提取FTP/TCP数据包中的文件。 FTP 是最为常见的文件传输协议,和HTTP协议不同的是FTP协议传输…

SpringBoot多环境配置的实现

前言 开发过程中必然使用到的多环境案例,通过简单的案例分析多环境配置的实现过程。 一、案例 1.1主配置文件 spring:profiles:active: prod server:port: 80801.2多环境配置文件 开发环境 blog:domain: http://localhost:8080测试环境 blog:domain: https:/…

本草纲目数字化:Spring Boot在中药实验管理中的应用

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理中药实验管理系统的相关信息成为必然。开发…

linux文件与重定向

目录 一、共识原理 二、回顾C语言文件函数 1.fopen 2.fwrite 3.fclose 三、文件系统调用 1.open 2.write 3.访问文件的本质 4.stdin&&stdout&&stderror 5.文件的引用计数 四、重定向 1.文件描述符的分配规则 2. 输出重定向 3.重定向系统调用 4.…

【微服务】SpringBoot 整合ELK使用详解

目录 一、前言 二、为什么需要ELK 三、ELK介绍 3.1 什么是elk 3.2 elk工作原理 四、ELK搭建 4.1 搭建es环境 4.1.1 获取es镜像 4.1.2 启动es容器 4.1.3 配置es参数 4.1.4 重启es容器并访问 4.2 搭建kibana 4.2.1 拉取kibana镜像 4.2.2 启动kibana容器 4.2.3 修改…

基于YOLOv8深度学习的汽车车身车损检测系统研究与实现(PyQt5界面+数据集+训练代码)

本文研究并实现了一种基于YOLOV8深度学习模型的汽车车身车损检测系统,旨在解决传统车损检测中效率低、精度不高的问题。该系统利用YOLOV8的目标检测能力,在单张图像上实现了车身损坏区域的精确识别和分类,尤其是在车身凹痕、车身裂纹和车身划…

ui->tableView升序

亮点 //设置可排序ui->tableView->setSortingEnabled(true);ui->tableView->sortByColumn(0,Qt::AscendingOrder); //排序void Widget::initTable() {//设置焦点策略:ui->tableView->setFocusPolicy(Qt::NoFocus);//显示网格线:ui->tableView->se…

Android Framework AMS(16)进程管理

该系列文章总纲链接:专题总纲目录 Android Framework 总纲 本章关键点总结 & 说明: 说明:本章节主要解读AMS 进程方面的知识。关注思维导图中左上侧部分即可。 我们本章节主要是对Android进程管理相关知识有一个基本的了解。先来了解下L…

QT_CONFIG宏使用

时常在Qt代码中看到QT_CONFIG宏,之前以为和#define、DEFINES 差不多,看了定义才发现不是那么回事,定义如下: 看注释就知道了QT_CONFIG宏,其实是:实现了一个在编译时期安全检查,检查指定的Qt特性…

Spring Boot教程之Spring Boot简介

Spring Boot 简介 接下来一段时间,我会持续发布并完成Spring Boot教程 Spring 被广泛用于创建可扩展的应用程序。对于 Web 应用程序,Spring 提供了 Spring MVC,它是 Spring 的一个广泛使用的模块,用于创建可扩展的 Web 应用程序。…

Vue2教程001:初识Vue

文章目录 1、初识Vue1.1、Vue2前言1.2、创建Vue实例1.3、插值表达式1.4 Vue响应式特性 1、初识Vue 1.1、Vue2前言 Vue是什么? 概念:Vue是一个用于构建用户界面的渐进式框架。 Vue的两种使用方式: Vue核心包开发 场景:局部模块…

Linux手动安装nginx

本次以安装nginx-1.12.2为例 1、首先说明一下,安装nginx之前需要安装如下素材: 2、开始安装 第一步,安装依赖 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-devel第二步,下载并安装nginx安装包(n…

Element-ui Select选择器自定义搜索方法

效果图 具体实现 <template><div class"home"><el-selectref"currencySelect"v-model"currency"filterable:spellcheck"false"placeholder"请选择":filter-method"handleCurrencyFilter"change&q…

leetcode-44-通配符匹配

题解&#xff1a; 代码&#xff1a; 参考&#xff1a; (1)牛客华为机试HJ71字符串通配符 (2)leetcode-10-正则表达式匹配

C/C++中使用MYSQL

首先要保证下载好mysql的库和头文件&#xff0c;头文件在/usr/include/mysql/目录下&#xff0c;库在/usr/lib64/mysql/目录下&#xff1a; 一般情况下&#xff0c;在我们安装mysql的时候&#xff0c;这些都提前配置好了&#xff0c;如果没有就重装一下mysql。如果重装mysql还是…

Tryhackme练习-Wonderland

基本信息 由于tryhackme是在线靶场&#xff0c;所以这里的IP均为对方的内网IP 攻击机器&#xff1a;10.10.242.186 靶机&#xff1a;10.10.173.3 目标&#xff1a;获取2个flagroot权限 具体流程 信息收集 首先我们使用fscan进行端口扫描&#xff0c;fscan -h 10.10.173.…