PyTorch翻译官网教程6-AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

官网链接

Automatic Differentiation with torch.autograd — PyTorch Tutorials 2.0.1+cu117 documentation

使用TORCH.AUTOGRAD 自动微分

当训练神经网络时,最常用的算法是方向传播算法。在该算法中,根据损失函数与给定参数的梯度来调整模型参数(权重)。

为了计算这些梯度,PyTorch有一个内置的微分引擎,名为torch.autograd。它支持任何计算图的梯度自动计算。

考虑最简单的单层神经网络,输入x,参数w和b,以及一些损失函数。它可以在PyTorch中以以下方式定义:

import torch

x = torch.ones(5)  # input tensor
y = torch.zeros(3)  # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)

张量、函数与计算图

这段代码定义了以下计算图:

在这个网络中,w和b是我们需要优化的参数。因此,我们需要能够计算损失函数相对于这些变量的梯度。为了做到这一点,我们设置了这些张量的requires_grad属性。

注意:

您可以在创建张量时设置requires_grad的值,或者稍后使用x.requires_grad_(True)方法设置。

我们使用张量来构造计算图的函数实际上是Function类的对象,该对象知道如何在正向方向上计算函数,以及如何在反向传播步骤中计算其导数。反向传播函数的引用存储在张量的grad_fn 属性中,你可以在文档中找到Function 的更多信息。Automatic differentiation package - torch.autograd — PyTorch 2.0 documentation

print(f"Gradient function for z = {z.grad_fn}")
print(f"Gradient function for loss = {loss.grad_fn}")

输出

Gradient function for z = <AddBackward0 object at 0x114113f70>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward0 object at 0x114113f70>


计算梯度

为了优化神经网络中参数的权重,我们需要计算损失函数对于参数的导数,即我们需要\frac{∂loss}{∂w} 和 \frac{∂loss}{∂b} 在x和y的固定值下。为了计算这些导数,我们调用loss.backward(),然后从w.gradb.grad中检索值:

loss.backward()
print(w.grad)
print(b.grad)

输出

tensor([[0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530],
        [0.3313, 0.0626, 0.2530]])
tensor([0.3313, 0.0626, 0.2530])

注意

  • 我们只能获得计算图的叶子节点的grad属性,当它的requires_grad属性设置为True时。对于图中的所有其他节点,梯度将不可用。
  • 出于性能原因,我们只能在给定的图上使用一次backward梯度计算。如果需要对同一个图进行多次backward调用,则需要将retain_graph=True 传递给backward调用。

禁用梯度跟踪

默认情况下,所有的requires_grad=True 的张量会自动跟踪它们的计算历史并支持梯度计算。然而,在一些情况下我们不需要这样做,例如,当我们完成了模型的训练,只想将其应用于一些测试数据时,即我们只想通过网络进行前向计算。我们可以通过使用torch.no_grad() 块包围我们的计算代码来停止跟踪计算:

z = torch.matmul(x, w)+b
print(z.requires_grad)

with torch.no_grad():
    z = torch.matmul(x, w)+b
print(z.requires_grad)

输出

True
False

实现相同结果的另一种方法是在张量上使用detach() 方法:

z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)

输出

False

你可能想要禁用梯度跟踪的原因如下:

  • 将神经网络中的一些参数标记为冻结参数
  • 当你只做正向传递时,为了加快计算速度,在不跟踪梯度的张量上的计算会更有效率。

更多关于计算图的知识

从概念上讲,autograd在由Function 的大小应该等于原始张量的大小,为了计算其乘积。

inp = torch.eye(4, 5, requires_grad=True)
out = (inp+1).pow(2).t()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"First call\n{inp.grad}")
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nSecond call\n{inp.grad}")
inp.grad.zero_()
out.backward(torch.ones_like(out), retain_graph=True)
print(f"\nCall after zeroing gradients\n{inp.grad}")
 

输出

First call
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

Second call
tensor([[8., 4., 4., 4., 4.],
        [4., 8., 4., 4., 4.],
        [4., 4., 8., 4., 4.],
        [4., 4., 4., 8., 4.]])

