主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nn

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batchsize, -1, height, width)
    return x


class CBRM(nn.Module):  # Conv BN ReLU Maxpool2d
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class Shuffle_Block(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(Shuffle_Block, self).__init__()

        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )

        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),

            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),

            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------
        elif m in [CBRM, Shuffle_Block]:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]
        # --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4
  - [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8
  - [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2
  - [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16
  - [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4
  - [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32
  - [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 3], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 9

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 12 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 15 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 6], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 18 (P5/32-large)

  - [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO

# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights

# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/463208.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

centos命令history设置记录10000行

今天在操作服务器的时候&#xff0c;用history查看操作记录的时候&#xff0c;发现只能查看10条&#xff0c;这样不行啊&#xff0c;我想查看所有人对服务器操作的命令。 [rootbogon ~]# history解决办法&#xff1a; #1、找到/etc/profile文件中的histsize 把10改成10000 […

机器学习周报第33周

目录 摘要Abstract一、文献阅读1.1 论文标题1.2 论文摘要1.3 论文背景1.4 过去研究1.5 论文介绍1.5.1 论文模型1.5.2 时空交互学习模块&#xff08;Spatiotemporal Interactive Learning Module&#xff09;1.5.3 动态图推理模块&#xff08;Dynamic Graph Inference Module&am…

ISIS接口认证实验简述

默认情况下&#xff0c;ISIS接口认证通过在ISIS协议数据单元&#xff08;PDU&#xff09;中添加认证字段&#xff0c;例如&#xff1a;一个密钥或密码&#xff0c;用于验证发送方的身份。 ISIS接口认证防止未经授权的设备加入到网络中&#xff0c;并确保邻居之间的通信是可信的…

java的前缀和算法

前缀和的概念 对于一个给定的数组A&#xff0c;它的前缀和数组S中S[i]表示从第1个元素到第i个元素的总和&#xff0c;用公式表示为&#xff1a; SiA1A2A3...An 前缀和的作用 在O(1)的时间求出数组任意区间的区间和。 降低求解的复杂度 算法模板 int n10; int [] arrnew in…

vxe-table表格组件的使用已经query函数扩展

最近新项目使用vue3typescript开发后台管理系统&#xff0c;基本上展示内容一致表格的方式展示&#xff0c;所以使用vxe-table组件来开发&#xff0c;主要是为了方便使用工具栏&#xff0c;以及其他表格操作。 vxe-table 开发文档&#xff1a;https://vxetable.cn/#/table/sta…

springboot项目学习-瑞吉外卖(1)

第一天任务如下&#xff1a; 建立基本架构完成登录、退出功能 注意&#xff1a;本博客没有使用网上教程里的mybatis-plus&#xff0c;使用的是mybatis&#xff1b;数据库连接池也没有使用教程里的druid&#xff0c;使用的是spring自带的连接池 基本架构 common包&#xff1a;存…

PostMan测试文件上传

后端代码 package com.example.backend.controller;import cn.hutool.core.io.FileUtil; import cn.hutool.core.util.StrUtil; import com.example.backend.common.Result; import lombok.extern.slf4j.Slf4j; import org.springframework.web.bind.annotation.*; import org…

macOS系统中通过brew安装MongoDB

Macos 修改目录权限&#xff1a; sudo chmod -R 777 你的文件夹 本文使用homebrew进行安装简单&#xff0c;因为从官网下载安装包并手动安装需要移动安装包到合适的目录下并配置环境变量等一大堆操作后才能使用数据库&#xff08;若没有安装过brew请自行百度进行安装brew&am…

Pikachu 靶场搭建

文章目录 环境说明1 Pikachu 简介2 Pikachu 安装 环境说明 操作系统&#xff1a;Windows 10PHPStudy 版本: 8.1.1.3Apache 版本&#xff1a;2.4.39MySQL 版本 5.7.26 1 Pikachu 简介 Pikachu是一个使用“PHP MySQL” 开发、包含常见的Web安全漏洞、适合Web渗透测试学习人员练…

python3GUI--qt仿暴风影音视频播放器By:PyQt5(附下载地址)

文章目录 一&#xff0e;前言二&#xff0e;环境1.开发环境2.打包环境3.运行环境 三&#xff0e;软件截图1.启动页2.视频播放3.音频播放4.其他1.托盘2.对话框 四&#xff0e;功能总览五&#xff0e;代码展示&心得1.UI设计2.如何防止卡顿3.如何自定义组件 五&#xff0e;思考…

特殊内齿轮加工的另一种选择

内齿轮加工普遍采用插齿或拉削&#xff0c;但对于一些特殊齿廓的内齿轮来说&#xff0c;插齿可能会有一定的困难&#xff0c;或者成本较高。在这种情况下&#xff0c;线切割加工不失为一种不错的选择。那么什么样的零件需要选择这种加工方式呢&#xff1f;一起来看看&#xff1…

政安晨:【深度学习处理实践】(九)—— Transformer架构

咱们接着这个系列的上一篇文章继续&#xff1a; 政安晨&#xff1a;【深度学习处理实践】&#xff08;八&#xff09;—— 表示单词组的两种方法&#xff1a;集合和序列https://blog.csdn.net/snowdenkeke/article/details/136762323 Transformer是一种架构&#xff0c;用于在…

C#,图论与图算法,图(Graph)的数据结构设计与源代码

因为后面即将发布的大量有关“图”的算法与源代码都需要用到下面的这些基础数据&#xff0c;为避免大家去下载&#xff0c;特意先发布于此。 一、图&#xff08;Graph&#xff09;的基础知识 图&#xff08;Graph&#xff09;是一组对象的图示&#xff0c;其中一些对象对通过链…

zabbix企业微信接入结合海螺问问编写的shell脚本

前言 博客懒得写详细了&#xff0c;视频剪的累死了&#xff0c;看视频就好了 白帽小丑的个人空间-白帽小丑个人主页-哔哩哔哩视频 shell脚本 #!/bin/bash #set -x CorpID"" #我的企业下面的CorpID Secret"" #创建的应用那…

web canvas系列——快速入门上手绘制二维空间点、线、面

文章目录 ⭐前言⭐基本用法&#x1f496;设置一个 canvas 2D 上下文&#x1f496;绘制矩形常用方法属性&#x1f496;绘制一个红蓝交替的矩形 &#x1f496;绘制路径常用方法属性&#x1f496;画一个点&#x1f496;画一条线&#x1f496;画一个三角形面&#x1f496;画一个笑脸…

Nginx高级技术: 代理缓存配置

一、缓存说明 Nginx缓存&#xff0c;Nginx 提供了一个强大的反向代理和 HTTP 服务器功能&#xff0c;同时也是一个高效的缓存服务器。一般情况下系统用到的缓存有以下三种&#xff1a; 1、服务端缓存&#xff1a;缓存存在后端服务器&#xff0c;如 redis。 2、代理缓存&#…

【QT入门】VS2019+QT的开发环境配置

声明&#xff1a;该专栏为本人学习Qt知识点时候的笔记汇总&#xff0c;希望能给初学的朋友们一点帮助(加油&#xff01;) 往期回顾&#xff1a; 【QT入门】什么是qt&#xff0c;发展历史&#xff0c;特征&#xff0c;应用&#xff0c;QtCreator-CSDN博客【QT入门】Windows平台下…

【Vue】Request模块 - axios 封装Vuex的持久化存储

&#x1f4dd;个人主页&#xff1a;五敷有你 &#x1f525;系列专栏&#xff1a;Vue ⛺️稳中求进&#xff0c;晒太阳 Request模块 - axios 封装 使用axios来请求后端接口&#xff0c;一般会对axios进行一些配置&#xff08;比如配置基础地址&#xff0c;请求响应拦截器…

金鸣表格文字识别大师:解决医学文档PDF生僻字识别难题的利器

在医学领域&#xff0c;文档资料常常涉及到大量的专业术语和生僻字&#xff0c;例如唑吡坦、哌替啶、氟桂利嗪等。这些专业词汇对于非专业人士来说可能较为陌生&#xff0c;但在医学研究和临床实践中却具有不可或缺的重要性。然而&#xff0c;当这些生僻字出现在PDF文档中&…

Rust学习02:推荐一本入门书,免费的

都说Rust的学习曲线很陡峭&#xff0c;试过才知雀实不容易。 先说我的基础&#xff0c;非科班&#xff0c;自学Python&#xff0c;写过几个小程序。 我买书从来不扣扣嗖嗖的&#xff0c;所以先啃了几本Rust的入门书&#xff0c;包括&#xff1a; Tim McNamara的《Rust实战》&am…