Pytorch学习:神经网络模块torch.nn.Module和torch.nn.Sequential

文章目录

    • 1. torch.nn.Module
      • 1.1 add_module(name,module)
      • 1.2 apply(fn)
      • 1.3 cpu()
      • 1.4 cuda(device=None)
      • 1.5 train()
      • 1.6 eval()
      • 1.7 state_dict()
    • 2. torch.nn.Sequential
      • 2.1 append
    • 3. torch.nn.functional.conv2d

1. torch.nn.Module

官方文档:torch.nn.Module
CLASS torch.nn.Module(*args, **kwargs)

  • 所有神经网络模块的基类。
  • 您的模型也应该对此类进行子类化。
  • 模块还可以包含其他模块,允许将它们嵌套在树结构中。您可以将子模块分配为常规属性:
  • training(bool)-布尔值表示此模块是处于训练模式还是评估模式。

定义一个模型

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))
  • 以这种方式分配的子模块将被注册,并且当您调用 to() 等时也将转换其参数。
    • to(device=None,dtype=None,non_blocking=False)
      device ( torch.device) – 该模块中参数和缓冲区所需的设备
    • to(dtype ,non_blocking=False)
      dtype ( torch.dtype) – 该模块中参数和缓冲区所需的浮点或复杂数据类型
    • to(tensor,non_blocking=False)
      张量( torch.Tensor ) – 张量,其数据类型和设备是该模块中所有参数和缓冲区所需的数据类型和设备

引用上面定义的模型,将模型转移到GPU上

# 创建模型
model = Model()

# 定义设备 gpu1
gpu1 = torch.device("cuda:1")
model = model.to(gpu1)

1.1 add_module(name,module)

将子模块添加到当前模块。
可以使用给定的名称作为属性访问模块。

add_module(name,module)
主要参数:

  • name(str)-子模块的名称。可以使用给定的名称从此模块访问子模块。
  • module(Module)-要添加到模块的子模块。

在这里插入图片描述
添加一个卷积层

model.add_module("conv3", nn.Conv2d(20, 20, 5))

在这里插入图片描述

1.2 apply(fn)

将 fn 递归地应用于每个子模块(由 .children() 返回)以及self。
典型的用法包括初始化模型的参数(另请参见torch.nn.init)。

apply(fn)
主要参数:

  • fn( Module -> None)-应用于每个子模块的函数

将所有线性层的权重置为1

import torch
from torch import nn


@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)


net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2,2))
net.apply(init_weights)

在这里插入图片描述

1.3 cpu()

将所有模型参数和缓冲区移动到CPU。

device = torch.device("cpu")
model = model.to(device)

1.4 cuda(device=None)

将所有模型参数和缓冲区移动到GPU。

这也使关联的参数和缓冲区成为不同的对象。因此,如果模块在优化时将驻留在GPU上,则应在构造优化器之前调用该函数。

cuda(device=None)
主要参数:

  • device(int,可选)-如果指定,所有参数将被复制到该设备

转移到GPU包括以下参数:

  1. 模型
  2. 损失函数
  3. 输入输出
# 创建模型
model = Model()

# 将模型转移到GPU上
model = model.cuda()

# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.cuda()

# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.cuda()
targets = targets.cuda()

另一种表示形式(通过 to(device) 来表示)

# 创建模型
model = Model()

# 定义设备:如果有GPU,则在GPU上训练, 否则在CPU上训练
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

# 将模型转移到GPU上
model = model.to(device)

# 将损失函数转移到GPU上
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)

# 将输入输出转移到GPU上
imgs, targets = data
imgs = imgs.to(device)
targets = targets.to(device)

1.5 train()

将模块设置为训练模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

train(mode=True)
主要参数:

  • mode(bool)-是否设置训练模式( True )或评估模式( False )。默认值: True 。

1.6 eval()

将模块设置为评估模式。

这只对某些模块有任何影响。如受影响,请参阅特定模块在培训/评估模式下的行为详情,例如: Dropout 、 BatchNorm 等。

在进行模型测试的时候会用到。

1.7 state_dict()

返回一个字典,其中包含对模块整个状态的引用。

返回模型的关键字典。

model = Model()
print(model.state_dict().keys())

在这里插入图片描述
在保存模型的时候我们也可以直接保存模型的 state_dict()

model = Model()

# 保存模型
# 另一种方式:torch.save(model, "model.pth")
torch.save(model.state_dict(), "model.pth")

# 加载模型
model.load_state_dict(torch.load("model.pth"))

2. torch.nn.Sequential

顺序容器。模块将按照它们在构造函数中传递的顺序添加到它。

Sequential 的 forward() 方法接受任何输入并将其转发到它包含的第一个模块。然后,它将输出“链接”到每个后续模块的输入,最后返回最后一个模块的输出。

