基于pytorch的深度学习基础3——模型创建与nn.Module

三 模型创建与nn.Module

3.1 nn.Module

模型构建两要素:

  1. 构建子模块——__init()__
  2. 拼接子模块——forward()

一个module可以有多个module;

一个module相当于一个运算,都必须实现forward函数;

每一个module有8个字典管理属性。

self._parameters = OrderedDict()

self._buffers = OrderedDict()

self._backward_hooks = OrderedDict()

self._forward_hooks = OrderedDict()

self._forward_pre_hooks = OrderedDict()

self._state_dict_hooks = OrderedDict()

self._load_state_dict_pre_hooks = OrderedDict()

self._modules = OrderedDict()

3.2 网络容器

nn.Sequential()

是nn.Module()的一个容器,用于按照顺序包装一组网络层;

顺序性:网络层之间严格按照顺序构建;

自带forward():

各网络层之间严格按顺序执行,常用于block构建

class LeNetSequential(nn.Module):

    def __init__(self, classes):

        super(LeNetSequential, self).__init__()

        self.features = nn.Sequential(

            nn.Conv2d(3, 6, 5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(6, 16, 5),

            nn.ReLU(),

            nn.MaxPool2d(kernel_size=2, stride=2),)

        self.classifier = nn.Sequential(

            nn.Linear(16*5*5, 120),

            nn.ReLU(),

            nn.Linear(120, 84),

            nn.ReLU(),

            nn.Linear(84, classes),)

    def forward(self, x):

        x = self.features(x)

        x = x.view(x.size()[0], -1)

        x = self.classifier(x)

        return x

nn.ModuleList()

是nn.Module的容器,用于包装网络层,以迭代方式调用网络层。

主要方法:

append():在ModuleList后面添加网络层;

extend():拼接两个ModuleList.

Insert():指定在ModuleList中插入网络层。

nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建

class ModuleList(nn.Module):

    def __init__(self):

        super(ModuleList, self).__init__()

        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):

        for i, linear in enumerate(self.linears):

            x = linear(x)

        return x

nn.ModuleDict()

以索引方式调用网络层

主要方法:

• clear():清空ModuleDict

• items():返回可迭代的键值对(key-value pairs)

• keys():返回字典的键(key)

• values():返回字典的值(value)

• pop():返回一对键值,并从字典中删除

n.ModuleDict:索引性,常用于可选择的网络层

class ModuleDict(nn.Module):

    def __init__(self):

        super(ModuleDict, self).__init__()

        self.choices = nn.ModuleDict({

            'conv': nn.Conv2d(10, 10, 3),

            'pool': nn.MaxPool2d(3)

        })

        self.activations = nn.ModuleDict({

            'relu': nn.ReLU(),

            'prelu': nn.PReLU()

        })

    def forward(self, x, choice, act):

        x = self.choices[choice](x)

        x = self.activations[act](x)

        return x

3.3卷积层

nn.ConV2d()

