【Datawhale组队学习】模型减肥秘籍:模型压缩技术6——项目实践

NNI (Neural Network Intelligence) 是由微软开发的一个开源自动化机器学习(AutoML)库,用于帮助研究人员和开发人员高效地进行机器学习实验。它提供了一套丰富的工具来进行模型调优、神经网络架构搜索、模型压缩以及自动化的超参数搜索。
在这里插入图片描述

1.模型剪枝

代码的主要目的是展示如何通过 NNI 进行神经网络模型剪枝,以减少模型大小和计算复杂度,之后通过微调来恢复模型的性能,并实现剪枝后的加速部署。这种方法有助于将复杂的神经网络压缩,使其更适合在资源受限的设备上运行,同时保持尽可能高的准确率。

核心代码:

config_list = [{
    'op_types': ['Linear', 'Conv2d'],
    'exclude_op_names': ['fc3'],
    'sparse_ratio': 0.8
}]

这段代码定义了一个剪枝配置 config_list,用于指定如何对模型的特定层进行剪枝。下面是每个参数的解释:
1. op_types: [‘Linear’, ‘Conv2d’]
这个参数指定了需要进行剪枝的层类型。‘Linear’ 和 ‘Conv2d’ 表示对模型中的全连接层和卷积层(Conv2d)进行剪枝。
2. exclude_op_names: [‘fc3’]
这个参数指定了在剪枝过程中需要排除的层。名为 ‘fc3’ 的层将不会被剪枝,即该层不受剪枝影响。
3. sparse_ratio: 0.8
这个参数指定了剪枝的稀疏度比例。‘sparse_ratio’: 0.8 表示在选定的层中,要将 80% 的参数剪枝掉,只保留 20% 的权重。

2.模型量化

使用 NNI 框架对深度学习模型进行量化处理的实践。代码使用了一种训练后量化的技术,通过减少模型中参数的位数,来减小模型大小并加快推理速度。

# 量化配置,将卷积层和全连接层量化为int8类型
config_list = [{
    'op_names': ['conv1', 'conv2', 'fc1', 'fc2'],  # 需要量化的操作
    'target_names': ['_input_', 'weight', '_output_'],  # 量化输入、权重和输出
    'quant_dtype': 'int8',  # 使用int8类型进行量化
    'quant_scheme': 'affine',  # 量化方法,使用仿射变换
    'granularity': 'default',
},{
    'op_names': ['relu1', 'relu2'],  # 需要量化的激活函数
    'target_names': ['_output_'],  # 量化输出
    'quant_dtype': 'int8',
    'quant_scheme': 'affine',
    'granularity': 'default',
}]

# 创建QATQuantizer对象进行量化感知训练
quantizer = QATQuantizer(model, config_list, evaluator, len(train_loader))

通过量化感知训练减少模型的复杂度,使模型可以被压缩为 int8 类型。这样能够在保证模型性能的同时,显著降低模型大小并加快推理速度,尤其适合在资源受限的环境中使用。

3.NAS

代码使用 NNI 进行神经网络架构搜索(NAS)的实践,展示如何通过定义模型空间、选择搜索策略、训练评估模型并启动实验来寻找最优模型架构。

核心代码:
定义了一个名为 MyModelSpace 的模型空间,通过 LayerChoice 和 MutableXXX 使其包含多种可能的结构,用于搜索不同的架构组合。

class MyModelSpace(ModelSpace):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        # LayerChoice用于选择卷积层类型(标准卷积或深度可分离卷积)
        self.conv2 = LayerChoice([
            nn.Conv2d(32, 64, 3, 1),
            DepthwiseSeparableConv(32, 64)
        ], label='conv2')
        # MutableDropout用于从指定的概率中选择一个dropout率
        self.dropout1 = MutableDropout(nni.choice('dropout', [0.25, 0.5, 0.75]))
        self.dropout2 = nn.Dropout(0.5)
        feature = nni.choice('feature', [64, 128, 256])
        self.fc1 = MutableLinear(9216, feature)
        self.fc2 = MutableLinear(feature, 10)

使用随机搜索策略来探索模型空间,以选择不同的层配置。

import nni.nas.strategy as strategy
search_strategy = strategy.Random()  # 使用随机搜索策略

代码通过 NNI 的 NAS 功能自动化地探索不同的神经网络结构组合,减少了人工设计架构的复杂性。结合随机搜索策略,能够有效地在不同架构之间进行搜索,并通过实验来评估每个模型的性能。

4.使用NNI对模型进行剪枝、量化、蒸馏压缩

使用 NNI 框架对 ResNet18 模型进行融合压缩,涉及模型剪枝、量化和知识蒸馏等方法。

使用了 TaylorPruner 和 AGPPruner,分别对模型进行基于泰勒展开法的重要性评估以及渐进剪枝。

