模型剪枝及yolov5剪枝实践

文章目录

  • 1、模型剪枝
  • 1、 稀疏化训练
  • 2、模型剪枝
    • 2.1 非结构化剪枝
    • 2.2 结构化剪枝
    • 2.3 一些疑惑:
      • 2.3.1 剪枝后参数量不变?
  • 3、微调

【结构化剪枝掉点太多,不如一开始就选个小模型训练。非结构化剪枝只是checkpoint文件变小了,推理速度没有增益,精度还略微下降。总之目前没有认为模型剪枝在推理速度上有什么增益,比较鸡肋】

1、模型剪枝

模型剪枝的定义就不再赘述,剪枝就是将网络结构中对性能影响不大且权重值较小的节点去掉,从而降低模型的计算量,提高运算速率。

模型剪枝的流程:

  • 正常训练: 初始化模型并进行正常训练,获得一个初始模型。
  • 稀疏化训练: 在训练过程中引入稀疏化方法,使部分参数趋于零。
  • 剪枝操作: 根据稀疏化结果,去除零值或接近零的参数。
  • 微调(Fine-tuning): 对剪枝后的模型进行重新训练,以恢复性能。

很好理解,在剪枝时哪些节点需要被裁剪,需要依靠一个指标。稀疏化就是在训练过程中加以约束,使得某些不重要的节点的参数趋于零,为后续剪枝提供依据。

虽然某些阶段的参数趋于零,但是仍然具有一定的作用,当裁剪后可能会对模型的性能产生较大的波动。所以再通过微调重新训练,恢复模型的性能。

1、 稀疏化训练

数学层面的稀疏化: 稀疏化的核心是通过某种方式(如正则化或剪枝)将权重矩阵中的绝大部分元素置零,使得非零元素的比例非常低。引申到模型中,就是将参数趋于零的过程。

实现稀疏化训练比较简单的就是引入L1正则化,所谓的L1正则化就是‌指向量中各个元素的绝对值之和。在模型训练过程中,引入L1正则化损失的稀疏化训练损失为:
L = L o r i + λ ∣ ∣ w ∣ ∣ 1 L=L_{ori}+λ||w||_1 L=Lori+λ∣∣w1
其实就是一个正常的损失函数 L o r i L_{ori} Lori(在分类中可能是交叉熵,在检测中可能是obj和bbox损失),加上一个特征值参数的绝对和。

可以想象到,模型越稀疏,权重参数的绝对和越小,其损失函数越小。也就是损失函数会驱使模型向着参数和小,稀疏程度高的方向更新。

2、模型剪枝

剪枝分为非结构化和结构化剪枝。

2.1 非结构化剪枝

非结构化就是将某些权重值置0,但是整个模型结构不改变。据说这种只是会减小checkpoint的储存大小,对模型速度没有太大的优化【未测试】,因为原先需要计算300步,现在仍然需要计算300步。结构化测试就是删除模型结构中的某些组件,比如删除某个卷积层之类的,这种可能就只需要计算200步,可以进行加速。

pytorch提供了很多工具方便剪枝,比如对卷积层权重矩阵中某些权重值裁剪:

def prune(model, amount=0.3):
    # Prune model to requested global sparsity
    import torch.nn.utils.prune as prune
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.l1_unstructured(m, name='weight', amount=amount)  # 根据l1对m的weight进行裁剪,生成mask
            print(m.weight_mask[0])
            print(m.weight[0])
            prune.remove(m, 'weight')  # 根据mask对weight进行固化,并删除mask变量
    LOGGER.info(f'Model pruned to {sparsity(model):.3g} global sparsity')

这种方式会对卷积层的weight进行l1统计,从小到大排序将前30%的值置零。效果如下,这是模型中第一个卷积的权重,经过剪枝后,权重矩阵中有很多都被置0了,模型的大小也降低了一半。
在这里插入图片描述
在这里插入图片描述

2.2 结构化剪枝

非结构化只是对权重矩阵中某些个值置0,结构化则是可以将某些通道全部置零:

def structured_prune(model, amount=0.3):
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.ln_structured(m, name='weight', amount=amount, n=2, dim=0)  # 对输出通道剪枝
            prune.remove(m, 'weight')  # 使剪枝永久化

关于ln_structured函数的解释:ln_structured 是 PyTorch 中 torch.nn.utils.prune 模块中的一个剪枝方法,它用于进行 结构化剪枝(structured pruning)。该方法根据某种准则(例如 L2 范数、L1 范数)剪除模型中整个卷积核、通道、或其他结构化的单元,而不是单独的权重。结构化剪枝通常有助于减少模型的计算量和存储量,尤其是对于卷积神经网络(CNN)这样的模型。

经过结构化剪枝后,对(1,16,6,6)的卷积核,统计16个通道权重值,可以看到有16*0.3=4.8≈5个通道置0了。相对于非结构化剪枝,结构化剪枝在更高层面进行更大范围的剪枝。
在这里插入图片描述

