昇思25天学习打卡营第15天|应用实践之ShuffleNet图像分类

基本介绍

         今天的应用实践的领域是计算机视觉领域,更确切的说是图像分类任务,不过,与昨日不同的是,今天所使用的模型是ShuffleNet模型。ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。今天会简单介绍一些ShuffleNet模型,并使用CIFAR-10数据集进行训练与评估,最后进行模型预测

ShuffleNet模型简介

        ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速

  • Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

  • Channel Shuffle

        Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

以上两个结构就是ShuffleNet的主要结构,ShuffleNet的模型代码(MindSpore版)如下:

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

数据集准备

        采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。可直接使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。这部分的操作和昨天几乎一样,就不进行展示

模型训练与评估

        采用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum,总共训练5轮。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。具体训练代码如下:

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

评估的时候直接使用model.eval()进行评估,具体代码如下:

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

模型预测

        训练完毕则可进行模型预测,并将预测结果可视化,结果如下:

可以看出,shuffleNet效果还是不错的,在轻量化的前提下也保证了一定的精度。

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/785264.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

26.6 Django模型层

1. 模型层 1.1 模型层的作用 模型层(Model Layer)是MVC或MTV架构中的一个核心组成部分, 它主要负责定义和管理应用程序中的数据结构及其行为. 具体职责包括: * 1. 封装数据: 模型层封装了应用程序所需的所有数据, 这些数据以结构化的形式存在, 如数据库表, 对象等. * 2. 数据…

LabVIEW电滞回线测试系统

铁电材料的性能评估依赖于电滞回线的测量,这直接关系到材料的应用效果和寿命。传统的电滞回线测量方法操作复杂且设备成本高。开发了一种基于LabVIEW的电滞回线测试系统,解决传统方法的不足,降低成本,提高操作便捷性和数据分析的自…

王老师 linux c++ 通信架构 笔记(二)配置服务器为固定的 ip 地址、远程登录、安装 gcc g++ 与虚拟机文件夹共享

(7)本条目开始配置 linux 的固定 ip 地址,以作为服务器使用: 首先解释 linux 的网口编号: linux 命令 cd : change directory 改变目录。 ls : list 列出某目录下的文件 根目录文件名 / etc &a…

Java创建多线程的几种方式详解

进程与线程 简单介绍下基础概念 进程:在内存中执行的应用程序 线程:进程中最小的执行单元,作用是负责当前进程中程序的运行。一个进程中至少有一个线程,一个进程还可以有多个线程,这样的应用程序就称之为多线程程序 …

【详细教程】PowerDesigner导出表结构word文档

📖【详细教程】PowerDesigner导出表结构word文档 ✅第一步:新建报告✅第二步:配置导出的参数✅第三步:导出 ✅第一步:新建报告 ✅第二步:配置导出的参数 如果你只需要导出纯粹的表结构,那么下面…

音频demo:使用opencore-amr将PCM数据与AMR-NB数据进行相互编解码

