动手学深度学习(pytorch)学习记录30-含并行连接的网络(GoogLeNet)[学习记录]

目录

  • GoogLeNet
  • Inception块
  • GoogLeNet模型
  • 训练模型

GoogLeNet

GoogLeNet,也称为Inception v1,是由Google团队在2014年提出的深度学习模型,它在当年的ImageNet竞赛中取得了显著的成绩。GoogLeNet的设计引入了多个创新点,包括Inception模块、辅助分类器、全局平均池化层等,这些设计使得网络在保持深度的同时减少了参数数量和计算复杂度。

Inception模块是GoogLeNet的核心,它通过并行的方式使用不同尺寸的卷积核(1x1、3x3、5x5)和最大池化层来提取特征,然后将这些特征在通道维度上进行拼接。这种设计允许网络在不同的尺度上捕捉信息,并且通过1x1卷积进行降维,有效控制了参数数量和计算量。

GoogLeNet还引入了辅助分类器,这些分类器在训练过程中提供额外的梯度信号,有助于模型的收敛,并在一定程度上提高了最终的分类性能。

此外,GoogLeNet在最后一层使用了全局平均池化层代替传统的全连接层,这不仅进一步减少了参数数量,还提高了模型的泛化能力。在输出层之前,GoogLeNet还使用了Dropout技术来防止过拟合。

GoogLeNet的网络结构设计非常灵活,可以根据不同的需求调整Inception模块中的卷积层数量和通道数。这种设计使得GoogLeNet在图像分类任务中表现出色,同时也为后续的深度学习模型设计提供了重要的参考。

在实际应用中,GoogLeNet的变种如Inception v2、Inception v3等在原有的基础上进行了进一步的优化和改进,例如引入了批量归一化(Batch Normalization)和残差连接(Residual Connections),以提高训练效率和模型性能。

本文只介绍初代版本,后续会介绍改进版本。

Inception块

如图所示,一个块由四条并行路径组成,每条路径选用大小不同的卷积层,以实现从不同空间大小中提取信息。
中间两路在输入上使用1×1卷积核,减少像素级上的通道维数。这些通路使用合适的填充,使得输出尺寸一致,最后能够在输出通道维度上合并。
Inception块中,通常调整的超参数是每层输出通道数。

在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l


class Inception(nn.Module):
    # c1--c4是每条路径的输出通道数
    def __init__(self, in_channels, c1, c2, c3, c4, **kwargs):
        super(Inception, self).__init__(**kwargs)
        # 线路1,单1x1卷积层
        self.p1_1 = nn.Conv2d(in_channels, c1, kernel_size=1)
        # 线路2,1x1卷积层后接3x3卷积层
        self.p2_1 = nn.Conv2d(in_channels, c2[0], kernel_size=1)
        self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)
        # 线路3,1x1卷积层后接5x5卷积层
        self.p3_1 = nn.Conv2d(in_channels, c3[0], kernel_size=1)
        self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)
        # 线路4,3x3最大汇聚层后接1x1卷积层
        self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.p4_2 = nn.Conv2d(in_channels, c4, kernel_size=1)

    def forward(self, x):
        p1 = F.relu(self.p1_1(x))
        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))
        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))
        p4 = F.relu(self.p4_2(self.p4_1(x)))
        # 在通道维度上连结输出
        return torch.cat((p1, p2, p3, p4), dim=1)

GoogLeNet模型

GoogleLeNet共使用9个Inception块和一个全局平均汇聚层的堆叠来生成估计值。
在这里插入图片描述

# 第一模块 64通道、7×7卷积层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# 第二模块 和Inception的第二条路径一样
b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),
                   nn.ReLU(),
                   nn.Conv2d(64, 192, kernel_size=3, padding=1),
                   nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第三个模块串联两个完整的lnception块。第一个Inception块的输出通道数为64+128+32+32=256,四个路径之间的输出通道数量比为64:128:32:32 =2:4:1:1。第二个和第三个路径首先将输入通道的数量分别减少到96/192= 1/2和16/192 = 1/12,然后连接第二个卷积层。第二个|nception块的输出通道数增加到128+192+96+64=480,四个路径之间的输出通道数量比为128:192:96:64=4:6:3:2。第二条和第三条路径首先将输入通道的数量分别减少到128/256=1/2和32/256=1/8。

# 第三模块
b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),
                   Inception(256, 128, (128, 192), (32, 96), 64),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第四模块更加复杂,它串联了5个Inception块,其输出通道数分别是