# 设置剪枝配置
bn_list = [module_name for module_name, module in model.named_modules() if isinstance(module, torch.nn.BatchNorm2d)]
p_config_list = [{
    'op_types': ['Conv2d'],
    'sparse_ratio': 0.5
}, *[{
    'op_names': [name],
    'target_names': ['_output_'],
    'target_settings': {
        '_output_': {
            'align': {
                'module_name': name.replace('bn', 'conv') if 'bn' in name else name.replace('downsample.1', 'downsample.0'),
                'target_name': 'weight',
                'dims': [0],
            },
            'granularity': 'per_channel'
        }
    }
} for name in bn_list]]

# 使用 TaylorPruner 和 AGPPruner 进行剪枝
sub_pruner = TaylorPruner(model, p_config_list, evaluator, training_steps=100)
scheduled_pruner = AGPPruner(sub_pruner, interval_steps=100, total_times=30)

使用了 QATQuantizer,以量化感知训练(QAT)的方式对模型进行量化。

q_config_list = [{
    'op_types': ['Conv2d'],
    'quant_dtype': 'int8',
    'target_names': ['_input_'],
    'granularity': 'per_channel'
}, {
    'op_types': ['BatchNorm2d'],
    'quant_dtype': 'int8',
    'target_names': ['_output_'],
    'granularity': 'per_channel'
}]

quantizer = QATQuantizer.from_compressor(scheduled_pruner, q_config_list, quant_start_step=100)

使用 DynamicLayerwiseDistiller 对模型进行蒸馏,将教师模型的知识传递给学生模型。

def teacher_predict(batch, teacher_model):
    return teacher_model(batch[0])

d_config_list = [{
    'op_types': ['Conv2d'],
    'lambda': 0.1,
    'apply_method': 'mse',
}]
distiller = DynamicLayerwiseDistiller.from_compressor(quantizer, d_config_list, teacher_model, teacher_predict, 0.1)

通过剪枝和量化得到的稀疏性,可以利用 ModelSpeedup 来加速模型推理。

masks = scheduled_pruner.get_masks()
speedup = ModelSpeedup(model, dummy_input, masks)
model = speedup.speedup_model()

展示了如何使用 NNI 对 ResNet18 模型进行融合压缩,在减少模型参数数量和加速推理的同时,最大限度地保留模型的性能。这种方式能够显著减小模型的存储需求和计算成本,非常适合在资源受限的设备上进行模型部署。
在这里插入图片描述

参考文献

  1. https://www.datawhale.cn/learn/content/68/966
  2. https://github.com/datawhalechina/awesome-compression/tree/main/docs/notebook/ch07
  3. https://nni.readthedocs.io/en/latest/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/927128.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

通讯专题4.1——CAN通信之计算机网络与现场总线