nn.Conv2d(in_channels, out_channels,kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros')

in_channels:输入通道数,比如RGB图像是3,而后续的网络层的输入通道数为前一卷积层的输出通道数;

out_channels:输出通道数,等价于卷积核个数

kernel_size:卷积核尺寸

stride:步

padding:填充个数

dilation:空洞卷积大小

groups:分组卷积设置

bias:偏置

    conv_layer = nn.Conv2d(3, 1, 3)   # input:(i, o, size) weights:(o, i , h, w)

    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation

    img_conv = conv_layer(img_tensor)

这里使用 input*channel 为 3,output_channel 为 1 ,卷积核大小为 3×3 的卷积核nn.Conv2d(3, 1, 3),使用nn.init.xavier_normal*()方法初始化网络的权值。

我们通过`conv_layer.weight.shape`查看卷积核的 shape 是`(1, 3, 3, 3)`,对应是`(output_channel, input_channel, kernel_size, kernel_size)`。所以第一个维度对应的是卷积核的个数,每个卷积核都是`(3,3,3)`。虽然每个卷积核都是 3 维的,执行的却是 2 维卷积。

转置卷积nn.ConvTranspose2d

转置卷积又称为反卷积(Deconvolution)和部分跨越卷积(Fractionally-stridedConvolution) ,用于对图像进行上采样(UpSample)

为什么称为转置卷积?

假设图像尺寸为4*4,卷积核为3*3,padding=0,stride=1

正常卷积:

转置卷积:

假设图像尺寸为2*2,卷积核为3*3,padding=0,stride=1

nn.ConvTranspose2d(in_channels, out_channels,

kernel_size,

stride=1,

padding=0,

output_padding=0,

groups=1,

bias=True,

dilation=1, padding_mode='zeros')

输出尺寸计算:

# flag = 1

flag = 0

if flag:

    conv_layer = nn.ConvTranspose2d(3, 1, 3, stride=2)   # input:(i, o, size)

    nn.init.xavier_normal_(conv_layer.weight.data)

    # calculation

    img_conv = conv_layer(img_tensor)

print("卷积前尺寸:{}\n卷积后尺寸:{}".format(img_tensor.shape, img_conv.shape))

img_conv = transform_invert(img_conv[0, 0:1, ...], img_transform)

img_raw = transform_invert(img_tensor.squeeze(), img_transform)

plt.subplot(122).imshow(img_conv, cmap='gray')

plt.subplot(121).imshow(img_raw)

plt.show()

3.4池化层nn.MaxPool2d && nn.AvgPool2d

池化运算:对信号进行 “收集”并 “总结”,类似水池收集水资源,因而

得名池化层

“收集”:多变少

“总结”:最大值/平均值

nn.MaxPool2d

nn.MaxPool2d(kernel_size, stride=None,

padding=0, dilation=1,

return_indices=False,

ceil_mode=False)

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

• dilation:池化核间隔大小

• ceil_mode:尺寸向上取整

• return_indices:记录池化像素索引

# flag = 1

flag = 0

if flag:

    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2))   # input:(i, o, size) weights:(o, i , h, w)

    img_pool = maxpool_layer(img_tensor)

nn.AvgPool2d

nn.AvgPool2d(kernel_size,

stride=None,

padding=0,

ceil_mode=False,

count_include_pad=True,

divisor_override=None)

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

• ceil_mode:尺寸向上取整

• count_include_pad:填充值用于计算

• divisor_override :除法因子

    avgpoollayer = nn.AvgPool2d((2, 2), stride=(2, 2))   # input:(i, o, size) weights:(o, i , h, w)

    img_pool = avgpoollayer(img_tensor)

    img_tensor = torch.ones((1, 1, 4, 4))

    avgpool_layer = nn.AvgPool2d((2, 2), stride=(2, 2), divisor_override=3)

    img_pool = avgpool_layer(img_tensor)

    print("raw_img:\n{}\npooling_img:\n{}".format(img_tensor, img_pool))

nn.MaxUnpool2d

功能:对二维信号(图像)进行最大值池化

上采样

主要参数:

• kernel_size:池化核尺寸

• stride:步长

• padding :填充个数

    # pooling

    img_tensor = torch.randint(high=5, size=(1, 1, 4, 4), dtype=torch.float)

    maxpool_layer = nn.MaxPool2d((2, 2), stride=(2, 2), return_indices=True)

    img_pool, indices = maxpool_layer(img_tensor)

    # unpooling

    img_reconstruct = torch.randn_like(img_pool, dtype=torch.float)

    maxunpool_layer = nn.MaxUnpool2d((2, 2), stride=(2, 2))

    img_unpool = maxunpool_layer(img_reconstruct, indices)

    print("raw_img:\n{}\nimg_pool:\n{}".format(img_tensor, img_pool))

    print("img_reconstruct:\n{}\nimg_unpool:\n{}".format(img_reconstruct, img_unpool))

3.5线性层

nn.Linear(in_features, out_features, bias=True)

功能:对一维信号(向量)进行线性组合

主要参数:

• in_features:输入结点数

• out_features:输出结点数

• bias :是否需要偏置

