【yolov8】yolov8剪枝训练流程

yolov8剪枝训练流程

流程:

  • 约束
  • 剪枝
  • 微调

一、正常训练

yolo train model=./weights/yolov8s.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=train

二、约束训练

2.1 修改YOLOv8代码:

ultralytics/yolo/engine/trainer.py
添加内容:

# Backward

self.scaler.scale(self.loss).backward()

# ========== 新增 ==========

l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
for k, m in self.model.named_modules():
    if isinstance(m, nn.BatchNorm2d):
        m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
        m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

# ========== 新增 ==========

# Optimize - https://pytorch.org/docs/master/notes/amp_examples.html

if ni - last_opt_step >= self.accumulate:
    self.optimizer_step()
    last_opt_step = ni

2.2 训练

需要注意的就是amp=False

yolo train model=prunt/train/weights/best.pt data=yolo_bvn.yaml epochs=100 amp=False project=prun name=constraint

训练完会得到一个best.pt和last.pt,推荐用last.pt

三、剪枝

上一步得到的last.pt作为剪枝对象,运行项目中的prun.py文件:

*这里的剪枝代码仅适用yolov8原模型,如有模块/模型的更改,则需要修改剪枝代码*

运行完会得到prune.pt和prune.onnx可以在netron.app网站拖入onnx文件查看是否剪枝成功了,成功的话可以看到某些通道数字为单数或者一些不规律的数字,如下图:

在这里插入图片描述

左侧为未剪枝的模型,右侧为剪枝后的模型。

关于yolov8剪枝有以下几点值得注意:

Pipeline:

    1. 为模型的BN增加L1约束,lambda用1e-2左右
    1. 剪枝模型使用的是全局阈值
    1. finetune模型时,一定要注意,此时需要去掉L1约束,最终的final的版本一定是去掉的(ultralytics/yolo/engine/trainer.py中注释)
    1. 对于yolo.model.named_parameters()循环,需要设置p.requires_gradTrue

Future work:

    1. 不能剪枝的layer,其实可以不用约束
    1. 对于低于全局阈值的,可以删掉整个module
    1. keep channels,对于保留的channels,它应该能整除n才是最合适的,否则硬件加速比较差
  • n怎么选呢?一般fp16时,n为8;int8时,n为16

四、 回调训练(finetune)

回调训练的唯一关键点就在于不让模型从yaml文件加载结构,直接加载pt文件

两种方法(因yolov8版本不同而选择不同方法):

方法一:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的443行左右

self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)  # calls Model(cfg, weights)

# ========== 新增该行代码 ==========

self.model = weights

# ========== 新增该行代码 ==========

return ckpt

方法二:

3.1 首先要把第一步约束训练的代码注释掉
3.2 修改相关代码,使模型不加载yaml文件

修改位置:yolo/engine/model.py的335行左右

if not args.get('resume'):  # manually set model only if not resuming
    # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
    # self.model = self.trainer.model
    ######################上面两行注释掉,添加下面一行#####
    self.trainer.model = self.model.train()
    ##########################修改####################
    self.trainer.hub_session = self.session  # attach optional HUB session
3.3 修改完代码就可以进行finetun训练了

命令行输入:

yolo train model=prun/prune/weights/last_prune.pt data="yolo_bvn.yaml" amp=False epochs=100 project=prun name=finetune device=0

五、结果展示:

5.1模型大小:ONNX模型大小从42M减少到34M

在这里插入图片描述

5.2PR曲线:

正常训练约束训练100轮微调
在这里插入图片描述在这里插入图片描述在这里插入图片描述

5.3实测视频在ubuntu上检测速度:

未剪枝:平均每帧5毫秒

剪枝后:平均每帧3.7毫秒

六、问题及解决:

对剪枝完的yolov8进行finetune时遇到RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_mm)

self.proj 可能不在与 pred_dist 相同的设备上。这可能是因为 self.proj 被指定在 CPU 上,而 pred_dist 在 GPU 上(或反之)。
要解决这个问题,需要确保两个张量位于相同的设备上。可以使用 to() 方法将 self.proj 放到与 pred_dist 相同的设备上。

解决:在loss.py添加如下代码:

