动手学深度学习(Pytorch版)代码实践 -卷积神经网络-26网络中的网络NiN

26网络中的网络NiN

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt

# 定义一个NiN块
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
    return nn.Sequential(
        # 传统的卷积层
        nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
        nn.ReLU(),  # 激活函数ReLU
        # 1x1卷积层
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU(),  
        # 另一个1x1卷积层
        nn.Conv2d(out_channels, out_channels, kernel_size=1),
        nn.ReLU()   
    )

# 设置dropout的概率
dropout = 0.5 

# 定义NiN模型
net = nn.Sequential(
    # 第一个NiN块,输入通道数为1,输出通道数为96
    nin_block(1, 96, kernel_size=11, strides=4, padding=0),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 第二个NiN块,输入通道数为96,输出通道数为256
    nin_block(96, 256, kernel_size=5, strides=1, padding=2),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # 第三个NiN块,输入通道数为256,输出通道数为384
    nin_block(256, 384, kernel_size=3, strides=1, padding=1),
    # 最大池化层
    nn.MaxPool2d(kernel_size=3, stride=2),
    # Dropout层,用于防止过拟合
    nn.Dropout(dropout),

    # 最后一个NiN块,输入通道数为384,输出通道数为10
    nin_block(384, 10, kernel_size=3, strides=1, padding=1),
    # 全局平均池化层,将特征图的每个通道的空间维度调整为1x1
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten()
)

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
"""
Sequential output shape:         torch.Size([1, 96, 54, 54])
MaxPool2d output shape:  torch.Size([1, 96, 26, 26])
Sequential output shape:         torch.Size([1, 256, 26, 26])
MaxPool2d output shape:  torch.Size([1, 256, 12, 12])
Sequential output shape:         torch.Size([1, 384, 12, 12])
MaxPool2d output shape:  torch.Size([1, 384, 5, 5])
Dropout output shape:    torch.Size([1, 384, 5, 5])
Sequential output shape:         torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:  torch.Size([1, 10, 1, 1])
Flatten output shape:    torch.Size([1, 10])
"""

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size, resize=224)
lp.train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu())
plt.show()  # 显示绘图
# loss 0.342, train acc 0.873, test acc 0.871
# 1395.1 examples/sec on cuda:0

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/735606.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

兰州理工大学24计算机考研情况,好多专业都接受调剂,只有计算机专硕不接收调剂,复试线为283分!

兰州理工大学(Lanzhou University of Technology),位于甘肃省兰州市,是甘肃省人民政府、教育部、国家国防科技工业局共建高校,甘肃省高水平大学和“一流学科”建设高校;入选国家“中西部高校基础能力建设工…

年薪50w+的项目经理,手把手教你如何复盘

复盘是一种重要的学习和改进工具,对于项目经理来说,能帮助识别项目中的成功与失败,为未来的项目管理提供宝贵经验。 理论部分 定义目标。在开始复盘之前,明确复盘的目标是什么。是为了找出项目中的问题并提出解决方案&#xff0c…

Open MMLab 之 MMDetection3D框架

MMDetection框架入门教程(完全版)-CSDN博客 OpenMMLab MMDetection是商汤和港中文大学针对目标检测任务推出的一个开源项目,它基于Pytorch实现了大量的目标检测算法,把数据集构建、模型搭建、训练策略等过程都封装成了一个个模块…

Chromium 调试指南2024 Mac篇 - 准备工作 (一)

1.引言 Chromium是一个由Google主导开发的开源浏览器项目,它为Google Chrome浏览器提供了基础框架。Chromium不仅是研究和开发现代浏览器技术的重要平台,还为众多其他基于Chromium的浏览器(如Microsoft Edge、Brave等)提供了基础…

基于Openmv的色块识别代码及注意事项

在给出代码之前我先说注意事项以及需要用到的函数 1、白平衡和自动增益的关闭 打开白平衡和自动增益会影响颜色识别的效果,具体影响体现在可能使你颜色阈值发生改变 关闭代码如下 sensor.set_auto_gain(False) #关闭自动增益 sensor.set_whitebal(False) …

【笔记】HashMap的头插死循环问题

HashMap头插死循环是指在JDK1.7中,多线程环境下,HashMap进行扩容时由于多个线程一起执行扩容,可能会导致某一结点被错误插入头部并形成一个循环链表。 发生死循环的源码如下: // hashmap由数组链表构成 void transfer(Entry[] ne…

MySQL进阶——触发器

目录 1介绍 2语法 3案例 3.1 insert插入数据类型 3.2 update修改数据类型 3.3 delete删除数据类型 4视图/存储过程/触发器—小结 1介绍 触发器是与表有关的数据库对象,指在insert/update/delete之前(BEFORE)或之后(AFTER),触发并执行触发器中定义…

