pytorch:nn.ModuleList和nn.Sequential、list的用法以及区别

文章目录

在构建网络的时候,pytorch有一些基础概念很重要,比如nn.Module,nn.ModuleList,nn.Sequential,这些类我们称为为容器(containers),可参考containers。本文中我们主要学习nn.ModuleList和nn.Sequential,并判断在什么时候用哪一个比较合适。

1. nn.ModuleList和nn.Sequential简介

nn.ModuleList

nn.ModuleList,它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器。你可以把任意nn.Module的子类(如nn.Conv2d,nn.Linear等)加到这个list里面,方法和python自带的list一样,无非是extend,append等操作,但不同于一般的list,加入到nn.ModuleList里面的module是会自动注册到整个网络上的,同时module的parameters也会自动添加到整个网络中。若使用python的list,则会出问题。


 
 
  1. class net 1(nn.Module):
  2. def __init__( self):
  3. super(net 1, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear( 10,10) for i in range( 2)])
  5. def forward( self, x):
  6. for m in self.linears:
  7. x = m(x)
  8. return x
  9. net = net 1()
  10. print(net)
  11. # net 1(
  12. # (modules): ModuleList(
  13. # ( 0): Linear( in_features = 10, out_features = 10, bias = True)
  14. # ( 1): Linear( in_features = 10, out_features = 10, bias = True)
  15. # )
  16. # )
  17. for param in net.parameters():
  18. print( type(param. data), param. size())
  19. # < class 'torch.Tensor' > torch. Size([ 10, 10])
  20. # < class 'torch.Tensor' > torch. Size([ 10])
  21. # < class 'torch.Tensor' > torch. Size([ 10, 10])
  22. # < class 'torch.Tensor' > torch. Size([ 10])

可以看到,这个网络权重(weights)和偏置(bias)都在这个网络之内。而对于使用python自带list的例子如下:
 


 
 
  1. class net 2(nn.Module):
  2. def __init__( self):
  3. super(net 2, self).__init__()
  4. self.linears = [nn.Linear( 10,10) for i in range( 2)]
  5. def forward( self, x):
  6. for m in self.linears:
  7. x = m(x)
  8. return x
  9. net = net 2()
  10. print(net)
  11. # net 2()
  12. print(list(net.parameters()))
  13. # []

显然,使用python的list添加的卷积层和它们的parameters并没有自动注册到我们的网络中。当然,我们还是可以使用forward来计算输出结果。但是如果用其实例化的网络进行训练的时候,因为这些层的parameters不在整个网络之中,所以其网络参数也不会被更新,也就是无法训练。

 

但是,我们需要注意到,nn.ModuleList并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言,比如:


 
 
  1. class net 3(nn.Module):
  2. def __init__( self):
  3. super(net 3, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear( 10,20), nn.Linear( 20,30), nn.Linear( 5,10)])
  5. def forward( self, x):
  6. x = self.linears[ 2](x)
  7. x = self.linears[ 0](x)
  8. x = self.linears[ 1](x)
  9. return x
  10. net = net 3()
  11. print(net)
  12. # net 3(
  13. # (linears): ModuleList(
  14. # ( 0): Linear( in_features = 10, out_features = 20, bias = True)
  15. # ( 1): Linear( in_features = 20, out_features = 30, bias = True)
  16. # ( 2): Linear( in_features = 5, out_features = 10, bias = True)
  17. # )
  18. # )
  19. input = torch.randn( 32, 5)
  20. print(net( input).shape)
  21. # torch. Size([ 32, 30])

根据net3的结果,我们可以看出ModuleList里面的顺序并不能决定什么,网络的执行顺序是根据forward函数来决定的。但是一般设置ModuleList中的顺序和forward中保持一致,增强代码的可读性。

