QAT量化 demo

一、QAT量化基本流程

QAT过程可以分解为以下步骤:

  1. 定义模型:定义一个浮点模型,就像常规模型一样。
  2. 定义量化模型:定义一个与原始模型结构相同但增加了量化操作(如torch.quantization.QuantStub())和反量化操作(如torch.quantization.DeQuantStub())的量化模型。
  3. 准备数据:准备训练数据并将其量化为适当的位宽。
  4. 训练模型:在训练过程中,使用量化模型进行正向和反向传递,并在每个 epoch 或 batch 结束时使用反量化操作计算精度损失。
  5. 重新量化:在训练过程中,使用反量化操作重新量化模型参数,并使用新的量化参数继续训练。
  6. Fine-tuning:训练结束后,使用fine-tuning技术进一步提高模型的准确率。

在这里插入图片描述

二、QAT量化代码示例

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.quantization import QuantStub, DeQuantStub, quantize_dynamic, prepare_qat, convert

# 模型
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 量化
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 10)
        # 反量化
        self.dequant = DeQuantStub()

    def forward(self, x):
        # 量化
        x = self.quant(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        # 反量化
        x = self.dequant(x)
        return x

# 数据
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
train_data = datasets.CIFAR10(root='./data', train=True, download=True,transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1,shuffle=True, num_workers=0)

# 模型 优化器
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Prepare the model
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)

# 训练
model.train()
for epoch in range(1):
    for i, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print('Epoch: [%d/%d], Step: [%d/%d], Loss: %.4f' %
                  (epoch+1, 10, i+1, len(train_loader), loss.item()))

    # Re-quantize the model
    model = quantize_dynamic(model, {'': torch.quantization.default_dynamic_qconfig}, dtype=torch.qint8)

# 微调
model.eval()
for data, target in train_loader:
    model(data)
model = convert(model, inplace=True)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/538721.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Canvas技法】四条C形色带填满一个圆/环形

【关键点】 通过三角函数计算控制点的位置。 【成果图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>四条C形色带填满一个…

【复现】浙大恩特客户资源管理系统 SQL注入漏洞_71

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一&#xff1a; 四.修复建议&#xff1a; 五. 搜索语法&#xff1a; 六.免责声明 一.概述 浙大恩特客户资源管理系统是一款针对企业客户资源管理的软件产品。该系统旨在帮助企业高效地管理和利用客户资源&#xff0c;提升…

【进阶篇】三、Java Agent实现自定义Arthas工具

文章目录 0、客户端代码1、JMX2、实现&#xff1a;查看内存使用情况3、实现&#xff1a;查看直接内存4、实现&#xff1a;生成堆内存快照5、实现&#xff1a;打印栈信息6、实现&#xff1a;打印类加载器的信息7、实现&#xff1a;打印类的源码8、需求&#xff1a;打印方法的耗时…

OpenHarmony开发学习:【源码下载和编译】

本文介绍了如何下载鸿蒙系统源码&#xff0c;如何一次性配置可以编译三个目标平台&#xff08;Hi3516&#xff0c;Hi3518和Hi3861&#xff09;的编译环境&#xff0c;以及如何将源码编译为三个目标平台的二进制文件。 坑点总结&#xff1a; 下载源码基本上没有太多坑&#xff…

【Android surface 】二:源码分析App的surface创建过程

文章目录 画布surfaceViewRoot的创建&setView分析setViewrequestLayoutViewRoot和WMS的关系 activity的UI绘制draw surfacejni层分析Surface无参构造SurfaceSessionSurfaceSession_init surface的有参构造Surface_copyFromSurface_writeToParcelSurface_readFromParcel 总结…

无人机概述

1、中英文对照表 中文中文简称英文全称英文简称无人驾驶飞机无人机Unmanned Aerial VehicleUAV无人机自组织网络无人机网络flying Ad-Hoc networkFANET 2、相关概念 2.1鲁棒性 网络鲁棒性是指网络系统在面对随机故障、蓄意攻击或其他异常情况时&#xff0c;能够保持其基本功…

Linux的网口名字的命名规则

在工作中&#xff0c;偶尔看到有些机器的网口名字是以ethX命令&#xff0c;有些则以enpXsX这种名字命名。网上的资料说的都不太明白,资料也无据可查&#xff0c;很难让人信服。于是决定自己查了下官方的资料和源码&#xff0c;把这些搞清楚。 官方文档&#xff1a;Predictable…

MobX进阶:从基础到高级特性全面探索

MobX 提供了丰富的高级特性,包括计算属性、反应式视图、中间件、异步流程控制、依赖注入和动态 observable 、在服务端渲染和 TypeScript 支持方面提供了良好的集成。这些特性进一步增强了 MobX 在状态管理方面的灵活性和可扩展性,使其成为一个功能强大、易于使用的状态管理解决…

