深度学习pytorch——nn.Module(持续更新)

作为一个初学者,发现构建一个简单的线性模型都能看到nn.Module的身影,初学者疑惑了,nn.Module到底是干什么的,如此形影不离,了解之后,很牛。

1、nn.Module是所有层的父类,比如Linear、BatchNorm2d、Conv2d、ReLU、Sigmoid、ConvTranposed、Dropout等等这些都是它的儿子(子类),你可以直接拿来使用。

2、nn.Module还支持一个nn.Module嵌套另一个nn.Module。

3、并且可以自动完成forward,你只需要nn.Sequential()这个容器就可以了,代码示例如下:

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.net = nn.Sequential(BasicNet(),
                                 nn.ReLU(),
                                 nn.Linear(3, 2))

    def forward(self, x):
        return self.net(x)

4、深度学习中参数可谓是量产,如果手动进行参数管理将会是一个庞大的工程,导师问你在干什么就不是在训练模型了,而是处理那些无处安放的参数,而使用nn.Module就提供的parameters就可以秒出结果,代码示例如下:

    for name, t in net.named_parameters():
        print('parameters:', name, t.shape)
    # parameters: net.0.net.weight torch.Size([3, 4])
    # parameters: net.0.net.bias torch.Size([3])
    # parameters: net.2.weight torch.Size([2, 3])
    # parameters: net.2.bias torch.Size([2])

然后将参数直接传入优化器进行优化,代码示例如下:

optimizer=optim.SGD(net.parameters(),lr=1e-3)

5、有很多的孩子,并且你还可以很简单的知道他孩子长什么样,我先介绍一下他的孩子们:

我们可以通过 net.named_children()了解他的亲孩子(children),也就是直系亲属,net.named_modules()了解他所有的孩子(modules),直系亲属外亲都算,代码示例如下:

    for name, m in net.named_children():
        print('children:', name, m)
    # children: net Sequential(
    #     (0): BasicNet(
    #     (net): Linear(in_features=4, out_features=3, bias=True)
    # )
    # (1): ReLU()
    # (2): Linear(in_features=3, out_features=2, bias=True)
    # )

    for name, m in net.named_modules():
        print('modules:', name, m)

# modules:  Net(
#   (net): Sequential(
#     (0): BasicNet(
#       (net): Linear(in_features=4, out_features=3, bias=True)
#     )
#     (1): ReLU()
#     (2): Linear(in_features=3, out_features=2, bias=True)
#   )
# )
# modules: net Sequential(
#   (0): BasicNet(
#     (net): Linear(in_features=4, out_features=3, bias=True)
#   )
#   (1): ReLU()
#   (2): Linear(in_features=3, out_features=2, bias=True)
# )
# modules: net.0 BasicNet(
#   (net): Linear(in_features=4, out_features=3, bias=True)
# )
# modules: net.0.net Linear(in_features=4, out_features=3, bias=True)
# modules: net.1 ReLU()
# modules: net.2 Linear(in_features=3, out_features=2, bias=True)

6、可以非常方便的将网络运行在不同的设备,这里的设备是指cuda、gpu之类的,代码示例如下:

device = torch.device('cuda')
net = Net()
net.to(device)

7、方便对模型进行保存和加载,一个模型一般需要训练好久,但是我们并没有如此连续的时间,我们可以将现在训练好的模型进行保存,下次加载继续训练,代码示例如下:

# 加载
net.load_state_dict(torch.load('ckpt.mdl'))
# 保存
torch.save(net.state_dict(),'ckpt.mdl')

8、方便进行训练和测试状态的切换,代码示例如下:

# 训练
net.train()
# 测试
net.eval()

9、可以实现自己构建的模型,代码示例如下:

# 以下是一个展平的实现
class Flatten(nn.Module):

    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, input):
        return input.view(input.size(0), -1)  #[b,打平],保留b,其他的全部打平



