【深度学习笔记】5_8 网络中的网络NiN

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.8 网络中的网络(NiN)

前几节介绍的LeNet、AlexNet和VGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽(增加通道数)和加深。本节我们介绍网络中的网络(NiN)[1]。它提出了另外一个思路,即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。

5.8.1 NiN块

我们知道,卷积层的输入和输出通常是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。回忆在5.3节(多输入通道和多输出通道)里介绍的 1 × 1 1\times 1 1×1卷积层。它可以看成全连接层,其中空间维度(高和宽)上的每个元素相当于样本,通道相当于特征。因此,NiN使用 1 × 1 1\times 1 1×1卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去。图5.7对比了NiN同AlexNet和VGG等网络在结构上的主要区别。

在这里插入图片描述

图5.7 左图是AlexNet和VGG的网络结构局部,右图是NiN的网络结构局部

NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的 1 × 1 1\times 1 1×1卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二和第三个卷积层的超参数一般是固定的。

import time
import torch
from torch import nn, optim

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def nin_block(in_channels, out_channels, kernel_size, stride, padding):
    blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU(),
                        nn.Conv2d(out_channels, out_channels, kernel_size=1),
                        nn.ReLU())
    return blk

5.8.2 NiN模型

NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之处。NiN使用卷积窗口形状分别为 11 × 11 11\times 11 11×11 5 × 5 5\times 5 5×5 3 × 3 3\times 3 3×3的卷积层,相应的输出通道数也与AlexNet中的一致。每个NiN块后接一个步幅为2、窗口形状为 3 × 3 3\times 3 3×3的最大池化层。

除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。

# 已保存在d2lzh_pytorch
import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):
    # 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现
    def __init__(self):
        super(GlobalAvgPool2d, self).__init__()
    def forward(self, x):
        return F.avg_pool2d(x, kernel_size=x.size()[2:])

net = nn.Sequential(
    nin_block(1, 96, kernel_size=11, stride=4, padding=0),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(96, 256, kernel_size=5, stride=1, padding=2),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nin_block(256, 384, kernel_size=3, stride=1, padding=1),
    nn.MaxPool2d(kernel_size=3, stride=2), 
    nn.Dropout(0.5),
    # 标签类别数是10
    nin_block(384, 10, kernel_size=3, stride=1, padding=1),
    GlobalAvgPool2d(), 
    # 将四维的输出转成二维的输出,其形状为(批量大小, 10)
    d2l.FlattenLayer())

我们构建一个数据样本来查看每一层的输出形状。

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): 
    X = blk(X)
    print(name, 'output shape: ', X.shape)

输出:

0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 26, 26])
3 output shape:  torch.Size([1, 256, 12, 12])
4 output shape:  torch.Size([1, 384, 12, 12])
5 output shape:  torch.Size([1, 384, 5, 5])
6 output shape:  torch.Size([1, 384, 5, 5])
7 output shape:  torch.Size([1, 10, 5, 5])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])

5.8.3 获取数据和训练模型

我们依然使用Fashion-MNIST数据集来训练模型。NiN的训练与AlexNet和VGG的类似,但这里使用的学习率更大。

batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)

lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0101, train acc 0.513, test acc 0.734, time 260.9 sec
epoch 2, loss 0.0050, train acc 0.763, test acc 0.754, time 175.1 sec
epoch 3, loss 0.0041, train acc 0.808, test acc 0.826, time 151.0 sec
epoch 4, loss 0.0037, train acc 0.828, test acc 0.827, time 151.0 sec
epoch 5, loss 0.0034, train acc 0.839, test acc 0.831, time 151.0 sec

小结

  • NiN重复使用由卷积层和代替全连接层的 1 × 1 1\times 1 1×1卷积层构成的NiN块来构建深层网络。
  • NiN去除了容易造成过拟合的全连接输出层,而是将其替换成输出通道数等于标签类别数的NiN块和全局平均池化层。
  • NiN的以上设计思想影响了后面一系列卷积神经网络的设计。