192+208+48+64=512、
160+224+64+64=512、
128+256+64+64=512、
112+288+64+64=528和
256+320+128+128 =832。
这些路径的通道数分配和第三模块中的类似,首先是含3x3卷积层的第二条路径输出最多通道,其次是仅含1x1卷积层的第一条路径,之后是含5x5卷积层的第三条路径和含3x3最大汇聚层的第四条路径。 其中第二、第三条路径都会先按比例减小通道数。这些比例在各个Inception块中都略有不同。

# 第四模块
b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),
                   Inception(512, 160, (112, 224), (24, 64), 64),
                   Inception(512, 128, (128, 256), (24, 64), 64),
                   Inception(512, 112, (144, 288), (32, 64), 64),
                   Inception(528, 256, (160, 320), (32, 128), 128),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

第五模块包含输出通道数为256+320+128+128=832和384+384+128+128=1024的两个lnception块。其中每条路径通道数的分配思路和第三、第四模块中的一致,只是在具体数值上有所不同。 需要注意的是,第五模块的后面紧跟输出层,该模块同NiN一样使用全局平均汇聚层,将每个通道的高和宽变成1。 最后,将输出变成二维数组,再接上一个输出个数为标签类别数的全连接层。

# 第五模块
b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),
                   Inception(832, 384, (192, 384), (48, 128), 128),
                   nn.AdaptiveAvgPool2d((1,1)),
                   nn.Flatten())

net = nn.Sequential(b1, b2, b3, b4, b5, nn.Linear(1024, 10))

GoogLeNet模型的计算复杂,而且不如VGG那样便于修改通道数。

X = torch.rand(size=(1, 1, 96, 96))
for layer in net:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 64, 24, 24])
Sequential output shape:	 torch.Size([1, 192, 12, 12])
Sequential output shape:	 torch.Size([1, 480, 6, 6])
Sequential output shape:	 torch.Size([1, 832, 3, 3])
Sequential output shape:	 torch.Size([1, 1024])
Linear output shape:	 torch.Size([1, 10])

训练模型

使用Fashion-MNIST数据集来训练我们的模型。

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

在这里插入图片描述

· 本文使用了d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码
封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/875180.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】string类的基本使用

一、string类的由来 在C语言中,字符串是以\0结尾的一些字符的集合,为了操作方便,C标准库中提供了一些str系列 的库函数,但是这些库函数与字符串是分离开的,不太符合OOP的思想,而且底层空间需要用户 自己管…

大数据之Flink(三)

9.3、转换算子 9.3.1、基本转换算子 9.3.1.1、映射map 一一映射 package transform;import bean.WaterSensor; import org.apache.flink.streaming.api.datastream.DataStreamSource; import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; impor…

分类预测|基于改进的灰狼IGWO优化支持向量机SVM的数据分类预测matlab程序 改进策略:Cat混沌与高斯变异

分类预测|基于改进的灰狼IGWO优化支持向量机SVM的数据分类预测matlab程序 改进策略:Cat混沌与高斯变异 文章目录 一、基本原理原理流程1. **定义目标函数**2. **初始化GWO**3. **评估适应度**4. **更新狼的位置**5. **更新狼的等级**6. **重复迭代**7. **选择最佳解…

春招审核新策略:Spring Boot系统实现

3系统分析 3.1可行性分析 通过对本大学生入学审核系统实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本大学生入学审核系统采用Spring Boot框架,JA…

综合案例-数据可视化-柱状图

一、基础柱状图 我们绘制一个关于三种水果销售额的柱状图,X轴数据为三种水果的名称,用列表[苹果,香蕉,橘子]添加进去,Y轴数据为三种水果的销售额,用列表[50,70,60]添加进去。 步骤: 导包构建柱状图对象添加X轴数据生…

Android 12系统源码_窗口管理(八)WindowConfiguration的作用

前言 在Android系统中WindowConfiguration这个类用于管理与窗口相关的设置,该类存储了当前窗口的显示区域、屏幕的旋转方向、窗口模式等参数,应用程序通过该类提供的信息可以更好的适配不同的屏幕布局和窗口环境,以提高用户体验。 一、类定…

喜报 | 知从科技荣获 “AutoSec 安全之星 - 优秀汽车软件供应链安全方案奖”

近日,「AutoSec 2024第八届中国汽车网络安全周暨第五届智能汽车数据安全展」在上海盛大举行。本届大会由谈思实验室和谈思汽车主办、上海市车联网协会联合主办,以汽车“网络数据安全、软件安全、功能安全”为主题,设置了“31X”模式&#xff…