我们再来考虑另一种情况,既然ModuleList可以根据序号来调用,那么一个模型可以在forward函数中被调用多次。但需要注意的是,被调用多次的模块,是使用同一组parameters的,也就是它们是参数共享的。


 
 
  1. class net 4(nn.Module):
  2. def __init__( self):
  3. super(net 4, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear( 5, 10), nn.Linear( 10, 10)])
  5. def forward( self, x):
  6. x = self.linears[ 0](x)
  7. x = self.linears[ 1](x)
  8. x = self.linears[ 1](x)
  9. return x
  10. net = net 4()
  11. print(net)
  12. # net 4(
  13. # (linears): ModuleList(
  14. # ( 0): Linear( in_features = 5, out_features = 10, bias = True)
  15. # ( 1): Linear( in_features = 10, out_features = 10, bias = True)
  16. # )
  17. # )
  18. for name, param in net.named_parameters():
  19. print(name, param. size())
  20. # linears. 0.weight torch. Size([ 10, 5])
  21. # linears. 0.bias torch. Size([ 10])
  22. # linears. 1.weight torch. Size([ 10, 10])
  23. # linears. 1.bias torch. Size([ 10])

 

nn.Sequential

不同于nn.ModuleList,nn.Sequential已经实现了内部的forward函数,而且里面的模块必须是按照顺序进行排列的,所以我们必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。


 
 
  1. class net 5(nn.Module):
  2. def __init__( self):
  3. super(net 5, self).__init__()
  4. self. block = nn. Sequential(nn.Conv 2d( 1,20,5),
  5. nn.ReLU(),
  6. nn.Conv 2d( 20,64,5),
  7. nn.ReLU())
  8. def forward( self, x):
  9. x = self. block(x)
  10. return x
  11. net = net 5()
  12. print(net)
  13. # net 5(
  14. # ( block): Sequential(
  15. # ( 0): Conv 2d( 1, 20, kernel_ size =( 5, 5), stride =( 1, 1))
  16. # ( 1): ReLU()
  17. # ( 2): Conv 2d( 20, 64, kernel_ size =( 5, 5), stride =( 1, 1))
  18. # ( 3): ReLU()
  19. # )
  20. # )

下面给出了两个nn.Sequential初始化的例子,在第二个初始化中我们用到了OrderedDict来指定每个module的名字


 
 
  1. # Example of using Sequential
  2. model 1 = nn. Sequential(
  3. nn.Conv 2d( 1,20,5),
  4. nn.ReLU(),
  5. nn.Conv 2d( 20,64,5),
  6. nn.ReLU()
  7. )
  8. print(model 1)
  9. # Sequential(
  10. # ( 0): Conv 2d( 1, 20, kernel_ size =( 5, 5), stride =( 1, 1))
  11. # ( 1): ReLU()
  12. # ( 2): Conv 2d( 20, 64, kernel_ size =( 5, 5), stride =( 1, 1))
  13. # ( 3): ReLU()
  14. # )
  15. # Example of using Sequential with OrderedDict
  16. import collections
  17. model 2 = nn. Sequential(collections.OrderedDict([
  18. ( 'conv1', nn.Conv 2d( 1,20,5)),
  19. ( 'relu1', nn.ReLU()),
  20. ( 'conv2', nn.Conv 2d( 20,64,5)),
  21. ( 'relu2', nn.ReLU())
  22. ]))
  23. print(model 2)
  24. # Sequential(
  25. # (conv 1): Conv 2d( 1, 20, kernel_ size =( 5, 5), stride =( 1, 1))
  26. # (relu 1): ReLU()
  27. # (conv 2): Conv 2d( 20, 64, kernel_ size =( 5, 5), stride =( 1, 1))
  28. # (relu 2): ReLU()
  29. # )

有同学可能发现了,诶,你这个 model1 和 从类 net5 实例化来的 net 有什么区别吗?是没有的。这两个网络是相同的,因为 nn.Sequential 就是一个 nn.Module 的子类,也就是 nn.Module 所有的方法 (method) 它都有。并且直接使用 nn.Sequential 不用写 forward 函数,因为它内部已经帮你写好了。

