【PyTorch单点知识】神经元网络模型剪枝prune模块介绍(上)

文章目录

      • 0. 前言
      • 1. 剪枝`prune`主要功能分类
      • 2. `torch.nn.utils.prune`中的方法介绍
      • 3. PyTorch实例
        • 3.1 `BasePruningMethod`
        • 3.2`PruningContainer`
        • 3.3 `Identity`
        • 3.4`RandomUnstructured`
        • 3.5`L1Unstructured`
      • 4. 总结

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

PyTorch中的torch.nn.utils.prune模块是一个专门用于神经网络模型剪枝的工具集。模型剪枝是一种减少神经网络参数数量的技术,其目标是在保持模型性能的同时减少计算成本内存占用。这对于部署模型到资源受限的设备(如移动设备或嵌入式系统)特别有用。

本文将通过实例介绍torch.nn.utils.prune模块中的各个方法,由于内容较多分为上、下两篇。

1. 剪枝prune主要功能分类

torch.nn.utils.prune模块提供了一系列的剪枝方法,包括但不限于:

  1. 无结构剪枝:这种剪枝方法可以独立地移除网络中的权重,而不考虑权重之间的结构关系。例如,L1UnstructuredRandomUnstructured 就是两种无结构剪枝方法,它们分别根据权重的绝对值大小和随机选择的方式移除权重。

  2. 结构化剪枝:与无结构剪枝相反,结构化剪枝会移除整个的结构单位(如整个神经元或通道),而不是单独的权重。RandomStructuredLnStructured 就是这样的例子,它们可以移除整个的通道。

  3. 自定义剪枝CustomFromMask 方法允许用户自定义剪枝策略,通过提供一个掩码来指定哪些权重应该被保留或移除。

  4. 剪枝管理:除了剪枝方法本身,torch.nn.utils.prune还提供了工具来管理和应用剪枝,例如,prune.global_unstructuredprune.remove 方法。前者允许跨多个层执行全局剪枝,而后者则用于移除剪枝操作,恢复原始权重或应用剪枝掩码。

2. torch.nn.utils.prune中的方法介绍

下面是本文将介绍的torch.nn.utils.prune中的方法:

  • BasePruningMethod: 抽象基类,用于创建新的剪枝类。
  • PruningContainer:允许组合多种不同的剪枝策略,并按顺序应用这些策略。
  • Identity: 实现了一个不剪枝任何单元仅生成一个全为一的掩码的实用剪枝方法。
  • RandomUnstructured: 随机剪枝张量中的单元。
  • L1Unstructured: 根据L1范数(绝对值)剪枝张量中的单元。

3. PyTorch实例

为了介绍这些剪枝方法,我们将首先定义一个简单的模型,并使用torch.nn.utils.prune模块中的各种剪枝方法来处理这个模型的权重。我们将以一个简单的卷积层为例,然后应用上述提到的每种剪枝方法。

首先,让我们导入必要的库并定义一个包含单个卷积层的模型:

import torch
import torch.nn as nn
from torch.nn.utils import prune

torch.manual_seed(888)
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        # 创建一个简单的卷积层
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3)

model = SimpleModel()

这里有一个值得注意的地方就是prune的导入:如果不写from torch.nn.utils import prune,而直接在代码中使用torch.nn.utils.prune.xxxx(),会报错↓
在这里插入图片描述
这个报错我不太能理解,不知道会不会在后续版本中更正。

接下来,我们将逐一介绍并应用每种剪枝方法:

3.1 BasePruningMethod

这是一个抽象类,可以理解为自定义剪枝的类。

class BasePruningMethod(ABC):
    r"""Abstract base class for creation of new pruning techniques.

    Provides a skeleton for customization requiring the overriding of methods
    such as :meth:`compute_mask` and :meth:`apply`.
    """
3.2PruningContainer

一开始我觉得这个方法和nn.Sequential差不多,但是实际并不是!

PruningContainer通常不会直接由用户实例化,而是作为torch.nn.utils.prune中其他剪枝方法的基础。当调用如L1UnstructuredRandomUnstructuredLnStructured等剪枝方法时,内部会创建一个PruningContainer实例,并且将特定的剪枝方法添加到容器中。

3.3 Identity

这个方法不会剪枝(改变)任何权重,它只会生成一个全为1的掩码。

print("Weight before Identity pruning:")
print(model.conv.weight)
prune.identity(model.conv, name="weight")
print("Weight after Identity pruning:")
print(model.conv.weight)
print("mask:")
print(model.conv.weight_mask)

输出为:

Weight before Identity pruning:
Parameter containing:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], requires_grad=True)
Weight after Identity pruning:
tensor([[[[-0.3017,  0.1290, -0.2468],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.1403]]]], grad_fn=<MulBackward0>)
mask:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])
3.4RandomUnstructured

