SqueezeNet模型详解

简介

SqueezeNet是一种轻量级卷积神经网络架构,旨在保持较高性能的同时减少模型的参数数量和计算复杂度。由于其小尺寸和高效性能,SqueezeNet适用于在资源受限的环境中部署,如移动设备和嵌入式系统。

SqueezeNet是通过使用一种"Fire Module"的结构来减小网络的深度和参数量。Fire Module包含一个称为"Squeeze"层和一个称为"Expand"层的组合。"Squeeze"层主要执行特征压缩,通过使用1x1卷积核来减小通道数量,从而降低模型的参数数量。"Expand"层使用1x1和3x3卷积核来增加通道数量,以产生更丰富的特征表示。

整个SqueezeNet网络由多个这样的Fire Modules组成,其中间使用了池化层来减小空间尺寸。

论文地址:https://arxiv.org/abs/1602.07360

Fire Module

它的结构图如下所示:

Fire 模块它是通过使用 1x1 的卷积核来进行压缩。1x1 卷积可以降低输入特征图的通道数,从而减少每个位置上的参数数量。如果一个模型中有很多层使用了 1x1 卷积,整个模型的参数数量会相对较小。同时,1x1 卷积通常比较轻量级,计算量相对较小。这也就意味着计算复杂度的减小,因为计算复杂度与参数数量相关。最后将经过不同卷积操作得到的特征在通道维度上拼接在一起。

class Fire(nn.Module):
    def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.activation(self.squeeze(x))
        return torch.cat([
            self.activation(self.expand1x1(x)),
            self.activation(self.expand3x3(x))
        ], dim=1)

我想大家使用Conv+BN+ReLU使用习惯了,这里面没有使用BN是因为会引入额外的参数和计算,像这种轻量级的网络结构中,可能不太需要 BN 层,因为模型本身较浅,梯度的变化也不会过于剧烈。

SqueezeNet的两个版本的特点

以下是一些个人的见解:

卷积核的大小和步幅,在1_0版本当中第一个卷积层使用的是核大小为7,而1_1版本当中的卷积核为3,这可能导致更大的感受野。

输出通道数,在1_0版本中,第一个卷积层输出通道为 96,而1_1版本中为 64。这也意味着1_1版本的模型整体更加轻量。1_0 Total params: 737,476,1_1 Total params: 724,548

采用了 Dropout 在训练期间有助于防止过拟合。

class SqueezeNet(nn.Module):
    def __init__(self, version='1_0', num_classes=1000):
        super(SqueezeNet, self).__init__()
        self.num_classes = num_classes
        if version == '1_0':
            self.features = nn.Sequential(
                nn.Conv2d(3, 96, kernel_size=7, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(96, 16, 64, 64),
                Fire(128, 16, 64, 64),
                Fire(128, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 32, 128, 128),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(512, 64, 256, 256),
            )
        elif version == '1_1':
            self.features = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3, stride=2),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(64, 16, 64, 64),
                Fire(128, 16, 64, 64),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(128, 32, 128, 128),
                Fire(256, 32, 128, 128),
                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
                Fire(256, 48, 192, 192),
                Fire(384, 48, 192, 192),
                Fire(384, 64, 256, 256),
                Fire(512, 64, 256, 256),
            )
        else:
            raise ValueError("Unsupported SqueezeNet version {version}:"
                             "1_0 or 1_1 expected".format(version=version))

        # Final convolution is initialized differently from the rest
        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            final_conv,
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if m is final_conv:
                    init.normal_(m.weight, mean=0.0, std=0.01)
                else:
                    init.kaiming_uniform_(m.weight)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return torch.flatten(x, 1)

 结构图如下所示:

下面是论文作者提供的一些细节和设计的选择

  • 1x1和3x3滤波器的输出尺寸一致性:为了确保1x1和3x3滤波器的输出在高度和宽度上一致,作者在expand模块的3x3滤波器的输入数据周围添加了1像素的零填充。
  • 激活函数的选择:使用了ReLU作为squeeze和expand层的激活函数。
  • Dropout的应用:在fire9模块之后应用了Dropout,丢弃率为50%。
  • 无全连接层:SqueezeNet没有使用全连接层,这个设计灵感来源于NiN(Lin et al., 2013)架构。
  • 学习率和训练策略:在训练SqueezeNet时,初始学习率为0.04,通过线性递减学习率的方式进行训练。

