7.3 详解NiN模型--首次使用多层感知机(1x1卷积核)替换掉全连接层的模型

一.前提知识

多层感知机:由一个输入层,一个或多个隐藏层和一个输出层组成。(至少有一个隐藏层,即至少3层)

全连接层:是MLP的一种特殊情况,每个节点都与前一层的所有节点连接,全连接层可以解决线性可分问题,无法学习到非线性特征。(只有输入和输出层)

二.NiN模型特点

NiN与过去模型的区别:AlexNet和VGG对LeNet的改进在于如何扩大加深这两个模块。他们都使用了全连接层,使用全连接层就可能完全放弃表征的空间结构。
NiN放弃了使用全连接层,而是使用两个1x1卷积层(将空间维度中的每个像素视为单个样本,将通道维度视为不同特征。),相当于在每个像素的通道上分别使用多层感知机

优点:NiN去除了全连接层,可以减少过拟合,同时显著减少NiN的参数数量

三.模型架构

在这里插入图片描述

四.代码

import torch
from torch import nn
from d2l import torch as d2l
import time
def nin_block(in_channels,out_channels,kernel_size,strides,padding):
    return nn.Sequential(
        # 卷积层
        nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),
        nn.ReLU(),
        # 两个带有ReLU激活函数的 1x1卷积层
        nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),
        nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU()
    )
