昇思25天学习打卡营第13天 | ShuffleNet图像分类

在这里插入图片描述

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

如下图所示,ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

图片来源:Bianco S, Cadene R, Celona L, et al. Benchmark analysis of representative deep neural network architectures[J]. IEEE access, 2018, 6: 64270-64277.

模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过对ResNet的Bottleneck单元进行改进,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/gkk,一共有g组,所有组共有(in_channels/gkk)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量。

在这里插入图片描述
Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1kk,则卷积核参数量为:in_channelskk,得到的feature maps通道数与输入通道数相等;

Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×1
,卷积核参数量为(in_channels/g11)*out_channels。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
from mindspore import nn
import mindspore.ops as ops
from mindspore import Tensor

class GroupConv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
        super(GroupConv, self).__init__()
        self.groups = groups
        self.convs = nn.CellList()
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))

    def construct(self, x):
        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
        outputs = ()
        for i in range(self.groups):
            outputs = outputs + (self.convs[i](features[i].astype("float32")),)
        out = ops.cat(outputs, axis=1)
        return out

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。

为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

在这里插入图片描述
如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

在这里插入图片描述

ShuffleNet模块

如下图所示,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), ©的更改:

  1. 将开始和最后的 1×1
    卷积模块(降维、升维)改成Point Wise Group Convolution;

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;

  3. 降采样模块中, 3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的 3×3
    平均池化,并把相加改成拼接。

在这里插入图片描述

class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

构建ShuffleNet网络

ShuffleNet网络结构如下图所示,以输入图像 224×224
,组数3(g = 3)为例,首先通过数量24,卷积核大小为 3×3
,stride为2的卷积层,输出特征图大小为 112×112
,channel为24;然后通过stride为2的最大池化层,输出特征图大小为 56×56
,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图©),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为 1×1×960
,再经过全连接层和softmax,得到分类概率。

在这里插入图片描述

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

训练集准备与加载
采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。如下示例使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()

在这里插入图片描述

学习心得

通过本次学习,我不仅掌握了ShuffleNetV1的网络结构和实现方法,还深入理解了分组卷积和通道重排在提高模型效率中的作用。未来,我希望能够进一步探索ShuffleNetV2以及其他高效模型的设计与应用,并尝试将其应用于更多复杂的数据集和任务中。同时,我还计划研究模型压缩和加速的其他技术,如模型剪枝和量化,以进一步提升模型的应用性能。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/782403.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring AOP实现操作日志记录示例