这时候有同学该说了,既然 nn.Sequential 这么好,我以后都直接用它了。如果你确定 nn.Sequential 里面的顺序是你想要的,而且不需要再添加一些其他处理的函数 (比如 nn.functional 里面的函数,nn 与 nn.functional 有什么区别? ),那么完全可以直接用 nn.Sequential。这么做的代价就是失去了部分灵活性,毕竟不能自己去定制 forward 函数里面的内容了。

一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。

 

2.nn.Sequential与nn.ModuleList的区别

不同点1:nn.Sequential内部实现了forward函数,因此可以不用写forward函数,而nn.ModuleList则没有实现内部forward函数。

不同点2:nn.Sequential可以使用OrderedDict对每层进行命名。

不同点3:nn.Sequential里面的模块按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。

不同点4:有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们,而不是一行一行地写,比如:

layers = [nn.Linear(10, 10) for i in range(5)]
 
 

那么这里我们使用ModuleList:


 
 
  1. class net 4(nn.Module):
  2. def __init__( self):
  3. super(net 4, self).__init__()
  4. layers = [nn.Linear( 10, 10) for i in range( 5)]
  5. self.linears = nn.ModuleList(layers)
  6. def forward( self, x):
  7. for layer in self.linears:
  8. x = layer(x)
  9. return x
  10. net = net 4()
  11. print(net)
  12. # net 4(
  13. # (linears): ModuleList(
  14. # ( 0): Linear( in_features = 10, out_features = 10, bias = True)
  15. # ( 1): Linear( in_features = 10, out_features = 10, bias = True)
  16. # ( 2): Linear( in_features = 10, out_features = 10, bias = True)
  17. # )
  18. # )

这个是比较一般的方法,但如果不想这么麻烦,我们也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。在 场景一 中,我个人觉得使用 net7 这种方法比较方便和整洁。


 
 
  1. class net 7(nn.Module):
  2. def __init__( self):
  3. super(net 7, self).__init__()
  4. self.linear_list = [nn.Linear( 10, 10) for i in range( 3)]
  5. self.linears = nn. Sequential( * self.linears_list)
  6. def forward( self, x):
  7. self.x = self.linears(x)
  8. return x
  9. net = net 7()
  10. print(net)
  11. # net 7(
  12. # (linears): Sequential(
  13. # ( 0): Linear( in_features = 10, out_features = 10, bias = True)
  14. # ( 1): Linear( in_features = 10, out_features = 10, bias = True)
  15. # ( 2): Linear( in_features = 10, out_features = 10, bias = True)
  16. # )
  17. # )

下面我们考虑 场景二,当我们需要之前层的信息的时候,比如 ResNets 中的 shortcut 结构,或者是像 FCN 中用到的 skip architecture 之类的,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList 比较方便,一个非常简单的例子如下:


 
 
  1. class net 8(nn.Module):
  2. def __init__( self):
  3. super(net 8, self).__init__()
  4. self.linears = nn.ModuleList([nn.Linear( 10, 20), nn.Linear( 20, 30), nn.Linear( 30, 50)])
  5. self.trace = []
  6. def forward( self, x):
  7. for layer in self.linears:
  8. x = layer(x)
  9. self.trace.append(x)
  10. return x
  11. net = net 8()
  12. input = torch.randn( 32, 10) # input batch size: 32
  13. output = net( input)
  14. for each in net.trace:
  15. print(each.shape)
  16. # torch. Size([ 32, 20])
  17. # torch. Size([ 32, 30])
  18. # torch. Size([ 32, 50])

我们使用了一个 trace 的列表来储存网络每层的输出结果,这样如果以后的层要用的话,就可以很方便地调用了。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/23536.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Python】正则表达式应用

知识目录 一、写在前面✨二、姓名检查三、解析电影排行榜四、总结撒花&#x1f60a; 一、写在前面✨ 大家好&#xff01;我是初心&#xff0c;希望我们一路走来能坚守初心&#xff01; 今天跟大家分享的文章是 正则表达式的应用 &#xff0c;希望能帮助到大家&#xff01;本篇…

