模型优化【2】-剪枝[局部剪枝]

模型剪枝是一种常见的模型压缩技术,它可以通过去除模型中不必要的参数和结构来减小模型的大小和计算量,从而提高模型的效率和速度。在 PyTorch 中,我们可以使用一些库和工具来实现模型剪枝。

pytorch实现剪枝的思路是生成一个掩码,然后同时保存原参数、mask、新参数,如下图:
在这里插入图片描述

Pytorch实现模型剪枝的基本步骤

  1. 加载模型:我们首先需要加载一个已经训练好的模型,可以使用 PyTorch 提供的模型库或者自己训练的模型。

  2. 定义剪枝方法:我们需要定义一种剪枝方法,来决定哪些参数和结构需要被剪枝。

  3. 执行剪枝操作:我们需要执行剪枝操作,将不必要的参数和结构从模型中去除。

  4. 保存剪枝后的模型:我们需要将剪枝后的模型保存下来,以便后续使用。

pytorch 剪枝分为 局部剪枝、全局剪枝、自定义剪枝;

局部剪枝

局部剪枝是指在什么网络的单个层或局部范围内进行剪枝。
Pytorch中与剪枝有关的接口封装在torch.nn.utils.prune中。下面开始演示三种剪枝在LeNet网络中的应用效果,首先给出LeNet网络结构。

加载模型

import torch
from torch import nn

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1: 图像的输入通道(1是黑白图像), 6: 输出通道, 3x3: 卷积核的尺寸
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 是经历卷积操作后的图片尺寸
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
model = LeNet().to(device=device)

局部剪枝实验,假定对模型的第一个卷积层中的权重进行剪枝

# 打印输出剪枝前的参数
module = model.conv1
print(list(module.named_parameters()))
print(list(module.buffers()))
print(module.weight)

运行结果

[('weight', Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],
          [-0.1132,  0.1059, -0.0381],
          [ 0.0430, -0.1634, -0.1345]]],
		
		...

        [[[-0.0226,  0.2091, -0.1479],
          [ 0.2302, -0.0988,  0.2117],
          [-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.2658,  0.2096, -0.2639, -0.3063, -0.1453,  0.1201], device='cuda:0',
       requires_grad=True))]
[]
Parameter containing:
tensor([[[[ 0.1158, -0.0091, -0.2742],
          [-0.1132,  0.1059, -0.0381],
          [ 0.0430, -0.1634, -0.1345]]],


        ...


        [[[-0.0226,  0.2091, -0.1479],
          [ 0.2302, -0.0988,  0.2117],
          [-0.2000, -0.2531,  0.2770]]]], device='cuda:0', requires_grad=True)

定义剪枝+执行剪枝

# 修剪是从 模块 中 删除 参数(如 weight),并用 weight_orig 保存该参数
# random_unstructured 是一种裁剪技术,随机非结构化裁剪
# 第一个参数:modeul,代表要进行剪枝的特定模型,之前我们已经制定了module=module.conv1,说明这里要对第一个卷积层执行剪枝
# 第二个参数:name,指定要对选中模块中的那些参数执行剪枝,这里设定为name='weight',意味着对连接网络的weight剪枝,而不死bias剪枝
# 第三个参数:amount,指定要对模型中的多大比例的参数执行剪枝,amount是一个介于0.0~1.0的float数值,或者一个正整数指定裁剪多少条连接边。
prune.random_unstructured(module, name="weight", amount=0.3)      # weight    bias
print(list(module.named_parameters()))
# 通过修剪技术会创建一个mask命名为 weight_mask 的模块缓冲区
print(list(module.named_buffers()))

# 经过裁剪操作后的模型,原始的参数存放在了weight-orig中,
# 对应的剪枝矩阵存放在weight-mask中,而将weight-mask视作掩码张量
# 再和weight-orig相乘的结果就存放在了weight中
print(module.weight)
print(module.bias)

运行结果

[('bias', Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-1.1443e-01,  3.2276e-01, -2.4664e-02],
          [ 4.6659e-02,  1.8311e-01,  6.6681e-02],
          [-2.5493e-01, -1.1471e-01,  2.8336e-01]]],


       ...

        [[[ 1.4041e-01,  2.0963e-02,  2.2884e-01],
          [ 3.5870e-02,  7.5861e-02,  8.4728e-02],
          [ 4.1965e-02, -1.2838e-01,  8.8462e-02]]]], device='cuda:0',
       requires_grad=True))]
