对 MODNet 其他模块的剪枝探索

写在前面

先前笔者分享了《对 MODNet 主干网络 MobileNetV2的剪枝探索》,没想到被选为了CSDN每天值得看系列,因为笔者开设的专栏《MODNet-Compression探索之旅》仅仅只是记录笔者在模型压缩领域的探索历程,对此笔者深感荣幸,非常感谢官方大大的认可!!!接下来,笔者会加倍努力,创作更多优质文章,为社区贡献更多有价值、有意思的内容!!!!

本文将分享笔者对 MODNet 网络结构内部其他模块的剪枝探索,剪枝策略同前文主干网络是一样的,剪枝完成后对参数进行替换即可,接下来,就开启探索之旅吧~~

1 开展思路

  1. 访问 MODNet 获取模块;
  2. torch.save(model.state_dict(), path),并检测能否 load,注意参数;
  3. 修改替换脚本中 for 循环下的 if 条件判断;
  4. 修改backbone、MODNet中 IBNorm 以及 wrapper 中的 channels,run script;
  5. 加载替换后的模型参数,观察是否能够成功执行。

2 核心要义

  1. 模型分析:根据先前对剪枝后 MobileNet V2 的结构修改,以及嵌入 MODNet 后的 channel 修改情况,确定待修改的网络层;

  2. 通道裁剪:根据1得到的待修改的网络层进行裁剪,以满足结构与参数匹配的情况;

  3. 参数嵌入:确认 channel 匹配以后,将参入嵌入 MODNet;

3 探索过程

确定修改后的结构与原先的区别在于下列网络层:

  • backbone;
  • lr_branch中的 lr16x、lr8x;
  • hr_branch中 enc2x;

目前,已对 backbone 成功嵌入。

接下来,针对lr16x、lr8x进行剪枝处理,但通过观察可以发现,这两层的前面存在着 se_block 模块,因此,先对 se_block 进行处理。

3.1 se block

观察该部分在 MODNet 中的尺寸与网络层名称:

获取并替换成功!不过这部分详细的过程笔者没有记录!存在不周,请谅解~~

3.2 lr16x、lr8x

💥注意:由于起初缺乏对网络层的分析,因此,在进行这两层的嵌入时,仅仅只是单一的嵌入。

将lr16x嵌入以后,出现了“参数 shape > 结构 shape”的情况。

于是,笔者联想到先前的解决方案固定结构,重新进行参数替换。但即便如此,通过键值对获取参数时,参数中的通道数尺寸并未发生变化。(因此,先前的这种方法存在不合理性,但却在执行后可以成功匹配,目前还没有进一步探寻。)

合理的方案以及针对情况如下

  • 对于output channel:单独提取该层,进行剪枝。(但是,如果和它相连的下一层 input channel 也发生了变化,需要将其合并,同时处理,这样,上一次的输出决定着下一层的输入。
  • 对于input channel:如上,合并处理。但是,如果与该层相连的上一层channel保持不变,那就无法使用剪枝。目前的解决方案是,切片提取,先满足结构要求。

而 lr16x 与 lr8x 正适合第一种情况!

原结构:

修改后的结构:

将 lr16x 与 lr8x 作为一个 sequential,剪枝:

model = modnet.MODNet(backbone_pretrained=False)
pretrained_ckpt = 'modnet_photographic_portrait_matting.ckpt'
model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_ckpt).items()})

# get model
model = nn.Sequential(model.lr_branch.conv_lr16x, model.lr_branch.conv_lr8x)
print(model)

# pruning
# 由于是针对lr16x的output以及lr8x的input,因此这里排除lr8x即可
config_list = [{'sparsity': 0.5,
                'op_types': ['Conv2d']},
               {'exclude': True,
                'op_names': ['1.layers.0']}
               ]

pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
pruner._unwrap_model()
ModelSpeedup(model, dummy_input, masks).speedup_model()
print(model)

结构变化:

修改网络结构(mobilenet、wrapper、IBNorm),加载裁剪后的参数,能成功执行计算:

IBNorm结构变化,init部分:

    def __init__(self, in_channels):
        super(IBNorm, self).__init__()
        in_channels = in_channels

        # 针对lr_16x
        if in_channels == 48:
            self.bnorm_channels = 27
            self.inorm_channels = 21
        else:
            self.bnorm_channels = int(in_channels / 2)
            self.inorm_channels = in_channels - self.bnorm_channels 

加载:

model = modnet.MODNet(backbone_pretrained=False)
model = nn.Sequential(model.lr_branch.conv_lr16x, model.lr_branch.conv_lr8x)
model.load_state_dict(torch.load('test.pth'))

dummy_input = torch.randn([1, 1280, 32, 32])
flops, params, _ = count_flops_params(model, dummy_input, verbose=True)
print(f"Model FLOPs {flops / 1e6:.2f}M, Params {params / 1e6:.2f}M")

结果:

替换MODNet中,这一部分的参数,保存并加载:

3.3 enc2x

至此,三个模块的参数全部嵌入!

4 探索结果

4.1 模型大小

4.2 参数量与计算量

剪枝前剪枝后
参数量6.45 M3.36 M
计算量18117.07 M15315.94 M

4.3 推理时延

序号剪枝前剪枝后
10.890.67
20.960.68
30.860.67

4.4 精度

评估指标原模型针对MobileNet V2剪枝后微调后从头训练后
MSE0.0042990.3607810.1403840.104005
MAD0.0081410.5765600.2111690.124459

5 实际推理测试

使用微调后的pth导出onnx模型:

model.eval()
batch_size = 1
height = 512
width = 512
dummy_input = Variable(torch.randn(batch_size, 3, height, width))

torch.onnx.export(
    model, dummy_input, 'test_modnet.onnx', export_params=True,
    input_names=['input'], output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                  'output': {0: 'batch_size', 2: 'height', 3: 'width'}}, opset_version=11)

推理:

和微调前的推理结果并无差别,但在直接使用pth格式模型推理时差异较大。

为何会这样?难道是因为笔者选用的不是人像,而是天线宝宝?

在观察导出的 ONNX 格式模型时,笔者发现模型输出节点的个数发生了变化。

原因是笔者在导出时没有注意 output,使用官方脚本解决了~

💥注意:这也就告诉我们,模型导出时的成功提示并不一定是真正处理好了,很多内部细节的丢失会对模型的推理精度带来致命的效果,这时我们可以重新思考模型的输入与输出,或者采用可视化的方式进行查看!

再次推理:

虽然效果仍然不理想,但至少好了很多,而且可以看出来,笔者选用的测试样例确实不是人!

推理时延变化:240ms---> 192ms,有明显改进!


在导出时也遇到了一个error:

onnxruntime::UpsampleBase::ScalesValidation scale >= 1 was false. Scale value should be greater tha

分析原因:调用 torch.export 时未指定 op_version;

解决方案:考虑到 笔者的pytorch version>=1.3.1,因此直接指定其为op为11,完成了推理!

6 结论 

  1. 在替换除了 MobileNet V2 以外的其他部分时,没有考虑整体,仅仅只是对单一的卷积层剪枝,以致于相连的下一个卷积层无法修改通道数。因此,剪枝无法直接对 input channels 操作,只能针对 output channels,进而影响 input channels。
  2. 关于IBNorm,直接修改了channels,可以运行,但缺乏通用性!
  3. 成功嵌入了除 MobileNet V2 以外的参数,并成功导出 ONNX 模型,完成模型推理!
  4. 经测试,模型大小、参数量降低了一半,推理时延降低 20%,从模型压缩的轻量化角度来看,本次探索是成功的,但从模型本身的精度来看,还有很长一段路要走!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/342120.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

vue3 npm i 一直卡到不动

