Torch-Pruning 库入门级使用介绍

在这里插入图片描述

项目地址:https://github.com/VainF/Torch-Pruning

Torch-Pruning 是一个专用于torch的模型剪枝库,其基于DepGraph 技术分析出模型layer中的依赖关系。DepGraph 与现有的修剪方法(如 Magnitude Pruning 或 Taylor Pruning)相结合可以达到良好的剪枝效果。

本博文结合项目官网案例,对信息进行结构话,抽离出剪枝技术说明、剪枝模型保存与加载、剪枝技术的基本使用,剪枝技术的具体使用案例。并结合外部信息,分析剪枝对模型性能精度的影响。

1、基本说明

1.1 项目安装

打开https://github.com/VainF/Torch-Pruning,下载项目
在这里插入图片描述
然后在终端中,进入项目目录,并执行pip install -r requirements.txt 安装项目依赖库
在这里插入图片描述
然后在执行 pip install -e . ,将项目安装在当前目录下,并设置为editing模式。
在这里插入图片描述
验证安装:执行命令python -c "import torch_pruning", 如果没有输出报错信息则表示安装成功。
在这里插入图片描述

1.2 DepGraph 技术说明

在结构修剪中,组被定义为深度网络中最小的可移除单元。每个组由多个相互依赖的层组成,需要同时修剪这些层以保持最终结构的完整性。然而,深度网络通常表现出层与层之间错综复杂的依赖关系,这对结构修剪提出了重大挑战。这项研究通过引入一种名为 DepGraph 的自动化机制来解决这一挑战,该机制可以轻松实现参数分组,并有助于修剪各种深度网络。
在这里插入图片描述

直接剪枝会会破坏layer间的依赖关系,会导致forward流程报错。具体如下面代码,移除model.conv1模块中的idxs为0与1的channel,导致后续的bn1层输入输入与参数格式对不上号,然后报错。

from torchvision.models import resnet18
import torch_pruning as tp
import torch

model = resnet18().eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) # remove channel 0 and channel 1
output = model(torch.randn(1,3,224,224)) # test

在这里插入图片描述
基本在后续层添加剪枝,运行代码也会保存,因为batchnorm的下一层要求的输出channel是64。

model = resnet18(pretrained=True).eval()
tp.prune_conv_out_channels(model.conv1, idxs=[0,1]) 
tp.prune_batchnorm_out_channels(model.bn1, idxs=[0,1])
tp.prune_batchnorm_in_channels(model.layer1[0].conv1, idxs=[0,1])
output = model(torch.randn(1,3,224,224)) 

使用DepGraph剪枝代码如下,先使用tp.DependencyGraph().build_dependenc构建出依赖图,然后基于DG.get_pruning_group函数获取目标剪枝层的依赖关系组,最后在检验关系并进行剪枝。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
print(group)
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()
    
# 4. Save & Load
model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

代码执行后的输出如下所示,可以看到捕捉到group对应的依赖layer

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

1.3 剪枝模型的保存与加载

剪枝后的模型由于网络结构改变了,如果只保存模型参数,是无法支持原始网络结构,需要将模型结构连参数一并保存。加载时连同参数一起加载。

model.zero_grad() # We don't want to store gradient information
torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the pruned model

或者基于tp库中tp.state_dict函数提取目标参数进行保存,并基于tp.load_state_dict函数将剪枝后的参数赋值到原始模型中形成剪枝模型。

# save the pruned state_dict, which includes both pruned parameters and modified attributes
state_dict = tp.state_dict(pruned_model) # the pruned model, e.g., a resnet-18-half
torch.save(state_dict, 'pruned.pth')

# create a new model, e.g. resnet18
new_model = resnet18().eval()

# load the pruned state_dict into the unpruned model.
loaded_state_dict = torch.load('pruned.pth', map_location='cpu')
tp.load_state_dict(new_model, state_dict=loaded_state_dict)
print(new_model) # This will be a pruned model.

2、剪枝基本案例

2.1 具有目标结构的剪枝