计算公式:y = 𝒙𝑾𝑻 + 𝒃𝒊𝒂s

    inputs = torch.tensor([[1., 2, 3]])

    linear_layer = nn.Linear(3, 4)

    linear_layer.weight.data = torch.tensor([[1., 1., 1.],

                                             [2., 2., 2.],

                                             [3., 3., 3.],

                                             [4., 4., 4.]])

    linear_layer.bias.data.fill_(0.5)

    output = linear_layer(inputs)

    print(inputs, inputs.shape)

    print(linear_layer.weight.data, linear_layer.weight.data.shape)

    print(output, output.shape)

3.6 激活函数层

nn.Sigmoid

nn.tanh:

nn.ReLU

nn.LeakyReLU

negative_slope: 负半轴斜率

nn.PReLU

init: 可学习斜率

nn.RReLU

lower: 均匀分布下限

upper:均匀分布上限

参考资料

深度之眼课程

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/943232.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智慧农业物联网传感器:开启农业新时代

在当今科技飞速发展的时代,农业领域正经历着一场前所未有的变革,而智慧农业物联网传感器无疑是这场变革中的关键利器。它宛如农业的 “智慧大脑”,悄然渗透到农业生产的各个环节,为传统农业注入了全新的活力,让农业生产…

OpenLayers实现渐变透明填充和光效边界

之前在cesium中做过多边形的填充使用渐变透明的效果,那个时候使用的是着色器,利用距离中心点的距离去写shader函数,距离中心越远颜色透明度越高,那么本文我们在openlayers中来实现这一过程。老规矩还是先来看一下效果: 好接下来开始讲述原理,首先关于边界发光的原理我在O…

低代码开发中 DDD 领域驱动的页面权限控制

在低代码开发的领域中,应用安全与灵活性是两大关键考量因素。领域驱动设计(DDD)作为一种在软件设计领域广泛应用且颇具影响力的方法论,正逐渐在低代码开发的页面权限控制方面展现出其独特的价值与潜力。本文旨在客观地探讨如何借助…

B端UI设计规范是什么?

一、B端UI设计规范是什么? B端UI设计规范是一套针对企业级应用界面设计的全面规则和标准,旨在确保产品界面的一致性、可用性和用户体验。 二、B端UI设计规范要素说明 B端UI设计的基本要素包括设计原则、主题、布局、颜色、字体、图标、按钮和控件、交互…

GitLab 服务变更提醒:中国大陆、澳门和香港用户停止提供服务(GitLab 服务停止)

目录 前言 一. 变更详情 1. 停止服务区域 2. 邮件通知 3. 新的服务提供商 4. 关键日期 5. 行动建议 二. 迁移指南 三. 注意事项 四. 相关推荐 前言 近期,许多位于中国大陆、澳门和香港的 GitLab 用户收到了一封来自 GitLab 官方的重要通知。根据这封邮件…

nginx Rewrite 相关功能

一、Nginx Rewrite 概述 定义 Nginx 的 Rewrite 模块允许对请求的 URI 进行重写操作。它可以基于一定的规则修改请求的 URL 路径,然后将请求定向到新的 URL 地址,这在很多场景下都非常有用,比如实现 URL 美化、网站重构后的 URL 跳转等。主要…

适用于Synology NAS的在线办公套件:ONLYOFFICE安装指南