官方文档:torch.nn.Sequential
CLASS torch.nn.Sequential(*args: Module)

import torch
from torch import nn

# 使用 Sequential 创建一个小型模型。运行 `model` 时、
# 输入将首先传递给 `Conv2d(1,20,5)`。输出
# `Conv2d(1,20,5)`的输出将作为第一个
# 第一个 `ReLU` 的输出将成为 `Conv2d(1,20,5)` 的输入。
# `Conv2d(20,64,5)` 的输入。最后
# `Conv2d(20,64,5)` 的输出将作为第二个 `ReLU` 的输入
model = nn.Sequential(
            nn.Conv2d(1, 20, 5),
            nn.ReLU(),
            nn.Conv2d(20, 64, 5),
            nn.ReLU()
        )

在这里插入图片描述

2.1 append

append 在末尾追加给定块。

  • append(module)
    在末尾追加给定模块。
    在这里插入图片描述
def append(self, module):
    self.add_module(str(len(self)), module)
    return self


append(model, nn.Conv2d(64, 64, 5))
append(model, nn.ReLU())
print(model)

在这里插入图片描述

3. torch.nn.functional.conv2d

对由多个输入平面组成的输入图像应用2D卷积。
卷积神经网络详解:csdn链接

官方文档:torch.nn.functional.conv2d
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
主要参数:

  • input:形状的输入张量,(minibatch, inchannels, iH, iW)。
  • weigh:卷积核权重,形状为 (out_channels, inchannels / groups, kH, kW)

默认参数:

  • bias:偏置,默认值: None。
  • stride:步幅,默认值:1。
  • padding:填充,默认值:0。
  • dilation :内核元素之间的间距。默认值:1。
  • groups:将输入拆分为组,in_channels 应被组数整除。默认值:1。

在这里插入图片描述
对上图卷积操作进行代码实现

import torch.nn.functional as F

input = torch.tensor([[0, 1, 2],
                      [3, 4, 5],
                      [6, 7, 8]], dtype=float32)
kernel = torch.tensor([[0, 1],
                       [2, 3]], dtype=float32)


# F.conv2d 输入维数为4维
# torch.reshape(input, shape)
# reshape(样本数,通道数,高度,宽度)
input = torch.reshape(input, (1, 1, 3, 3))
kernel = torch.reshape(kernel, (1, 1, 2, 2))

output = F.conv2d(input, kernel, stride=1)
print(input.shape)
print(kernel.shape)
print(input)
print(kernel)
print(output)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/97452.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

环保数字化,让污染无处遁形

环保一直以来都是我国大力推崇的举措,“保护环境、人人有责”的标语深入人心,但是环保绝不是某一天某一年就能做好的事情,而在于一朝一夕坚持不懈,下文将针对环保的场景介绍一下数字孪生技术在环保领域的应用。 一、环保背景 新中…

几个nlp的小项目(文本分类)

几个nlp的小项目(文本分类) 导入加载数据类、评测类查看数据集精确展示数据测评方法设置参数tokenizer,token化的解释对数据集进行预处理加载预训练模型进行训练设置训练模型的参数一个根据任务名获取,测评方法的函数创建预训练模型开始训练本项目的工作完成了什么任务?导…

CNN 02(CNN原理)

一、卷积神经网络(CNN)原理 1.1 卷积神经网络的组成 定义 卷积神经网络由一个或多个卷积层、池化层以及全连接层等组成。与其他深度学习结构相比,卷积神经网络在图像等方面能够给出更好的结果。这一模型也可以使用反向传播算法进行训练。相比较其他浅层或深度神经…

景联文科技数据标注:人体关键点标注用途及各点的位置定义

人体关键点标注是一种计算机视觉任务,指通过人工的方式,在指定位置标注上关键点,例如人脸特征点、人体骨骼连接点等,常用来训练面部识别模型以及统计模型。这些关键点可以表示图像的各个方面,例如角、边或特定特征。在…

unity 之参数类型之引用类型

文章目录 引用类型引用类型与值类型的差异 引用类型 在Unity中,引用类型是指那些在内存中存储对象引用的数据类型。以下是在Unity中常见的引用类型的介绍: 节点(GameObject): 在Unity中,游戏对象&#xff…

day28 异常