【NLP练习】调用Gensim库训练Word2Vec模型

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 一、准备工作 1.安装Gensim库 使用pip安装&#xff1a; !pip install gensim2. 对原始语料分词 选择《人民的民义》的小说原文作为语料&#xff0c;先采用…

在Windows 10中打开PowerShell的几种方法,总有一种适合你

PowerShell是一种比命令提示符更强大的命令行shell和脚本语言。自Windows10发布以来,它已成为默认选择,并且有许多方法可以打开它。 PowerShell和命令提示符之间的区别是什么 PowerShell的使用更复杂,但它比命令提示符强大得多。这就是为什么它成为超级用户和it专业人员的…

CTF之Flask_FileUpload

拿到题目就一个上传文件的网页 查看源码&#xff0c;发现注释告诉我们会运行python的文件&#xff0c;但是系统只能上传图片格式&#xff08;这个是自己尝试知道的&#xff09; 那我们就写一个python代码改成jpg或者png格式的文件 内容为 import os os.system(cat /flag) 上传…

算法 囚犯幸存者

题目 主类 public static List<Prisoner> prisoners new ArrayList<Prisoner>(); public static List<Prisoner> remainPrisoners new ArrayList<Prisoner>(); public static Prisoner lastPrisoner null;public static void main(String[] args) …

jeecg-boot安装

我看大家都挺关注&#xff0c;所以集中上传了下代码和相关工具&#xff0c;方便大家快速完成 链接&#xff1a;https://pan.baidu.com/s/1-Y9yHVZ-4DQFDjPBWUk4-A 提取码&#xff1a;op1r 1. 下载代码 下载地址 : JEECG官方网站 - 基于BPM的低代码开发平台(低代码平台_零代…

单链表的应用

上篇博客中&#xff0c;我们学习了单链表&#xff0c;为了更加熟练掌握这一知识点&#xff0c;就让我们将单链表的应用操练起来吧&#xff01; 203. 移除链表元素 - 力扣&#xff08;LeetCode&#xff09; 思路一&#xff1a;遍历原链表&#xff0c;将值为val的节点释放掉。 …

华为校园公开课走入上海交大,鸿蒙成为专业核心课程

4月12日&#xff0c;华为校园公开课在中国上海交通大学成功举办&#xff0c;吸引了来自计算机等相关专业的150余名学生参加。据了解&#xff0c;由吴帆、陈贵海、过敏意、吴晨涛、刘生钟等教授在中国上海交通大学面向计算机系本科生开设的《操作系统》课程&#xff0c;是该系学…

CS224N第二课作业--word2vec与skipgram

文章目录 CS224N: 作业2 word2vec (49 Points)1. Math: 理解 word2vec计算 J n a i v e − s o f t m a x ( v c , o , U ) J_{naive-softmax}(v_c, o, U) Jnaive−softmax​(vc​,o,U) 关于 v c v_c vc​ 的偏导数计算 J n a i v e − s o f t m a x ( v c , o , U ) J_{na…

【SpringBoot】mybatis-plus实现增删改查

mapper继承BaseMapper service 继承ServiceImpl 使用方法新增 save,updateById新增和修改方法返回boolean值,或者使用saveOrUpdate方法有id执行修改操作,没有id 执行新增操作 案例 Service public class UserService extends ServiceImpl<UserMapper,User> {// Au…

第四百五十六回

文章目录 1. 概念介绍2. 思路与方法2.1 实现思路2.2 使用方法 3. 内容总结 我们在上一章回中介绍了"overlay_tooltip用法"相关的内容&#xff0c;本章回中将介绍onBoarding包.闲话休提&#xff0c;让我们一起Talk Flutter吧。 1. 概念介绍 我们在本章回中介绍的onBo…

Python 以点生成均匀二维圆点数据

已知一个数&#xff0c;但这个数不是最对的&#xff0c;最对的可能在它附近&#xff0c;想要从附近随机生成一群数&#xff0c;放入模型中暴力查找最对的那个值。 以下是代码片段 n 800 # 个数 m 2 # 角度&#xff0c;只有2才是均匀的&#xff0c;1为半圆&#xff0c;以此类…

HashMap的常见问题

Entry中的hash属性为什么不直接使用key的hashCode()返回值呢&#xff1f; 不管是JDK1.7还是JDK1.8中&#xff0c;都不是直接用key的hashCode值直接与table.length-1计算求下标的&#xff0c;而是先对key的hashCode值进行了一个运算&#xff0c;JDK1.7和JDK1.8关于hash()的实现…