使用 Synology NAS 上的 ONLYOFFICE 文档,您能在私有云中直接编辑文本文档、电子表格、演示文稿和 PDF,确保工作流程既安全又高效。本指南将分步介绍如何在 Synology 上安装 ONLYOFFICE 文档。 关于 Synology Synology NAS(网络附加存储&…

[按键精灵IOS安卓版][脚本基础知识]按键post基本写法

这一期我们来讲按键post的写法,希望通过本期的学习,实现常见的post提交都能编写。 下面开始讲解: 一、使用的命令:url.httppost 选用这个命令的理由是它的参数比较全。 二、post请求都有哪些参数(可能用到&#xf…

如何检查交叉编译器gcc工具链里是否有某个库(以zlib库和libpng库为例)

freetype 依赖于 libpng,libpng 又依赖于 zlib,所以我们应该:先编译 安装 zlib,再编译安装 libpng,最后编译安装 freetype。 但是,有些交叉编译器工具链里已经有 zlib库和freetype,所以我们需要…

3D几何建模引擎Parasolid功能解析

一、什么是Parasolid? Parasolid是由Siemens PLM Software开发的高精度精密几何建模引擎。它全面评估CAD(计算机辅助设计)、CAM(计算机辅助制造)、CAE(计算机辅助工程)、PLM(产品生…

基于STM32单片机矿井矿工作业安全监测设计

基于STM32单片机矿井矿工作业安全监测设计 目录 项目开发背景设计实现的功能项目硬件模块组成设计思路系统功能总结使用的模块技术详情介绍总结 1. 项目开发背景 随着矿井矿工作业环境的复杂性和危险性逐渐增加,矿井作业安全问题引起了社会各界的广泛关注。传统的…

linux-22 目录管理(二)rmdir命令,删除目录

那接下来我们来看看我们如何去删除目录?那接下来我们来看看我们如何去删除目录?叫remove,remove表示移除的意思,remove directory叫移除目录。所以简写为rmdir,但需要注意,它只能删除空目录,只能…

计算机考研选西电还是成电?

谢邀~先来个总结:电子科技大学计算机综合实力优于西安电子科技大学,但是,二者计算机学硕考研难度没有太大差距,而且考试难度也同属于一个水平,成电性价比更高一些!推荐同学优先报考作为985的电子科技大学&a…

基于YOLOV5+Flask安全帽RTSP视频流实时目标检测

1、背景 在现代工业和建筑行业中,安全始终是首要考虑的因素之一。特别是在施工现场,工人佩戴安全帽是确保人身安全的基本要求。然而,人工监督难免会有疏漏,尤其是在大型工地或复杂环境中,确保每个人都佩戴安全帽变得非…

oscp学习之路,Kioptix Level2靶场通关教程

oscp学习之路,Kioptix Level2靶场通关教程 靶场下载:Kioptrix Level 2.zip 链接: https://pan.baidu.com/s/1gxVRhrzLW1oI_MhcfWPn0w?pwd1111 提取码: 1111 搭建好靶场之后输入ip a看一下攻击机的IP。 确定好本机IP后,使用nmap扫描网段&…

yii2 手动添加 phpoffice\phpexcel

1.下载地址:https://github.com/PHPOffice/PHPExcel 2.解压并修改文件名为phpexcel 在yii项目的vendor目录下创建一个文件夹命名为phpoffice 把phpexcel目录放到phpoffic文件夹下 查看vendor\phpoffice\phpexcel目录下会看到这些文件 3.到vendor\composer目录下…

弹性盒子(display: flex)布局超全讲解|Flex 布局教程

文章目录 弹性盒子flex什么是弹性布局?弹性布局的特点?justify-contentalign-itemflex-direction (主轴的方向:水平或者垂直)flex-wrapflex-flowalign-contentflex-grow 属性flex-shrink 属性flex-basis 属性flex 属性align-self 属性 弹性盒…

基于c语言的union、字符串、格式化输入输出

结构体之共用体union 共用体也叫联合体,其关键词为union 与结构体不同的是,共用体所开辟的存储空间仅仅为其中最长类型变量的存储空间而不是全部变量的存储空间,由于同一内存单元在同一时间内只能存放其中一种的数据类型,因此在每…

【全栈开发】----用pymysql库连接MySQL,批量存入

本文基于前面的MySQL基础语句使用,还不会的宝子可以先回去看看: 全栈开发----Mysql基本配置与使用-CSDN博客 仅仅用控制台命令对数据库进行操作,虽然大部分操作都很简单,但对于大量数据的存入,存储数据将会变得很繁琐&…

PyQt实战——使用python提取JSON数据(十)

系类往期文章: PyQt5实战——多脚本集合包,前言与环境配置(一) PyQt5实战——多脚本集合包,UI以及工程布局(二) PyQt5实战——多脚本集合包,程序入口QMainWindow(三&…