mysql启动时遇到:本地计算机上的MySQL服务启动后停止

1.问题重述: 今早启动数据库时发现无法启动,报错:本地计算机 上的 MySQL服务启动后停止。某些服务在未由其他服务或程序使用时将自动停止。 2.解决方案: 1.数据备份: 2.在bin目录下,命令行中输入 mysqld …

认识微服务

单体架构 单体架构:将业务的所有功能集中在一个项目中开发,打成一个包部署。 优点: 架构简单部署成本低缺点: 团队协作成本高系统发布效率低系统可用性差 总结: 单体架构适合开发功能相对简单,规模较小…

上位机图像处理和嵌入式模块部署(mcu开发注意事项)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 mcu的种类很多,生产的厂家很多。这里面有进口的,有国产的,并且国产替代的趋势越来越明显。但是,不管…

我国人工智能核心产业规模近6000亿元

以下文章来源:中国证券报 2024世界智能产业博览会6月20日至6月23日在天津举行。会上发布的《中国新一代人工智能科技产业发展报告2024》显示,我国人工智能企业数量已经超过4000家,人工智能已成为新一轮科技革命和产业变革的重要驱动力量和战略…

SAP PP学习笔记22 - 生产订单(制造指图)的元素1

前面几章讲了PP 里面生产计划的各种策略以及策略的Customize。 SAP PP学习笔记20 - 复习总结一下MTS,MTO,ATO的各种生产策略-CSDN博客 SAP PP学习笔记21 - 计划策略的Customize:策略组 > 策略 > 需求类型 &#xff1…

Java——集合(一)

前言: Collection集合,List集合 文章目录 一、Collection 集合1.1 集合和数组的区别1.2 集合框架1.3 Collection 集合常用方法1.4 Collction 集合的遍历 二、List 集合2.1 List 概述2.2 List集合的五种遍历方式2.3 List集合的实现类 一、Collection 集合 1.1 集合和…

正则表达式,linux文本三剑客

正则表达式匹配的是文本内容,linux的文本三剑客都是针对文本内容,按行进行匹配 文本三剑客: grep 过滤文本内容 sed 针对文本内容进行增删改查 awd 按行取列 一.grep命令 作用就是使用正则表达式来匹配文本内容 -m 数字:匹配…

什么是深度神经网络?与深度学习、机器学习、人工智能的关系是什么?

什么是深度神经网络?与深度学习、机器学习、人工智能的关系是什么? 🤖什么是深度神经网络?与深度学习、机器学习、人工智能的关系是什么?摘要引言正文内容1. 什么是深度神经网络?🧠1.1 深度神经…

git拉取gitee项目到本地

git安装等不做赘述。 根据需要选择不同操作 1.只是单纯拉取个项目,没有后续的追踪等操作 不需要使用git init初始化本地文件夹 新建一个文件夹用于存储项目,右键选择 git bash here 会出现命令行窗口 如果像我一样,只是拉取个项目作业&…

stm32学习笔记---GPIO输入(代码部分)按键控制LED/光敏传感器控制蜂鸣器

目录 第一个代码:按键控制LED 模块化程序 LED驱动程序 GPIO的四个读取函数 GPIO_ReadInputDataBit GPIO_ReadInputData GPIO_ReadOutputDataBit GPIO_ReadOutputData Key驱动程序 第二个代码:光敏传感器控制蜂鸣器 蜂鸣器驱动代码 光敏传感器…

[C++][设计模式][模板方法]详细讲解

目录 1.动机2.理解1.设计流程对比1.结构化软件设计流程2.面向对象软件设计流程 2.早绑定与晚绑定 3.模式定义4.要点总结5.代码感受1.代码一 -- 结构化1.lib.cpp2.app.cpp 2.代码二 -- 面向对象1.lib.cpp2.app.cpp 1.动机 在软件构建过程中,对于某一项任务&#xff…

“实习情报官”上万月薪!量化私募发力学界

选拔要求有“门道” 量化私募的招聘又有新花样! 顶尖高校的数理人才一直是量化机构紧盯的“资源”,包括各类奥赛奖牌得主、理科状元。 最头部的量化机构,也一直是高校学霸们流向的去处。 但是,当业内已有近30家百亿量化私募机…

LLVM——安装多版本LLVM和Clang并切换使用(Ubuntu)

1、描述 本机(Ubuntu22)已经安装了LLVM-14,但是需要使用LLVM-12。安装LLVM-12和Clang-12并切换使用。 2、过程 安装LLVM-12和Clang-12。 sudo apt-get install llvm-12 sudo apt-get install clang-12 【注】运行 sudo apt-get install ll…