以下代码使用TaylorImportance指标进行剪枝,设置忽略输出层的剪枝。并设置MagnitudePruner中对通道剪枝50%,一共分iterative_steps步完成剪枝,每一次剪枝都进行微调。
整体来说,具备目标结构的剪枝,效果是最差的。 基于https://blog.csdn.net/a486259/article/details/140407147 分析的数据得出的结论。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    #pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

代码的输出信息如下所示,可以看到macs与nparams在逐步降低。最终输出的模型结构,所有的chanel都减半了,只有输出层例外。

iter 0 | rate:0.8092  0.8111
iter 1 | rate:0.6469  0.6445
iter 2 | rate:0.4971  0.4979
iter 3 | rate:0.3718  0.3695
iter 4 | rate:0.2674  0.2614
ResNet(
  (conv1): Conv2d(3, 32, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
PS D:\开源项目\Torch-Pruning-master>
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=1000, bias=True)
)

2.2 自动结构剪枝

这里的自动结构是有一个预设目标,即将总体channel剪枝到原模型的多少,但没有预定的目标结构。可能有的laye通道剪枝数多,有的剪枝数少。 与2.1中的代码相比,主要是增加了参数 global_pruning=True。但这个剪枝方式比具有目标结构的剪枝更加有效。就像裁员一样,要求各个部门内裁员比例相同与在公司内控制裁员比例(各个部门裁员比例按重要度排列,裁员比例不一样),必然是第二种方式更有效。第一种方式,使低效率部门的靠前但无用员工保留下来了。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

#model = resnet18(pretrained=True)
model = resnet18()

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.TaylorImportance()

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 3 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50%的channel
    ignored_layers=ignored_layers,
    global_pruning=True
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    if isinstance(imp, tp.importance.TaylorImportance):
        # Taylor expansion requires gradients for importance estimation
        loss = model(example_inputs).sum() # a dummy loss for TaylorImportance
        loss.backward() # before pruner.step()
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"iter {i} | rate:{macs/base_macs:.4f}  {nparams/base_nparams:.4f}")
print(model)
    # finetune your model here
    # finetune(model)
    # ...

2.3 MagnitudePruner中的参数

指定特定层的剪枝比例 通过pruning_ratio_dict参数,指定model.layer2的剪枝比例为20%,这里适用于有先验经验的layer,控制对特定layer的剪枝比例。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    imp,
    pruning_ratio = 0.5,
    pruning_ratio_dict = {model.layer2: 0.2}
)
pruner.step()
print(model)

代码执行后的层为:ResNet{64, 128, 256, 512} => ResNet{32, 102, 128, 256}

设置最大剪枝比例 通过 max_pruning_ratio 参数设置最大剪枝比例,避免由于稀疏剪枝或者自动剪枝时某个层被严重剪枝或者移除。

剪枝次数与剪枝调度器 您打算分多轮修剪模型,iterative_steps 会很有用。默认情况下,修剪器会逐渐增加模型的稀疏度,直到达到所需的 pruning_ratio。如以下代码,分5次实现剪枝目标。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    pruning_ratio=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
)

# prune the model, iteratively if necessary.
base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print("Round %d/%d, Params: %.2f M" % (i+1, iterative_steps, nparams/1e6))
    # finetune your model here
    # finetune(model)
    # ...
print(model)

对应输出如下
Round 1/5, Params: 9.44 M
Round 2/5, Params: 7.45 M
Round 3/5, Params: 5.71 M
Round 4/5, Params: 4.20 M
Round 5/5, Params: 2.93 M

设置忽略的层 这主要是避免对输出层进行剪枝,修改模型的输出结构。使用代码如下,通过ignored_layers参数传入忽略的layer对象。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, # remove 50% channels
    ignored_layers=[model.conv1, model.fc] # ignore the first & last layers
)
pruner.step()
print(model)

channel取整 在很多的时候都认为channel为16的倍数,gpu运行效率最高。使用代码如下,通过round_to参数,保持channel是特定数的倍数。

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18()
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.3, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    round_to=10 # round to 10x. Note: 10x is not a good practice.
)

pruner.step()
print(model)

channel_groups 某些层(例如 nn.GroupNorm 和 nn.Conv2d)具有 group 参数,这会在层内引入额外的依赖项。修剪后,保持所有组的大小相同至关重要。为了满足这一要求,引入了参数 channel_groups 以启用对这些通道的手动分组。如以下代码,通过channel_groups参数,控制model.group_conv1中的参数为8个一组

