主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2(包括完整代码+添加步骤+网络结构图)

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nn

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batchsize, -1, height, width)
    return x


class CBRM(nn.Module):  # Conv BN ReLU Maxpool2d
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class Shuffle_Block(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(Shuffle_Block, self).__init__()

        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )

        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),

            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),

            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------
        elif m in [CBRM, Shuffle_Block]:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]
        # --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4
  - [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8
  - [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2
  - [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16
  - [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4
  - [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32
  - [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 3], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 9

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 12 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 15 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 6], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 18 (P5/32-large)

  - [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO

# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights

# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/462014.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

THM学习笔记—Simple CTF

nmap扫描&#xff0c;发现2222端口很奇怪啊&#xff0c;重新换一种方式扫描2222端口 发现是ssh 先用ftp试试&#xff0c;尝试匿名登录 下载所有文件 发现只有一个ForMitch.txt&#xff0c;告诉我们其账号密码为弱密码&#xff0c;我们猜测Mitch为其用户名&#xff0c;尝试暴力…

力扣112、113、101--树

112. 路径总和 题目描述&#xff1a; 给你二叉树的根节点 root 和一个表示目标和的整数 targetSum 。 判断该树中是否存在 根节点到叶子节点 的路径&#xff0c;这条路径上所有节点值相加等于目标和 targetSum 。 如果存在&#xff0c;返回 true &#xff1b;否则&#xff0c…

体系班第十七节(经典递归)

1汉诺塔 从左移到最右&#xff0c;圆盘必须满足小压大原则 写一个大方法&#xff0c;大方法包括两步&#xff1a;第一步将最后一个圆盘上面的所有的放到第二个塔上面&#xff0c;然后将最后一个圆盘放到最后塔上面&#xff0c;再把第二个塔上面圆盘全放在第三个塔上面 #incl…

一文搞懂分布式事务解决方案

前言 在当今的分布式系统中&#xff0c;分布式事务管理是一个关键挑战。在面对跨多个服务的复杂业务流程时&#xff0c;确保数据一致性和事务的原子性变得至关重要。本文将深入探讨分布式事务的概念、原理、实现方式以及在Java领域的应用。 什么是分布式事务 分布式事务是指涉…

leetcode代码记录(动态规划基础题(斐波那契数列)

目录 1. 题目&#xff1a;2. 斐波那契数列&#xff1a;小结&#xff1a; 1. 题目&#xff1a; 斐波那契数 &#xff08;通常用 F(n) 表示&#xff09;形成的序列称为 斐波那契数列 。该数列由 0 和 1 开始&#xff0c;后面的每一项数字都是前面两项数字的和。也就是&#xff1a…

【Spring Boot】创建你的第一个 Spring Boot 应用

创建你的第一个 Spring Boot 应用 1.环境配置2.步骤详解3.项目结构分析3.1 入口类 DemoApplication3.2 控制器 PathVariableController3.3 控制器 BasicController3.4 模型 User 4.运行 Spring Boot 目前已经成为了 Java 开发领域的框架范式。本篇博客&#xff0c;我将带领大家…

c语言:操作符详解(上)

目录 一、操作符的分类二、二进制和进制转换1.2进制转10进制2.10进制转2进制3.2进制转8进制4.2进制转16进制 三、原码、反码、补码四、算术操作符、-、*、/、%1.**和-**2.*3./4.% 五、移位操作符1.左移操作符2.右移操作符 六、位操作符&#xff1a;&、|、^、~七、赋值操作符…

【Linux进程状态】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 一、直接谈论Linux的进程状态 看看Linux内核源代码怎么说 1.1、R状态 -----> 进程运行的状态 1.2、S状态 -----> 休眠状态(进程在等待“资源”就绪) 1.3、T状…

想兼职赚钱?盘点6个靠谱兼职,赚钱更轻松!

1&#xff0c;微头条搬砖 微头条搬砖是一个门槛不高的赚钱方式&#xff0c;而且不需要你有多么好的原创能力&#xff0c;去收集一些热门文章的素材进行文章伪原创&#xff0c;十分钟就能搞定&#xff0c;只要你的文章有爆点&#xff0c;足够吸人眼球&#xff0c;就能够获取不低…

liunx离线安装mysql

liunx离线安装mysql 一.安装二.添加系统mysql组和mysql用户三.创建并修改mysql数据目录四.修改目录权限五.初始化数据库如果报错关于libaio.so.1六.修改权限为root七.添加启动服务**八.** ***\*登录数据库\**** 博主 默语带您 Go to New World. ✍ 个人主页—— 默语 的博客&a…

OpenHarmony 鸿蒙操作系统初体验

一、测试环境 硬件&#xff1a;野火开发板 RK3568 &#xff0c;LubanCat2-N 系统&#xff1a;OpenHarmony_3.2.3 二、驱动安装 下载驱动 Rockchip_DriverAssitant_v5.1.1&#xff0c;并安装。 安装完成 三、镜像烧录 准备工作 注&#xff1a;仅支持烧录镜像到EMMC&…

Arthas使用案例(二)

说明&#xff1a;记录一次使用Arthas排查测试环境正在运行的项目BUG&#xff1b; 场景 有一个定时任务&#xff0c;该定时任务是定时去拉取某FTP服务器上的文件&#xff0c;进行备份、读取、解析等一系列操作。 而现在&#xff0c;因为开发环境是Windows&#xff0c; 线上项…

代码随想录 -- 回溯算法

文章目录 回溯算法理论什么是回溯法回溯法的效率回溯法解决的问题理解回溯法回溯法模板 组合问题I描述题解优化 组合总和III描述题解 电话号码的字母组合描述题解 组合总和描述题解 组合总和II描述题解 分割回文串描述题解 复原IP地址描述题解 子集描述题解 子集II描述题解 递增…

Linux 学习笔记(16)

十六、 计划任务 在很多时候为了自动化管理系统&#xff0c;我们都会用到计划任务&#xff0c;比如关机&#xff0c;管理&#xff0c;备份之类的操作&#xff0c;我 们都可以使用计划任务来完成&#xff0c;这样可以是管理员的工作量大大降低&#xff0c;而且可靠度更好。 l…

软件测试之学习测试用例的设计(等价类法、边界值法、错误猜测法、场景法、因果图法、正交法)

1. 测试用例的概念 软件测试人员向被测试系统提供的一组数据的集合&#xff0c;包括 测试环境、测试步骤、测试数据、预期结果 2. 为什么在测试前要设计测试用例 测试用例是执行测试的依据 在回归测试的时候可以进行复用 是自动化测试编写测试脚本的依据 衡量需求的覆盖率…

通过简单的案例入门Mybatis~

目录 一.概述 二.JDBC的缺点 三.案例 1.创建测试类 2.加载Mybatis核心配置文件获取SqlSessionFactory 3.获取SqlSession对象 4.执行sql 5.释放资源 一.概述 Mybatis是一款持久层框架&#xff0c;用于简化JDBC开发。所谓框架&#xff0c;就是一个半成品软件&#xff0c;…

2024年【P气瓶充装】考试资料及P气瓶充装考试总结

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 2024年P气瓶充装考试资料为正在备考P气瓶充装操作证的学员准备的理论考试专题&#xff0c;每个月更新的P气瓶充装考试总结祝您顺利通过P气瓶充装考试。 1、【多选题】CNG双燃料汽车系统主要包括&#xff08;&#xff…

数据链路层_以太网

IP协议确定数据跨网络从主机A到主机B的路径&#xff0c;即IP协议解决了路径选择问题&#xff0c;但在这之前&#xff0c;必须先解决数据在一个子网内的传输的问题。跨网络的本质就是跨多个子网&#xff0c;只要一个子网内可以通信&#xff0c;那么便可以跨网络通信。 一.以太…

Echarts+Vue 首页大屏静态示例Demo 第四版 支持自适应

效果: 源码: <template><ScaleScreenclass="scale-wrap":selfAdaption="true":autoScale="true":class="{ fullscreen-container: isFullScreen }"><div class="bg"><dv-loading v-if="loading&…

Docker 中 MySQL 的部署与管理

目录 一、Docker 中部署 MySQL1.1 部署 MySQL1.2 进入容器并创建数据库1.3 Navicat 可视化工具连接 二、可能存在的问题2.1 1130 - Host ‘172.17.0.1‘ is not allowed to connect to this MySQL server 参考资料 一、Docker 中部署 MySQL 1.1 部署 MySQL 首先&#xff0c;从…