昇思25天学习打卡营第10天|ShuffleNet图像分类

ShuffleNet网络结构

ShuffleNet是一种专为移动设备设计的、计算效率极高的卷积神经网络(CNN)架构。其网络结构的设计主要围绕减少计算复杂度和提高模型效率展开,通过引入逐点分组卷积(Pointwise Group Convolution)和通道洗牌(Channel Shuffle)两种新技术,实现了在保持精度的同时大幅降低计算成本。

逐点分组卷积(Pointwise Group Convolution):

逐点分组卷积是ShuffleNet中用于减少1x1卷积计算复杂度的方法。它将输入特征图的通道分成多个组,每个组内的通道独立进行1x1卷积,从而显著降低了计算量。
在这里插入图片描述

然而,这种方法可能导致通道间的信息无法充分交流,影响模型的表达能力。可能会降低网络的特征提取能力

通道洗牌(Channel Shuffle):

为了解决逐点分组卷积带来的通道间信息交流不足的问题,ShuffleNet引入了通道洗牌操作。通过均匀地打乱不同分组中的通道,使得每个分组都能获得来自其他分组的信息,从而增强模型的特征提取能力。

在这里插入图片描述
在这里插入图片描述

ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:

  1. 将开始和最后的 1×1卷积模块(降维、升维)改成Point Wise Group Convolution

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle

  3. 降采样模块中, 3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为23×3平均池化,并把相加改成拼接。
    在这里插入图片描述
    ShuffleV1Block

class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

ShuffleNet的基本单元是在残差单元(residual block)的基础上改进而成的,具体结构如下:

1x1分组卷积:首先,输入特征图通过一个1x1的分组卷积进行降维,减少通道数。
通道洗牌:紧接着,对分组卷积的输出进行通道洗牌操作,以实现不同分组之间的信息交流。
3x3深度可分离卷积:然后,使用3x3的深度可分离卷积(depthwise separable convolution)进行特征提取。这里的3x3卷积是瓶颈层(bottleneck),用于降低计算量。
1x1分组卷积(可选):最后,根据需要,可以通过另一个1x1的分组卷积将通道数恢复到与输入相同或更大的数量。
短路连接:在基本单元中,还包含短路连接(shortcut),用于将输入特征图直接加到输出特征图上,以保留原始信息并帮助梯度回传。
在这里插入图片描述

ShuffleNet网络结构如上图所示,以输入图像 224×224 ,组数3(g = 3)为例,首先通过数量24,卷积核大小为 3×3stride2的卷积层,输出特征图大小为 112×112 ,channel为24;然后通过stride为2的最大池化层,输出特征图大小为 56×56channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为 1×1×960 ,再经过全连接层softmax,得到分类概率

ShuffleNetV1

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

设置model_size="2.0x",定义模型的复杂度。

 net = ShuffleNetV1(model_size="2.0x", n_class=10)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/786325.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

neo4j 图数据库:Cypher 查询语言、医学知识图谱

neo4j 图数据库:Cypher 查询语言、医学知识图谱 Cypher 查询语言创建数据查询数据查询并返回所有节点查询并返回所有带有特定标签的节点查询特定属性的节点及其所有关系和关系的另一端节点查询从名为“小明”的节点到名为“小红”的节点的路径 更新数据更新一个节点…

汇川Easy系列PLC使用本地脉冲5轴设置

根据官网手册可以看到,Easy302往上的系列都是支持本地5轴脉冲控制的 常规汇川PLC本地脉冲轴配置时,脉冲和方向的输出点都是成对出现的,但是easy如果要使用5轴的话,就需要自己定义方向 可以看到,Y0,Y1这两个点是单独…

SQLite 命令行客户端 + HTA 实现简易UI

SQLite 命令行客户端 HTA 实现简易UI SQLite 客户端.hta目录结构参考资料 仅用于探索可行性&#xff0c;就只实现了 SELECT。 SQLite 客户端.hta <!DOCTYPE html> <html> <head><meta http-equiv"Content-Type" content"text/html; cha…

27 岁的程序员 Gap 一年感受

最大的感受&#xff1a;变成 28 岁了 好吧&#xff0c;开个玩笑&#xff0c;下面是正文。 0.背景以及 Gap 原因 我硕士毕业时是 26 岁&#xff0c;然后校招进入一家航天国企&#xff0c;负责 Web 后端开发&#xff0c;工作了一年之后发现个人成长和挑战的空间极其有限&#…

SAP 新增移动类型简介

在SAP系统中新增移动类型的过程涉及多个步骤,‌包括复制现有的移动类型、‌调整科目设置以及进行必要的测试。‌以下是新增移动类型的一般步骤和关键点:‌ 复制现有的移动类型:‌ 使用事务代码OMJJ进入移动类型维护界面。‌ 勾选移动类型 这里不填写移动类型,然后直接下…

告别堆积,迎接清新:回收小程序,打造无废生活新选择

在快节奏的现代生活中&#xff0c;物质的丰富与便利似乎成为了我们日常的一部分&#xff0c;但随之而来的&#xff0c;是日益增长的废弃物堆积问题。街道边、社区里&#xff0c;甚至是我们的家中&#xff0c;废弃物品仿佛无孔不入&#xff0c;逐渐侵蚀着我们的生活空间与环境质…

Pyspider WebUI 未授权访问致远程代码执行漏洞复现