一. 首先node 版本要18.0及以上 查看node版本并安装指定版本 二. 查看npm镜像源以及指定安装npm的镜像 三. 删除项目中的package-lock.json文件 最好是把node_modules安装包也删除掉,然后npm i 就可以了

如何用 500 行 SQL 实现 GPT2学习

目录 理论背景实现过程GenerationTokenizerEmbeddingsAttention为什么我们需要有因果掩码?为什么矩阵是 Q,K 和 V? BlocksTokens为什么要使用 softmax 转换概率?Inference 俄罗斯有个大佬每年都会用 SQL 来实现一个挑战庆祝新年&a…

1.23寒假集训

A: 解题思路&#xff1a; 大于x输出0&#xff0c;小于输出x减去这个数 下面是c代码&#xff1a; #include<iostream> using namespace std; int main() {int a,b,c,d,x;cin >> a >> b >> c >> d >> x;cout << (a < x ? x - a…

Gen AI大潮来袭!8个Salesforce新岗位,你会选择哪个?

人工智能席卷全球&#xff0c;企业对如何整合GenAI有着浓厚的兴趣。为启动企业的GenAI转型浪潮&#xff0c;Salesforce宣布与埃森哲和德勤建立合作伙伴关系&#xff0c;并计划推出更多支持项目。 目前&#xff0c;Salesforce领域的其他咨询公司正在提高员工技能&#xff0c;以…

Unity 适配器模式(实例详解)

文章目录 简介1. **Input Adapter 示例**2. **Component Adapter 示例**3. **网络数据解析适配器**4. **物理引擎适配**5. **跨平台服务适配** 简介 Unity中的适配器模式&#xff08;Adapter Pattern&#xff09;主要用于将一个类的接口转换为另一个接口&#xff0c;以便于原本…

贪吃蛇(C)

游戏背景&#xff1a;贪吃蛇是久负盛名的游戏&#xff0c;它也和俄罗斯⽅块&#xff0c;扫雷等游戏位列经典游戏的⾏列。 总&#xff1a; 游戏设计大纲&#xff1a; 使⽤C语⾔在Windows环境的控制台中模拟实现经典⼩游戏贪吃蛇。 实现的基本功能&#xff1a; 1、贪吃蛇地图绘制…

Whale 帷幄创始人叶生晅荣获亿欧 2023 中国泛人工智能优秀人物 TOP 20

近日&#xff0c;亿欧在 WIM 2023&#xff08;World Innovators Meet&#xff0c;世界创新者年会&#xff09;上发布 2023 世界创新奖「2023 中国泛人工智能优秀人物 TOP 20」&#xff0c;表彰那些过去一年中在泛人工智能领域做出突出贡献的领导者、开拓者。「Whale 帷幄」创始…

用ChatGPT教学、科研!亚利桑那州立大学与OpenAI合作

亚利桑那州立大学&#xff08;简称“ASU”&#xff09;在官网宣布与OpenAI达成技术合作。从2024年2月份开始&#xff0c;为所有学生提供ChatGPT企业版访问权限&#xff0c;主要用于学习、课程作业和学术研究等。 为了帮助学生更好地学习ChatGPT和大语言模型产品&#xff0c;AS…

图像分割实战-系列教程18:MaskRCNN项目介绍与配置

&#x1f341;&#x1f341;&#x1f341;图像分割实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 Mask R-CNN for Object Detection and Segmentation MaskRCNN是一个通用的物体检测框架&#xff…

MySQL学习(1):centos7安装MySQL

1.安装自己系统对应的MySQL版本 1.1查看自己系统的内核版本 cat /etc/redhat-release 可以看到我的系统版本是centos7.6 1.2去官网下载对应的MySQL安装文件 MySQL官网&#xff1a; https://dev.mysql.com/downloads/ 点击MYSQL Community Server 然后可以在索引的位置选…

PWM调光 降压恒流LED芯片FP7127:为照明系统注入新能量(台灯、GBR、调光电源、汽车大灯)

