Yolov8 目标检测剪枝学习记录

最近在进行YOLOv8系列的轻量化,目前在网络结构方面的优化已经接近极限了,所以想要学习一下模型剪枝是否能够进一步优化模型的性能
这里主要参考了torch-pruning的基本使用,v8模型剪枝,Jetson nano部署剪枝YOLOv8
下面只是记录一个简单流程,用于后续使用在自己的任务和网络中,数据不作为参考

首先训练一个base模型用于参考

  • 环境:Ultralytics YOLOv8.2.18 🚀 Python-3.10.14 torch-2.4.0 CUDA:0 (NVIDIA H100 PCIe, 81008MiB)
  • 训练代码

参考网上或者自己写一个能训练即可,为了方便我将通用的记录下来,实测可用来自代码来源

from ultralytics import YOLO
import os

root = os.getcwd()
## 配置文件路径
name_yaml             = os.path.join(root, "ultralytics/datasets/VOC.yaml")
name_pretrain         = os.path.join(root, "yolov8s.pt")
## 原始训练路径
path_train            = os.path.join(root, "runs/detect/VOC")
name_train            = os.path.join(path_train, "weights/last.pt")
## 约束训练路径、剪枝模型文件
path_constraint_train = os.path.join(root, "runs/detect/VOC_Constraint")
name_prune_before     = os.path.join(path_constraint_train, "weights/last.pt")
name_prune_after      = os.path.join(path_constraint_train, "weights/last_prune.pt")
## 微调路径
path_fineturn         = os.path.join(root, "runs/detect/VOC_finetune")

def else_api():
    path_data = ""
    path_result = ""
    model = YOLO(name_pretrain) 
    metrics = model.val()  # evaluate model performance on the validation set
    model.export(format='onnx', opset=11, simplify=True, dynamic=False, imgsz=640)
    model.predict(path_data, device="0", save=True, show=False, save_txt=True, imgsz=[288,480], save_conf=True, name=path_result, iou=0.5)  # 这里的imgsz为高宽

def step1_train():
    model = YOLO(name_pretrain) 
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_train)  # train the model

## 2024.3.4添加【amp=False】
def step2_Constraint_train():
    model = YOLO(name_train) 
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, amp=False, workers=16, save_period=1,name=path_constraint_train)  # train the model

def step3_pruning():
    from LL_pruning import do_pruning
    do_pruning(name_prune_before, name_prune_after)

def step4_finetune():
    model = YOLO(name_prune_after)     # load a pretrained model (recommended for training)
    model.train(data=name_yaml, device="0,1", imgsz=640, epochs=50, batch=32, workers=16, save_period=1, name=path_fineturn)  # train the model

step1_train()
# step2_Constraint_train()
# step3_pruning()
# step4_finetune()

第一步,step1_train()

  • 即训练一个base模型,用于最后性能好坏的重要参考
    在这里插入图片描述

第二步,step2_Constraint_train()

训练之前在ultralytics\engine\trainer.py添加bn的L1正则,使得bn参数在训练时变得稀疏

  • 通过对参数的绝对值进行惩罚,使得一些不重要的权重变为零,从而实现模型的稀疏化和简化
     # Backward
     self.scaler.scale(self.loss).backward()
     ## add new code=============================duj
     ## add l1 regulation for step2_Constraint_train               
     l1_lambda = 1e-2 * (1 - 0.9 * epoch / self.epochs)
     for k, m in self.model.named_modules():
         if isinstance(m, nn.BatchNorm2d):
             m.weight.grad.data.add_(l1_lambda * torch.sign(m.weight.data))
             m.bias.grad.data.add_(1e-2 * torch.sign(m.bias.data))

     # Optimize - https://pytorch.org/docs/master/notes/amp_examples.html
     if ni - last_opt_step >= self.accumulate:
         self.optimizer_step()
         last_opt_step = ni