class TestNet(nn.Module):

    def __init__(self):
        super(TestNet, self).__init__()

        self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),
                                 nn.MaxPool2d(2, 2),
                                 Flatten(),
                                 nn.Linear(1*14*14, 10))

    def forward(self, x):
        return self.net(x)
# 构建自己的线性层
class MyLinear(nn.Module):

    def __init__(self, inp, outp):
        super(MyLinear, self).__init__()
        # 这里使用nn.Parameter代表了可以回传给nn.model,进行更新,所以不用写requires_grad = True
        # requires_grad = True
        self.w = nn.Parameter(torch.randn(outp, inp))
        self.b = nn.Parameter(torch.randn(outp))

    def forward(self, x):
        x = x @ self.w.t() + self.b
        return x

如果你还有什么更好的idea,欢迎分享!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/501605.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Bun安装与使用

Bun安装与使用。 它目前无法在windows上直接安装使用,必须通过虚拟机安装。 在win10虚拟机中安装 # 查看内核版本 $ uname -srm Linux 6.1.0-10-amd64 x86_64# 安装unzip解压工具 $ sudo apt install unzip# 下载安装脚本并开始安装 curl -fsSL https://bun.sh/ins…

C#使用SQLite(含加密)保姆级教程

C#使用SQLite 文章目录 C#使用SQLite涉及框架及库复制runtimes创建加密SQLite文件生成连接字串执行SQL生成表SQLiteConnectionFactory.cs 代码结构最后 涉及框架及库 自己在NuGet管理器里面安装即可 Chloe.SQLite:ORM框架Microsoft.Data.Sqlite.Core:驱…

3.Python数据分析—数据分析入门知识图谱索引(知识体系中篇)

3.Python数据分析—数据分析入门知识图谱&索引-知识体系中篇 一个人简介二数据获取和处理2.1 数据来源:2.2 数据清洗:2.2.1 缺失值处理:2.2.2 异常值处理: 2.3 数据转换:2.3.1 数据类型转换:2.3.2 数据…

【unity】解决unity编译器安装中文汉化包失败

如果有的同学中文包安装失败,我们找到相应的编译器版本,点击在资源管理器中显示按钮, 我们点击当前目录的上一级,进入编译器目录。 找到modules.json文件双击打开 我们找到简体中文,复制downloadUrl后面的值到浏览…

Spark SQL— Catalyst 优化器

Spark SQL— Catalyst 优化器 1. 目的 本文的目标是描述Spark SQL 优化框架以及它如何允许开发人员用很少的代码行表达复杂的查询转换。我们还将描述Spark SQL如何通过大幅提高其查询优化能力来提高查询的执行时间。在本教程中,我们还将介绍什么是优化、为什么使用…

K8S安装和部署(kubeadmin安装1主2从)

这里用kubeadmin方式进行安装部署 1. 准备三台服务器 服务器地址 节点名称 192.168.190.200 master 主 192.168.190.201 node1 从 192.168.190.202 node2 从 2. 主机初始化(所有主机) 2.1根据规划设置主机名 #切换到192.168.190.200 hostnamectl…

【C++的奇迹之旅】C++关键字命名空间使用的三种方式C++输入输出命名空间std的使用惯例

文章目录 📝前言🌠 C关键字(C98)🌉 命名空间🌠命名空间定义🌉命名空间使用 🌠命名空间的使用有三种方式:🌉加命名空间名称及作用域限定符🌠使用using将命名空间中某个成员…

【Frida】【Android】06_夜神模拟器中间人抓包

🛫 系列文章导航 【Frida】【Android】01_手把手教你环境搭建 https://blog.csdn.net/kinghzking/article/details/136986950【Frida】【Android】02_JAVA层HOOK https://blog.csdn.net/kinghzking/article/details/137008446【Frida】【Android】03_RPC https://bl…

如何更新STEAM税务信息

回复邮件 Here are three attachments:. Figure 1: My personal tax information file in the government system, including my TIN, permanent address and mailing address Figure 2. My tax payment certificate in China in 2002 was issued by the tax bureau, Figure 3:…

