模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

模型部署:量化中的Post-Training-Quantization(PTQ)和Quantization-Aware-Training(QAT)

  • 前言
  • 量化
    • Post-Training-Quantization(PTQ)
    • Quantization-Aware-Training(QAT)
  • 参考文献

前言

随着人工智能的不断发展,深度学习网络被广泛应用于图像处理、自然语言处理等实际场景,将其部署至多种不同设备的需求也日益增加。然而,常见的深度学习网络模型通常包含大量参数和数百万的浮点数运算(例如ResNet50具有95MB的参数以及38亿浮点数运算),实时地运行这些模型需要消耗大量内存和算力,这使得它们难以部署到资源受限且需要满足实时性、低功耗等要求的边缘设备。为了进一步推动深度学习网络模型在移动端或边缘设备中的快速部署,深度学习领域提出了一系列的模型压缩与加速方法:

  • 知识蒸馏(Knowledge distillation):使用教师-学生网络结构,让小型的学生网络模仿大型教师网络的行为,以使得准确率尽可能高的同时,能够获得一个轻量化的网络。
  • 剪枝(Parameter pruning):删除不必要的网络参数,以减少模型的规模和计算复杂度。
  • 低秩分解(Low-rank factorization):将模型的参数矩阵分解为较低秩的小矩阵,以减少模型的复杂度和计算成本。
  • 参数共享(Parameter sharing):将多个层共用一组参数,以减少模型的参数数量。
  • 量化(Quantization):将模型的参数和运算转化为更小的数据类型,以减少内存占用和计算时间。

量化

模型量化(Quantization)是一种将浮点计算转化为定点计算的技术,例如从FP32降低至INT8,主要用于减少模型的计算强度、参数大小以及内存消耗,以提高模型在设备上的推理计算效率,但是也有可能会带来一定的精度损失。

模型量化精度损失的主要原因为量化-反量化(Quantization-Dequantization)过程中取整引起的误差。这里简单介绍一下量化的计算方法,以FP32到INT8的量化为例,量化的核心思想就是将浮点数区间的参数映射到INT8的离散区间中。
量化公式:
q = r s + Z q = \frac{r}{s} + Z q=sr+Z反量化公式:
r = S ( q − Z ) r = S(q-Z) r=S(qZ)其中, r r r 为FP32的浮点数(real value), q q q 为INT8的量化值(quantization value),
S S S Z Z Z 分别为缩放因子(Scale-factor)和零点(Zero-Point)。

量化最重要的便是确定 S S S Z Z Z 的值, S S S Z Z Z 的计算公式如下:
S = r m a x − r m i n q m a x − q m i n S = \frac{r_{max}-r_{min}}{q_{max}-q_{min}} S=qmaxqminrmaxrmin Z = − r m i n S + q m i n Z = -\frac{r_{min}}{S} + q_{min} Z=Srmin+qmin其中, r m a x r_{max} rmax r m i n r_{min} rmin 分别为FP32网络参数最大、最小值, q m a x q_{max} qmax q m i n q_{min} qmin 分别为INT8网络参数最大、最小值。

为了减少量化所带来的精度损失,学者提出了Quantization-Aware-Training(QAT)方法,再介绍此之前,由于Post-Training-Quantization(PTQ)方法也经常在文献中出现,此篇博客将着重介绍这两个方法的含义与区别。
在这里插入图片描述

Post-Training-Quantization(PTQ)

Post-Training-Quantization(PTQ)是目前常用的模型量化方法之一。以INT8量化为例,PTQ方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的模型;
  2. 使用小部分数据对FP32模型进行采样(Calibration),主要是为了得到网络各层参数的数据分布特性(比如统计最大最小值);
  3. 根据步骤2中的数据分布特性,计算出网络各层 S 和 Z 量化参数;
  4. 使用步骤3中的量化参数对FP32模型进行量化得到INT8模型,并将其部署至推理框架进行推理。

PTQ方法会使用小部分数据集来估计网络各层参数的数据分布,找到合适的S和Z的取值,从而一定程度上降低模型精度损失。然而,论文中指出PTQ方式虽然在大模型上效果较好(例如ResNet101),但是在小模型上经常会有较大的精度损失(例如MobileNet) 不同通道的输出范围相差可能会非常大(大于100x), 对异常值较为敏感。

Quantization-Aware-Training(QAT)