目录 一、降压恒流LED芯片FP7127 二、降压恒流LED芯片FP7127具有以下特点&#xff1a; 三、降压恒流LED芯片FP7127应用领域&#xff1a; LED照明和调光的新纪元随着LED照明技术的不断发展&#xff0c;人们对于照明调光的需求也越来越高。PWM调光技术作为一种常用的调光方法&…

获取货币供应量

用bs库&#xff1a; import baostock as bs import pandas as pd# 登陆系统 lg bs.login() # 显示登陆返回信息 print(login respond error_code:lg.error_code) print(login respond error_msg:lg.error_msg)# 获取货币供应量 rs bs.query_money_supply_data_month(start_…

App各大应用商城的排名被哪些因素影响着?(小米/vivo篇)

小米&#xff1a; ①关键词设置&#xff1a; 小米应用商店允许在后台设置关键词&#xff0c;8个关键词&#xff0c;每个词不超过5个字&#xff0c;权重从左到右逐渐降低。 关键词内最好不要填写应用名称里面已有的关键词&#xff0c;不叠加权重&#xff0c;浪费位置。 ②应…

5G+物联网:连接万物,重塑智慧社区,开启未来生活新纪元,助力智慧社区的革新与发展

一、5G与物联网&#xff1a;技术概述与基础 随着科技的飞速发展&#xff0c;第五代移动通信技术&#xff08;5G&#xff09;和物联网&#xff08;IoT&#xff09;已经成为当今社会的热门话题。这两项技术作为现代信息社会的核心基础设施&#xff0c;正深刻地改变着人们的生活和…

宿舍安全用电监模块

学校宿舍安全用电监测模块是针对 0.4kV 以下的 TT、TN 系统设计的智能电力装置&#xff0c;具有单、三相交流电测量、四象限电能计量、谐波分析、开关量输入、继电器输出功能&#xff0c;以及 RS485 通讯或 GPRS 无线通讯功能&#xff0c;通过对配电回路的剩余电流、导线温度等…

教师转行适合做什么工作

当教师转型成为社会话题时&#xff0c;无数同仁都开始思考&#xff1a;我要转行吗&#xff1f;转到哪里去呢&#xff1f;作为一位曾经的教师&#xff0c;我想说&#xff0c;转行不是盲目地跳出教育界&#xff0c;而是基于自身优势和兴趣的理性选择。 作为教师&#xff0c;我们…

k8s集群异常恢复

前提、我自己的k8s采用的是单master节点两个从节点部署&#xff0c;我针对单master情况进行恢复说明 场景一&#xff1a;正常开关虚拟机&#xff0c;可直接重启kubelet进行恢复 1、1、一般重启后三个节点都需要检查&#xff0c;输入命令检查kubelet&#xff1a; systemctl s…

gitlab设置/修改克隆clone地址端口

最近由于公司要停测试库云服务器? 什么?要停测试库服务器??? 是的! 你没听错。 真是醉了,多大的集团,为了省钱,也真是拼了, 作为开发人员,没有测试服务器,犹如断臂之人。 所以,在之前搭建环境的时候都没有写文档,今天算是弥补上,以后都可以作为参考了, …

数据结构:完全二叉树(递归实现)

如果完全二叉树的深度为h&#xff0c;那么除了第h层外&#xff0c;其他层的节点个数都是满的&#xff0c;第h层的节点都靠左排列。 完全二叉树的编号方法是从上到下&#xff0c;从左到右&#xff0c;根节点为1号节点&#xff0c;设完全二叉树的节点数为sum&#xff0c;某节点编…

C++提高编程——STL:string容器、vector容器

本专栏记录C学习过程包括C基础以及数据结构和算法&#xff0c;其中第一部分计划时间一个月&#xff0c;主要跟着黑马视频教程&#xff0c;学习路线如下&#xff0c;不定时更新&#xff0c;欢迎关注。 当前章节处于&#xff1a; ---------第1阶段-C基础入门 ---------第2阶段实战…