我感觉它给出的这些放在现在看来都比较的常规了(当时可能比较新颖),现在也有很多的其他方式去和技术去提升。

SqueezeNet分类实验测试

100轮训练和验证损失记录 

100轮错误率记录  

验证集记录最佳的指标 

总结

SqueezeNet相对于传统的深层CNN模型,如VGG或ResNet,具有更小的模型大小和更少的参数,但在一些任务上仍能取得不错的性能。这使得SqueezeNet成为在资源受限环境中进行实时图像分类等应用的有力选择。

目前我也是在学习的阶段,学习这一部分也是为了积累轻量化模型的方法,因为因为轻量化模型在移动设备、嵌入式系统以及一些资源受限的环境中都具有重要的应用价值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/361883.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI的安全应答之道

作者:统信UOS技术团队 2023,随着各种大语言模型的爆发,整个AI生态正处于从决策式AI进化到生成式AI的进程中。各类AI模型和AI应用层出不穷,也随之带来了与AI相关的各类潜在风险。AI开发和使用过程中的风险防范和治理,成为了不可忽…

神经网络的一些常规概念

epoch:是指所有样本数据在神经网络训练一次(单次epoch(全部训练样本/batchsize)/iteration1)或者(1个epochiteration数 batchsize数) batch-size:顾名思义就是批次大小,也就是一次训练选取的样…

力扣hot100 前 K 个高频元素 小根堆 流 IntStream

Problem: 347. 前 K 个高频元素 文章目录 思路复杂度Code 思路 &#x1f468;‍&#x1f3eb; 参考 小根堆&#xff08;维护k个高频元素&#xff09;遍历所有元素&#xff0c;当前堆大小 < k 或者 当前元素出现次数大于堆顶元素出现次数&#xff1a;替换掉堆顶元素 复杂…

2024Node.js零基础教程(小白友好型),nodejs新手到高手,(三)NodeJS入门——http协议

033_HTTP协议_初识HTTP协议 hello&#xff0c;大家好&#xff0c;这个小节我们来认识一下 http协议。 http是几个单词的首字母拼写&#xff0c;全称为Hypertext Transfer Protocol 译为超文本传输协议&#xff0c;那么这个http协议是互联网上应用最广泛的协议之一。顺便说一下…

使用 axios 请求库,设置请求拦截

什么是 axios&#xff1f; 基于promise网络请求库&#xff0c;可以同构&#xff08;同一套代码可以运行在浏览器&#xff09;&#xff0c;在服务端&#xff0c;使用原生node.js的http模块&#xff0c;在客户端&#xff08;浏览器&#xff09;中&#xff0c;使用XMLHttpRequests…

【Godot4自学手册】第十节将场景添加到TileSet绘制背景,主人公走到房子后面房子变得半透明

这节主要学习将场景添加到TileSet作为TileMap来搭建背景。同时&#xff0c;主人公进入房子后面&#xff0c;房子变得半透明&#xff0c;离开房子后房子变的不透明。 一、创建新场景 首先导入房子素材&#xff0c;最终文件系统内容如下&#xff1a; 点击新建场景按钮&#x…

【Qt学习笔记】(一)初识Qt

Qt学习笔记 1 使用Qt Creator 新建项目2 项目代码解释3 创建第一个 Hello World 程序4 关于内存泄漏问题5 Qt 中的对象树6 关于 qDebug&#xff08;&#xff09;的使用7 使用其他方式创建一个 Hello World 程序&#xff08;编辑框和按钮方式&#xff09;8 关于 Qt 中的命名规范…

阿里云智能集团副总裁安筱鹏:企业数字化的终局是什么?

以下文章来源于数字化企业 &#xff0c;作者安筱鹏博士 回答数字化终局追问的起点是&#xff0c;企业需要重新定义我是谁。成为有竞争力的行业领导厂商&#xff0c;你应当成为一个客户运营商&#xff0c;即能够实时洞察、实时满足客户需求&#xff0c;追求极致的客户体验。而要…

使用 Docker 部署扫雷小游戏

1&#xff09;源码 介绍&#xff1a;扫雷游戏是一款经典的单人益智游戏&#xff0c;旨在通过揭示方块和避开地雷来展示玩家的逻辑思维和推理能力。 源码&#xff1a;saolei.zip 个人文件站&#xff1a;https://share.wuhanjiayou.cn/ 2&#xff09;部署 2.1&#xff09;安装…