参考文献

[1] Lin, M., Chen, Q., & Yan, S. (2013). Network in network. arXiv preprint arXiv:1312.4400.


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/434019.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vue.js+SpringBoot开发森林火灾预警系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 系统基础模块2.3 烟雾传感器模块2.4 温度传感器模块2.5 历史记录模块2.6 园区数据模块 三、系统设计3.1 用例设计3.1.1 森林园区基础系统用例设计3.1.2 森林预警数据用例设计 3.2 数据库设计3.2.1 烟雾…

[PyQt5]PyQt5连接MYSQL时显示Driver not loaded解决方案

在第一次用PyQt5的 QSqlDatabase.addDatabase 连接mysql的时候,可能会出现Driver not loaded的情况,如下: from PyQt5.QtSql import QSqlQuery, QSqlDatabase from PyQt5.QtWidgets import QApplication import sysapp QApplication(sys.ar…

2024.3.6

#include<myhead.h>int do_add(sqlite3* ppDb) {int numb;char name[10];int salary;printf("请输入员工的信息&#xff1a;");scanf("%d %s %d",&numb, name, &salary);//将员工的信息存储到数组char sql_add[128] "";sprintf(s…

【K哥爬虫普法】二十五岁 人大本硕 腾讯在职 爬虫被捕!

我国目前并未出台专门针对网络爬虫技术的法律规范&#xff0c;但在司法实践中&#xff0c;相关判决已屡见不鲜&#xff0c;K 哥特设了“K哥爬虫普法”专栏&#xff0c;本栏目通过对真实案例的分析&#xff0c;旨在提高广大爬虫工程师的法律意识&#xff0c;知晓如何合法合规利用…

你知道katalon studio 如何完成 get/post 请求发送吗?

katalon studio作为目前最火的自动化测试工具之一&#xff0c;不仅仅只能完成webUI自动化&#xff0c;更是能完成api、app以及桌面应用程序的自动化测试。 本文将讲解一下katalon studio是如果完成接口测试的。 请求发送 get请求 1、先在object repository里new一个请求 2、…

CSS盒子模型笔记

尚硅谷学习视频链接&#xff1a;117_CSS_盒子模型的组成部分_哔哩哔哩_bilibili 1、盒子组成 盒子组成 content内容 padding border &#xff08;margin不包含在盒子内&#xff09; 2、div样式width、height 当css3属性box-sizingcontent-box&#xff08;默认&#xff0…

64位Office API声明语句第116讲

跟我学VBA&#xff0c;我这里专注VBA, 授人以渔。我98年开始&#xff0c;从源码接触VBA已经20余年了&#xff0c;随着年龄的增长&#xff0c;越来越觉得有必要把这项技能传递给需要这项技术的职场人员。希望职场和数据打交道的朋友&#xff0c;都来学习VBA,利用VBA,起码可以提高…

逻辑代数基础(二)(卡诺图)

目录 逻辑图表示 卡诺图表示 卡诺图的标准格式 二变量卡诺图 三变量卡诺图 四变量卡诺图 卡诺图表示逻辑函数 从逻辑表达式到卡诺图 逻辑代数的三个规则 代入规则 反演规则 对偶规则 逻辑函数的化简方式 化简逻辑函数的意义 逻辑函数最简表示式的判别标准 公式化简法 并…

【Software Platform Bundle】

https://www.ni.com/zh-cn/support/downloads/software-products/download.software-platform-bundle.html

【python】六个常见爬虫案例【附源码】

大家好&#xff0c;我是博主英杰&#xff0c;整理了几个常见的爬虫案例&#xff0c;分享给大家&#xff0c;适合小白学习 一、爬取豆瓣电影排行榜Top250存储到Excel文件 近年来&#xff0c;Python在数据爬取和处理方面的应用越来越广泛。本文将介绍一个基于Python的爬虫程序&a…

linux系统的进程管理

文章目录 前言一、系统的进程的运转方式1、系统时间&#xff1a;&#xff08;jiffies 系统滴答&#xff09;2、task_struct 二、如何创建一个新的进程&#xff08;重要&#xff09;三、进程调度①、主要函数②、辅助函数 四、进程的退出内核的销毁 前言 本文讲解系统的进程管理…

LeetCode142题:环形链表II(python3)

代码思路&#xff1a; 双指针的第一次相遇&#xff1a; 设两指针 fast&#xff0c;slow 指向链表头部 head 。 令 fast 每轮走 2 步&#xff0c;slow 每轮走 1 步。 fast 指针走过链表末端&#xff0c;说明链表无环&#xff0c;此时直接返回 null。 如果链表存在环&#xff0c;…

学习Java的第二天

如何使用文本文档在cmd里打印出HelloWorld 1、创建一个文本文档&#xff0c;并命名为HelloWorld&#xff0c;将后缀改为java&#xff08;需要自己去把后缀打开显示出来&#xff09; 2、打开编辑 也可以双击打开 3、在里面写出以下代码 上面红框里为你要打印的语句&#xff0c;…

Mybatis-Plus——05,乐观锁(新注解)

乐观锁&#xff08;新注解&#xff09; 一、数据库添加一个字段二、实体类添加version注解三、注册乐观锁插件四、测试一下4.1成功的乐观锁4.2失败的乐观锁————————创作不易&#xff0c;笔记不易&#xff0c;如觉不错&#xff0c;请三连&#xff0c;谢谢~~ 乐观锁实现方…

zabbix监控中间件服务

zabbix监控Nginx 自定义nginx访问量的监控项&#xff0c;首先要通过脚本将各种状态的值取出来&#xff0c;然后通过zabbix监控。找到自定义脚本上传到指定目录/etc/zabbix/script/ 在zbx-client客户端主机操作 #创建目录&#xff0c;然后将脚本上传到该目录mkdir /etc/zabbix/…

信息安全与阿里云等保三级方案实践总结

信息安全在当今数字化时代变得至关重要&#xff0c;企业和组织需要采取有效措施来保护其数据和信息资产。阿里云作为中国领先的云服务提供商&#xff0c;提供了等保三级方案&#xff0c;帮助用户满足国家信息安全等级保护的要求。本文将探讨信息安全和阿里云等保三级方案的重要…

使用vite创建一个vue3项目

创建一个vue3项目 1.使用命令npm create vuelatest来创建一个vue3项目&#xff0c;注意&#xff1a;官网说明了必须node版本是18及以上的&#xff0c;这边需要注意下 2.然后根据提示进入项目目录 先npm install安装依赖&#xff0c;然后npm run dev启动项目 大家可以看到&am…

TCP与UDP基础

思维导图&#xff1a; TCP&#xff1a; 服务器 #include<myhead.h> #define SER_IP "192.168.252.163" #define SER_PORT 6666 int main(int argc, const char *argv[]) {//&#xff11;、创建用于监听的套接字int sfd-1;sfdsocket(AF_INET,SOCK_STREAM,0);/…

2023 电脑PC FetchV网页视频下载器 浏览器插件

FetchV&#xff0c;它可以下载网页视频&#xff0c;下载速度快到离谱&#xff0c;非常好用&#xff01; FetchV:网页视频下载器(HLS|m3u8|mp4|blob) 下载和录制各种格式的在线网页视频&#xff0c;包括HLS、m3u8、blob、mp4、webm等各种类型的视频。 这是一个通用的网页视频…

保留数据的重装系统教程!(win11系统)

上车警告&#xff01;&#xff01;&#xff01; 本教程无需思考&#xff0c;跟着操作一步一步来就能完成系统的重装。原理是将C盘系统重装&#xff0c;其他盘符数据保存。适用于系统盘重装数据或更改系统版本。 重要提示&#xff01;&#xff01;&#xff01; C盘有重要学习资…