0x01 产品简介 Pyspider是由国人binux编写的强大的网络爬虫系统,它带有强大的WebUI(Web用户界面),为用户提供了可视化的编写、调试和管理爬虫的能力。这一特点使得Pyspider在爬虫框架中脱颖而出,尤其适合那些希望快速上手并高效开发爬虫的用户。允许用户直接在网页上编写…

zabbix“专家坐诊”第245期问答

问题一 Q&#xff1a;vfs.dev.discovery拿的是哪里的文件&#xff0c;我看源码里面获取的是/proc/parttions里面的信息&#xff0c;但是我没有这个device&#xff0c;是怎么获取出来的&#xff1f; 在这里插入图片描述 A&#xff1a;检查下系统内核版本或者agent程序版本&…

15 CIG重量级监控

目录 1. docker stats原生命令 2. CIG CAdvisor InfluxDB Granfana 3. 安装部署 4. Grafana配置 4.1. 添加数据源 4.2. 添加工作台 grafana官网文档参考&#xff1a;Grafana documentation | Grafana documentation influxdb官网文档参考&#xff1a;https://docs.in…

拨开迷雾,寻找大模型应用落地的支点

自主可控大模型底座个性化刚需场景&#xff0c;这家大模型公司率先趟出一条个性化发展路径。 作者 | 辰纹 来源 | 洞见新研社 上海的温度很高&#xff0c;接近40度&#xff0c;比上海温度更高的是AI的热度。 7月4日&#xff0c;2024世界人工智能大会暨人工智能全球治理高…

tapd项目管理由完全免费的工具向付费工具转变

TAPD从2022年左右开始面由一个完全免费的工具向付费工具转变。从最新政策看&#xff0c;TAPD 针对不同规模和需求的团队&#xff0c;TAPD提供了多种版本&#xff0c;其中包括“卓越版”和“企业版”。免费版本人数规模由原来的100人不断缩小&#xff0c;2024年仅支持30人以内免…

Java-Redis-Clickhouse-Jenkins-MybatisPlus-Zookeeper-vscode-Docker-jdbc

文章目录 Clickhouse基础实操windows docker desktop 下载clickhousespringboot项目配置clickhouse Redis谈下你对Redis的了解&#xff1f;Redis一般都有哪些使用的场景&#xff1f;Redis有哪些常见的功能&#xff1f;Redis支持的数据类型有哪些&#xff1f;Redis为什么这么快…

科普文:深入理解负载均衡(四层负载均衡、七层负载均衡)

概叙 网络模型&#xff1a;OSI七层模型、TCP/IP四层模型、现实的五层模型 应用层&#xff1a;对软件提供接口以使程序能使用网络服务&#xff0c;如事务处理程序、文件传送协议和网络管理等。&#xff08;HTTP、Telnet、FTP、SMTP&#xff09; 表示层&#xff1a;程序和网络之…

循环练习题

代码&#xff1a; public static void main(String[] args) { for (char c1a;c1<z;c1){System.out.print(" "c1); }System.out.println();for (char c2Z;c2>A;c2--){System.out.print(" "c2);}} 结果为&#xff1a;

二. Linux内核

一. Linux内核源码目录分析 arch 包含与体系结构相关的代码&#xff0c;用于支持不同硬件体系结构的实现。这个目录下会根据不同的架构&#xff08;如x86、arm、mips等&#xff09;进一步细分。 block 用于处理块设备的子系统&#xff0c;包含与块设备驱动和I/O调度相关的代码。…

HTML(29)——立体呈现

作用&#xff1a;设置元素的子元素是位于3D空间中还是平面中 属性名&#xff1a;transform-style 属性值&#xff1a; flat&#xff1a;子级处于平面中preserve-3d:子级处于3D空间 步骤&#xff1a; 父级元素添加 transform-style:preserve-3d 子级定位调整子盒子的位置&a…

高智能土壤养分检测仪:农业生产的科技新助力

在科技日新月异的今天&#xff0c;农业领域也迎来了革命性的变革。其中&#xff0c;高智能土壤养分检测仪作为现代农业的科技新助力&#xff0c;正逐渐改变着传统的农业生产方式&#xff0c;为农民带来了前所未有的便利与效益。 高智能土壤养分检测仪&#xff0c;是一款集高科技…

PMON的解读和开发

提示&#xff1a;龙芯2K1000PMON相关记录 文章目录 1 PMON的发展和编译环境PMONPMON2000 2 PMON2000的目录结构3 Targets目录的组成4 PMON编译环境的建立5 PMON2000的框架6 异常向量表7 Pmon的空间分配8 PMON的汇编部分(starto.S或sbdreset.S)的解读Start.SC代码部分dbginit 9 …

为什么要参加学术会议?

为什么要参加学术会议&#xff1f; 学术会议是一种以促进科学发展、学术交流、课题研究等学术性话题为主题的会议。学术会议一般都具有国际性、权威性、高知识性、高互动性等特点&#xff0c;其参会者一般为科学家、学者、教师等具有高学历的研究人员。下面苏老师就跟大家详细…

92. 反转链表 II (Swift 版本)

题目描述 给你单链表的头指针 head 和两个整数 left 和 right &#xff0c;其中 left < right 。请你反转从位置 left 到位置 right 的链表节点&#xff0c;返回 反转后的链表 。 分析 这是一个经典的链表问题&#xff0c;要求反转链表的部分节点。我们可以通过以下步骤实…