在这里插入图片描述

  • 个人理解的稀疏化作用
    • 通过对 gamma 和 beta 添加 L1 正则化,可以促使某些通道的 BN 权重变得非常小,甚至为零。这意味着在剪枝时,可以将这些通道从模型中移除
    • 通过稀疏化 BN 层并剪除不重要的通道,剩下的通道会更有效地利用计算资源,减少无用计算。

第三步,step3_pruning()剪枝操作

LL_pruning.py

from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
import os


class PRUNE():
    def __init__(self) -> None:
        self.threshold = None

    def get_threshold(self, model, factor=0.8):
        ws = []
        bs = []
        for name, m in model.named_modules():
            if isinstance(m, torch.nn.BatchNorm2d):
                w = m.weight.abs().detach()
                b = m.bias.abs().detach()
                ws.append(w)
                bs.append(b)
                print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())
                print()
        # keep
        ws = torch.cat(ws)
        self.threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]

    def prune_conv(self, conv1: Conv, conv2: Conv):
        ## a. 根据BN中的参数,获取需要保留的index================
        gamma = conv1.bn.weight.data.detach()
        beta  = conv1.bn.bias.data.detach()
        
        keep_idxs = []
        local_threshold = self.threshold
        while len(keep_idxs) < 8:  ## 若剩余卷积核<8, 则降低阈值重新筛选
            keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]
            local_threshold = local_threshold * 0.5
        n = len(keep_idxs)
        # n = max(int(len(idxs) * 0.8), p)
        print(n / len(gamma) * 100)
        # scale = len(idxs) / n

        ## b. 利用index对BN进行剪枝============================
        conv1.bn.weight.data = gamma[keep_idxs]
        conv1.bn.bias.data   = beta[keep_idxs]
        conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]
        conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]
        conv1.bn.num_features = n
        conv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]
        conv1.conv.out_channels = n
        
        ## c. 利用index对conv1进行剪枝=========================
        if conv1.conv.bias is not None:
            conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]

        ## d. 利用index对conv2进行剪枝=========================
        if not isinstance(conv2, list):
            conv2 = [conv2]
        for item in conv2:
            if item is None: continue
            if isinstance(item, Conv):
                conv = item.conv
            else:
                conv = item
            conv.in_channels = n
            conv.weight.data = conv.weight.data[:, keep_idxs]
     
    def prune(self, m1, m2):
        if isinstance(m1, C2f):      # C2f as a top conv
            m1 = m1.cv2
        if not isinstance(m2, list): # m2 is just one module
            m2 = [m2]
        for i, item in enumerate(m2):
            if isinstance(item, C2f) or isinstance(item, SPPF):
                m2[i] = item.cv1
        self.prune_conv(m1, m2)
     
def do_pruning(modelpath, savepath):
    pruning = PRUNE()

    ### 0. 加载模型
    yolo = YOLO(modelpath)                  # build a new model from scratch
    pruning.get_threshold(yolo.model, 0.8)  # 获取剪枝时bn参数的阈值,这里的0.8为剪枝率。

    ### 1. 剪枝c2f 中的Bottleneck
    for name, m in yolo.model.named_modules():
        if isinstance(m, Bottleneck):
            pruning.prune_conv(m.cv1, m.cv2)

    ### 2. 指定剪枝不同模块之间的卷积核
    seq = yolo.model.model
    for i in [3,5,7,8]: 
        pruning.prune(seq[i], seq[i+1])

    ### 3. 对检测头进行剪枝
    # 在P3层: seq[15]之后的网络节点与其相连的有 seq[16]、detect.cv2[0] (box分支)、detect.cv3[0] (class分支)
    # 在P4层: seq[18]之后的网络节点与其相连的有 seq[19]、detect.cv2[1] 、detect.cv3[1] 
    # 在P5层: seq[21]之后的网络节点与其相连的有 detect.cv2[2] 、detect.cv3[2] 
    detect:Detect = seq[-1]
    last_inputs   = [seq[15], seq[18], seq[21]]
    colasts       = [seq[16], seq[19], None]
    for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):
        pruning.prune(last_input, [colast, cv2[0], cv3[0]])
        pruning.prune(cv2[0], cv2[1])
        pruning.prune(cv2[1], cv2[2])
        pruning.prune(cv3[0], cv3[1])
        pruning.prune(cv3[1], cv3[2])

    ### 4. 模型梯度设置与保存
    for name, p in yolo.model.named_parameters():
        p.requires_grad = True
     
    yolo.val()
    torch.save(yolo.ckpt, savepath)
    yolo.model.pt_path = yolo.model.pt_path.replace("last.pt", os.path.basename(savepath))
    yolo.export(format="onnx")

    ## 重新load模型,修改保存命名,用以比较剪枝前后的onnx的大小
    yolo = YOLO(modelpath)  # build a new model from scratch
    yolo.export(format="onnx")