这个方法会随机选择权重值进行剪枝:

prune.random_unstructured(model.conv, name="weight", amount=0.5) 
#amount参数指定的是要被剪枝(即置零)的权重比例。
print("Weight after RandomUnstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after RandomUnstructured pruning (50%):
tensor([[[[-0.0000,  0.0000, -0.0000],
          [ 0.2107,  0.1799,  0.1923],
          [ 0.1887, -0.0527,  0.0000]]]], grad_fn=<MulBackward0>)

可以明显看出,对比3.3节的输出结果,有4个(50%)参数被剪枝(置零)了。

3.5L1Unstructured

这个方法会根据权重的L1范数选择要剪枝的权重。

prune.l1_unstructured(model.conv, name="weight", amount=0.5)
print("Weight after L1Unstructured pruning (50%):")
print(model.conv.weight)

输出为:

Weight after L1Unstructured pruning (50%):
tensor([[[[-0.3017,  0.0000, -0.2468],
          [ 0.2107,  0.0000,  0.1923],
          [ 0.1887, -0.0000,  0.0000]]]], grad_fn=<MulBackward0>)

Process finished with exit code 0

对比3.3输出的结果,可以看出L1范数(绝对值)最小的4个(50%)参数被剪枝(置零)了。

4. 总结

本文介绍了PyTorch中的prune模型剪枝模块的作用及部分方法的用法,后续将把所有方法的用法补齐!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/737712.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【AI大模型】驱动的未来:穿戴设备如何革新血液、皮肤检测与营养健康管理

文章目录 1. 引言2. 现状与挑战3. AI大模型与穿戴设备概述4. 数据采集与预处理4.1 数据集成与增强4.2 数据清洗与异常检测 5. 模型架构与训练5.1 高级模型架构5.2 模型训练与调优 6. 个性化营养建议系统6.1 营养建议生成优化6.2 用户反馈与系统优化 7. 关键血液成分与健康状况评…

Win11最适合打游戏的版本推荐:畅玩游戏,告别卡顿!

在Win11电脑操作中&#xff0c;用户不仅可以进行办公、学习等操作&#xff0c;也可以畅玩喜欢的游戏。如果喜欢打游戏的用户&#xff0c;就可以安装上适合打游戏的系统版本。但许多新手用户不知道去哪里找到最适合打游戏的Win11系统版本&#xff1f;以下小编就给大家带来这样的…

el-upload 上传图片及回显照片和预览图片,文件流和http线上链接格式操作

<div v-for"(info, index) in zsjzqwhxqList.helicopterTourInfoList" :key"info.id" >编辑上传图片// oss返回线上地址http链接格式&#xff1a;<el-form-itemlabel"巡视结果照片":label-width"formLabelWidth"><el…

Appium APP测试学习

1、安装client编程库(客户端) (1)如果遇到以下问题可以使用全路径安装 (2)安装后导致selenium升级&#xff0c;导致某些方法失效&#xff1a;如find_element_by_id。解决方法&#xff1a;卸载两个安装包&#xff0c;后面重新安装 2、安装appium Server:&#xff08;服务端&…

【论文阅读】场景生成及编辑3D定位论文阅读

<div id"content_views" class"htmledit_views" style"user-select: auto;"><div class"kdocs-document"> 前置知识 归纳偏置 关于归纳偏置的理解&#xff1a;首先推荐一篇解释归纳偏置非常好的博客&#xff1a;浅谈归纳…

【PostgreSQL】AUTO_EXPLAIN - 慢速查询的日志执行计划

本文为云贝教育 刘峰 原创&#xff0c;请尊重知识产权&#xff0c;转发请注明出处&#xff0c;不接受任何抄袭、演绎和未经注明出处的转载。 一、介绍 在本文中&#xff0c;我们将了解 PostgreSQL AUTO_EXPLAIN功能的工作原理&#xff0c;以及为什么应该使用它来收集在生产系统…

高考十字路口:24年考生如何权衡专业与学校的抉择?

文章目录 每日一句正能量前言专业解析理工科专业商科专业人文社科专业艺术与设计专业个人经验与思考过程结论 名校效应分析名校声誉与品牌效应资源获取学术氛围就业优势个人发展结论 好专业和好学校的权衡个人职业目标行业需求教育质量资源和机会学术氛围就业优势经济和地理位置…

SVM算法-人脸识别背后技术详解

引言 支持向量机&#xff08;SVM&#xff09;是一种强大的监督学习算法&#xff0c;广泛应用于分类和回归任务中。本文将详细介绍SVM算法在人脸识别任务中的应用&#xff0c;并通过代码示例来展示其背后的技术精髓。我们将分三大部分来展开&#xff0c;本部分将重点介绍SVM算法…

SpringBoot2+Vue3开发博客管理系统

项目介绍 博客管理系统&#xff0c;可以帮助使用者管理自己的经验文章、学习心得、知识文章、技术文章&#xff0c;以及对文章进行分类&#xff0c;打标签等功能。便于日后的复习和回忆。 架构介绍 博客管理系统采用前后端分离模式进行开发。前端主要使用技术&#xff1a;Vu…

网管工作实践_02_IP/MAC地址管理工具

1、ipconfig命令格式及参数 ipconfig是内置于Windows的TCP/IP应用程序&#xff0c;用于显示本地计算机网络适配器的MAC地址和IP地址等配置信息&#xff0c;这些信息一般用来榆验手动配置的TCP/IP设置是否正确。当在网络中使用 DHCP服务时&#xff0c;IPConfig可以检测计算机中分…

利用定时器1产生全双工软件串口

代码; /*《AVR专题精选》随书例程3.通信接口使用技巧项目&#xff1a;使用AVR定时器1和外中断实现全双工软件串口文件&#xff1a;softuart.c说明&#xff1a;软件串口驱动文件作者&#xff1a;邵子扬时间&#xff1a;2012年12月16日*/ #include "softuart.h"// 内部…

达梦数据守护集群部署

接上篇 达梦8单机规范化部署 https://blog.csdn.net/qq_25045631/article/details/139898690 1. 集群规划 在正式生产环境中&#xff0c;两台机器建议使用统一配置的服务器。使用千兆或千兆以上网络。 两台虚拟机各加一块网卡&#xff0c;仅主机模式&#xff0c;作为心跳网卡…

【日记】原来真的有人不适合谈恋爱(1194 字)

正文 21 日正是周五&#xff0c;夏至。全年当中&#xff0c;白天时长最长的一天。而恰好那天也是银行扣息的日子。所以很忙&#xff0c;我差点没能走掉。 所幸最终还是有惊无险。 到斯的家里&#xff0c;是晚上 9 点钟。比我想得要早。这个周周四&#xff0c;他过生日。但是那天…

提升效率新选择——运营必备API合集

在数字化、自动化的时代背景下&#xff0c;运营小工具API已经成为了众多企业不可或缺的重要组成部分。这些API工具为运营工作提供了极大的便利&#xff0c;不仅简化了工作流程&#xff0c;还提高了工作效率。 运营小工具API是针对运营工作特定需求而设计的接口&#xff0c;通过…

7.系统工具——黑马程序员Java最新AI+若依框架项目

目录 前言一、表单构建任务&#xff1a;设计添加课程表单 二、 代码生成1.任务&#xff1a;将部门表在页面端显示改为树形结构 三、系统接口任务&#xff1a;使用sagger进行接口测试 前言 提示&#xff1a;本篇讲解若依框架 系统工具 一、表单构建 功能&#xff1a;完成前端…

【持续集成_01课_Git版本管理及基本应用】

一、什么是持续集成 测试代码、开发代码、测试报告、发送邮件...组合到一起来 --- 流水线。 &#xff08;有一个规范的流程&#xff09; 持续集成&#xff08;Continuous Integration&#xff0c;简称CI&#xff09; 是一种软件开发实践&#xff0c;旨在将代码集成到共享仓库…

网页抓取和网页爬取之间有何区别?

随着互联网的发展和信息的爆炸式增长&#xff0c;数据收集和处理已成为企业和个人不可或缺的需求。在此背景下&#xff0c;网页抓取和网络爬虫已成为两种常见的数据收集方法。虽然这两种方法看似相似&#xff0c;但它们的方法和目标存在显著差异。本文将为您详细介绍网页抓取和…

【科技前沿】电子设计新贵SmartEDA:为何它引领行业风潮?

在当今这个电子科技日新月异的时代&#xff0c;电子设计工具如同设计师的魔法棒&#xff0c;不断推动着产品创新的速度。而近期&#xff0c;一款名为SmartEDA的电子国产设计仿真软件异军突起&#xff0c;成为了行业内的新宠。那么&#xff0c;SmartEDA究竟有何过人之处&#xf…

人工和AI大语言模型成本对比 ai语音模型

这里既有AI&#xff0c;又有生活大道理&#xff0c;无数渺小的思考填满了一生。 上一专题搭建了一套GMM-HMM系统&#xff0c;来识别连续0123456789的英文语音。 但若不是仅针对数字&#xff0c;而是所有普通词汇&#xff0c;可能达到十几万个词&#xff0c;解码过程将非常复杂…

# bash: chkconfig: command not found 解决方法

bash: chkconfig: command not found 解决方法 一、chkconfig 错误描述&#xff1a; 这个错误表明在 Bash 环境下&#xff0c;尝试执行 chkconfig 命令&#xff0c;但是系统找不到这个命令。chkconfig 命令是一个用于管理 Linux 系统中服务的启动和停止的工具&#xff0c;通常…