[('weight_mask', tensor([[[[1., 0., 0.],
          [1., 0., 0.],
          [0., 0., 0.]]],


        ...


        [[[1., 0., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]]], device='cuda:0'))]
tensor([[[[-1.1443e-01,  0.0000e+00, -0.0000e+00],
          [ 4.6659e-02,  0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00]]],


        ...

        [[[ 1.4041e-01,  0.0000e+00,  0.0000e+00],
          [ 3.5870e-02,  0.0000e+00,  8.4728e-02],
          [ 0.0000e+00, -1.2838e-01,  0.0000e+00]]]], device='cuda:0',
       grad_fn=<MulBackward0>)
Parameter containing:
tensor([ 0.1303,  0.1208, -0.0989, -0.0611, -0.1103, -0.2433], device='cuda:0',
       requires_grad=True)

保存剪枝后的模型

# 保存剪枝后的模型
torch.save(model.state_dict(), 'pruned_model.pth')

模型经历剪枝以后,原始的权重矩阵weight参数不见了,变成了weight_orig。并且剪枝前打印为空的列表module.name_buffers(),此时拥有了一个weight_mask参数。经过剪枝操作后的模型,原始的参数存放在了weight_orig中,对应的剪枝矩阵存在weight_mask中,而将weight_mask视作掩码张量,再和weight_orig相乘的结果就存在了weight中。

Q1:打印经过剪枝处理的 weight 参数。这个 weight 实际上是原始的 weight_orig 和 weight_mask 的元素乘积,其中被剪枝的权重会被设置为0。这个weight不是剪枝了嘛?为什么还能打印出来?

答:在Pytorch的剪枝过程中,当我们说剪枝一个权重的参数时,并不是真的从网络中移除这些参数,而是通过一个掩码来“禁用”它们。这是通过将某些权重的值设为0来实现的,从而在网络的前向传播中这些权重不会有任何作用,这种方法允许我们在保留原始权重信息的同时,实现剪枝的效果。

在 PyTorch 的剪枝过程中,当我们说“剪枝”一个权重参数时,并不是真的从网络中移除这些权重,而是通过应用一个掩码来“禁用”它们,为什么禁用就可以达到模型压缩的目的?为什么剪枝完,执行print(list(module.named_parameters())),没有显示weight属性,但是执行print(module.weight)时,weight依然存在?

  • 为什么“禁用”权重可以达到模型压缩的目的?
    虽然剪枝后的权重仍然占据内存空间,但在实际计算中,值为0的权重不会对前向传播产生任何影响。这意味着在计算层面可以忽略这些权重,从而减少计算量。
  • 为什么 print(list(module.named_parameters())) 没有显示 weight 属性,但执行 print(module.weight) 时 weight 依然存在?
    • 修改参数列表:当执行剪枝操作时,PyTorch 会修改模块的参数列表。原始的 weight 参数被重命名为 weight_orig,并且创建了一个新的名为 weight_mask 的缓冲区。原始的 weight 参数(现在是 weight_orig)和 weight_mask 通过一个钩子(hook)相结合,生成了新的 weight 属性。

    • 动态权重生成:在调用 module.weight 时,由于剪枝过程中添加的前向钩子,weight 参数是动态生成的,它是 weight_orig 和 weight_mask 的元素乘积。因此,尽管 weight 在 named_parameters 列表中看起来已经不存在,但它实际上是在运行时动态生成的。

    • 参数与属性的区别:在 PyTorch 中,模块的参数(可通过 named_parameters 访问)和模块的属性(如直接通过 module.weight 访问)是不同的。module.weight 被视为一个可访问的属性,但由于剪枝过程的内部处理,它可能不再直接列在模块的参数列表中。

既然原始的 weight 参数被重命名为 weight_orig,那参数是不是并没有发生变化,又怎么能达到剪枝的效果呢?
原始的 weight 参数在剪枝过程中被重命名为 weight_orig,并且保持不变。剪枝的效果是通过以下几个关键步骤实现的:

  1. 掩码(Mask)创建:
  • 在剪枝过程中,PyTorch 创建了一个掩码(weight_mask),它是一个与 weight 形状相同的二进制张量(由0和1组成)。
  • 在这个掩码中,1表示相应的权重保持不变,而0表示相应的权重被“剪枝”(实际上是被禁用)。
  1. 动态权重更新:
  • 尽管 weight_orig 保持不变,但是模块的 weight 属性被动态更新为 weight_orig 和 weight_mask 的元素乘积。
  • 这意味着,在模型的前向传播过程中,实际使用的 weight 是被掩码修改过的。在这个新的 weight 中,被剪枝的权重(在 weight_mask 中对应0的位置)的值为0,而其他位置的权重保持原始值。
  1. 前向传播的影响:
  • 当模型进行前向传播时,使用的是被掩码修改过的 weight。因此,尽管原始的 weight 参数(现在是 weight_orig)没有变化,模型实际上使用的权重已经被剪枝修改了。
  • 剪枝过程实际上通过使某些权重值为0,从而在模型的计算过程中禁用了这些权重。
  1. 模型复杂度的降低:
  • 通过这种方式,模型的复杂度在实际运行时降低了,因为一部分权重不再对输出产生影响。
  • 这可以提高计算效率,并且在某些情况下,可以通过专门的硬件和软件优化来利用权重的这种稀疏性。
    综上所述,虽然原始的 weight 参数作为 weight_orig 保留下来,但是实际上模型使用的是被掩码修改过的权重,这就是剪枝效果的实现方式。这种方法的一个优势是可以在不永久性地移除权重的情况下,测试和评估剪枝的影响,甚至可以在必要时撤销剪枝操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/188519.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

如何学习VBA:3.2.8 OnTime方法与OnKey方法

我给VBA的定义&#xff1a;VBA是个人小型自动化处理的有效工具。利用好了&#xff0c;可以大大提高自己的劳动效率&#xff0c;而且可以提高数据处理的准确度。我推出的VBA系列教程共九套和一部VBA汉英手册&#xff0c;现在已经全部完成&#xff0c;希望大家利用、学习。 如果…

基于51单片机的人体追踪可控的电风扇系统

**单片机设计介绍&#xff0c; 基于51单片机超声波测距汽车避障系统 文章目录 一 概要概述硬件组成工作原理优势应用场景总结 二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 # 基于51单片机的人体追踪可控的电风扇系统介绍 概述 该系统是基于51…

Flink实战(11)-Exactly-Once语义之两阶段提交

0 大纲 [Apache Flink]2017年12月发布的1.4.0版本开始&#xff0c;为流计算引入里程碑特性&#xff1a;TwoPhaseCommitSinkFunction。它提取了两阶段提交协议的通用逻辑&#xff0c;使得通过Flink来构建端到端的Exactly-Once程序成为可能。同时支持&#xff1a; 数据源&#…

ElasticSearch查询语法及深度分页问题

一、ES高级查询Query DSL ES中提供了一种强大的检索数据方式,这种检索方式称之为Query DSL&#xff08;Domain Specified Language 领域专用语言&#xff09; , Query DSL是利用Rest API传递JSON格式的请求体(RequestBody)数据与ES进行交互&#xff0c;这种方式的丰富查询语法…

html实现我的故乡,城市介绍网站(附源码)

文章目录 1. 我生活的城市北京&#xff08;网站&#xff09;1.1 首页1.2 关于北京1.3 北京文化1.4 加入北京1.5 北京景点1.6 北京美食1.7 联系我们 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_43…

【Linux】匿名管道与命名管道,进程池的简易实现

文章目录 前言一、匿名管道1.管道原理2.管道的四种情况3.管道的特点 二、命名管道1. 特点2.创建命名管道1.在命令行上2.在程序中 3.一个程序执行打开管道并不会真正打卡 三、进程池简易实现1.makefile2.Task.hpp3.ProcessPool.cpp 前言 一、匿名管道 #include <unistd.h&g…

链表?细!详细知识点总结!

链表 定义&#xff1a;链表是一种递归的数据结构&#xff0c;它或者为空&#xff08;null)&#xff0c;或者是指向一个结点&#xff08;node&#xff09;的引用&#xff0c;该结点含有一个泛型的元素和一个指向另一条链表的引用。 ​ 其实链表就是有序的列表&#xff0c;它在内…

批量按顺序1、2、3...重命名所有文件夹里的文件

最新&#xff1a; 最快方法&#xff1a;先用这个教程http://文件重命名1,2......nhttps://jingyan.baidu.com/article/495ba841281b7079b20ede2c.html再用这个教程去空格&#xff1a;利用批处理去掉文件名中的空格-百度经验 (baidu.com) 以下为原回答 注意文件名有空格会失败…

Javaweb之Vue组件库Element的详细解析

4 Vue组件库Element 4.1 Element介绍 不知道同学们还否记得我们之前讲解的前端开发模式MVVM&#xff0c;我们之前学习的vue是侧重于VM开发的&#xff0c;主要用于数据绑定到视图的&#xff0c;那么接下来我们学习的ElementUI就是一款侧重于V开发的前端框架&#xff0c;主要用…

深信服实验学习笔记——nmap常用命令

文章目录 1. 主机存活探测2. 常见端口扫描、服务版本探测、服务器版本识别3. 全端口&#xff08;TCP/UDP&#xff09;扫描4. 最详细的端口扫描5. 三种TCP扫描方式 1. 主机存活探测 nmap -sP <靶机IP>-sP代表 2. 常见端口扫描、服务版本探测、服务器版本识别 推荐加上-v参…

C++初阶(十二)string的模拟实现

&#x1f4d8;北尘_&#xff1a;个人主页 &#x1f30e;个人专栏:《Linux操作系统》《经典算法试题 》《C》 《数据结构与算法》 ☀️走在路上&#xff0c;不忘来时的初心 文章目录 一、string类的模拟实现1、构造、拷贝构造、赋值运算符重载以及析构函数2、迭代器类3、增删查…

Nacos安装使用

Nacos安装使用 官方下载地址: https://github.com/alibaba/nacos/releases 官方文档地址: https://nacos.io/zh-cn/docs/quick-start.html Nacos介绍 Nacos是阿里巴巴开源的一款支持服务注册与发现&#xff0c;配置管理以及微服务管理的组件。用来取代以前常用的注册中心&a…

史上最全前端知识点+高频面试题合集,十二大专题,命中率高达95%

前言&#xff1a; 下面分享一些关于阿里&#xff0c;美团&#xff0c;深信服等公司的面经&#xff0c;供大家参考一下。大家也可以去收集一些其他的面试题&#xff0c;可以通过面试题来看看自己有哪里不足。也可以了解自己想去的公司会问什么问题&#xff0c;进行有针对的复习。…

基于springboot网上超市管理系统

基于springboot网上超市管理系统 摘要 随着互联网的快速发展&#xff0c;电子商务行业迎来了蓬勃的发展&#xff0c;网上超市作为电子商务的一种形式&#xff0c;为消费者提供了便利的购物体验。本文基于Spring Boot框架&#xff0c;设计和实现了一个网上超市管理系统&#xff…

项目中如何配置数据可视化展现

在现今数据驱动的时代&#xff0c;可视化已逐渐成为数据分析的主要途径&#xff0c;可视化大屏的广泛使用便应运而生。很多公司及政务机构&#xff0c;常利用大屏的手段展现其实力或演示业务&#xff0c;可视化的效果能让观者更快速的理解结果并直观的看到数据展现。因此&#…

【Web】NewStarCtf Week2 个人复现

目录 ①游戏高手 ②include 0。0 ③ez_sql ④Unserialize&#xff1f; ⑤Upload again! ⑥ R!!C!!E!! ①游戏高手 经典前端js小游戏 检索与分数相关的变量 控制台直接修改分数拿到flag ②include 0。0 禁了base64和rot13 尝试过包含/var/log/apache/access.log,ph…

图论|知识图谱——详解自下而上构建知识图谱全过程

导读&#xff1a;知识图谱的构建技术主要有自顶向下和自底向上两种。其中自顶向下构建是指借助百科类网站等结构化数据源&#xff0c;从高质量数据中提取本体和模式信息&#xff0c;加入到知识库里。而自底向上构建&#xff0c;则是借助一定的技术手段&#xff0c;从公开采集的…

3.2 CPU的自动化

CPU的自动化 改造1-使用2进制导线改造2根据整体流程开始改造指令分析指令MOV_A的开关2进制表格手动时钟gif自动时钟gif 根据之前的CPU内部结构改造,制造一个cpu控制单元 改造一 之前的CPU全由手动开关自己控制,极度繁琐,而开关能跟二进制一一对应, 开:1, 关:0图1是之前的, …

Vue3的计算属性(computed)和监听器(watch)案例语法

一&#xff1a;前言 Vue3 是 Vue2 的一个升级版&#xff0c;随着 2023年12月31日起 Vue2 停止维护。这意味着 Vue3 将会为未来国内一段时间里&#xff0c;前端的开发主流。因此熟练的掌握好 Vue3 是前端开发程序员所不可避免的一门技术栈。而 Vue3 是 Vue2 的一个升级版&#x…

易错知识点(数学一)

一、反常积分判敛 1、构造使其极限等于一个大于0的常数 1&#xff09;前者通过&#xff1a;化等价无穷小 or 泰勒展开 2&#xff09;若存在p>1使得等式成立&#xff0c;则收敛 考察形式&#xff1a;1、已知收敛&#xff0c;求f(x)中的幂次取值范围 主要思想&#xff1a;比较…