Call after zeroing gradients
tensor([[4., 2., 2., 2., 2.],
        [2., 4., 2., 2., 2.],
        [2., 2., 4., 2., 2.],
        [2., 2., 2., 4., 2.]])

注意,当我们使用相同的参数第二次调用backward时,梯度的值是不同的。这是因为在进行backward传播时,PyTorch会累积梯度的,即计算梯度的值被添加到计算图的所有叶节点的grad属性中。如果你想计算合适的梯度,你需要在此之前将grad属性归零。在现实训练中,优化器可以帮助我们做到这一点

注意

以前我们调用没有参数的backward()函数。这基本上相当于调用backward(torch.tensor(1.0)),这是在标量值函数的情况下计算梯度的有用方法,例如神经网络训练期间的损失。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/38288.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python程序设计基础:列表与元组(二)

文章目录 一、数值列表的生成1、通过input()函数输入创建列表2、通过list()函数转换3、列表生成式4、数值列表的几种统计计算 二、元组1、元组的定义2、元组的操作3、元组作为列表元素 三、转换函数1、元组和列表之间的转换2、字符串和列表之间的转换3、split()方法 一、数值列…

2024考研408-操作系统 第二章-进程与线程 学习笔记

文章目录 前言一、进程1.1、进程的概念、组成与特征1.1.1、进程的概念1.1.2、进程的组成认识PCB认识程序段与数据段&#xff08;包含进程实体概念&#xff09; 1.1.3、进程的特征知识回顾与重要考点 1.2、进程的状态、状态间的转换和组织方式1.2.1、进程的状态进程的五种状态详…

RandLA-Net 复现

GPU3090 CUDA12 1、代码 [github地址] git clone --depth1 https://github.com/QingyongHu/RandLA-Net && cd RandLA-Net 2、虚拟环境中配置&#xff1a; 在跑代码的时候出现错误&#xff1a;open3d.so文件中函数报错。查看open3d版本发现不是要求的0.3版本&#xff…

设计模式 ~ 工厂模式

工厂模式 工厂模式是一种设计模式&#xff0c;指在通过使用工厂函数或构造函数来创建对象&#xff1b; 它提供了一种灵活的方式来创建对象实例&#xff0c;而无需直接调用构造函数或使用new关键字&#xff1b; 可以分类&#xff0c;解耦&#xff1b; 可以扩展多个类&#xff0…

【Spring——Spring的基础与创建】

目录 &#x1f367;1. 什么是 Spring &#xff1f; &#x1fad6;1.1 容器 &#x1f359;1.2 IoC &#x1f97d;1.3 汽车类——传统写法 &#x1f358;1.4 汽车类——IoC 写法 &#x1f32d;2. 配置 maven 国内源 &#x1f32e;2.1 在设置中勾选文件 &#x1f364;2.2 在…

【网站建设】HTTP/HTTPS 是什么?有什么区别?

&#x1f680;欢迎来到本文&#x1f680; &#x1f349;个人简介&#xff1a;陈童学哦&#xff0c;目前学习C/C、算法、Java等方向&#xff0c;一个正在慢慢前行的普通人。 &#x1f3c0;系列专栏&#xff1a;陈童学的日记 &#x1f4a1;其他专栏&#xff1a;CSTL&#xff0c;感…

自定义程序包不存在的解决方法

方案一&#xff1a; 在pom文件中加入以下代码 <plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-surefire-plugin</artifactId><version>2.4.2</version><configuration><skipTests>true</sk…

openResty的Redis模块踩坑记录

OpenResty提供了操作Redis的模块&#xff0c;我们只要引入该模块就能直接使用。说是这样说&#xff0c;但是实践起来好像并不太顺利。 1.设置了密码的redis&#xff0c;lua业务逻辑中需要添加身份认证代码 网上很多资料、文章似乎都是没有设置redis密码&#xff0c;说来也奇怪…

fileinclude

前提知识&#xff1a; filter伪协议 include函数用php://filter伪协议来绕过 题目&#xff1a; 打开题目 页面显示如图&#xff0c;可以知道flag在flag.php中&#xff0c;还知道当前页面的绝对路径 先查看源代码 15行$lan用$_cookie传参&#xff0c;可以修改cookie值从而控制…