pruner = tp.pruner.MagnitudePruner(
            model,
            example_inputs=example_inputs,
            importance=importance,
            iterative_steps=1,
            pruning_ratio=0.5,
            channel_groups = {model.group_conv1: 8} # For Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), groups=8)
        )

额外参数剪枝 有些时候模型具备的可训练参数并非conv、fc等传统layer中,需要基于unwrapped_parameters参数将额外的可剪枝参数传入到剪枝器中。具体如下所示:

from torchvision.models.convnext import CNBlock, ConvNeXt
unwrapped_parameters = []
for m in model.modules():
    if isinstance(m, CNBlock):
        unwrapped_parameters.append( (m.layer_scale, 0) )

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    unwrapped_parameters=unwrapped_parameters 

限定剪枝范围 root_module_types 参数用于指定组的“根”或第一层。在许多情况下,我们专注于修剪线性层和卷积 (Conv) 层。要专门针对这些层启用修剪,我们可以使用以下参数:root_module_types=[nn.Conv2D, nn.Linear]。这可确保将修剪应用于所需的层。

pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    pruning_ratio=0.5, 
    root_module_types=[nn.Conv2D, nn.Linear]

3、具体应用案例

3.1 timm模型剪枝

官方代码为:examples\timm_models\prune_timm_models.py
具体详情如下,这里有一个特殊用法,是通过num_heads参数实现对于transformer layer的支持

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn 
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse

parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()

def main():
    timm_models = timm.list_models()
    if args.list_models:
        print(timm_models)
    if args.model is None: 
        return
    assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)

    imp = tp.importance.GroupNormImportance()
    print("Pruning %s..."%args.model)
        
    input_size = model.default_cfg['input_size']
    example_inputs = torch.randn(1, *input_size).to(device)
    test_output = model(example_inputs)
    ignored_layers = []
    num_heads = {}

    for m in model.modules():
        if hasattr(m, 'head'): #isinstance(m, nn.Linear) and m.out_features == model.num_classes:
            ignored_layers.append(model.head)
            print("Ignore classifier layer: ", m.head)
       
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                num_heads[m.qkv] = m.num_heads
                print("Attention layer: ", m.qkv, m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                num_heads[m.qkv_proj] = m.num_heads

    print("========Before pruning========")
    print(model)
    base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
    pruner = tp.pruner.MetaPruner(
                    model, 
                    example_inputs, 
                    global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
                    importance=imp, # importance criterion for parameter selection
                    iterative_steps=1, # the number of iterations to achieve target pruning ratio
                    pruning_ratio=args.pruning_ratio, # target pruning ratio
                    num_heads=num_heads,
                    ignored_layers=ignored_layers,
                )
    for g in pruner.step(interactive=True):
        g.prune()

    for m in model.modules():
        # Attention layers
        if hasattr(m, 'num_heads'):
            if hasattr(m, 'qkv'):
                m.num_heads = num_heads[m.qkv]
                m.head_dim = m.qkv.out_features // (3 * m.num_heads)
            elif hasattr(m, 'qkv_proj'):
                m.num_heads = num_heads[m.qqkv_projkv]
                m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)

    print("========After pruning========")
    print(model)
    test_output = model(example_inputs)
    pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
    print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
    print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))

if __name__=='__main__':
    main()

3.2 llm模型剪枝

在examples\LLMs\prune_llama.py中提供了一个对于llama模型的剪枝案例.
核心代码如下,可以看到也是基于num_heads记录transformer的结构信息,然后在剪枝后将num_heads数据赋值到对应模型参数上。与原始代码相比,这里删除了模型精度验证相关的代码。


# Code adapted from 
# https://github.com/IST-DASLab/sparsegpt/blob/master/datautils.py
# https://github.com/locuslab/wanda

import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))

import argparse
import os 
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from importlib.metadata import version
import time
import torch
import torch.nn as nn
from collections import defaultdict
import fnmatch
import numpy as np
import random