if __name__ == "__main__":

    modelpath = "runs/detect1/14_Constraint/weights/last.pt"
    savepath  = "runs/detect1/14_Constraint/weights/last_prune.pt"
    do_pruning(modelpath, savepath)

在这里插入图片描述

  • 如下图可用看到剪枝前后还是有区别的,参数量减少很多,网络性能将不可用,需要微调恢复精度
    在这里插入图片描述
  • 查看剪枝前后模型大小 du -sh ./runs/detect/VOC_Constraint/weights/last*yolov8n模型
    在这里插入图片描述

微调

该部分内容我也存在一些疑问,例如很多博主让ultralytics\engine\trainer.py添加加载模型代码,经过我8.2版本测试代码添加是完全失效的,因为setup_model在执行if isinstance(self.model, torch.nn.Module)就已经return了。

 def setup_model(self):
        """Load/create/download model for any task."""
        if isinstance(self.model, torch.nn.Module):  # if model is loaded beforehand. No setup needed
            return
  • 例如ultralytics\engine\trainer.py
  • v8…x添加代码:548行 参考这里
self.model = self.get_model(cfg=cfg, weights=weights, verbose=RANK == -1)
# duj add code to finetune
self.model = weights
return ckpt
  • 如果是v8.0.x 参考这里

在看到这篇中的修改1启发

  • v8.2.x上面我不确定是哪个版本需要添加的,但是我实测都不起作用
  • 我尝试在ultralytics\engine\model.py添加如下代码加载模型成功
 self.trainer = (trainer or self._smart_load("trainer"))(overrides=args, _callbacks=self.callbacks)
        if not args.get("resume"):  # manually set model only if not resuming
            # self.trainer.model = self.trainer.get_model(weights=self.model if self.ckpt else None, cfg=self.model.yaml)
            # self.model = self.trainer.model
            # dujiang edit 
            self.trainer.model = self.model.train()

            if SETTINGS["hub"] is True and not self.session:
  • 这里就是确保自己加载的是剪枝后的模型,但是不同版本好像不同,后续在探究原因。。。
  • 这里有个小插曲,我在使用自己模型稀疏训练后剪枝发现(步骤2)发现BN层全没了,这里后面我将别人的稀疏训练的v8s模型拿来进行剪枝就没问题
  • 可能是v8n的问题,也可能是我训练的问题,这里先不做深究继续查看剪枝是否成功且微调加载成功后能否恢复精度
    在这里插入图片描述
  • 此时多次尝试我基本确定微调加载的是我剪枝后的模型,接下来就是等待训练结果是否参数量正确了。
    在这里插入图片描述

总结

总的来说跑通整个流程了,接下来尝试在自己的任务和数据上面进行剪枝,看看更换了模型结构又会有哪些坑等着我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/954889.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【深度学习】关键技术-激活函数(Activation Functions)

激活函数&#xff08;Activation Functions&#xff09; 激活函数是神经网络的重要组成部分&#xff0c;它的作用是将神经元的输入信号映射到输出信号&#xff0c;同时引入非线性特性&#xff0c;使神经网络能够处理复杂问题。以下是常见激活函数的种类、公式、图形特点及其应…

图数据库 | 18、高可用分布式设计(中)