把字节大佬花3个月时间整理的软件测试面经偷偷给室友,差点被他开除了···

写在前面 “这份软件测试面经看起来不错&#xff0c;等会一起发给他吧”&#xff0c;我看着面前的面试笔记自言自语道。 就在这时&#xff0c;背后传来了leder“阴森森”的声音&#xff1a;“不错吧&#xff0c;我可是足足花了三个月整理的” 始末 刚入职字节的我收到了大学室…

Windows 10 X64 内核对象句柄表解析

fweWindows 很多API函数都会创建和使用句柄(传入参数)&#xff0c;句柄代表一个内核对象的内存地址&#xff0c;每个进程都有一个句柄表&#xff0c;它保存着进程拥有的句柄&#xff0c;内核也有一个句柄表 PspCidTable&#xff0c;它保存着整个系统的句柄。 ExpLookupHandleTa…

DNS风险分析及安全防护研究(一):DNS自身风险分析(中科三方)

作为互联网上的一项基础服务&#xff0c;DNS在网站运行中起到了至关重要的作用&#xff0c;然而其安全性在很长一段时间内都没有得到足够的重视。DNS采用不可靠的UDP协议&#xff0c;安全性具有较大的漏洞&#xff0c;攻击者很容易利用这些漏洞发动攻击&#xff0c;从而引起一些…

华为设备这14个广域网命令,值得每位做广域网业务的网工收藏!

你好&#xff0c;这里是网络技术联盟站。 华为设备广域网命令是网络管理员在运维过程中常用的一类命令。该命令集涵盖了DCC配置命令、PPP配置命令、MP配置命令、PPPoE命令、ATM配置命令、帧中继配置命令、HDLC配置命令、LAPB配置命令、X.25配置命令、IP-Trunk配置命令、ISDN配…

Java 与数据结构(6):快速排序

ChatGPT 中文指南(大全) 内容包含&#xff1a;如何开通chatgpt、chatgpt的同类站点、prompts 、AI绘图、ChatGPT 工具、相关报告论文、ChatGPT应用项目等 链接&#xff1a;ChatGPT 中文指南(大全) 指令指南&#xff0c;精选资源清单&#xff0c;更好的使用 chatGPT 让你的生产力…

详解如何使用LAMP架构搭建论坛

文章目录 1.LAMP概述2.编译安装Apache httpd服务1.关闭防火墙&#xff0c;将安装Apache所需软件包传到/opt目录下2.安装环境依赖包 3.配置软件模块4.编译及安装5.优化配置文件路径&#xff0c;并把httpd服务的可执行程序文件放入路径环境变量的目录中便于系统识别6.添加httpd系…

复杂的C++继承

文章目录 什么是继承继承方式赋值规则继承中的作用域&#xff08;隐藏&#xff09;子类中的默认成员函数需要自己写默认成员函数的情况 继承与友元及静态成员多继承菱形继承菱形继承的问题菱形虚拟继承 继承和组合 面向对象三大特性&#xff1a;封装继承和多态。封装在类和对象…

(四)调整PID控制器参数的指南

一、控制系统设计快速入门和环境 首先确定一下控制任务。快速、精准地控制&#xff0c;必要的稳定性&#xff0c;时域&#xff08;上升时间、超调等&#xff09;&#xff0c;频域&#xff08;带宽、阻尼比&#xff09;然后明白控制系统特点。类积分器&#xff1f;开环稳定性、高…

注解实现自动装配

要使用注解须知&#xff1a; 1.导入约束 context约束 2.配置注解的支持 官方配置文件 <?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:xsi"http://www.w3.org/2001/…

详解知识蒸馏原理和代码

