PyTorch项目源码学习(3)——Module类初步学习

torch.nn.Module

Module类是用户使用torch来自定义网络模型的基础,Module的设计要求包括低耦合性,高模块化等等。一般来说,计算图上所有的子图都可以是Module的子类,包括卷积,激活函数,损失函数节点以及相邻节点组成的集合等等,注意这里的关键词是“节点”,Module族类在计算图中主要起到搭建结构的作用,而不涉及运算逻辑的具体实现。

要注意的是,Module类对象的children所指向的其他Module类对象,并不等同于计算图中的子节点。如果我们展开Module网络,得到的一般是树形结构而非DAG,Module网络需要经过其他工作才能转化为计算图。

源代码分析

成员分析

首先直接从前端入手,找到torch/nn/module目录,可以看到这个目录下主要存放Module及其子类的定义,如。我们首先找到module.py内Module的定义

阅读__init__ 函数,可以看到Module基类的主要私有成员,其中包括

指向本Module内带梯度的可学习参数的parameter
指向本Module内不需要学习的模型状态参数的buffer
其他临时参数
前向与反向过程的hook函数,这些函数在运行backward与forward时允许自定义其它额外工作
state_dict相关函数,state_dict保存了模型的状态,是模型写入磁盘与加载的主要方式
modules指向该模块内部的所有子模块

方法分析
结构相关
  • 子模块生成

首先从我们日常使用pytorch搭建网络的用法可以想到,应该先去__setattr__函数寻找建立子节点的入口。

下图折叠了几个分支,可以看到当我们运行self.c1 = Conv2d(…)时,将会进入1202行的分支内,并且判断新成员是否是Module类型,如果是则将其放入本对象的子模块字典内。

__setattr__ 内主要对Parameter,Module,特定name的Tensor(也就是buffer)等参数做特判,其他情况则调用object的属性设置流程。事实上,其他的类似方法(如getattr等)也是同样的流程。

  • 内部参数访问

对于存储于私有成员_module内的子模块,一般使用children方法进行调用

我们在外部所使用的xx.modules()方法,就是通过调用children方法实现的。

另外,nn.Module实现了许多对参数转化的方法,比如CPU(将内部参数转移到内存中),CUDA(将内部参数转移到显存中)以及type(将参数转化为指定类型),而这些是通过调用内部的_apply方法实现的

可以看到,_apply接受一个函数指针参数,并对所有的子模块递归地调用自己。然后对本Module内所有的Parameter与buffer应用该函数。

问题来了,既然每个节点都进行函数应用,那么如何避免对同一参数重复应用fn?这个问题的关键在于内部的Parameter到底是如何存储的。

印象里,我们在外部使用xx.parameters()时,得到的是xx模块的所有参数,看起来和上述代码里的_parameters并非直接取用的关系,我们可以看一下parameters()的实现

注意到默认参数recurse=True,相信大部分人已经明白原因了,我们继续看到named_parameters()

对_named_member方法传入了获取子模块_parameters字典键值对的匿名函数,继续看到_named_members()

可以看到具体流程是先递归或者不递归地获取该模块下的所有用户希望获得的东西(具体定义在第一个函数参数中),然后返回迭代器

这里1489行体现递归调用,原因是named_modules方法本身就是一个递归函数

在这里插入图片描述

事实上,named_parameters, named_buffers均是通过named_members进而调用named_modules方法实现的,_module成员体现网络结构的特殊性在这里可以窥见一二。另外可以看到,上述方法内都存在memo集合进行去重,确保不会返回相同的指针对象。

  • 简要流程图

在这里插入图片描述

参考文章

https://zhuanlan.zhihu.com/p/340453841

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/329578.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

几何_直线方程 Ax + By + C = 0 的系数A,B,C几何含义是?