print('torch', version('torch'))
print('transformers', version('transformers'))
print('accelerate', version('accelerate'))
print('# of gpus: ', torch.cuda.device_count())

def get_llm(model_name, cache_dir="./cache"):
    model = AutoModelForCausalLM.from_pretrained(
        model_name, 
        torch_dtype=torch.float16, 
        cache_dir=cache_dir, 
        device_map="auto"
    )

    model.seqlen = model.config.max_position_embeddings 
    return model

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, help='LLaMA model')
    parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
    parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
    parser.add_argument('--pruning_ratio', type=float, default=0, help='Sparsity level')
    parser.add_argument("--cache_dir", default="./cache", type=str )
    parser.add_argument('--save', type=str, default=None, help='Path to save results.')
    parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
    parser.add_argument("--eval_zero_shot", action="store_true")
    args = parser.parse_args()

    # Setting seeds for reproducibility
    np.random.seed(args.seed)
    torch.random.manual_seed(args.seed)

    model_name = args.model.split("/")[-1]
    print(f"loading llm model {args.model}")
    model = get_llm(args.model, args.cache_dir)       
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
    device = torch.device("cuda:0")
    if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
        device = model.hf_device_map["lm_head"]
    print("use device ", device)

    ##############
    # Pruning
    ##############
    print("----------------- Before Pruning -----------------")
    print(model)
    text = "Hello world."
    inputs = torch.tensor(tokenizer.encode(text)).unsqueeze(0).to(model.device)
    import torch_pruning as tp 
    num_heads = {}
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            num_heads[m.q_proj] = model.config.num_attention_heads
            num_heads[m.k_proj] = model.config.num_key_value_heads
            num_heads[m.v_proj] = model.config.num_key_value_heads
            
    head_pruning_ratio = args.pruning_ratio
    hidden_size_pruning_ratio = args.pruning_ratio
    pruner = tp.pruner.MagnitudePruner(
        model, 
        example_inputs=inputs,
        importance=tp.importance.GroupNormImportance(),
        global_pruning=False,
        pruning_ratio=hidden_size_pruning_ratio,
        ignored_layers=[model.lm_head],
        num_heads=num_heads,
        prune_num_heads=True,
        prune_head_dims=False,
        head_pruning_ratio=head_pruning_ratio,
    )
    pruner.step()

    # Update model attributes
    num_heads = int( (1-head_pruning_ratio) * model.config.num_attention_heads )
    num_key_value_heads = int( (1-head_pruning_ratio) * model.config.num_key_value_heads )
    model.config.num_attention_heads = num_heads
    model.config.num_key_value_heads = num_key_value_heads
    for name, m in model.named_modules():
        if name.endswith("self_attn"):
            m.hidden_size = m.q_proj.out_features
            m.num_heads = num_heads
            m.num_key_value_heads = num_key_value_heads
        elif name.endswith("mlp"):
            model.config.intermediate_size = m.gate_proj.out_features
    print("----------------- After Pruning -----------------")
    print(model)

    #ppl_test = eval_ppl(args, model, tokenizer, device)
    #print(f"wikitext perplexity {ppl_test}")

    if args.save_model:
        model.save_pretrained(args.save_model)
        tokenizer.save_pretrained(args.save_model)

if __name__ == '__main__':
    main()

3.3 目标检测模型剪枝

在Torch-Pruning 库中提供了针对yolov8、yolov7、yolov5的剪枝案例。关于yolov8还提供了剪枝后的训练策略,其主要技巧在与对不可剪枝层的可剪枝话处理(C2f模块的剪枝,其含split操作,不利于剪枝索引)。后续会补充博客,说明对yolov8的剪枝使用。

4、其他信息

4.1 剪枝器中的评价指标

在torch_pruning\pruner\importance.py中有很多个剪枝评价指标

__all__ = [
    # Base Class
    "Importance",

    # Basic Group Importance
    "GroupNormImportance",
    "GroupTaylorImportance",
    "GroupHessianImportance",

    # Aliases
    "MagnitudeImportance",
    "TaylorImportance",
    "HessianImportance",

    # Other Importance
    "BNScaleImportance",
    "LAMPImportance",
    "RandomImportance",
]