探索非洲专线物流的新时代_国际物流供应链管理平台_箱讯科技

随着全球化的发展&#xff0c;非洲作为一个充满机遇和挑战的大陆&#xff0c;吸引着越来越多的企业和投资者。然而&#xff0c;由于非洲的地理复杂性和基础设施不完善&#xff0c;物流问题一直是制约非洲发展的瓶颈之一。为了解决这一问题&#xff0c;非洲专线物流应运而生。本…

网络原理之传输层与网络层重点协议

目录 传输层重点协议 TCP协议 TCP协议段格式 TCP原理 确认应答机制&#xff08;安全机制&#xff09; 超时重传机制&#xff08;安全机制&#xff09; 连接管理机制&#xff08;安全机制&#xff09; 滑动窗口&#xff08;效率机制&#xff09; 流量控制&#xff08;安…

数字IC笔试面试常考问题及答案汇总(内含各岗位大厂题目)

经历了无数的笔试面试之后&#xff0c;不知道大家有没有发现数字IC的笔试面试还是有很多共通之处和规律可循的。所以一定要掌握笔试面试常考的问题。 数字IC笔试面试常考问题及答案汇总&#xff08;文末可领全部哦~&#xff09; 验证方向&#xff08;部分题目&#xff09; Q1…

数据可视化分析,2023结婚全品类消费趋势洞察报告

结婚消费与人们的关系密切相关。结婚是一个重要的人生事件&#xff0c;往往伴随着大量的消费。人们倾向于在婚礼仪式、婚纱摄影、宴会等方面进行豪华的投资&#xff0c;以展示社会地位和个人品味。此外&#xff0c;结婚还涉及到婚戒、婚庆、蜜月旅行等费用。然而&#xff0c;随…

Jenkins+Robot 接口自动化测试

目录 前言&#xff1a; 设计目标 项目说明 目录结构 配置 jenkins 1.安装插件 2.配置项目 前言&#xff1a; JenkinsRobot是一种常见的接口自动化测试方案&#xff0c;可以实现自动化的接口测试和持续集成。Jenkins是一个流行的持续集成工具&#xff0c;而Robot Framew…

【5G PHY】5G 调制与编码策略(MCS)介绍

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 本人就职于国际知名终端厂商&#xff0c;负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作&#xff0c;目前牵头6G算力网络技术标准研究。 博客…

Cisco学习笔记(CCNA)——Internetworking

Internetworking Internetworking Basics 什么是网络&#xff1f; 计算机网络&#xff1a;具有独立功能的多台计算机及其外部设备&#xff0c;通过通信线路连接起来 网络设备 Hub&#xff08;集线器&#xff09; 优点&#xff1a;便宜、操作简单 缺点&#xff1a;共享型、…

Kubernetes(K8s)常用命令大全:熟练编排更完美

&#x1f337;&#x1f341; 博主 libin9iOak带您 Go to New World.✨&#x1f341; &#x1f984; 个人主页——libin9iOak的博客&#x1f390; &#x1f433; 《面试题大全》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33…

react 升级

1、查看react版本 当前开发项目的react版本从哪里看呢&#xff1f;其实就在package.json文件中&#xff0c;搜索"react"&#xff0c;即可看到版本号 2、输入命令npm info react查看最新的react版本 3、执行命令 npm install --save react18.2.0 react-dom18.2.0 4…

Docker 部署 Jenkins (一)

Docker 部署 Jenkins (一) 一. 安装 jenkins $ mkdir -p /home/tester/data/docker/jenkins $ vim jenkins:lts-jdk11.sh./jenkins:lts-jdk11.sh 内容 #! /bin/bash mkdir -p /home/tester/data/docker/jenkins/jenkins_homesudo chown -R 1000:1000 /home/tester/data/dock…

2023年Java最新面试题

由【后端面试题宝典】提供 和 equals 的区别是什么&#xff1f; 对于基本类型&#xff0c;比较的是值&#xff1b;对于引用类型&#xff0c;比较的是地址&#xff1b;equals不能用于基本类型的比较&#xff1b;如果没有重写equals&#xff0c;equals就相当于&#xff1b;如果重…