上文我们聊了在设计高性能、高可用图数据库的时候&#xff0c;从单实例、单节点出发&#xff0c;一般有3种架构演进选项&#xff1a;主备高可用&#xff0c;今天我们具体讲讲分布式共识&#xff0c;以及大规模水平分布式。 主备高可用、分布式共识、大规模水平分布式&#xff…

Oracle 终止正在执行的SQL

目录 一. 背景二. 操作简介三. 投入数据四. 效果展示 一. 背景 项目中要求进行性能测试&#xff0c;需要向指定的表中投入几百万条数据。 在数据投入的过程中发现投入的数据不对&#xff0c;需要紧急停止SQL的执行。 二. 操作简介 &#x1f449;需要DBA权限&#x1f448; ⏹…

Datawhale组队学习笔记task1——leetcode面试题

文章目录 写在前面刷题流程刷题技巧 Day1题目1、0003.无重复字符的最长子串解答&#xff1a;2.00004 寻找两个正序数组的中位数解答&#xff1a;3.0005.最长回文子串解答4.0008.字符串转换整数解答&#xff1a; Day2题目1.0151.反转字符串中的单词解答2.0043.字符串相乘解答3.0…

K3二开:在工业老单工具栏增加按钮,实现打印锐浪报表

在上次实现用GridRepot报表实现打印任务单后&#xff0c;在想着能不能给将生产任务单原来要通过点击菜单栏&#xff0c;打印任务单的功能&#xff0c;在工具栏上也增加按钮实现&#xff0c;这样就不需要多点了。 原本是需要点击菜单栏才能实现的 现在在工具栏上增加按钮实现同…

[计算机网络]一. 计算机网络概论第一部分

作者申明&#xff1a;作者所有文章借助了各个渠道的图片视频以及资料&#xff0c;在此致谢。作者所有文章不用于盈利&#xff0c;只是用于个人学习。 1.0推荐动画 【网络】半小时看懂<计算机网络>_哔哩哔哩_bilibili 1.1计算机网络在信息时代的作用 在当今信息时代&…

机器学习之支持向量机SVM及测试

目录 1 支持向量机SVM1.1 概念1.2 基本概念1.3 主要特点1.4 优缺点1.5 核函数1.6 常用的核函数1.7 函数导入1.8 函数参数 2 实际测试2.1 二维可视化测试代码2.2 多维测试 1 支持向量机SVM 1.1 概念 支持向量机&#xff08;Support Vector Machine&#xff0c;简称SVM&#xff…

云服务信息安全管理体系认证,守护云端安全

在数据驱动的时代&#xff0c;云计算已成为企业业务的超级引擎&#xff0c;推动着企业飞速发展。然而&#xff0c;随着云计算的广泛应用&#xff0c;信息安全问题也日益凸显&#xff0c;如同暗流涌动下的礁石&#xff0c;时刻威胁着企业的航行安全。这时&#xff0c;云服务信息…

服务器数据恢复—Zfs文件系统数据恢复案例

服务器数据恢复环境&故障&#xff1a; 一台zfs文件系统的服务器&#xff0c;管理员误操作删除了服务器上的数据。 服务器数据恢复过程&#xff1a; 1、将故障服务器中所有硬盘做好标记后取出&#xff0c;硬件工程师检测后没有发现有硬盘存在硬件故障。以只读方式将所有硬盘…

​​​​​​​​​​​​​​★3.3 事件处理

★3.3.1 ※MouseArea Item <-- MouseArea 属性 acceptedButtons : Qt::MouseButtons containsMouse : bool 【书】只读属性。表明当前鼠标光标是否在MouseArea上&#xff0c;默认只有鼠标的一个按钮处于按下状态时才可以被检测到。 containsPress : bool curs…

GIS大模型:三维重建与建模

文章目录 数据收集预处理特征提取深度估计点云生成表面重建纹理映射大模型的角色 大模型在三维重建与建模方面&#xff0c;尤其是在处理低空地图数据时&#xff0c;展现了其强大的能力。通过使用深度学习算法&#xff0c;特别是那些基于卷积神经网络&#xff08;CNNs&#xff0…