整体来看是TaylorImportance最好,一直使用该值即可。
来看

4.2 剪枝对性能精度的影响

在博客https://blog.csdn.net/a486259/article/details/140407147?spm=1001.2014.3001.5501 中基本确定了剪枝50%,对模型精度是没有任何影响的。这里对Torch-Pruning 库相关的论文数据进行二次核验,以致于分析出剪枝中速度提升对精度的影响。

以DepGraph: Towards Any Structural Pruning数据为例,可以发现最高支持6x速度剪枝后保持模型性能。
在这里插入图片描述
以LLM-Pruner: On the Structural Pruning of Large Language Models 论文数据为例,可以发现使用Vector评价方法的剪枝,移除10%的参数,zero-shot下对模型精度影响不大。而图4更表明,剪枝方法正确的话,移除50%的参数对模型性能影响也不大。
在这里插入图片描述
以论文 Structural Pruning for Diffusion Models 的数据为分析,同样可以发现剪枝50%左右的通道,对结果影响不对。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/798711.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

uniapp实现水印相机

uniapp实现水印相机-livePusher 水印相机 背景 前两天拿到了一个需求,要求在内部的oaApp中增加一个卫生检查模块,这个模块中的核心诉求就是要求拍照的照片添加水印。对于这个需求,我首先想到的是直接去插件市场,下一个水印相机…

《Python数据科学之五:模型评估与调优深入解析》

《Python数据科学之五:模型评估与调优深入解析》 在数据科学项目中,精确的模型评估和细致的调优过程是确保模型质量、提高预测准确性的关键步骤。本文将详细探讨如何利用 Python 及其强大的库进行模型评估和调优,确保您的模型能够达到最佳性能…

docker中1个nginx容器搭配多个django项目中设置uwsgi.ini的django项目路径

docker中,1个nginx容器搭配多个django项目容器,设置各个uwsgi.ini的django项目路径 被这个卡了一下,真是,哎 各个uwsgi配置应该怎样设置项目路径 django项目1中创建的django项目名为 web 那么uwsgi.ini中要设置为 chdir …

【Vue3 ts】echars图表展示统计的月份数据