net = nn.Sequential(
    nin_block(1,96,kernel_size=11,strides=4,padding=0),
    nn.MaxPool2d(3,stride=2),
    nin_block(96,256,kernel_size=5,strides=1,padding=2),
    nn.MaxPool2d(3,stride=2),
    nin_block(256,384,kernel_size=3,strides=1,padding=1),
    nn.MaxPool2d(3,stride=2),
    nn.Dropout(0.5),
    # 标签类别是10
    nin_block(384,10,kernel_size=3,strides=1,padding=1),
    # 二维自适应平均池化,不用指定池化窗口大小
    nn.AdaptiveAvgPool2d((1,1)),
    # 将(样本,通道,w,h) = (批量,10,1,1),四维的输出转成2维的输出,其形状为(批量大小,10)
    nn.Flatten()
)
X = torch.rand(size=(1,1,224,224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

六.不同参数训练结果

学习率是0.1的情况

# 训练模型
lr,num_epochs,batch_size = 0.1,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

学习率是0.05的情况(提升了6个点)

'''开始计时'''
start_time = time.time()
# 训练模型
lr,num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
print(f'{round(run_time,2)}s')

在这里插入图片描述

学习率为0.01,批次等于30的情况(反而下降了)

在这里插入图片描述

思考

为什么NiN块中有两个1x1卷积层?

从NiN替换掉全连接层,使用多层感知机角度来说:
因为1个1x1卷基层相当于全连接层,两个1x1卷积层使输入和输出层中间有了隐藏层,才相当于多层感知机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/73411.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

7.5.tensorRT高级(2)-RAII接口模式下的生产者消费者多batch实现

目录 前言1. RAII接口模式封装生产者消费者2. 问答环节总结 前言 杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。 本次课程学习 tensorRT 高级-RAI…

前后端分离------后端创建笔记(05)用户列表查询接口(上)

本文章转载于【SpringBootVue】全网最简单但实用的前后端分离项目实战笔记 - 前端_大菜007的博客-CSDN博客 仅用于学习和讨论,如有侵权请联系 源码:https://gitee.com/green_vegetables/x-admin-project.git 素材:https://pan.baidu.com/s/…

City Walk带动茶饮品牌售1200万,媒介盒子带你探究奥秘

年轻人生活趋势又出现了一个新鲜词——City Walk,简单来说,City Walk就是没有目的地,没有目标,只是出行,填充自己的生活。 其实这个词源于gap year,而这个说法一直是国外的一些毕业生,大多会在…

plt取消坐标轴刻度、自定义取消绘图边框(或坐标轴)、白边处理、自定义颜色图谱、设置坐标轴刻度朝向

目录 1、取消坐标轴刻度 2、自定义取消绘图边框(或坐标轴) 3、去掉图片周边白边 4、自定义颜色图谱 5、设置坐标轴刻度朝向 import matplotlib.pyplot as plt 1、取消坐标轴刻度 ax plt.subplot() ax.set_xticks([]) ax.set_yticks([]) 2、自定…

自定义 视频/音频 进度条

复制代码根据自己需求改动就可以了 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><metaname"viewport"conten…

模仿火星科技 基于cesium+水平面积测量+可编辑

​ 当您进入Cesium的编辑水平积测量世界&#xff0c;下面是一个详细的操作过程&#xff0c;帮助您顺利使用这些功能&#xff1a; 1. 创建提示窗&#xff1a; 启动Cesium应用&#xff0c;地图场景将打开&#xff0c;欢迎您进入编辑模式。 在屏幕的一角&#xff0c;一个友好的提…

计算机网络:网络通信相关概念入门

目录 一、网络发展背景二、理解网络通信三、理解IP地址1.简述IP地址2.IP地址的版本3.提高地址利用率的技术 四、理解端口1.简述端口2.使用端口的原因 五、理解网络通信协议 一、网络发展背景 网络发展背景&#xff1a; 最初的计算机是单机&#xff0c;那么单机是这样传输数据的…

【金融量化】Python实现根据收益率计算累计收益率并可视化

1 理论 理财产品&#xff08;本金100元&#xff09; 第1天&#xff1a;3% &#xff1a;&#xff08;13%&#xff09; ✖ 100 103 第2天&#xff1a;2% &#xff1a;&#xff08;12%&#xff09;✖ 以上 103 2.06 第3天&#xff1a;5% : &#xff08;15%&#xff09;✖ 以上…

企业直播MR虚拟直播(MR混合现实直播技术)视频介绍

到底什么是企业直播MR虚拟直播&#xff08;MR混合现实直播技术&#xff09;&#xff1f; 企业直播MR虚拟直播新玩法&#xff08;MR混合现实直播技术&#xff09; 我的文章推荐&#xff1a; [视频图文] 线上研讨会是什么&#xff0c;企业对内对外培训可以用线上研讨会吗&#x…

2023年新学期12306高铁火车学生票如何在线核验享受优惠?

2023学年优惠资质核验已开始&#xff0c;完成学生优惠资质核验后&#xff0c;您可以在线购买2022年10月1日至2023年9月30日的学生优惠票。&#xff08;注&#xff1a;非该时间段需要重新核验&#xff0c;可享受学生优惠票&#xff09;&#xff1b; 『扩展阅读』 1、美团外卖红…

QT:UI控件(按设计师界面导航界面排序)

基础部分 创建新项目&#xff1a;QWidget&#xff0c;QMainWindow&#xff0c;QDialog QMainWindow继承自QWidget&#xff0c;多了菜单栏; QDialog继承自QWidget&#xff0c;多了对话框 QMainWindow 菜单栏和工具栏&#xff1a; Bar: 菜单栏&#xff1a;QMenuBar&#xff0…

【Sklearn】基于随机森林算法的数据分类预测(Excel可直接替换数据)

【Sklearn】基于随机森林算法的数据分类预测&#xff08;Excel可直接替换数据&#xff09; 1.模型原理1.1 模型原理1.2 数学模型 2.模型参数3.文件结构4.Excel数据5.下载地址6.完整代码7.运行结果 1.模型原理 随机森林&#xff08;Random Forest&#xff09;是一种集成学习方法…

【深度学习】PyTorch快速入门

【深度学习】学习PyTorch基础 介绍PyTorch 深度学习框架是一种软件工具&#xff0c;旨在简化和加速构建、训练和部署深度学习模型的过程。深度学习框架提供了一系列的函数、类和工具&#xff0c;用于定义、优化和执行各种深度神经网络模型。这些框架帮助研究人员和开发人员专注…

RabbitMQ 安装教程

RabbitMQ 安装教程 特殊说明 因为RabbitMQ基于Erlang开发&#xff0c;所以安装时需要先安装Erlang RabbitMQ和Erlang版本对应关系 查看地址&#xff1a;www.rabbitmq.com/which-erlan… 环境选择 Erlang: 23.3及以上 RabbitMQ: 3.10.1Windows 安装 1. 安装Erlang 下载地…

Spark 学习记录

基础 SparkContext是什么&#xff1f;有什么作用&#xff1f; https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么&#xff1f; SparkContext 是通往 Spark 集群的唯一入口&#xff0c;可以用来在 Spark 集群中创建 RDDs 、累加和广播变量( Br…

【css】渐变

渐变是设置一种颜色或者多种颜色之间的过度变化。 两种渐变类型&#xff1a; 线性渐变&#xff08;向下/向上/向左/向右/对角线&#xff09; 径向渐变&#xff08;由其中心定义&#xff09; 1、线性渐变 语法&#xff1a;background-image: linear-gradient(direction, co…

黑客必备的操作系统——kali linux安装

大家经常会在电视里面看到各种炫酷的黑客操作&#xff0c;那么黑客一般用什么操作系统呢&#xff1f;今天小训带大家来安装黑客必备的kali linux-2022操作系统&#xff0c;有兴趣的一起来学习下吧&#xff01; 1、安装前准备 1.1 VMware下载 VMware官网下载&#xff1a; ht…

开源力量再现,国产操作系统商业化的全新探索

文章目录 1. 开源运动的兴起2. 开源力量的推动3. 国产操作系统的崭露头角3.1 国产操作系统有哪些 4.国产操作系统的商业化探索5.开源力量对国产操作系统商业化的推动 操作系统作为连接硬件、中间件、数据库、应用软件的纽带&#xff0c;被认为是软件技术体系中最核心的基础软件…

linux4.0新增32位ARM的系统调用

修改内核源码 Linux系统为每一个系统调用赋予一个系统调用号。当应用程序执行一个系统调用时&#xff0c;应用程序就可以知道执行和调用到哪个系统调用了&#xff0c;从而不会造成混乱。系统调用号一旦分配之后就不会有任何变更&#xff0c;否则已经编译好的应用程序就不能运行…