目录 知识蒸馏原理概念技巧举例说明KL 散度及损失 KD训练代码导入包网络架构teacher网络student网络 teacher网络训练定义基本函数训练主函数 student网络训练&#xff08;重点&#xff09;理论部分定义kd的loss定义基本函数训练主函数 绘制结果teacher网络的暗知识softmax_t推…

C4d Octane渲染器内存满、卡顿、崩溃、缓慢、updating解决办法

最近碰到Octane渲染动画序列&#xff0c;总是会渲染一段时间后卡在某一张图片上&#xff0c;图片查看器左下角一直显示updating。 偶然发现在C4D界面点击octane工具栏的设置&#xff0c;它又会开始渲染&#xff0c;但渲染一些序列帧后又会卡在一张图上显示updating 点击octane工…

【Netty】 工作原理详解(十一)

文章目录 前言一、Netty 模型二、代码示例2.1、引入Maven依赖2.2、服务端的管道处理器2.3、服务端主程序2.4、客户端管道处理器2.5、客户端主程序2.6、测试运行 总结 前言 回顾Netty系列文章&#xff1a; Netty 概述&#xff08;一&#xff09;Netty 架构设计&#xff08;二&…

【Python]】地图热力图如何绘制?(含源代码)

文章目录 一、问题引入 & 使用地图的说明1.1 问题的引入1.2 使用地图的说明 二、方法1三、方法2 一、问题引入 & 使用地图的说明 1.1 问题的引入 我们有一个中国各省份的数据集&#xff0c;要求绘制地图热力图&#xff0c;该怎么实现呢&#xff1f; 部分数据集如下&…

tcp套接字的应用

tcp服务端流程 tcp客户端流程 客户端代码 tcpClient.hpp #include<iostream> #include<string> #include<cstring> #include<stdlib.h> #include<unistd.h> #include<sys/types.h> #include<sys/socket.h> #include<netinet/in…

2172. 最大公约数

Powered by:NEFU AB-IN Link 文章目录 2172. 最大公约数题意思路代码 2022年第十三届决赛真题 2172. 最大公约数 题意 给定一个数组, 每次操作可以选择数组中任意两个相邻的元素 x , y x, yx,y 并将其 中的一个元素替换为 gcd ⁡ ( x , y ) \operatorname{gcd}(x, y)gcd(x,y),…

117.【微信小程序】

微信小程序 (一)、微信小程序概括1.微信小程序简介(1).小程序与普通网页开发的区别 2.注册微信小程序账号(1).注册小程序账号(2).获取小程序的AppID 3.安装微信开发者工具(1).微信开发者工具的简介:(2).微信开发者工具的下载 4.创建第一个小程序(1).创建小程序步骤(2).开发者工…

新入职一个00后卷王,每天加班到2点,太让人崩溃了····

在程序员职场上&#xff0c;什么样的人最让人反感呢? 是技术不好的人吗?并不是。技术不好的同事&#xff0c;我们可以帮他。 是技术太强的人吗?也不是。技术很强的同事&#xff0c;可遇不可求&#xff0c;向他学习还来不及呢。 真正让人反感的&#xff0c;是技术平平&…

Java企业工程项目管理系统+spring cloud 系统管理+java 系统设置+二次开发

工程项目各模块及其功能点清单 一、系统管理 1、数据字典&#xff1a;实现对数据字典标签的增删改查操作 2、编码管理&#xff1a;实现对系统编码的增删改查操作 3、用户管理&#xff1a;管理和查看用户角色 4、菜单管理&#xff1a;实现对系统菜单的增删改查操…

【C++】-string的介绍以及使用(迭代器的介绍和使用)

&#x1f496;作者&#xff1a;小树苗渴望变成参天大树 ❤️‍&#x1fa79;作者宣言&#xff1a;认真写好每一篇博客 &#x1f4a8;作者gitee:gitee &#x1f49e;作者专栏&#xff1a;C语言,数据结构初阶,Linux,C 如 果 你 喜 欢 作 者 的 文 章 &#xff0c;就 给 作 者 点…