图片展示 此处内容为展示24年各个月份产品的创建数量。在后端统计24年各个月份产品数量后,以数组的格式发送给前端,前端负责展示。 后端 entity层: Data Schema(description "月份统计")public class MonthCount {private Stri…

得物六宫格验证码分析

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用,不用于其他任何目的,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关! 前言(lianxi a…

算法的时间复杂度和空间复杂度-例题

一、消失的数字 . - 力扣&#xff08;LeetCode&#xff09; 本题要求的时间复杂度是O(n) &#xff0c;所以我们不能用循环嵌套&#xff1b; 解法一&#xff1a; int missingNumber(int* nums, int numsSize){int sum10;for(int i0;i<numsSize;i){sum1i;}int sum20;for(i…

C到C嘎嘎的衔接篇

本篇文章&#xff0c;是帮助大家从C向C嘎嘎的过渡&#xff0c;那么我们直接开始吧 不知道大家是否有这样一个问题&#xff0c;学完C的时候感觉还能听懂&#xff0c;但是听C嘎嘎感觉就有点难度或者说很难听懂&#xff0c;那么本篇文章就是帮助大家从C过渡到C嘎嘎。 C嘎嘎与C的区…

MPC轨迹跟踪控制器推导及Simulink验证

文章目录 MPC轨迹跟踪控制器推导及Simulink验证MPC的特点MPC轨迹跟踪控制器推导一 系统离散化二 预测区间状态和变量推导三 代价函数推导四 优化求解 <center> 基于MPC的倒立摆控制系统相关资料Reference&#xff1a; MPC轨迹跟踪控制器推导及Simulink验证 MPC的特点 多…

SAP 消息输出 - Adobe Form

目录 1 安装链接 2 前台配置 - Fiori app 2.1 维护表单模板 (maintain form templates) 2.2 管理微标 (manage logos) 2.3 管理文本 (manage texts) 3 后台配置 3.1 定义表单输出规则 3.2 分配表单模板 SAP 消息输出&#xff0c;不仅是企业内部用来记录关键业务操作也是…

Win11任务栏当中对 STM32CubeMX 的堆叠问题

当打开多个 CubeMX 程序的时候&#xff0c;Win11 自动将其进行了堆叠&#xff0c;这时候就无法进行预览与打开。 问题分析&#xff1a;大部分ST的工具都是基于 JDK 来进行开发的&#xff0c;Win11 将其识别成了同一个 Binary 但是实际上他们并不是同一个&#xff0c;通过配置…

基于conda包的环境创建、激活、管理与删除

Anaconda是一个免费、易于安装的包管理器、环境管理器和 Python 发行版&#xff0c;支持平台包括Windows、macOS 和 Linux。下载安装地址&#xff1a;Download Anaconda Distribution | Anaconda 很多不同的项目可能需要使用不同的环境。例如某个项目需要使用pytorch1.6&#x…

C语言详解(结构体)

Hi~&#xff01;这里是奋斗的小羊&#xff0c;很荣幸各位能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎~~ &#x1f4a5;个人主页&#xff1a;小羊在奋斗 &#x1f4a5;所属专栏&#xff1a;C语言 本系列文章为个人学习笔记&#xff0c;在这里撰写成文一…

《后端程序猿 · EasyPOI 导入导出》

&#x1f4e2; 大家好&#xff0c;我是 【战神刘玉栋】&#xff0c;有10多年的研发经验&#xff0c;致力于前后端技术栈的知识沉淀和传播。 &#x1f497; &#x1f33b; CSDN入驻不久&#xff0c;希望大家多多支持&#xff0c;后续会继续提升文章质量&#xff0c;绝不滥竽充数…

Android OkHttp3中HttpLoggingInterceptor使用

目录 一 概述1.1 日志级别 二 使用2.1 引入依赖2.2 创建对象2.3 添加拦截器 三 结果展示3.1 日志级别为BODY3.2 日志级别为BASIC3.3 日志级别为HEADERS 参考 一 概述 HttpLoggingInterceptor是OkHttp3提供的拦截器&#xff0c;用来记录HTTP请求和响应的详细信息。 1.1 日志级…

Dify中的经济索引模式实现过程

当索引模式为经济时&#xff0c;使用离线的向量引擎、关键词索引等方式&#xff0c;降低了准确度但无需花费 Token。 一.提取函数**_extract** 根据不同文档类型进行内容的提取&#xff1a; def _extract(self, index_processor: BaseIndexProcessor, dataset_document: Data…

pico+unity预设配置

picosdk中有很多预设的配置、使用预设配置的方法有 1、创建 XR Origin、展开 XR Origin > Camera Offset&#xff0c;选中 LeftHand Controller。点击 XR Controller (Action-Based) 面板右上角的 预设 按钮 2、打开Assets\Samples\XR Interaction Toolkit\2.5.2\Starter A…

《人工智能 从小白到大神》:一本让你彻底掌握AI的书

在当今这个快速发展的时代&#xff0c;人工智能&#xff08;AI&#xff09;已经成为改变世界的关键力量。你是否曾想过&#xff0c;如何从一个对AI一无所知的小白&#xff0c;成长为一名真正的AI大神&#xff1f;今天&#xff0c;我要向大家推荐一本能够帮助你实现这一目标的书…

51单片机11(蜂鸣器硬件设计和软件设计)

一、蜂鸣器硬件设计 1、 2、上面两张图&#xff0c;是针对不同产品的电路图。像左边这一块&#xff0c;是我们的A2&#xff0c;A3&#xff0c;A4的一个产品对应的一个封闭器的硬件电路。而右边的这一块是对应的A5到A7的一个硬件电路。因为A5到A7的一个产品&#xff0c;它的各…

Python和C++全球导航卫星系统和机器人姿态触觉感知二分图算法

&#x1f3af;要点 &#x1f3af;马尔可夫随机场网格推理学习 | &#x1f3af;二维伊辛模型四连网格模型推理 | &#x1f3af;统计物理学模型扰动与最大乘积二值反卷积 | &#x1f3af;受限玻尔兹曼机扰动和最大乘积采样 | &#x1f3af;视觉概率生成模型测试图像 &#x1f3…