wireshark抓路由器上的包 抓包路由器数据

文字目录 抓包流程概述设置抓包配置选项 设置信道设置无线数据包加密信息设置MAC地址过滤器 抓取联网过程 抓包流程概述 使用Omnipeek软件分析网络数据包的流程大概可以分为以下几个步骤&#xff1a; 扫描路由器信息&#xff0c;确定抓包信道&#xff1b;设置连接路由器的…

阿里云无影云电脑的使用场景

阿里云无影云电脑是一种安全、高效的云上虚拟桌面服务&#xff0c;广泛应用于多种场景&#xff0c;包括教育、企业办公、设计与视频制作、客服中心等。以下是九河云总结的无影云电脑的几个典型使用场景&#xff1a; #### 1. 教育机构 - **业务痛点**&#xff1a; - 学生实践操…

力扣 查找元素的位置

二分查找经典例题。 题目 要是只是从数组中用二分查找对应的元素&#xff0c;套一下模板一下就可以得出了&#xff0c;然后这题就在于其中会有多个目标元素&#xff0c;要用不同的方式在找到第一个元素时再做偏移。 时间复杂度&#xff1a;O(log n)&#xff0c;空间复杂度&am…

Profibus DP转Modbus TCP协议转换网关模块功能详解

Profibus DP 和 Modbus TCP 是两种不同的工业现场总线协议&#xff0c;Profibus DP 常用于制造业自动化领域&#xff0c;而 Modbus TCP 则在工业自动化和楼宇自动化等领域广泛应用。实现 Profibus DP 转 Modbus TCP 功能&#xff0c;通常需要特定的网关设备&#xff0c;以下为你…

SQL Prompt 插件

SQL Prompt 插件 注&#xff1a;SQL Prompt插件提供智能代码补全、SQL格式化、代码自动提示和快捷输入等功能&#xff0c;非常方便&#xff0c;可以自行去尝试体会。 1、问题 SSMS&#xff08;SQL Server Management Studio&#xff09;是SQL Server自带的管理工具&#xff0c…

OpenCV基础:矩阵的创建、检索与赋值

本文主要是介绍如何使用numpy进行矩阵的创建&#xff0c;以及从矩阵中读取数据&#xff0c;修改矩阵数据。 创建矩阵 import numpy as npa np.array([1,2,3]) b np.array([[1,2,3],[4,5,6]]) #print(a) #print(b)# 创建全0数组 eros矩阵 c np.zeros((8,8), np.uint8) #prin…

【Flink系列】9. Flink容错机制

9. 容错机制 在Flink中&#xff0c;有一套完整的容错机制来保证故障后的恢复&#xff0c;其中最重要的就是检查点。 9.1 检查点&#xff08;Checkpoint&#xff09; 9.1.1 检查点的保存 1&#xff09;周期性的触发保存 “随时存档”确实恢复起来方便&#xff0c;可是需要我…

于灵动的变量变幻间:函数与计算逻辑的浪漫交织(上)

大家好啊&#xff0c;我是小象٩(๑ω๑)۶ 我的博客&#xff1a;Xiao Xiangζั͡ޓއއ 很高兴见到大家&#xff0c;希望能够和大家一起交流学习&#xff0c;共同进步。 这一节我们主要来学习函数的概念&#xff0c;了解库函数中的标准库、头文件&#xff0c;了解自定义函数…

【CSS】---- CSS 实现超过固定高度后出现展开折叠按钮

1. 实现效果 2. 实现方法 使用 JS 获取盒子的高度&#xff0c;来添加对应的按钮和样式&#xff1b;使用 CSS 的浮动效果&#xff0c;参考CSS 实现超过固定高度后出现展开折叠按钮&#xff1b;使用容器查询 – container 语法&#xff1b;使用 clamp 函数进行样式判断。 3. 优…