Pytorch量化之Post Train Static Quantization(训练后静态量化)

使用Pytorch训练出的模型权重为fp32,部署时,为了加快速度,一般会将模型量化至int8。与fp32相比,int8模型的大小为原来的1/4, 速度为2~4倍。
Pytorch支持三种量化方式:

  • 动态量化(Dynamic Quantization): 只量化权重,激活在推理过程中进行量化
  • 静态量化(Static Quantization): 量化权重和激活
  • 量化感知训练(Quantization Aware Training,QAT): 插入量化算子后进行训练,主要在静态量化精度不满足需求时进行。
    大多数情况下,我们只需要进行静态量化,少数情况下在量化感知训练不满足时使用QAT进行微调。所以本篇只重点讲静态量化,并且理论部分先略过(后面再专门总结),只关注实操。
    注:下面的代码是在pytorch1.10下,后面Pytorch对量化的接口有调整
    官方文档:Quantization — PyTorch 1.10 documentation

动态模式(Eager Mode)与静态模式(fx graph)

Pytorch支持用2种方式量化,一种是动态图模式,也是我们日常使用Pytorch训练所使用的方式,使用这种方式量化需要自己手动修改网络结构,在支持量化的算子前、后插入量化节点,优点是方便调试。静态模式则是由pytorch自动在计算图中插入量化节点,不需要手动修改网络。
网络上大部分的教程都是基于静态模式,这种方式比较大的问题就是需要手动修改网络结构,官方教程里的网络是属于demo型, 其中的QuantStub和DeQuantStub就分别是量化和反量化的节点:

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

Pytorch对于很多网络层是不支持量化的(比如很常用的Prelu),如果我们用这种方式,我们就必须在这些不支持的层前面插入DeQuantStub,然后在支持的层前面插入QuantStub。笔者体验下来,体验很差,个人觉得不太实用,会破坏原来的网络结构。
而静态图模式,我们只需要调用Pytorch提供的接口将原模型转换一下即可,不需要修改原来的网络结构文件,个人认为实用性更强。
image.png

静态模式量化

1. 载入fp32模型,并转成fx graph

其中量化参数有‘fbgemm’和‘qnnpack’两种,前者在x86运行,后者在arm运行。

model_fp32 = torch.load(xxx)
model_fp32_quantize = copy.deepcopy(model_fp32)
qconfig_dict = {"": torch.quantization.get_default_qconfig('fbgemm')}
model_fp32_quantize.eval()
# prepare

model_prepared = quantize_fx.prepare_fx(model_fp32_quantize, qconfig_dict)
model_prepared.eval()

2.读取量化数据,标定(Calibration)量化参数

标定的过程就是使用模型推理量化图片,然后统计权重和激活分布,从而得到量化参数。量化图片一般来源于训练集(几百张左右,根据测试情况调整)。量化图片可以通过Pytorch的Dataloader读取,也可以直接自行实现读图片然后送入网络。

### 使用dataloader读取
for i, (data, label) in enumerate(train_loader):
    data = data.to(torch.device("cpu:0"))
    outputs = model_prepared(data)
    print("calibrating {}".format(i))
    if i > 1000:
        break

3. 转换为量化模型并保存

quantized_model = quantize_fx.convert_fx(model_prepared)
torch.jit.save(torch.jit.script(quantized_model), "quantized_model.pt")

速度测试

量化后的模型使用方法与fp32模型一样:

import torch
import cv2
import numpy as np
torch.set_num_threads(1)

fused_model = torch.jit.load("jit_model.pt")
fused_model.eval()
fused_model.to(torch.device("cpu:0"))

img = cv2.imread("./1.png")
img_fp32 = img.astype(np.float32)
img_fp32 = (img_fp32-127.5) / 127.5
input = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()

def speed_test(model, input):
    # warm up
    for i in range(10):
        model(input)

    import time
    start = time.time()
    for i in range(100):
        model(input)
    end = time.time()
    print("model time: ", (end-start)/100)
    time.sleep(10)