to{}catch{} try{}catch{}的流传输 try {fis new FileInputStream("file-APP\\fos.txt");fos new FileOutputStream("fos.txt");int a ;while ((a fis.read())! -1){fos.write(a);}System.out.println(a); } catch (IOException e) {e.printStackTrace()…

在编辑器中使用正则

正则是一种文本处理工具,常见的功能有文本验证、文本提取、文本替换、文本切割等。有一些地方说的正则匹配,其实是包括了校验和提取两个功能。 校验常用于验证整个文本的组成是不是符合规则,比如密码规则校验。提取则是从大段的文本中抽取出…

0基础学习VR全景平台篇 第92篇:智慧景区-智慧景区常见问题

Q:怎么编辑景区里面各个景点的介绍和推荐该景点A:在下方素材栏中该景点(素材)的右上角选择【编辑场景】里面就可以在场景介绍中编辑该场景的介绍并且在该选项中可以将此场景设置为推荐景点。 Q:景区项目可不可以离线浏…

数字孪生智慧工厂:电缆厂 3D 可视化管控系统

近年来,我国各类器材制造业已经开始向数字化生产转型,使得生产流程变得更加精准高效。通过应用智能设备、物联网和大数据分析等技术,企业可以更好地监控生产线上的运行和质量情况,及时发现和解决问题,从而提高生产效率…

【大虾送书第七期】深入浅出SSD:固态存储核心技术、原理与实战

目录 ✨写在前面 ✨内容简介 ✨作者简介 ✨名人推荐 ✨文末福利 🦐博客主页:大虾好吃吗的博客 🦐专栏地址:免费送书活动专栏地址 写在前面 近年来国家大力支持半导体行业,鼓励自主创新,中国SSD技术和产业…

ChromeOS 的 Linux 操作系统和 Chrome 浏览器分离

导读科技媒体 Ars Technica 报道称,谷歌正在将 ChromeOS 的浏览器从操作系统中分离出来 —— 让它变得更像 Linux。虽然目前还没有任何官方消息,但这项变化可能会在本月的版本更新中推出。 据介绍,谷歌将该项目命名为 "Lacros"——…

数据驱动的生活:探索未来七天生活指数API的应用

前言 随着科技的不断发展,数据已经成为我们生活中不可或缺的一部分。从社交媒体上的点赞和分享,到电子邮件和搜索引擎的历史记录,数据正在以前所未有的速度积累。而这些数据的利用不仅仅停留在社交媒体或商业领域,它们还可以为我…

机器学习:争取被遗忘的权利

随着越来越多的人意识到他们通过他们经常访问的无数应用程序和网站共享了多少个人信息,数据保护和隐私一直在不断讨论。看到您与朋友谈论的产品或您在 Google 上搜索的音乐会迅速作为广告出现在您的社交媒体提要中,这不再那么令人惊讶。这让很多人感到担…

【炼气境】Java集合框架篇

【炼气境】Java集合框架篇 文章目录 【炼气境】Java集合框架篇概述接口Collection接口List接口ArrayList类LinkedList类 Set接口HashSet类LinkedHashSet类TreeSet类 Queue接口LinkedList类PriorityQueue类ArrayDeque Map接口HashMap类LinkedHashMap类TreeMap类 常用方法特性适用…

软件工程(八) UML之类图与对象图

1、类图与对象图 1.1、类图与对象图的概念 类图(class diagram)描述一组类、接口、协作和它们之间的关系 对象图(object diagram)描述一组对象及它们之间的关系、对象图描述了在类图中所建立的事物实例的静态快照。 1.2、类图与对象图的区别 类图和对象图基本上是一样…

大数据平台与数据仓库的五大区别

随着大数据的快速发展,很多人难以区分大数据平台与数据仓库的区别,两者傻傻分不清楚。今天我们小编就给大家汇总了大数据平台与数据仓库的五大区别,希望有用哦!仅供参考! 大数据平台与数据仓库的五大区别 一、概念不同…

如何把本地项目上传github

一、在gitHub上创建新项目 【1】点击添加()-->New repository 【2】填写新项目的配置项 Repository name:项目名称 Description :项目的描述 Choose a license:license 【3】点击确定,项目已在githu…

火狐浏览器使用scss嵌套编写css无法识别问题

火狐浏览器使用scss嵌套编写css无法识别问题 版本: “node-sass”: “^4.14.1”, “sass-loader”: “^7.3.1”,vue版本: v2问题描述: 我的文件目录是这样的: 而在scss文件中我是这样书写的 .vue文件中 在火狐浏览器中 在谷…

几个nlp的小任务(生成任务(摘要生成))

几个nlp的小任务生成任务——摘要生成 安装库选择模型加载数据集展示数据集数据预处理 tokenizer注意特殊的 token处理组成预处理函数调用map,对数据集进行预处理微调模型,设置参数设置数据收集器,将处理好的数据喂给模型封装测评方法将参数传给 trainer,开始训练安装库 选…

【git】Idea撤回本地分支、或远程分支提交记录的各种实际场景操作步骤

文章目录 idea撤回本地分支、远程分支场景操作集合场景1:要撤回最后一次本地分支的提交实现效果:操作步骤: 场景2:要撤回最后一次远程分支的提交有撤销记录的:实现效果:操作步骤: 无撤销记录的&…