由上文可知PTQ方法中模型的训练和量化是分开的,而Quantization-Aware-Training(QAT)方法则是在模型训练时加入了伪量化节点,用于模拟模型量化时引起的误差,并通过微调使得模型在量化后尽可能减少精度损失。以INT8量化为例,QAT方法的处理流程为:

  1. 首先在数据集上以FP32精度进行模型训练,得到训练好的FP32模型;
  2. 在FP32模型中插入伪量化节点,得到QAT模型,并且在数据集上对QAT模型进行微调(Fine-tuning);
  3. 同PTQ方法中的采样(Calibration),并计算量化参数 S 和 Z ;
  4. 使用步骤3中得到的量化参数对QAT模型进行量化得到INT8模型,并部署至推理框架中进行推理。

在PyTorch中,可以使用 torch.quantization.quantize_dynamic() 方法来执行 QAT。这是一个基本的 QAT 代码示例:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import quantize_dynamic, QuantStub, DeQuantStub

# 定义简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.quant = QuantStub()
        self.dequant = DeQuantStub()
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.quant(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x)
        return x

# 数据加载
# 这里使用 MNIST 数据集作为示例
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

train_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform),
                                           batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./data', train=False, download=True, transform=transform),
                                          batch_size=64, shuffle=False)

# 定义损失函数和优化器
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 定义 QAT 训练函数
def train(model, train_loader, criterion, optimizer, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data.view(data.shape[0], -1))
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=5)

# 在训练完成后执行动态量化
quantized_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

# 评估量化模型
def test(model, test_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data.view(data.shape[0], -1))
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = correct / total
    print(f'Accuracy of the network on the test images: {accuracy * 100:.2f}%')

# 测试量化模型
test(quantized_model, test_loader, criterion)

上述代码示例中,我使用了一个简单的全连接神经网络,并在训练完成后使用torch.quantization.quantize_dynamic()对模型进行动态量化。在量化之前,我们通过QuantStub()DeQuantStub()添加了量化和反量化的辅助模块。这个示例使用了MNIST数据集,你可以根据你的实际需求替换成其他数据集和模型。

参考文献

量化感知训练(Quantization-aware-training)探索-从原理到实践

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/137011.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python---字符串 lstrip()--删除字符串两边的空白字符、rstrip()--删除字符串左边的空白字符、strip()--删除字符串右边的空白字符

strip() 方法主要作用:删除字符串两边的空白字符(如空格) lstrip() 方法 left strip,作用:只删除字符串左边的空白字符 left 英 /left/ 左 rstrip() 方法 right strip,作用:只删除字符…

阿里云严重故障,影响阿里系、淘宝、饿了么、语雀等都崩了...

作者:JavaPub 编程学习一条龙:http://luxian.javapub.net.cn 就在一年一度的双十一剁手节火热进行时,阿里云服务出现了严重故障。 关键是前不久刚发生了语雀事件,不了解的朋友阅读这里 阿里语雀突发P0级事故,一度崩溃…

Vector - CANoe - Vector Hardware Manager基础介绍

经常使用CANoe的人都知道,我们之前使用配置VN系列硬件通道的时候使用的是Vector Hardware Config,非常的方便,不过在Vector Driver Setup驱动版本大于22.14后,为了更好的适用车载以太网相关的配置,以及各个配置之间继承…

【Qt之Model/View】编程

Model/View编程介绍 Qt包含一组使用模型/视图架构来管理数据和用户呈现的关系的视图类。此架构引入的功能分离使开发人员可以更灵活地自定义项的呈现方式,并提供标准的模型接口,以允许各种数据源与现有项视图一起使用。在本文档中,我们简要介…

Leetcode—20.有效的括号【简单】

2023每日刷题&#xff08;二十七&#xff09; Leetcode—20.有效的括号 C实现代码 class Solution { public:bool isValid(string s) {stack<char> arr;int len s.size();if(len 1) {return false;}for(int i 0; i < len; i) {if(s[i] ( || s[i] [ || s[i] {)…

HDMI之编码篇

概述 HDMI 2.0b(含)以下版本,采用3个Channel方式输出。传输又分为3三种周期,视频数据,数据岛以及控制周期。视频传输采用8/10编码。数据岛采用4/10编码(TERC4)。控制周期采用2/10。编码都拓展成了10bits。 上图中,Pixel component(e.g.B)->D[7:0]表示视频数据周期…

深入了解JVM和垃圾回收算法

1.什么是JVM&#xff1f; JVM是Java虚拟机&#xff08;Java Virtual Machine&#xff09;的缩写&#xff0c;是Java程序运行的核心组件。JVM是一个虚拟的计算机&#xff0c;它提供了一个独立的运行环境&#xff0c;可以在不同的操作系统上运行Java程序。 2.如何判断可回收垃圾…

Git Commit 之道:规范化 Commit Message 写作指南

1 commit message 规范 commit message格式都包括三部分&#xff1a;Header&#xff0c;Body和Footer <type>(<scope>): <subject><body><footer>Header是必需的&#xff0c;Body和Footer则可以省略 1.1 Header Type&#xff08;必需&#xf…