npm ERR! errno CERT_HAS_EXPIRED

1 问题描述 使用npm命令安装相关依赖报错:npm ERR! code CERT_HAS_EXPIRED npm ERR! errno CERT_HAS_EXPIRED npm ERR! request to https://registry.npm.taobao.org/vue%2fcli failed, reason: certificate has expired报错示例图如下所示: 2原因分析…

C语言循环结构的程序设计

在C语言中,循环结构是一种重要的控制结构,用于重复执行特定的代码块,直到满足特定的条件为止。循环结构使得程序可以更加灵活和高效地处理重复性的任务,从而提高了程序的可读性和可维护性。本文将深入介绍C语言中循环结构的程序设…

小型分布式文件存储系统GoFastDfs应用简介

前言 最近稍微留意了一下各个文件存储系统的协议,发现minio是LGPLV3, 而fastdfs 是GPL3,这些协议其实对于商业应用是一个大坑。故而寻找一些代替品。 go-fastdfs就是其中之一,官网在: go-fastdfs 具体应用 其实可以直接查看官网教程的。 下…

Jenkins详细安装配置部署

目录 简介一、安装jdk二、安装jenkins这里如果熟悉 Jenkins ,可以【选择插件来安装】,如果不熟悉,还是按照推荐来吧。注意: 三、插件安装如果上面插件安装,选择的不是【安装推荐的插件】,而是【选择插件来安…

学习Fast-LIO系列代码中相关概念理解

目录 一、流形和流形空间(姿态) 1.1 定义 1.2 为什么要有流形? 1.3 流形要满足什么性质? (1) 拓扑同胚 (2) 可微结构 1.4 欧式空间和流形空间的区别和联系? (1) 区别: (2) 联系: 1.5 将姿态定义在流形上比…

基于java+springboot+vue实现的二手闲置物品置换系统(文末源码+Lw+ppt)23-375

摘 要 大学生二手闲置物品置换交易管理系统设计的目的是为用户提供免费物品、积分物品等功能。 与其它应用程序相比,大学生二手闲置物品置换交易的设计主要面向于学校,旨在为管理员和卖家、用户提供一个大学生二手闲置物品置换交易管理系统。用户可以…

Java项目:80 springboot师生健康信息管理系统

作者主页:源码空间codegym 简介:Java领域优质创作者、Java项目、学习资料、技术互助 文中获取源码 项目介绍 系统的角色:管理员、宿管、学生 管理员管理宿管员,管理学生,修改密码,维护个人信息。 宿管员…

LLM推理入门指南②:深入解析KV缓存

在本系列文章《LLM推理入门指南①:文本生成的初始化与解码阶段》中,作者对Transformer解码器的文本生成算法进行了高层次概述,着重介绍了两个阶段:单步初始化阶段,即提示的处理阶段,和逐个生成补全词元的多…

数组类模板(类模拟实现静态数组)

目录 介绍: 案例描述: 思路: 对要求分别分析实现: 创建对应的类: 1.定义一个数组类 2.类中属性有:数组, 容量, 大小 3.数组函数有: 构造函数(容量&am…

Oracle EBS AR接口和OM销售订单单价为空数据修复

最近,用户使用客制化Web ADI 批量导入销售订单行功能,把销售订单行的单价更新成空值,直到发运确认以后,财务与客户对帐才发现大量销售订单的单价空,同时我们检查AR接口发现销售订单的单价和金额均为空。 前提条件 采用PAC成本方式具体问题症状 销售订单行的单价为空 Path:…

车载以太网AVB交换机 gPTP透明时钟 6口 百兆车载以太网交换机

SW100TE百兆车载以太网交换机 一、产品简要分析 6端口百兆车载以太网交换机,其中包含5通道100BASE-T1泰科MATEnet接口和1个通道100BASE-TX标准以太网(RJ45接口),可以实现车载以太网多通道交换,车载以太网数据采集和模拟,Bypass数…