1、README a. 编译 编译demo 由于提供的.a静态库是在x86_64的机器上编译的,所以仅支持该架构的主机上编译运行。 $ make编译opencore-amr 如果想要在其他架构的CPU上编译运行,可以使用以下命令(脚本)编译opencore-amr[下载地…

Html5前端基本知识整理与回顾上篇

今天我们结合之前上传的知识资源来回顾学习的Html5前端知识,与大家共勉,一起学习。 目录 介绍 了解 注释 标签结构 排版标签 标题标签 ​编辑 段落标签 ​编辑 换⾏标签 ​编辑 ⽔平分割线 ⽂本格式化标签 媒体标签 绝对路径 相对路径 …

【Python】不小心卸载pip后(手动安装pip的两种方式)

文章目录 方法一:使用get-pip.py脚本方法二:使用easy_install注意事项 不小心卸载pip后:手动安装pip的两种方式 在使用Python进行开发时,pip作为Python的包管理工具,是我们安装和管理Python库的重要工具。然而&#x…

接口调用的三种方式

例子: curl --location http://110.0.0.1:1024 \ --header Content-Type: application/json \ --data {"task_id": 1 }方式一:postman可视化图形调用 方式二:Vscode中powershell发送请求 #powershell (psh) Invoke-WebRequest -U…

用R在地图上绘制网络图的三种方法

地理网络图与传统的网络图不同,当引用地理位置进行节点网络可视化时,需要将这些节点放置在地图上,然后绘制他们之间的连结。Markus konrad的帖子(https://datascience.blog.wzb.eu/2018/05/31/three-ways-of-visualizing-a-graph-on-a-map/)&…

Linux系统编程——线程控制

目录 一,关于线程控制 二,线程创建 2.1 pthread_create函数 2.2 ps命令查看线程信息 三,线程等待 3.1 pthread_join函数 3.2 创建多个线程 3.3 pthread_join第二个参数 四,线程终止 4.1 关于线程终止 4.2 pthread_exit…

LeetCode 算法:腐烂的橘子 c++

原题链接🔗:腐烂的橘子 难度:中等⭐️⭐️ 题目 在给定的 m x n 网格 grid 中,每个单元格可以有以下三个值之一: 值 0 代表空单元格;值 1 代表新鲜橘子;值 2 代表腐烂的橘子。 每分钟&#…

Java版Flink使用指南——定制RabbitMQ数据源的序列化器

大纲 新建工程新增依赖数据对象序列化器接入数据源 测试修改Slot个数打包、提交、运行 工程代码 在《Java版Flink使用指南——从RabbitMQ中队列中接入消息流》一文中,我们从RabbitMQ队列中读取了字符串型数据。如果我们希望读取的数据被自动化转换为一个对象&#x…

JAVA案例ATM系统

一案例要求: 首先完成ATM的用户登录和用户开户两个大功能,用户开户有账户名,性别,账户密码,确认密码,每次取现额度,并且随机生成一个7位数的账号,用户登录功能有查询,存…

k8s 部署 metribeat 实现 kibana 可视化 es 多集群监控指标

文章目录 [toc]环境介绍老(来)板(把)真(展)帅(示)helm 包准备配置监控集群获取集群 uuid生成 api_key配置 values.yaml 配置 es 集群获取集群 uuid 和 api_key配置 values.yaml 查看监控 缺少角色的报错 开始之前,需要准备好以下场景 一套 k8s 环境 k8s 内有两套不同…

Aqara 发布多款智能照明新品,引领空间智能新时代

7月8日,全球 IoT 独角兽品牌 Aqara 以“光,重塑空间想象”为主题,举办了线上智能照明新品沟通会。 会上,Aqara 正式发布一系列引领行业的智能照明新品,包括银河系列轨道灯 V1 以及繁星系列妙控旋钮 V1 等,…

Hospital Management System v4.0 SQL 注入漏洞(CVE-2022-24263)

前言 CVE-2022-24263 是一个影响 Hospital Management System (HMS) v4.0 的 SQL 注入漏洞。这个漏洞允许攻击者通过注入恶意 SQL 代码来获取数据库的敏感信息,甚至可能控制整个数据库。以下是对这个漏洞的详细介绍: 漏洞描述 在 Hospital Management…

使用Keil 点亮LED灯 F103ZET6

1.新建项目 不截图了 2.startup_stm32f10x_hd.s Keil\Packs\Keil\STM32F1xx_DFP\2.2.0\Device\Source\ARM 搜索startup_stm32f10x_hd.s 复制到项目路径,双击Source Group 1 3.项目文件夹新建stm32f10x.h, 新建文件main.c #include "stm32f10x…

OS-HACKNOS-2.1

确定靶机IP地址 扫描靶机开放端口信息 目录扫描 访问后发现个邮箱地址 尝试爆破二级目录 确定为wordpress站 利用wpscan进行漏洞扫描 #扫描所有插件 wpscan --url http://192.168.0.2/tsweb -e ap 发现存在漏洞插件 cat /usr/share/exploitdb/exploits/php/webapps/46537.txt…