JavaFX(其他控件02)(综合运用)

小技巧 图片控件的使用:Image/ImageViewnew ImageView(new Image(url,宽,高,true,true))--绝对路径: file:D:\\图片\\6.jpg --相对路径: src里面建了个文件夹 images/1.png滑块&#xff1a;Slider show(true) major(10) getValue() 保留2位小数&#xff1a;String.format(&q…

浙大恩特客户资源管理系统 fileupload.jsp 任意文件上传

一、漏洞描述 杭州恩软信息技术有限公司&#xff08;浙大恩特&#xff09;提供外贸管理软件、外贸客户管理软件等外贸软件&#xff0c;是一家专注于外贸客户资源管理及订单管理产品及服务的综合性公司。 浙大恩特客户资源管理系统中的fileupload.jsp接口存在安全漏洞&#xf…

VS Code画流程图:draw.io插件

文章目录 简介快捷键 简介 Draw.io是著名的流程图绘制软件&#xff0c;开源免费&#xff0c;对标Visio&#xff0c;用过的都说好。而且除了提供常规的桌面软件之外&#xff0c;直接访问draw.io就可以在线使用&#xff0c;堪称百分之百跨平台&#xff0c;便捷性直接拉满。 那么…

【python后端】- 初识Django框架

Django入门 &#x1f604;生命不息&#xff0c;写作不止 &#x1f525; 继续踏上学习之路&#xff0c;学之分享笔记 &#x1f44a; 总有一天我也能像各位大佬一样 &#x1f31d;分享学习心得&#xff0c;欢迎指正&#xff0c;大家一起学习成长&#xff01; 文章目录 Django入门…

SpringBoot的Data开发篇:整合JDBC、整合MybatisMP,YAML文件加密的实现,数据项目监控平台的使用和实现

SpringBoot整合JDBC 实现步骤&#xff1a; 导pom文件坐标 除springboot启动器和test坐标外&#xff0c;还需要导入spring jdbc和mysql的坐标 <dependencies><!--Spring JDBC--><dependency><groupId>org.springframework.boot</groupId><art…

【fast2021论文导读】 Learning Cache Replacement with Cacheus

文章:Learning Cache Replacement with Cacheus 导读摘要: 机器学习的最新进展为解决计算系统中的经典问题开辟了新的、有吸引力的方法。对于存储系统,缓存替换是一个这样的问题,因为它对性能有巨大的影响。 本文第一个贡献,确定了与缓存相关的特征,特别是,四种工作负载…

JSP运行环境搭建

将安装JSP引擎的计算机称作一个支持JSP的Web服务器。这个服务器负责运行JSP&#xff0c;并将运行结果返回给用户。 JSP的核心内容之一就是编写JSP页面,JSP页面是Web应用程序的重要组成部分之一。一个简单Web应用程序可能只有一个JSP页面,而一个复杂的Web应用程序可能由许多JSP…

CnosDB 狂欢!全面支持 Helm 部署,轻松搞定你的分布式时序数据库!

大家好&#xff01;今天有个热辣新闻要和大家分享——CnosDB 狂欢时刻来啦&#xff0c;全面支持 Helm 部署&#xff01;如果你是物联网、工业互联网、车联网或者IT运维的粉丝&#xff0c;那你绝对不能错过这个重磅消息&#xff01; CnosDB Helm Chart 究竟是啥&#xff1f; 别…

网络通讯基础

Socket Socket是应用层与TCP/IP协议簇通信的中间软件抽象层&#xff0c;它是一组接口。Socket通常用于实现客户端和服务器之间的通信。它允许客户端应用程序与服务器应用程序建立连接&#xff0c;并通过网络传输数据。 Socket包含了网络通讯必须的5种信息 Socket例子 { 协议: …

计算机毕业设计 基于Vue篮球联盟管理系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

vue:实现顶部消息横向滚动通知

前言 最近有个需求&#xff0c;是在系统顶部展示一个横向滚动的消息通知。需求很简单&#xff0c;就是消息内容从右往左一直滚动。 效果如下&#xff1a; 因为我的需求很简单&#xff0c;功能就这样。如果有什么其他需求&#xff0c;可以再继续修改。 代码 使用 <noti…

《深入理解计算机系统》书籍学习笔记 - 第二课 - 位,字节和整型

Lecture 02 Bits,Bytes, and Integer 位&#xff0c;字节和整型 文章目录 Lecture 02 Bits,Bytes, and Integer 位&#xff0c;字节和整型Byte 字节位操作布尔代数集合的表现形式和操作C语言的逻辑操作 位移操作整型数值范围无符号与有符号数值无符号与有符号在C中 拓展和截断拓…