从通讯专题4开始,来学习CAN总线的内容。 为了更好的学习CAN,先从计算机网络与现场总线开始了解。 1 计算机网络体系的结构 在我们生活当中,有许多的网络,如交通网(铁路、公路等)、通信网(电信、…

【51单片机】程序实验910.直流电机-步进电机

主要参考学习资料:B站【普中官方】51单片机手把手教学视频 前置知识:C语言 单片机套装:普中STC51单片机开发板A4标准版套餐7 码字不易,求点赞收藏加关注(•ω•̥) 有问题欢迎评论区讨论~ 目录 程序实验9&10.直流电机-步进电机…

Qt支持RKMPP硬解的视频监控系统/性能卓越界面精美/实时性好延迟低/录像存储和回放/云台控制

一、前言 之前做的监控系统,已经实现了在windows上硬解码比如dxva2和d3d11va,后续又增加了linux上的硬解vdpau的支持,这几种方式都是跨系统的硬解实现方案,也是就是如果都是windows系统,无论X86还是ARM都通用&#xf…

Web API基本认知

作用和分类 作用:就是使用JS去操作html和浏览器 分类:DOM(文档对象模型)、BOM(浏览器对象模型) 什么是DOM DOM(Document Object Model ——文档对象模型)是用来呈现以及与任意 HTM…

Linux——自定义简单shell

shell 自定义shell目标普通命令和内建命令(补充) shell实现实现原理实现代码 自定义shell 目标 能处理普通命令能处理内建命令要能帮助我们理解内建命令/本地变量/环境变量这些概念理解shell的运行 普通命令和内建命令(补充) …

智能桥梁安全运行监测系统守护桥梁安全卫士

一、方案背景 桥梁作为交通基础设施中不可或缺的重要组成部分,其安全稳定的运行直接关联到广大人民群众的生命财产安全以及整个社会的稳定与和谐。桥梁不仅是连接两地的通道,更是经济发展和社会进步的重要纽带。为了确保桥梁的安全运行,桥梁安…

【Python网络爬虫笔记】5-(Request 带参数的get请求) 爬取豆瓣电影排行信息

目录 1.抓包工具查看网站信息2.代码实现3.运行结果 1.抓包工具查看网站信息 请求路径 url:https://movie.douban.com/typerank请求参数 页面往下拉,出现新的请求结果,参数start更新,每次刷新出20条新的电影数据 2.代码实现 # 使用网络爬…

新质驱动·科东软件受邀出席2024智能网联+低空经济暨第二届湾区汽车T9+N闭门会议

为推进广东省加快发展新质生产力,贯彻落实“百县千镇万村高质量发展工程”,推动韶关市新丰县智能网联新能源汽车、低空经济与数字技术的创新与发展,充分发挥湾区汽车产业链头部企业的带动作用。韶关市指导、珠三角湾区智能网联新能源汽车产业…

C#使用ExcelDataReader读取Xlsx文件为DataTable对象

创建控制台项目 在NuGet中安装ExcelDataReader.DataSet 3.7.0 创建一个xlsx文件 测试代码 读取xlsx文件内容,为一个DataTable对象。 读取xlsx时,xlsx文件不能被其他软件打开,否则会报“进程无法访问此文件”的错。 using ExcelDataRead…

“harmony”整合不同平台的单细胞数据之旅

其实在Seurat v3官方网站的Vignettes中就曾见过该算法,但并没有太多关注,直到看了北大张泽民团队在2019年10月31日发表于Cell的《Landscap and Dynamics of Single Immune Cells in Hepatocellular Carcinoma》,为了同时整合两类数据&#xf…

智慧银行反欺诈大数据管控平台方案(一)

智慧银行反欺诈大数据管控平台建设方案的核心在于通过整合先进的大数据技术和深度学习算法,打造一个全面、智能且实时的反欺诈系统,以有效识别、预防和应对各类金融欺诈行为。该方案涵盖数据采集、存储、处理和分析的全流程,利用多元化的数据…

基于 JNI + Rust 实现一种高性能 Excel 导出方案(上篇)

每个不曾起舞的日子,都是对生命的辜负。 ——尼采 一、背景:Web 导出 Excel 的场景 Web 导出 Excel 功能在数据处理、分析和共享方面提供了极大的便利,是许多 Web 应用程序中的重要功能。以下是一些典型的场景: 数据报表导出:在企业管理系统(如ERP、CRM)中,用户经常需…

使用 Tkinter 创建一个简单的 GUI 应用程序来合并视频和音频文件

使用 Tkinter 创建一个简单的 GUI 应用程序来合并视频和音频文件 Python 是一门强大的编程语言,它不仅可以用于数据处理、自动化脚本,还可以用于创建图形用户界面 (GUI) 应用程序。在本教程中,我们将使用 Python 的标准库模块 tkinter 创建一…

「Mac畅玩鸿蒙与硬件35」UI互动应用篇12 - 简易日历

本篇将带你实现一个简易日历应用,显示当前月份的日期,并支持选择特定日期的功能。用户可以通过点击日期高亮选中,还可以切换上下月份,体验动态界面的交互效果。 关键词 UI互动应用简易日历动态界面状态管理用户交互 一、功能说明…

江协科技最新OLED保姆级移植hal库

江协科技最新OLED移植到hal库保姆级步骤 源码工程存档 工程和源码下载(密码 1i8y) 原因 江协科技的开源OLED封装的非常完美, 可以满足我们日常的大部分开发, 如果可以用在hal库 ,将是如虎添翼, 为我们开发调试又增加一个新的瑞士军刀, 所以我们接下来手把手的去官网移植源码…

NLTK工具包

NLTK工具包 NLTK工具包安装 非常实用的文本处理工具,主要用于英文数据,历史悠久~ 安装命令: pip install nltk import nltk # nltk.download() # nltk.download(punkt) # nltk.download(stopwords) # nltk.download(maxent_ne_chunker) nl…

HarmonyOS:使用Emitter进行线程间通信

Emitter主要提供线程间发送和处理事件的能力,包括对持续订阅事件或单次订阅事件的处理、取消订阅事件、发送事件到事件队列等。 一、Emitter的开发步骤如下: 订阅事件 import { emitter } from kit.BasicServicesKit; import { promptAction } from kit.…

Unity之一键创建自定义Package包

内容将会持续更新,有错误的地方欢迎指正,谢谢! Unity之一键创建自定义Package包 TechX 坚持将创新的科技带给世界! 拥有更好的学习体验 —— 不断努力,不断进步,不断探索 TechX —— 心探索、心进取! …

【html网页页面007】html+css制作旅游主题内蒙古网页制作含注册表单(4页面附效果及源码)

旅游家乡主题网页制作 🥤1、写在前面🍧2、涉及知识🌳3、网页效果🌈4、网页源码4.1 html4.2 CSS4.3 源码获取 🐋5、作者寄语 🥤1、写在前面 家乡网站主题内蒙古的网页 一共4个页面 网页使用htmlcss制作页面…

Ardupilot开源无人机之Geek SDK讨论

Ardupilot开源无人机之Geek SDK讨论 1. 源由2. 假设3. 思考3.1 结构构型3.2 有限资源3.3 软硬件构架 4.Ardupilot构架 - 2024kaga Update5. 讨论5.1 话题1:工作模式5.2 话题2:关键要点5.3 话题3:产品设计 6. Geek SDK - OpenFire6.1 开源技术…