关于 PC打开“我的电脑”后有一些快捷如腾讯视频、百度网盘、夸克网盘、迅雷等各种捷方式在磁盘驱动器上面统一删除 的解决方法

若该文为原创文章,转载请注明原文出处 本文章博客地址:https://hpzwl.blog.csdn.net/article/details/142029325 长沙红胖子Qt(长沙创微智科)博文大全:开发技术集合(包含Qt实用技术、树莓派、三维、OpenCV…

numpy(基于Numpy外文文档的学习)

学习目标: Understand the difference between one-, two- and n-dimensional arrays in NumPy; Understand how to apply some linear algebra operations to n-dimensional arrays without using for-loops;(调用一些简单的方法) Underst…

外包干了三年,快要废了。。。

先简单说一下自己的情况,普通本科,在外包干了3年多的功能测试,这几年因为大环境不好,我整个人心惊胆战的,怕自己卷铺盖走人了,我感觉自己不能够在这样蹉跎下去了,长时间呆在一个舒适的环境真的会…

Docker 部署 Redis (图文并茂超详细)

部署 Redis ( Docker ) [Step 1] : 拉取 Redis 镜像, 推荐使用 7 的 Redis 版本 docker pull redis:7.0.12[Step 2] : 创建 Redis 相关目录 ➡️ 启动 Redis 容器 ➡️ 拷贝文件 ➡️ 授权文件夹 ➡️ 删除容器 # 创建 Redis 相关目录 mkdir -p /data/redis/{conf,data,log…

写的一致性问题之失效模式

文章目录 1、先删除redis缓存,再写入mysql:1.1、高并发情况下分析出现的问题 2、先写入mysql,再删除redis缓存 失效模式存在的问题:在事务提交之前可能会有其他读操作重新把旧数据放入redis缓存中 1、先删除redis缓存,…

深入解析全连接层:PyTorch 中的 nn.Linear、nn.Parameter 及矩阵运算

文章目录 数学概念(全连接层,线性层)nn.Linear()nn.Parameter()Q1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_featur…

Bev pool 加速(2):自定义c++扩展

文章目录 1. c++扩展2. 案例2.1 案例12. 1.1 代码实现(1) c++ 文件(2) setup.py编写(3) python 代码编写2.2 案例22.2.1 模型搭建2.2.2 c++ 扩展实现(1)c++ 扩展代码(2)setup.py编写(3)python 调用c++扩展在bevfusion论文中,将bev_pooling定义为view transform中的效率瓶…

PROTOTYPICAL II - The Practice of FPGA Prototyping for SoC Design

The Art of the “Start” The semiconductor industry revolves around the “start.” Chip design starts lead to more EDA tool purchases, more wafer starts, and eventually to more product shipments. Product roadmaps develop to extend shipments by integrating…

FloodFill算法

文章目录 1. 图像渲染(733)2. 岛屿数量(200)3. 岛屿的最大面积(695)4. 被围绕的区域(130) 1. 图像渲染(733) 题目描述: 算法原理: …

DAY13信息打点-Web 应用源码泄漏开源闭源指纹识别GITSVNDS备份

#知识点 0、Web架构资产-平台指纹识别 1、开源-CMS指纹识别源码获取方式 2、闭源-习惯&配置&特性等获取方式 3、闭源-托管资产平台资源搜索监控 演示案例: ➢后端-开源-指纹识别-源码下载 ➢后端-闭源-配置不当-源码泄漏 ➢后端-方向-资源码云-源码泄漏 …

1、https的全过程

目录 一、概述二、SSL过程如何获取会话秘钥1、首先认识几个概念:2、没有CA机构的SSL过程:3、没有CA机构下的安全问题4、有CA机构下的SSL过程 一、概述 https是非对称加密和对称加密的过程,首先建立https链接需要经过两轮握手: T…

算法提高模板强连通分量tarjan算法

AC代码&#xff1a; #include<bits/stdc.h>using namespace std;typedef long long ll; const int MOD 998244353; const int N 2e5 10;//强联通分量模板 //tarjan算法 vector<int>e[N]; int n, m, cnt; int dfn[N], low[N], ins[N], idx; int bel[N];//记录每…

Redis高可用,Redis性能管理

文章目录 一&#xff0c;Redis高可用&#xff0c;Redis性能管理二&#xff0c;Redis持久化1.RDB持久化1.1触发条件&#xff08;1&#xff09;手动触发&#xff08;2&#xff09;自动触发 1.2 Redis 的 RDB 持久化配置1.3 RDB执行流程(1) 判断是否有其他持久化操作在执行(2) 父进…