SpringBoot中处理校验逻辑的两种方式:Hibernate Validator+全局异常处理

最近正在开发一个知识库学习网站编程喵&#x1f431;&#xff0c;需要对请求参数进行校验&#xff0c;比如说非空啊、长度限制啊等等&#xff0c;可选的解决方案有两种&#xff1a; 一种是用 Hibernate Validator 来处理一种是用全局异常来处理 两种方式&#xff0c;我们一一…

基于EdgeWorkers的边缘应用如何进行单元测试?

随着各行各业数字化转型的持续深入&#xff0c;越来越多企业开始选择将一些应用程序放在距离最终用户更近的边缘位置来运行&#xff0c;借此降低延迟&#xff0c;提高应用程序响应速度&#xff0c;打造更出色的用户体验。 相比传统集中部署和运行的方式&#xff0c;这种边缘应…

websocket编写聊天室

【黑马程序员】WebSocket打造在线聊天室【配套资料源码】 总时长 02:45:00 共6P 此文章包含第1p-第p6的内容 简介 温馨提示&#xff1a;现在都是第三方支持聊天&#xff0c;如极光&#xff0c;学这个用于自己项目完全没问题&#xff0c;大项目不建议使用 需求分析 代码

Vue学习总结

声明&#xff1a;本文来源于黑马程序员PDF讲义 双向绑定&#xff1a; 修改表单项标签&#xff0c;发现vue对象data中的数据也发生了变化 双向绑定的作用&#xff1a;可以获取表单的数据的值&#xff0c;然后提交给服务器 事件绑定 v-on: 用来给html标签绑定事件的。需要注意…

了解 Redis Channel:消息传递机制、发布与订阅,以及打造简易聊天室的实战应用。

文章目录 1. Redis Channel 是什么2. Redis-Cli 中演示使用3. 利用 Channel 打造一个简易的聊天室参考文献 1. Redis Channel 是什么 Redis Channel 是一种消息传递机制&#xff0c;允许发布者向特定频道发布消息&#xff0c;而订阅者则通过订阅频道实时接收消息。 Redis Cha…

LRU 缓存置换策略:提升系统效率的秘密武器(下)

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

小程序定制开发前,应该考虑些什么?

引言 在移动互联网时代&#xff0c;小程序已经成为许多企业和个人推广业务、提供服务的理想平台。然而&#xff0c;在进行小程序定制开发之前&#xff0c;开发者和业务方需要细致入微地考虑一系列关键因素&#xff0c;以确保最终的小程序既能满足用户需求&#xff0c;又能够顺…

Linux第40步_移植ST公司uboot的第1步_创建配置文件_设备树_修改电源管理和sdmmc节点

ST公司uboot移植分两步走&#xff1a; 第1步&#xff1a;完成“创建配置文件&#xff0c;设备树&#xff0c;修改电源管理和sdmmc节点&#xff0c;以及shell脚本和编译”。 第2步“完成”修改网络驱动、USB OTG设备树和LCD驱动&#xff0c;以及编译和烧写测试“。 移植太复杂…

牛客——中位数图(连续子数组和二维前缀和)

链接&#xff1a;登录—专业IT笔试面试备考平台_牛客网 来源&#xff1a;牛客网 题目描述 给出1~n的一个排列&#xff0c;统计该排列有多少个长度为奇数的连续子序列的中位数是b。中位数是指把所有元素从小到大排列后&#xff0c;位于中间的数。 输入描述: 第一行为两个正…

Mysql基础篇笔记

数据表 链接&#xff1a;https://pan.baidu.com/s/1dPitBSxLznogqsbfwmih2Q 提取码&#xff1a;b0rp --来自百度网盘超级会员V5的分享 sql的执行顺序 根据顺序 也就是说 select后面的字段别名 只能在order by中使用 mysql不支持sql92的外连接 mysql不支持满外连接 可以…

springBoot+Vue汽车销售源码

源码描述: 汽车销售管理系统源码基于spring boot以及Vue开发。 针对汽车销售提供客户信息、车辆信息、订单信息、销售人员管理、 财务报表等功能&#xff0c;提供经理和销售两种角色进行管理。 技术架构&#xff1a; idea(推荐)、jdk1.8、mysql5.X(不能为8驱动不匹配)、ma…