# quantized model
quantized_model= torch.jit.load("quantized_model.pt")
quantized_model.eval()
quantized_model.to(torch.device("cpu:0"))

speed_test(fused_model, input)
speed_test(quantized_model, input)

实测fp32模型单核运行120ms, 量化后47ms

结语

本文介绍了fx graph模式下的Pytorch的PTSQ方法,并实测了一个模型,效果还比较不错。
1_995567224_161_79_3_732056265_62005da0d7c1b531a6cf91ea587d312e.jpg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/66947.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android 13 Hotseat定制化修改

一.背景 由于需求是需要自定义修改Hotseat,所以此篇文章是记录如何自定义修改hotseat的,应该可以覆盖大部分场景,修改点有修改hotseat布局方向,hotseat图标数量,hotseat图标大小,hotseat布局位置,hotseat图标禁止形成文件夹,hotseat图标禁止移动到Launcher中,下面开始…

Gpt微信小程序搭建的前后端流程 - 前端小程序部分-2.确定交互所需的后端API(二)

Gpt微信小程序搭建的前后端流程 - 前端小程序部分-2.确定交互所需的后端API(二) 参考微信小程序-小柠AI智能聊天,可自行先体验。 根据上一节的小程序静态页面设计,需要从后端获取数据的主要4个点: 登录流程;获取今日已提问次数&a…

[保研/考研机试] KY102 计算表达式 上海交通大学复试上机题 C++实现

描述 对于一个不存在括号的表达式进行计算 输入描述: 存在多组数据,每组数据一行,表达式不存在空格 输出描述: 输出结果 示例1 输入: 6/233*4输出: 18思路: ①设立运算符和运算数两个…

CSS的引入方式有哪些?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 内联样式(Inline Styles)⭐ 内部样式表(Internal Stylesheet)⭐ 外部样式表(External Stylesheet)⭐ 导入样式表(Import Stylesheet)⭐ 写在最…

HarmonyOS应用开发者基础认证考试题库

此博文为HarmonyOS应用开发者基础认证考试的最后的大考,要求100分取得90分方可获取证书、现将考试的题库进行分享,希望能帮到大家。但是需要注意的是,题库会不定时的进行题目删减,但是大概的内容是不会进行改变的。真心希望这篇博…

ArcGISPro随机森林自动化调参分类预测模型展示

更改ArcGISPro的python环境变量请参考文章 ArcGISPro中如何使用机器学习脚本_Z_W_H_的博客-CSDN博客 脚本文件如下 点击运行 结果展示 负类预测概率 正类预测概率 二值化概率 文件夹(模型验证结果) 数据集数据库 ROC曲线 由于个人数据量太少所以…

立即开始使用 3D 图像

一、说明 这个故事介绍了使用这种类型的数据来训练机器学习3D模型。特别是,我们讨论了Kaggle中可用的MNIST数据集的3D版本,以及如何使用Keras训练模型识别3D数字。 3D 数据无处不在。由于我们希望构建AI来与我们的物理世界进行交互,因此使用3…

Openlayers实战:多数据分散聚合