参考: 直线方程 Ax By C 0 的系数A,B,C有什么几何含义?_设直线 l 的方程为axbyc0 怎么理解-CSDN博客 1. A B的含义:组成一个与直线垂直的向量 我们先来看A和B有什么含义。 在直线上取任意两点 P1:(x1…

基于SSM的小儿疾病中医护理系统

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用JSP技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…

自动驾驶中的坐标系

自动驾驶中的坐标系 自动驾驶中的坐标系 0.引言1.相机传感器坐标系2.激光雷达坐标系3.车体坐标系4.世界坐标系4.1.地理坐标系4.2.投影坐标系4.2.1.投影方式4.2.2.墨卡托(Mercator)投影4.2.3.高斯-克吕格(Gauss-Kruger)投影4.2.4.通用横轴墨卡托UTM(UniversalTransve…

Ceph分布式文件系统

目录 1.存储基础 单机DAS存储设备 DAS NAS SAN 单机存储的问题 分布式存储 分布式存储的类型 2.Ceph简介 Ceph优势 Ceph架构 Ceph组件 OSD(ceph-osd) Monitor(ceph-mon) Manager(ceph-mgr) MDS(ceph-mds) PG(归置组) Pool(存储池/资源池/数据池) OSD 存储后…

What is `addArgumentResolvers` does in `WebMvcConfigurer` ?

addArgumentResolvers 在SpringMVC框架中,主要用于向Spring容器注册自定义的参数解析器。在处理HTTP请求时,SpringMVC会使用这些参数解析器将请求中的数据(如查询参数、路径变量、表单数据等)转换并注入到控制器方法的参数中。 使…

【Docker】网络模式详解及容器间网络通信

目录 一、概述 二、默认网络 三、网络模式及应用 1. Bridge模式 2. Host模式 3. none网络模式 四、自定义网络应用 1. 网络相连 2. 自定义通讯 3. 自定义IP 每篇一获 一、概述 在 Docker 中,网络设置包括 IP 地址、子网掩码、网关、DNS 和端口号等关键组…

离线数据仓库-关于增量和全量

数据同步策略 数据仓库同步策略概述一、数据的全量同步二、数据的增量同步三、数据同步策略的选择 数据仓库同步策略概述 应用系统所产生的业务数据是数据仓库的重要数据来源,我们需要每日定时从业务数据库中抽取数据,传输到数据仓库中,之后…

Jxls 实现动态导出功能

目录 引言前端页面后端代码excel模板导出效果 引言 在实际做项目的过程中,导出报表时需要根据每个人所关注的点不一样,所需导出的字段也不一样,这时后端就需要根据每个所选的字段去相应的报表,这就是本文要讲的动态导出报表。 前端…

Vue Mixin 代码重用与逻辑共享

✨ 专栏介绍 在当今Web开发领域中,构建交互性强、可复用且易于维护的用户界面是至关重要的。而Vue.js作为一款现代化且流行的JavaScript框架,正是为了满足这些需求而诞生。它采用了MVVM架构模式,并通过数据驱动和组件化的方式,使…

海外云手机:跨境养号的新趋势

近年来,市场综合数据显示,利用海外云手机进行跨境养号已经成为跨境电商发展的新潮流。特别是在社交电商营销和短视频引流领域,海外云手机不仅能够提高流量的质量,还能让商家实现业务翻倍增长。接下来,本文将简要阐述海…

软件测试|使用matplotlib绘制气泡图

简介 气泡图(Bubble Chart)是一种数据可视化工具,通常用于展示三维数据的分布情况,其中数据点以气泡的形式显示在二维平面上,每个气泡的位置表示两个变量的值,气泡的大小表示第三个变量的值。在Python中&a…

Java21 + SpringBoot3集成Spring Data JPA

Java21 SpringBoot3集成Spring Data JPA 文章目录 Java21 SpringBoot3集成Spring Data JPA前言相关技术简介ORM(Object-Relational Mapping,对象-关系映射)JPA(Java Persistence API,Java持久层API)Hiber…

复现PointNet++(语义分割网络):Windows + PyTorch + S3DIS语义分割 + 代码

一、平台 Windows 10 GPU RTX 3090 CUDA 11.1 cudnn 8.9.6 Python 3.9 Torch 1.9.1 cu111 所用的原始代码:https://github.com/yanx27/Pointnet_Pointnet2_pytorch 二、数据 Stanford3dDataset_v1.2_Aligned_Version 三、代码 分享给有需要的人&#xf…

2.6、云负载均衡产品详述

一、定义 弹性负载均衡(Elastic Load Balance,简称ELB)可将来自公网的访问流量分发到后端云主机,可选多种负载均衡策略,并支持自动检测云主机健康状况,消除单点故障,保障应用系统的高可用。 二、产品架构 1&am…

GEE:随机森林分类器投票方法的优化与修改

作者:CSDN @ _养乐多_ 在随机森林中,每棵决策树都对输入数据进行分类或回归,并产生一个输出。对于分类问题,这个输出通常是一个类别标签,而对于回归问题,输出通常是一个连续的数值。在分类问题中,每棵决策树会为每个样本投票,然后采用众数来确定最终的类别。例如,如果…

P3952 [NOIP2017 提高组] 时间复杂度————C++

目录 [NOIP2017 提高组] 时间复杂度题目背景题目描述输入格式输出格式样例 #1样例输入 #1样例输出 #1 提示 解题思路Code运行结果 [NOIP2017 提高组] 时间复杂度 题目背景 NOIP2017 提高组 D1T2 题目描述 小明正在学习一种新的编程语言 A,刚学会循环语句的他激动…

ilqr 算法说明

1 Introduction 希望能用比较简单的方式将ilqr算法进行整理和总结。 2 HJB方程 假定我们现在需要完成一个从A点到B点的任务,执行这段任务的时候,每一步都需要消耗能量,可以用下面这个图表示。 我们在执行这个A点到B点的任务的时候&#xf…

项目架构之Zabbix部署

1 项目架构 1.1 项目架构的组成 业务架构:客户端 → 防火墙 → 负载均衡(四层、七层) → web缓存/应用 → 业务逻辑(动态应用) → 数据缓存 → 数据持久层 运维架构:运维客户端 → 跳板机/堡垒机&#x…

2023年第十四届蓝桥杯软件赛省赛总评

报名明年4月蓝桥杯软件赛的同学们,如果你是大一零基础,目前懵懂中,不知该怎么办,可以看看本博客系列:备赛20周合集 20周的完整安排请点击:20周计划 每周发1个博客,共20周。 在QQ群上交流答疑&am…

VUE工程化--vue组件注册

组件注册的两种方式: 1. 局部注册:只能在注册的组件内使用 2. 全局注册:所有组件内都能使用 局部注册步骤: 1、导入 import MyHeader from "./components/myHeader.vue"; import MyMain from "./components/myMa…