2.3 一些疑惑:

2.3.1 剪枝后参数量不变?

如果使用total_params = sum(p.numel() for p in model.parameters())统计模型的参数量,会发现无论是非结构化剪枝还是结构化剪枝的数量和原始模型相比都没有下降。原因在于:1)非结构化中只是将某些单个值置0,整个模型结构不变;2)结构化剪枝将通道全部置0了,理论上可以去掉这个通道,比如上方的(1,16,6,6)卷积核可以去掉5个0值通道后,变成(1,11,6,6)。但是上方使用的pytorch提供的prune没有自动删除这些,只是用0值标记这个通道没用。所以可以用手动自行判断是否为0值通道然后删除,比如:

def structured_prune2(model, amount=0.3):
    for name, m in model.named_modules():
        if isinstance(m, nn.Conv2d):
            prune.ln_structured(m, name='weight', amount=amount, n=2, dim=0)  # 对输出通道剪枝
            prune.remove(m, 'weight')  # 使剪枝永久化

            weight = m.weight.data
            nonzero_channels = weight.abs().sum(dim=(1, 2, 3)) > 0  # 找出哪些通道仍然有效
            m.weight.data = weight[nonzero_channels]  # 只保留有效通道的权重

            if m.bias is not None:
                m.bias.data = m.bias.data[nonzero_channels]  # 更新偏置

可以看到原先(1,16,6,6)的卷积变成了(1,11,6,6),此时再看参数量从1764118降到了1240371,下降了30%
在这里插入图片描述

3、微调

经过剪枝,模型的参数多多少少都发生了变化,肯定对模型的精度是有所影响的:

# 剪枝前
P        R        mAP50    mAP50-95
0.987    0.956    0.975    0.641

# 非结构化剪枝30%后
P        R        mAP50    mAP50-95
0.989    0.935    0.965    0.617

# 结构化剪枝30%后
P        R        mAP50    mAP50-95
0        0        0        0

可见结构化剪枝把握不住,非结构化剪枝稍有下降。

非结构化剪枝后微调应该没有意义,比较稍微训练下那些0值权重,就又变成了有值权重。这和剪枝前的模型没有区别了。
结构化剪枝后的微调同理,当然如果是结构化后剪枝并移除了0值通道,修改了模型结构,那么微调就是有意义的了【那为什么不一开始就选择一个较小的模型或者16的通道直接设定为11?】

这般来看,除了非结构化剪枝以较小的精度损失为代价,换来降低模型储存大小的优势外,似乎没有特别有用的好处。

将模型剪枝为小模型,直接用知识蒸馏用大模型训练小模型,似乎效果也不错,效果还稳定一点。这样模型剪枝就更加没有什么意义乱了?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/957846.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

黑马程序员C++ P1-P40

一.注释和常量 1.多行注释:/*...............*/ ; 单行注释://.............. 2.常量:用于记录程序中不可修改的量 。定义方式:宏常量#define定义在文件上方 ;const修饰变量 3.标识符命名规则:标识符不能是关键字&a…

Airflow:BranchOperator实现动态分支控制流程

Airflow是用于编排复杂工作流的开源平台,支持在有向无环图(dag)中定义、调度和监控任务。其中一个关键特性是能够使用BranchOperator创建动态的、有条件的工作流。在这篇博文中,我们将探索BranchOperator,讨论它是如何…

怎么使用CRM软件?操作方法和技巧有哪些?

什么是CRM? 嘿,大家好!你知道吗,在当今这个数字化时代里,我们每天都在与各种各样的客户打交道。无论是大公司还是小型企业,都希望能够更好地管理这些关系并提高业务效率。这时候就轮到我们的“老朋友”——…

java开发,IDEA转战VSCODE配置(mac)

一、基本java开发环境配置 前提:已经安装了jdk、maven、vscode,且配置了环境变量 1、安装java相关的插件 2、安装spring相关的插件 3、vscode配置maven环境 打开 VsCode -> 首选项 -> 设置,也可以在setting.json文件中直接编辑&…

AI模型提示词(prompt)优化-实战(一)

一、prompt作用 用户与AI模型沟通的核心工具,用于引导模型生成特定内容、控制输出质量、调整行为模式,并优化任务执行效果,从而提升用户体验和应用效果 二、prompt结构 基本结构 角色:设定一个角色,给AI模型确定一个基…

Unreal Engine 5 C++ Advanced Action RPG 十章笔记

第十章 Survival Game Mode 2-Game Mode Test Map 设置游戏规则进行游戏玩法 生成敌人玩家是否死亡敌人死亡是否需要刷出更多 肯定:难度增加否定:玩家胜利 流程 新的游戏模式类游戏状态新的数据表来指定总共有多少波敌人生成逻辑UI告诉当前玩家的敌人波数 3-Survival Game M…

设计模式的艺术-代理模式