在飞机、轮船等地图显示的应用中,很多时候会用到数据聚合,Openlayers中提供了Cluster这个API ,他作为souce的一部分,设定distance值,如果2个点的间距小于 distance 所设置的数时,就会以聚合的方式显示。从而解决了数据淤积显示的状态,非常实用。 效果图 源代码 /* * @…

js 正则表达式

js 正则表达式 http://tool.oschina.net/regex https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Guide/Regular_Expressions 11 22

1999-2021年全国各地级市专利申请与获得情况、绿色专利申请与获得情况面板数据

1999-2021年全国各地级市专利申请与获得情况、绿色专利申请与获得情况面板数据 1、时间:2000-2021年 2、来源:国家知识产权局 3、范围:地级市(具体每年地级市数量参看下文图片) 4、指标:申请专利数&…

数据结构----结构--线性结构--链式存储--链表

数据结构----结构–线性结构–链式存储–链表 1.链表的特点 空间可以不连续,长度不固定,相对于数组灵活自由 搜索: 时间复杂度O(n) 增删: 头增头删时间复杂度O(1) 其他时间复杂度为O(n) 扩展:单向循环链表的特性 从任意节…

基于kettle实现pg数据定时转存mongodb

mogodb 待创建 基于kettle实现pg数据定时转存mongodb_kettle 实时迁移 mongodb_呆呆的私房菜的博客-CSDN博客

【JUC】复习指南

JUC复习指南: JUC有哪些知识点? 什么是Juc Lock接口 线程间通信 集合的线程安全问题 多线程锁 Callable接口 JUC三大辅助类 CountDownLatch CyclicBarrier Semaphore 读写锁 ReetrantReadWriteLOck 阻塞队列 ThreadPool线程池 Fork/join Com…

CVE漏洞复现-CVE-2021-3493 Linux 提权内核漏洞

CVE-2021-3493 Linux 提权内核漏洞 漏洞描述 CVE-2021-3493 用户漏洞是 Linux 内核中没有文件系统中的 layfs 中的 Ubuntu over 特定问题,在 Ubuntu 中正确验证有关名称空间文件系统的应用程序。buntu 内核代码允许低权限用户在使用 unshare() 函数创建的用户命名…

JVM:运行时数据区域(白话文)

最近有时间在看一本<深入了解Java虚拟机>的书籍&#xff0c;这本书是一个中国人&#xff0c;名叫周志明的人写的。相比于其他翻译过来的技术书籍&#xff0c;这本书还是挺通俗易懂的。先前有和彬哥在聊&#xff0c;他说如果是自己一个人看的话会很枯燥&#xff0c;很难坚…

Docker安装Grafana以及Grafana应用

Doker基础 安装 1、 卸载旧的版本 sudo yum remove docker docker-client docker-client-latest docker-common docker-latest docker-latest-logrotate docker-logrotate docker-engine 2、需要的安装包 sudo yum install -y yum-utils 3、设置镜像的仓库 yum-config-m…

openlayers渲染rgb三波段cog时达到类似rgba的效果(去掉黑底)

图是arcgis渲染成rgb的&#xff0c;由于没有透明度波段&#xff0c;底下是黑的。 为了能在前端显示透明效果&#xff0c;之前是用python处理数据&#xff0c;给它加个透明度波段 后来研究了一下ol的样式表达式&#xff0c;可以直接在前端去掉黑底 样式设置代码如下 const s…

浪潮数字咨询专家孙崇虎受邀为第十二届中国PMO大会演讲嘉宾

浪潮数字企业技术有限公司集团管控事业部咨询专家孙崇虎先生受邀为由PMO评论主办的2023第十二届中国PMO大会演讲嘉宾&#xff0c;演讲议题&#xff1a;VUCA时代的项目管理信息化应对。大会将于8月12-13日在北京举办&#xff0c;敬请关注&#xff01; 议题简要&#xff1a; 当前…

HTML+CSS+JavaScript:两种方法实现商品价格筛选效果

一、需求 鼠标点击上方菜单栏中不同的价格区间&#xff0c;自动筛选出价格符合条件的商品&#xff0c;并渲染在页面中 二、代码素材 以下是缺失JS部分的代码&#xff0c;感兴趣的小伙伴可以先自己试着写一写 <!DOCTYPE html> <html lang"en"><head…

【递归算法实践】验证二叉搜索树

目录 1. 递归算法 2. 递归实现验证二叉搜索树 3. 递归解法的实现逻辑 4. 递归实现的实例分析 1. 递归算法 递归是一种通过函数自身调用来解决问题的算法&#xff0c;它可以使代码更加简洁和优雅&#xff0c;同时也能够解决许多复杂的问题。在递归中&#xff0c;函数会不断…