1. 准备工作 项目环境:jdk8springboot2.6.13mysql8 1.1 MySQL表 /*Navicat Premium Data TransferSource Server : localhostSource Server Type : MySQLSource Server Version : 50730Source Host : 127.0.0.1:3306Source Schema …

通过IDEA生成webapp及web.xml配置文件

1、选择File->Project Structure 2、选择Modules-> + -> Web 有的springboot工程选择是war工程,这个web可能已经存在了。 如果不存在,就手动创建,创建后,需要修改pom.xml中的配置 <packaging>war</packaging> 3、创建webapp根目录 这步重点就是创建…

ARM架构以及程序运行解析

文章目录 1. ARM架构 2. ARM处理器程序运行的过程 3. 示例 3. 基于ARM架构的STM32单片机 1. 运行模式 2. 寄存器组 3. STM32的基本结构 4. STM32的运行模式 4. 寄存器组详解 1. 未备份寄存器 2. 备份寄存器 3. 程序计数器 4. 程序状态寄存器 5. CPSR和SPSR寄存器…

运维Tips | Ubuntu 24.04 安装配置 xrdp 远程桌面服务

[ 知识是人生的灯塔&#xff0c;只有不断学习&#xff0c;才能照亮前行的道路 ] Ubuntu 24.04 Desktop 安装配置 xrdp 远程桌面服务 描述&#xff1a;Xrdp是一个微软远程桌面协议&#xff08;RDP&#xff09;的开源实现&#xff0c;它允许我们通过图形界面控制远程系统。这里使…

Banana Pi BPI-M5 Pro 低调 SBC 采用 Rockchip RK3576 八核 Cortex-A72/A53 AIoT SoC

Banana Pi BPI-M5 Pro&#xff0c;也称为 Armsom Sige5&#xff0c;是一款面向 AIoT 市场的低调单板计算机 (SBC)&#xff0c;由 Rockchip RK3576 八核 Cortex-A72/A53 SoC 驱动&#xff0c;提供Rockchip RK3588和RK3399 SoC 之间的中档产品。 该主板默认配备 16GB LPDDR4X 和…

学习笔记——动态路由——OSPF(特殊区域)

十、OSPF特殊区域 1、技术背景 早期路由器靠CPU计算转发&#xff0c;由于硬件技术限制问题&#xff0c;因此资源不是特别充足&#xff0c;因此是要节省资源使用&#xff0c;规划是非常必要的。 OSPF路由器需要同时维护域内路由、域间路由、外部路由信息数据库。当网络规模不…

NAT:地址转换技术

为什么会引入NAT&#xff1f; NAT&#xff08;网络地址转换&#xff09;的引入主要是为了解决两个问题 IPv4地址短缺&#xff1a;互联网快速发展&#xff0c;可用的公网IP地址越来越少。网络安全&#xff1a;需要一种方法来保护内部网络不被直接暴露在互联网上。 IPv4 &…

人脸检测(Python)

目录 环境&#xff1a; 初始化摄像头&#xff1a; 初始化FaceDetector对象&#xff1a; 获取摄像头帧&#xff1a; 获取数据&#xff1a; 绘制数据&#xff1a; 显示图像&#xff1a; 完整代码&#xff1a; 环境&#xff1a; cvzone库&#xff1a;cvzone是一个基于…

RAG实践:ES混合搜索BM25+kNN(cosine)

1 缘起 最近在研究与应用混合搜索&#xff0c; 存储介质为ES&#xff0c;ES作为大佬牌数据库&#xff0c; 非常友好地支持关键词检索和向量检索&#xff0c; 当然&#xff0c;支持混合检索&#xff08;关键词检索向量检索&#xff09;&#xff0c; 是提升LLM响应质量RAG(Retri…

spring boot(学习笔记第十二课)

spring boot(学习笔记第十二课) Spring Security内存认证&#xff0c;自定义认证表单 学习内容&#xff1a; Spring Security内存认证自定义认证表单 1. Spring Security内存认证 首先开始最简单的模式&#xff0c;内存认证。 加入spring security的依赖。<dependency>…

【TB作品】51单片机 Proteus仿真 MAX7219点阵驱动数码管驱动

1、8乘8点阵模块&#xff08;爱心&#xff09; 数码管测试程序与仿真 实验报告: MAX7219 数码管驱动测试 一、实验目的 通过对 MAX7219 芯片的编程与控制&#xff0c;了解如何使用单片机驱动数码管显示数字&#xff0c;并掌握 SPI 通信协议的基本应用。 二、实验器材 51…

触发器编程-创建(CREATE TRIGGER)、删除(DROP TRIGGER)

一、定义 1、触发器&#xff08;Trigger&#xff09;是用户对某一表中的数据做插入、更新和删除操作时被处罚执行的一段程序&#xff0c;通常我们使用触发器来检查用户对表的操作是否合乎整个应用系统的需求&#xff0c;是否合乎商业规则以维持表内数据的完整性和正确性 2、一…

SPL-404:如何彻底改变Solana上的NFT与DeFi

在不断发展的数字资产领域中&#xff0c;非同质化Token&#xff08;NFT&#xff09;已成为一股革命性力量&#xff0c;彻底改变了我们对数字所有权的看法和互动方式。从艺术和收藏品到游戏和虚拟房地产&#xff0c;NFT吸引了创作者、投资者和爱好者的想象力。 本指南将带您进入…

力扣-双指针1

何为双指针 双指针指向同一数组&#xff0c;然后配合着进行搜索等活动。 滑动窗口的时候很好使用。 167.两数之和Ⅱ-输入有序数组 167. 两数之和 II - 输入有序数组 题目 给你一个下标从 1 开始的整数数组 numbers &#xff0c;该数组已按 非递减顺序排列 &#xff0c;请你从…

[Flink]三、Flink1.13

11. Table API 和 SQL 如图 11-1 所示&#xff0c;在 Flink 提供的多层级 API 中&#xff0c;核心是 DataStream API &#xff0c;这是我们开发流 处理应用的基本途径&#xff1b;底层则是所谓的处理函数&#xff08; process function &#xff09;&#xff0c;可以访…

Android 简单快速实现 下弧形刻度尺(滑动事件)

效果图&#xff1a; 直接上代码&#xff1a; package com.my.view;import android.content.Context; import android.graphics.Bitmap; import android.graphics.BitmapFactory; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Pai…

iptables实现端口转发ssh

iptables实现端口转发 实现使用防火墙9898端口访问内网front主机的22端口&#xff08;ssh连接&#xff09; 1. 防火墙配置(lb01) # 配置iptables # 这条命令的作用是将所有目的地为192.168.100.155且目标端口为19898的TCP数据包的目标IP地址改为10.0.0.148&#xff0c;并将目标…

基于Java+SpringMvc+Vue技术的在线学习交流平台的设计与实现---60页论文参考

博主介绍&#xff1a;硕士研究生&#xff0c;专注于Java技术领域开发与管理&#xff0c;以及毕业项目实战✌ 从事基于java BS架构、CS架构、c/c 编程工作近16年&#xff0c;拥有近12年的管理工作经验&#xff0c;拥有较丰富的技术架构思想、较扎实的技术功底和资深的项目管理经…

【PB案例学习笔记】-29制作一个调用帮助文档的小功能

写在前面 这是PB案例学习笔记系列文章的第29篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

【Spring Boot】关系映射开发(三):多对多映射

《JPA 从入门到精通》系列包含以下文章&#xff1a; Java 持久层 API&#xff1a;JPA认识 JPA 的接口JPA 的查询方式基于 JPA 开发的文章管理系统&#xff08;CRUD&#xff09;关系映射开发&#xff08;一&#xff09;&#xff1a;一对一映射关系映射开发&#xff08;二&#…