def bbox_decode(self, anchor_points, pred_dist):
    """Decode predicted object bounding box coordinates from anchor points and distribution."""
    if self.use_dfl:
        b, a, c = pred_dist.shape  # batch, anchors, channels
        ####添加
        device = pred_dist.device
        self.proj = self.proj.to(device)
        #####

        pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = pred_dist.view(b, a, c // 4, 4).transpose(2,3).softmax(3).matmul(self.proj.type(pred_dist.dtype))
        # pred_dist = (pred_dist.view(b, a, c // 4, 4).softmax(2) * self.proj.type(pred_dist.dtype).view(1, 1, -1, 1)).sum(2)
    return dist2bbox(pred_dist, anchor_points, xywh=False)

七、参考:

7.1 【yolov8系列】 yolov8 目标检测的模型剪枝_yolov8 剪枝-CSDN博客
7.2 YOLOv8剪枝全过程-CSDN博客

7.3 剪枝与重参第七课:YOLOv8剪枝-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/588229.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

freertos入门---创建FreeRTOS工程

freertos入门—创建FreeRTOS工程 1 STM32CubeMx配置 双击运行STM32CubeMX,在首页选择“ACCESS TO MCU SELECTOR”,如下图所示:   在MCU选型界面,输入自己想要开发的芯片型号,如:STM32F103C8T6: 2 配置时钟 在“System Core”…

手机测试之-adb

一、Android Debug Bridge 1.1 Android系统主要的目录 1.2 ADB工具介绍 ADB的全称为Android Debug Bridge,就是起到调试桥的作用,是Android SDK里面一个多用途调试工具,通过它可以和Android设备或模拟器通信,借助adb工具,我们可以管理设备或手机模拟器的状态。还可以进行很多…

与Apollo共创生态:探索自动驾驶的未来蓝图

目录 引言Apollo开放平台Apollo开放平台企业生态计划Apollo X 企业自动驾驶解决方案:加速企业场景应用落地Apollo开放平台携手伙伴共创生态生态共创会员权益 个人心得与展望技术的多元化应用数据驱动的智能化安全与可靠性的重视 结语 引言 就在2024年4月19日&#x…

简约大气的全屏背景壁纸导航网源码(免费)

简约大气的全屏背景壁纸导航网模板 效果图部分代码领取源码下期更新预报 效果图 部分代码 <!DOCTYPE html> <html lang"zh-CN"> <!--版权归孤独 --> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible…

pyqt QSplitter控件

pyqt QSplitter控件 QSplitter控件效果代码 QSplitter控件 PyQt中的QSplitter控件是一个强大的布局管理器&#xff0c;它允许用户通过拖动边界来动态调整子控件的大小。这个控件对于创建灵活的、用户可定制的用户界面非常有用。 QSplitter控件可以水平或垂直地分割其包含的子…

阿里云开源大模型开发环境搭建

ModelScope是阿里云通义千问开源的大模型开发者社区&#xff0c;本文主要描述AI大模型开发环境的搭建。 如上所示&#xff0c;安装ModelScope大模型基础库开发框架的命令行参数&#xff0c;使用清华大学提供的镜像地址 如上所示&#xff0c;在JetBrains PyCharm的项目工程终端控…

交通 | 电动汽车车辆路径问题及FRVCP包的调用以及代码案例

编者按&#xff1a; 电动汽车的应用给车辆路线问题带来了更多的挑战&#xff0c;如何为给定路线行驶的电动汽车设计充电决策是一个需要解决的难题&#xff0c;本文介绍了开源python包frvcpy使用精确式算法对该问题求解。 文献解读&#xff1a;Aurelien Froger, Jorge E Mendo…

前端开发框架Vue

版权声明 本文原创作者&#xff1a;谷哥的小弟作者博客地址&#xff1a;http://blog.csdn.net/lfdfhl Vue概述 Vue.js&#xff08;简称Vue&#xff09;是由尤雨溪&#xff08;Evan You&#xff09;创建并维护的一款开源前端开发框架。Vue以其轻量级、易上手和高度灵活的特点&…

06_电子设计教程基础篇(学习视频推荐)

文章目录 前言一、基础视频1、电路原理3、模电4、高频电子线路5、电力电子技术6、数学物理方法7、电磁场与电磁波8、信号系统9、自动控制原理10、通信原理11、单片机原理 二、科普视频1、工科男孙老师2、达尔闻3、爱上半导体4、华秋商城5、JT硬件乐趣6、洋桃电子 三、教学视频1…

【openLooKeng集成Hive连接器完整过程】

【openLooKeng集成Hive连接器完整过程】 一、摘要二、正文2.1 环境说明2.2 Hadoop安装2.2.1. 准备工作2.2.2 在协调节点coordinator上进行安装hadoop2.2.3、将Hadoop安装目录分发到从节点worker2.2.4、在协调节点coordinator上启动hadoop集群2.3 MySQL安装2.4 Hive安装及基本操…

讯饶科技 X2Modbus 敏感信息泄露

讯饶科技 X2Modbus 敏感信息泄露 文章目录 讯饶科技 X2Modbus 敏感信息泄露漏洞描述影响版本实现原理漏洞复现修复建议 漏洞描述 X2Modbus是一款功能很强大的协议转换网关&#xff0c; 这里的X代表各家不同 的通信协议&#xff0c;2是To的谐音表示转换&#xff0c;Modbus就是最…

【C++题解】1035. 判断成绩等级

问题&#xff1a;1035. 判断成绩等级 类型&#xff1a;多分支结构 题目描述&#xff1a; 输入某学生成绩&#xff0c;如果 86 分以上(包括 86 分&#xff09;则输出 VERY GOOD &#xff0c;如果在 60 到 85 之间的则输出 GOOD (包括 60 和 85 )&#xff0c;小于 60 的则输出 …

MySQL数据库安装——zip压缩包形式

安装压缩包zip形式的 MySQL 8数据库 一 、先进入官网下载 https://dev.mysql.com/downloads/mysql/ 二、解压到某个文件夹 我解压到了D:\mysql\mysql8 下面 然后在这个文件夹下手动创建 my.ini 文件和 data 文件夹 my.ini 内容如下&#xff1a; 注意 basedir 和 datadi…

企业气候风险披露、报表词频、文本分析数据集合(2007-2022年)

01、数据介绍 企业气候风险披露是指企业通过一定的方式&#xff0c;将气候变化对其影响、自身采取的应对措施等信息披露出来。这有助于投资者更准确地评估企业价值&#xff0c;发现投资机会&#xff0c;规避投资风险。解企业在气候风险方面的关注度和披露情况。 可以帮助利益…

JavaEE_操作系统之进程(计算机体系,,指令,进程的概念、组成、特性、PCB)

一、冯诺依曼体系&#xff08;Von Neumann Architecture&#xff09; 现代的计算机, 大多遵守冯诺依曼体系结构 CPU 中央处理器: 进行算术运算和逻辑判断.存储器: 分为外存和内存, 用于存储数据(使用二进制方式存储)输入设备: 用户给计算机发号施令的设备.输出设备: 计算机个…

Pixelmator Pro for Mac:简洁而强大的图像编辑软件

Pixelmator Pro for Mac是一款专为Mac用户设计的图像编辑软件&#xff0c;它集简洁的操作界面与强大的功能于一身&#xff0c;为用户提供了卓越的图像编辑体验。 Pixelmator Pro for Mac v3.5.9中文激活版下载 该软件支持多种文件格式&#xff0c;包括常见的JPEG、PNG、TIFF等&…

原生IP和住宅IP有什么区别?

原生IP和住宅IP在多个方面存在显著的区别。 从定义和来源来看&#xff0c;原生IP是指未经NAT&#xff08;网络地址转换&#xff09;处理的真实、公网可路由的IP地址&#xff0c;它直接从互联网服务提供商&#xff08;ISP&#xff09;获得&#xff0c;而不是通过代理服务器或VP…

springboot3整合redis

redis在我们的日常开发中是必不可少的&#xff0c;本次来介绍使用spring boot整合redis实现一些基本的操作&#xff1b; 1、新建一个spring boot项目&#xff0c;并导入相应的依赖&#xff1b; <dependency><groupId>org.springframework.boot</groupId><…

【前端探索者:从零到精通的Web前端实战专栏】

🚀 在这个代码编织梦想的时代,Web前端作为互联网的颜值担当,正以日新月异的速度重塑数字世界。想要在前端江湖里游刃有余,你需要的不仅仅是一把锋利的剑,更是一套完整的武功秘籍!今天,我们就为你揭开【Web前端】专栏的神秘面纱,带你从菜鸟到大神,一飞冲天! 📚 专栏…

hadoop学习---基于hive的航空公司客户价值的LRFCM模型案例

案例需求&#xff1a; RFM模型的复习 在客户分类中&#xff0c;RFM模型是一个经典的分类模型&#xff0c;模型利用通用交易环节中最核心的三个维度——最近消费(Recency)、消费频率(Frequency)、消费金额(Monetary)细分客户群体&#xff0c;从而分析不同群体的客户价值。在某些…