结构性模式的名称、定义、学习难度和使用频率如下表所示: 1.如何理解代理模式 代理模式(Proxy Pattern):给某一个对象提供一个代理,并由代理对象控制对原对象的引用。代理模式是一种对象结构型模式。 代理模式类型较多…

每日一题洛谷P1423 小玉在游泳c++

#include<iostream> using namespace std; int main() {double s;cin >> s;int n 0;double sum 0;double k 2;while (sum < s) {sum k;n;k * 0.98;}cout << n << endl;return 0; }

Python3 OS模块中的文件/目录方法六

一. 简介 前面文章简单学习了Python3中 OS模块中的文件/目录的部分函数。 本文继续来学习 OS模块中文件、目录的操作方法。 二. Python3 OS模块中的文件/目录方法 1. os.lseek() 方法、os.lstat() 方法 os.lseek() 方法用于在打开的文件中移动文件指针的位置。在Unix&#…

HTB:Heist[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用smbclient匿…

【HarmonyOS NEXT】华为分享-碰一碰开发分享

关键词&#xff1a;鸿蒙、碰一碰、systemShare、harmonyShare、Share Kit 华为分享新推出碰一碰分享&#xff0c;支持用户通过手机碰一碰发起跨端分享&#xff0c;可实现传输图片、共享wifi等。我们只需调用系统 api 传入所需参数拉起对应分享卡片模板即可&#xff0c;无需对 U…

使用Inno Setup软件制作.exe安装包

1.下一步&#xff1a; 2. 填写 程序名字 和 版本号&#xff1a; 3.设置安装路径信息 4.添加要打包的exe和依赖文件 5.为应用程序创建关联的文件 如果不需要就直接取消勾选 6.创建快捷方式 &#xff08;1&#xff09;第一种&#xff1a;常用 &#xff08;1&#xff09;第二种&am…

CPU 缓存基础知识

并发编程首先需要简单了解下现代CPU相关知识。通过一些简单的图&#xff0c;简单的代码&#xff0c;来认识CPU以及一些常见的问题。 目录 CPU存储与缓存的引入常见的三级缓存结构缓存一致性协议MESI协议缓存行 cache line 通过代码实例认识缓存行的重要性 CPU指令的乱序执行通过…

初步搭建并使用Scrapy框架

目录 目标 版本 实战 搭建框架 获取图片链接、书名、价格 通过管道下载数据 通过多条管道下载数据 下载多页数据 目标 掌握Scrapy框架的搭建及使用&#xff0c;本文以爬取当当网魔幻小说为案例做演示。 版本 Scrapy 2.12.0 实战 搭建框架 第一步&#xff1a;在D:\pyt…

Python - itertools- pairwise函数的详解

前言&#xff1a; 最近在leetcode刷题时用到了重叠对pairwise,这里就讲解一下迭代工具函数pairwise,既介绍给大家&#xff0c;同时也提醒一下自己&#xff0c;这个pairwise其实在刷题中十分有用&#xff0c;相信能帮助到你。 参考官方讲解&#xff1a;itertools --- 为高效循…

YOLO-cls训练及踩坑记录

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 一、模型训练 二、测试 三、踩坑记录 1、推理时设置的imgsz不生效 方法一&#xff1a; 方法二&#xff1a; 2、Windows下torchvision版本问题导致报错 总结 前…

云计算、AI与国产化浪潮下DBA职业之路风云变幻,如何谋破局启新途?

引言 在近日举办的一场「云和恩墨大讲堂」直播栏目中&#xff0c;云和恩墨联合创始人李轶楠、副总经理熊军和欧冶云商数据库首席薛晓刚共同探讨了DBA的现状与未来发展。三位专家从云计算、人工智能、国产化替代等多个角度进行了深入的分析和探讨&#xff0c;为从业者提供了宝贵…

PAT甲级-1017 Queueing at Bank

题目 题目大意 银行有k个窗口&#xff0c;每个窗口只能服务1个人。如果3个窗口已满&#xff0c;就需要等待。给出n个人到达银行的时间和服务时间&#xff0c;要求计算每个人的平均等待时间。如果某个人的到达时间超过17:00:00&#xff0c;则不被服务&#xff0c;等待时间也不计…

从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑

文章目录 从零安装 LLaMA-Factory 微调 Qwen 大模型成功及所有的坑一 参考二 安装三 启动准备大模型文件 四 数据集&#xff08;关键&#xff09;&#xff01;4.1 Alapaca格式4.2 sharegpt4.3 在 dataset_info.json 中注册4.4 官方 alpaca_zh_demo 例子 999条数据, 本机微调 5分…

AI刷题-策略大师:小I与小W的数字猜谜挑战

问题描述 有 1, 2,..., n &#xff0c;n 个数字&#xff0c;其中有且仅有一个数字是中奖的&#xff0c;这个数字是等概率随机生成的。 Alice 和 Bob 进行一个游戏&#xff1a; 两人轮流猜一个 1 到 n 的数字&#xff0c;Alice 先猜。 每完成一次猜测&#xff0c;主持会大声…