0基础学习地平线QAT量化感知训练

文章目录

  • 1. 背景
  • 2. 基础理论知识
  • 3. 文件准备与程序运行
  • 4. 代码详解
    • 4.1 导入必要依赖
    • 4.2 主函数
    • 4.3 构建fx模式所需要的float_model
    • 4.4 不同阶段模型的获取
    • 4.5 定义常规模型训练与验证的函数
    • 4.6 float与qat训练代码解读——float_model/qat_model
    • 4.7 模型校准部分的代码解读——calib_model
    • 4.8 量化模型看精度代码解读——quantized_model
    • 4.9 编译生成上板模型——script_model/model.hbm
  • 5. 建议or吐槽

1. 背景

首先感谢一下地平线工具链用户手册和官方提供的示例,给了我很大的帮助,特别是代码的注释写了很多的知识点,超赞!要是注释能再详细点,就是超超赞了!下面开始正文。
以前从0开始学习过地平线的PTQ(后量化)方案,写了一些基础知识文章,后来发现地平线的用户手册关于PTQ方面其实挺完善的,东西很多很全,就没再想着写。
最近想着学QAT(量化感知训练)玩玩,大体看了一下地平线的用户手册,不说精度调优、性能调优之类比较复杂的,光一个QAT上手,就感觉对我这种小白不是很友好,比如我这种小白,捣鼓了好久,感觉在用户手册中很多基础概念都没写,不同模块之间的关联性也没有详细地介绍,直到我“精读”用户手册 4.2量化感知训练(QAT) ,发现了这么一句话,

懂了,没用过Pytorch的QAT,直接看手册学起来有点费劲才是正常滴!
那针对只使用过Pytorch在服务器上训练过一些分类、检测模型,没接触过QAT的小白,又不想读PyTorch官方文档,只想简单入个门,怎么办嘞?欢迎看看这篇文章,提供实操代码和运行步骤,如果文章对你有点作用的话,麻烦收藏+点个赞再走~

该文章参考自J5 OE1.1.52中对应的示例以及用户手册,为啥不是用的XJ3 OE,请看第5节吐槽部分

2. 基础理论知识

深度学习量化通常是指以int类型的数据代替浮点float类型的数据进行计算和存储,从而减小模型大小,降低带宽需求,理论上,INT8 量化,与常规的 FP32 模型相比,模型大小减少 4 倍,内存带宽需求减少 4 倍。
量化可以分为PTQ与QAT,

  • PTQ:Post-training Quantization,训练后量化,指浮点模型训练完成后,基于一些校准数据,直接通过工具自动进行模型量化的过程,相比QAT,PTQ更简单一些,这篇文章不介绍PTQ。
  • QAT:Quantization aware training,量化感知训练,指浮点模型训练完成后,在模型中插入伪量化节点再进行量化训练的过程,大体过程如下图所示,相比PTQ,QAT精度更有保障一些,这篇文章介绍QAT

小白:图中伪量化节点FakeQuantize node是什么?有什么作用?

大黑:从命名看,就是假装量化呗,模拟将数据从float类型量化为int类型,主要作用于网络的权重和激活(节点输出,不是relu这种激活函数的意思)。在QAT中,通过使用伪量化节点,可以在训练期间优化模型以适应后续的真实量化操作,从而提高量化模型的准确性和性能。一旦模型训练完成后,伪量化节点将被替换为真实的量化操作,以生成最终的量化模型。

小白:插入伪量化节点后需要Retraining/Funetuning?感觉很浪费资源的样子…

大黑:通常再多训 1/10 浮点阶段训练的轮数就好了,比如浮点阶段训练了100epoch,QAT训个10epoch就好,为了精度,浪费就浪费点,小问题!

小白:从上面这个图看,感觉QAT还挺简单的,其实目前我就只会用pytorch搭一个卷积网络,然后去训练,那我要经历哪些阶段才能得到最终上板部署的模型呢?

大黑:整个过程会涉及到以下几个模型:

在每个阶段,还有一些需要注意的地方,比如…

小白:停停停,先别急,这里面新名词有点多,先帮我捋捋。float_model和我直接用pytorch搭建的有什么不同吗?fx是什么?calib是什么?qat_model和quantized_model还不是一个意思?script_model又是哪儿冒出来的?板端部署hbm模型我知道,就是可以在板子上推理的模型,类似于PTQ里的bin模型对吧?

大黑:这一连串问题问的挺好,我下面逐个简单解释一下。

  • float_model和我直接用pytorch搭建的有什么不同吗?
    这里float_model浮点模型,其实就是在pytorch搭建的常规网络输入处插入QuantStub节点、输出处插入DeQuantstub节点,在PyTorch中,QuantStub/DequantStub 是一种用于量化的辅助工具,用于标记量化过程中需要量化/反量化的层或操作,前期浮点训练时可以当它不存在,在量化时会自动被替换为对应的量化操作。从普遍意义上说,每个分支都要对应插入QuantStub,别再追问为什么了,问就是甲鱼的臀部——“规定”。
  • fx是什么?
    pytorch中量化方式有两种,分别是Eager Mode Quantization和FX Graph Mode Quantization,它俩各有优劣。对于初学者,Eager模式需要手工修改网络代码,并对很多节点进行替换,比较复杂,而 FX模式不需要这种操作,使用起来比较简单,因此,推荐使用fx模式。
    关于fx与eager两种模式体现在地平线量化训练以及部署层面的差异,大家感兴趣的话,可参考地平线开发者社区专业介绍:QAT - 异构与非异构方案使用简介。
    地平线同时支持fx和eager两种模式,fx模式体现在地平线封装的各种函数中,例如prepare_qat_fx(),就是在函数最后有fx字样。
  • calib是什么?
    calib是校准calibration的缩写,主要作用是确定量化参数,我们知道,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然不做 Calibration 也可以进行量化训练,但一般来说,它对量化训练有益无害,所以推荐大家将此步骤作为必选项。
  • qat_model和quantized_model还不是一个意思?
    不一样的。
    qat_model是一种插入了伪量化节点的伪量化模型,简单理解为:它是为了量化训练而存在的模型,里面还“流淌”着浮点的参数,伪量化节点在模拟量化而已。
    quantized_model:模型中的浮点参数转换为定点参数,且把浮点算子转换成定点算子,这种转换后的模型称之为quantized_model /定点模型 / 量化模型。
  • script_model又是哪儿冒出来的?
    scipt_model是一种可以序列化的Torch脚本(TorchScript),方便在不需要Python解释器的环境中使用模型,例如C++应用程序、移动端应用等。scipt_model的获取通过torch.jit.trace实现。torch.jit.trace是PyTorch中的一个静态图转换工具,用于将一个PyTorch模型转换成一个可以序列化的Torch脚本(TorchScript)。其工作流程是,首先使用输入张量对模型进行前向计算,然后将计算图转换为Torch脚本。在这个过程中,PyTorch会执行所有与输入相关的计算,从而记录下计算图的结构和参数的值。
    以下是torch.jit.trace方法的基本语法:script_model = torch.jit.trace(model, example_inputs, optimize=True),其中,model是待转换的PyTorch模型,并不一定需要是quantized_model,普通的也可以,这里是QAT场景,因此是quantized_model。example_inputs是一个输入张量或元组,用于为模型执行前向计算,并记录计算图的结构和参数的值。optimize是一个布尔值,用于指定是否对转换后的计算图进行优化。默认情况下,optimize为True,将对计算图进行常量折叠、运算融合等优化。
  • 板端部署hbm模型我知道,就是可以在板子上推理的模型,类似于PTQ里的bin模型对吧?
    非常对。

小白:这些模型是如何生成的?通过图中那几个函数?是地平线封装好的,直接用?
大黑:是的。

3. 文件准备与程序运行

  • 一共就需要3个文件
(plugin) [xxx plugin_basic]$ tree -L 3
.
|-- data
|   |-- cifar-10-batches-py    					# cifar10数据集
|-- mobilenet_example_release_fx_only.py    	# 代码
|-- model
|   `-- mobilenetv2
|       |-- mobilenet_v2-b0353104.pth       	# 预训练权重

为了方便大家获取,以上文件均存放在网盘链接中:

链接:https://pan.baidu.com/s/1yJjjWEOB9rtHug77yA5mIw 
提取码:zdi5

代码运行,建议在地平线提供的docker里运行,当然,如果大家自己会配置本地环境的话,也可以不用docker,我两种都试了,都是ok的。

  • 运行过程
# 生成float-checkpoint.ckpt
python3 mobilenet_example_release_fx_only.py --stage=float 
# 生成calib-checkpoint.ckpt   
python3 mobilenet_example_release_fx_only.py --stage=calib
# 生成qat-checkpoint.ckpt    
python3 mobilenet_example_release_fx_only.py --stage=qat
# 使用定点quantized model evaluate一次      
python3 mobilenet_example_release_fx_only.py --stage=int_infer    
# 编译生成model.hbm,并对script_model进行可视化
python3 mobilenet_example_release_fx_only.py --stage=compile    

特别是在stage=compile,产出物有点多,在这儿具体介绍一下

模型名称模型解读
int_model.pttorch.jit.save(script_model, “int_model.pt”)生成的,指 torchscript 模型
model.ptcompile_model函数产出的中间产物,和int_model.pt是一回事,指 torchscript 模型
model.hbircompile_model函数产出的中间产物,用于出现问题时提供给地平线技术支持分析,我们不需要关注
model.hbmcompile_model函数产出的最终产物,即板端可部署模型
xxx.htmlperf_model函数的产物,两个html文件,里面提供一些编译器层面分析出的性能信息

运行完全程,所有文件如下图:

跑起来很简单,下面再和大家一起看看代码层面的情况。

4. 代码详解

该章节参考地平线用户手册:XJ3用户手册 4.2.3 快速上手、J5用户手册 4.2.3. 快速入门,由于XJ3 OE包中未提供对应示例,代码参考的是J5 OE ddk/samples/ai_toolchain/horizon_model_train_sample/plugin_basic/mobilenet_example_release.py,OE包中代码是fx模式和eager模式混合在一起的,为了防止大家搞混,我给拆开了,这里只放fx模式的例子,其实XJ3用户手册 4.2.3 快速上手、J5用户手册 4.2.3. 快速入门都有提供fx模式对应ipynb的代码,只是我不太习惯而已,大家可以根据自己偏好使用。

4.1 导入必要依赖

之所以写这一节,主要是希望大家可以从注释中,简单了解各个函数的作用,像torch、os这种导入就省略没写,全部的依赖可以看提供的代码。其中,horizon_plugin_pytorch是地平线基于 PyTorch 开发的 的量化训练工具,可以理解成numpy这种库,里面有很多用于量化训练的的依赖,我们直接用就好了。

# 定义程序需要接收哪些命令行参数,以及这些参数的类型、默认值等信息。
import argparse     
# torch中的一个类,主要用于将量化操作的结果转换回浮点数,也就是对输出数据转换回浮点数
from torch.quantization import DeQuantStub
# 用CIFAR10数据集,简单快速
from torchvision.datasets import CIFAR10
# 导入两个类,用来当父类,目的是构建float_model。model_urls是一个字典
from torchvision.models.mobilenetv2 import (
    InvertedResidual,
    MobileNetV2,
    model_urls,
)      
# 从url中下载预训练权重
from torchvision._internally_replaced_utils import load_state_dict_from_url
# 硬件芯片架构,J5:bayes;XJ3:bernoulli2,具体可看源码
from horizon_plugin_pytorch.march import March, set_march       
from horizon_plugin_pytorch.quantization import (
    QuantStub,      # 类似于torch中的类QuantStub,用于将输入数据量化,使用plugin中的QuantStub是因为它支持通过参数手动固定 scale
    convert_fx,     # 将伪量化模型qat_model转换为定点模型quantized_model
    prepare_qat_fx, # 将float模型转成calib/qat模型,变动表现:进行一些conv+bn等算子融合
    set_fake_quantize,  # 用于设置qat/calib model 伪量化状态,内参包括FakeQuantState
    FakeQuantState,     # 用于设置伪量化状态,有FakeQuantState.QAT用于qat model train,FakeQuantState.VALIDATION用于qat/calib model eval,FakeQuantState.CALIBRATION用于 calib eval
    check_model,        # hbdk中函数,用于检查模型是否可以被hbdk编译,本例中输入是可序列化的script_model,并给出一些根据硬件对齐规则可以提升性能的建议
    compile_model,      # hbdk中函数,用于编译生成可以上板的hbm模型
    perf_model,         # hbdk中函数,用于推测模型耗时等信息
    visualize_model,    # hbdk中函数,用于可视化算子优化替换后的模型结构
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,      # 校准时,模型总体伪量化节点的量化配置
    default_qat_8bit_fake_quant_qconfig,        # 量化训练时,模型总体伪量化节点的量化配置
    default_qat_out_8bit_fake_quant_qconfig,    # 模型输出的伪量化节点配置,用于配置输出conv节点高精度int32输出
    default_calib_out_8bit_fake_quant_qconfig,  # 和上一行是一个东西
)

4.2 主函数

看了第2节理论知识部分,主函数部分的代码就是严格执行那几个阶段stage(详见第2节),很easy,关于内部细节,在后面几个小节挨个介绍。

def main(
    stage: str,
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device_id: int = 0,
    quant_method: str = "fx",
    march: str = March.BAYES,
    compile_opt: int = 0,
):
    # 对应操作几个阶段的模型
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    assert quant_method in ("fx")

    device = torch.device(
        "cuda:{}".format(device_id) if device_id >= 0 else "cpu"
    )

    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)

    # 浮点训练阶段优化器
    def float_optim_config(model: nn.Module):
        # This is an example to illustrate the usage of QAT training tool, so
        # we do not fine tune the training hyper params to get optimized
        # float model accuracy.
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=2e-4)

        return optimizer, None

    # qat训练阶段优化器
    def qat_optim_config(model: nn.Module):
        # QAT training is targeted at fine tuning model params to match the
        # numerical quantization, so the learning rate should not be too large.
        optimizer = torch.optim.SGD(
            model.parameters(), lr=0.0001, weight_decay=2e-4
        )

        return optimizer, None

    default_epoch_num = {
        "float": 20,     
        "qat": 2,       # 通常float训练epoch数量是qat训练epoch数量的10倍
    }

    if stage in ("float", "qat"):
        if epoch_num is None:
            epoch_num = default_epoch_num[stage]

        train(
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            epoch_num,
            device,
            float_optim_config if stage == "float" else qat_optim_config,
            stage,
            march,
            quant_method,
        )

    elif stage == "calib":
        calibrate(
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            device,
            march=march,
            quant_method=quant_method,
        )

    elif stage == "int_infer":
        int_infer(
            data_path,
            model_path,
            eval_batch_size,
            device,
            march=march,
            quant_method=quant_method,
        )

    else:
        compile(
            data_path,
            model_path,
            compile_opt,
            march=march,
            quant_method=quant_method,
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run mobilenet example.")
    parser.add_argument(
        "--stage",
        type=str,
        choices=("float", "calib", "qat", "int_infer", "compile"),
        help=(
            "Pipeline stage, must be executed in following order: "
            "float -> calib(optional) -> qat(optional) -> int_infer -> compile"
        ),
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="data",
        help="Path to the cifar-10 dataset",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="model/mobilenetv2",
        help="Where to save the model and other results",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=256,
        help="Batch size for training",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=256,
        help="Batch size for evaluation",
    )
    parser.add_argument(
        "--epoch_num",
        type=int,
        default=None,
        help=(
            "Rewrite the default training epoch number, pass 0 to skip "
            "training and only do evaluation (in stage 'float' or 'qat')"
        ),
    )
    parser.add_argument(
        "--device_id",
        type=int,
        default=2,
        help="Specify which device to use, pass a negative value to use cpu",
    )
    parser.add_argument(
        "--quant_method",
        type=str,
        choices=["fx"],
        default="fx",
        help=(
            "Specify fx mode quantization."
            " Please do not change quant method "
            "between different stages, or the model may fail to load"
        ),
    )
    parser.add_argument(
        "--opt",
        type=str,
        choices=["0", "1", "2", "3", "ddr", "fast", "balance"],
        default=0,
        help="Specity optimization level for compilation",
    )
    args = parser.parse_args()
    print(args)

    main(
        args.stage,
        args.data_path,
        args.model_path,
        args.train_batch_size,
        args.eval_batch_size,
        args.epoch_num,
        args.device_id,
        args.quant_method,
        compile_opt=args.opt,
    )

4.3 构建fx模式所需要的float_model

从torchvision.models中继承MobileNetV2,微调一下,以支持量化相关操作。模型改造必要的操作有:

  • 在模型所有输入分支前插入 QuantStub
  • 在模型所有输出分支后插入 DequantStub

这部分具体实现过程解读可见代码注释。

# ----------------------------------------------------------------------------#
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# For FX mode, we need to:
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# Operation replacement and fusion will be carried out automatically (^_^).
# ----------------------------------------------------------------------------#
# 在PyTorch中,QuantStub/DequantStub 是一种用于量化的辅助工具,
# 用于标记量化过程中需要量化/反量化的层或操作,
# 前期浮点训练时当它不存在,在量化时会自动被替换为对应的量化操作
# ----------------------------------------------------------------------------#
# 从torchvision.models中继承MobileNetV2,微调一下
class FxQATReadyMobileNetV2(MobileNetV2):
    def __init__(
        self,
        num_classes: int = 10,      # 实例变量,使用self.来引用变量
        width_mult: float = 0.5,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
    ):
        super().__init__(   # 类变量,使用类名来引用变量,如ClassName.variable_name
            num_classes, width_mult, inverted_residual_setting, round_nearest
        )
        # --------------------------------------------------------------------#
        # 简单理解,在模型首尾部包一层类似于量化反量化操作,每个输入分支都需要包一下
        # --------------------------------------------------------------------#
        # 地平线plugin中的QuantStub可以配置scale
        # 这里的scale=1/128是后面模型输入配置为pyramid必备的
        # pyramid是地平线的芯片上的一个硬件,数据输入可以从这儿来,也可以从DDR来
        # --------------------------------------------------------------------#
        self.quant = QuantStub(scale=1 / 128)   
        self.dequant = DeQuantStub()

    def forward(self, x: Tensor) -> Tensor:
        x = self.quant(x)
        x = super().forward(x)
        x = self.dequant(x)

        return x

关于如何加载预训练权重部分的代码在函数load_pretrain里,详细内容可以看Python文件,这里不再呈现。

def load_pretrain(model: nn.Module, model_path: str):
    state_dict = load_state_dict_from_url(
        model_urls["mobilenet_v2"], model_dir=model_path, progress=True
    )   # model_urls是一个字典,取里面mobilenet_v2的对应url,下载路径到model_dir,progress是下载进度条显示

4.4 不同阶段模型的获取

在代码运行时,有个输入参数stage必须配置,表示拿到哪个model去整后面的事,当stage参数传入(“float”, “calib”, “qat”, “int_infer”)中某一个时,会通过如下函数去获取,具体实现过程解读可见代码注释。

# --------------------------------------------------------------------------#
# Next, we define the model convert pipeline to generate model for each stage.
# --------------------------------------------------------------------------#
def get_model(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BAYES,
    quant_method="fx",
) -> nn.Module:
    # 运行代码时,有个输入参数stage必须配置,表示拿到哪个model去整后面的事
    assert stage in ("float", "calib", "qat", "int_infer")
    assert quant_method in ("fx")

    model_kwargs = dict(num_classes=10, width_mult=1.0)
    float_model = FxQATReadyMobileNetV2(**model_kwargs).to(device)

    if stage == "float":
        # Load pretrained model (on ImageNet) to speed up float training.
        load_pretrain(float_model, model_path)

        return float_model      # float的时候,到这儿就退出了

    # 浮点训练完成后的权重
    float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
    assert os.path.exists(float_ckpt_path)
    float_state_dict = torch.load(float_ckpt_path, map_location=device)

    # A global march indicating the target hardware version must be setted
    # before prepare qat.
    set_march(march)

    # Preserve a clean float_model for calibration and qat training.
    ori_float_model = float_model         
    float_model = copy.deepcopy(ori_float_model)

    float_model.load_state_dict(float_state_dict)
    # -----------------------------------------------------------#
    # The op fusion is included in `prepare_qat_fx`.
    # -----------------------------------------------------------#
    # Make sure the output model is on target device.
    # CAUTION: prepare_qat_fx and convert_fx do not guarantee the
    # output model is on the same device as input model.
    # ----------------------------------------------------------#

    # ----------------从float_model转成calib_model----------------#
    float_model.qconfig = default_calib_8bit_fake_quant_qconfig
    # ----------------------------------------------------------------------#
    #   不配置输出层的qconfig,其输出默认是int8输出
    #   尾部conv/linear,calib和qat配置为
    #   default_{calib/qat}_out_8bit_fake_quant_qconfig时,表示int32高精度输出
    # ----------------------------------------------------------------------#
    float_model.classifier.qconfig = (
        default_calib_out_8bit_fake_quant_qconfig
    )
    calib_model = prepare_qat_fx(float_model).to(device)

    # calib stage时,函数到这儿就会返回了
    if stage == "calib":
        return calib_model

    calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
    assert os.path.exists(calib_ckpt_path)
    calib_state_dict = torch.load(calib_ckpt_path, map_location=device)

    # ---------------------------------------------#
    #   这一行是必须的,上面的float_model已经"变味"了
    # ---------------------------------------------#
    float_model = copy.deepcopy(ori_float_model)

    # 尾部conv/linear,qat配置为default_qat_out_***_qconfig时,可为int32高精度输出
    qat_model = prepare_qat_fx(
        float_model,        # 这儿必须是float_model,不能是calib_model,也不能是"变味"的float_model
        {
            "": default_qat_8bit_fake_quant_qconfig,
            "module_name": {
                "classifier": default_qat_out_8bit_fake_quant_qconfig,
            },
        },
    ).to(device)    # prepare_qat_fx 接口不保证输出模型的 device 和输入模型完全一致

    # qat_model加载的是calib_state_dict!!!
    qat_model.load_state_dict(calib_state_dict)

    if stage == "qat":    # qat阶段到这儿就退出了
        return qat_model

    qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")
    assert os.path.exists(qat_ckpt_path)
    qat_model.load_state_dict(torch.load(qat_ckpt_path, map_location=device))

    # 将模型转为定点状态
    # 通过参数转换把伪量化模型中的浮点参数转换成定点参数,
    # 并且把浮点算子转换成定点算子,该转换后的模型称为 Quantized 模型 / 定点模型 / 量化模型
    quantized_model = convert_fx(qat_model).to(device)

    return quantized_model    # int_infer阶段会到这儿才退出

4.5 定义常规模型训练与验证的函数

具体实现,看py代码就行,很常规。

# --------------------------------------------------------------------------#
# Next, we define dataloaders and other helper functions used in training
# and evaluation.
# --------------------------------------------------------------------------#

def prepare_data_loaders(
    data_path: str, train_batch_size: int, eval_batch_size: int
) -> Tuple[data.DataLoader, data.DataLoader]:


class AverageMeter(object):
    """Computes and stores the average and current value"""
    

def accuracy(output: Tensor, target: Tensor, topk=(1,)) -> List[Tensor]:
    """Computes the accuracy over the k top predictions for the specified values of k"""
    

def train_one_epoch(
    model: nn.Module,
    criterion: Callable,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    data_loader: data.DataLoader,
    device: torch.device,
) -> None:


def evaluate(
    model: nn.Module, data_loader: data.DataLoader, device: torch.device
) -> Tuple[AverageMeter, AverageMeter]:

4.6 float与qat训练代码解读——float_model/qat_model

针对float_model和qat_model的参数训练,代码解读如下,

# --------------------------------------------------------------------------#
# Next, we define the main function for each stage.
# --------------------------------------------------------------------------#

# Float and qat share the same training procedure.
def train(
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device: torch.device,
    optim_config: Callable,
    stage: str,
    march=March.BAYES,
    quant_method="fx",
):
    # --------------------------------------------#
    #   qat模型训练和普通浮点模型训练的不同之处!
    # --------------------------------------------#
    model = get_model(stage, model_path, device, march, quant_method)

    train_data_loader, eval_data_loader = prepare_data_loaders(
        data_path, train_batch_size, eval_batch_size
    )

    optimizer, scheduler = optim_config(model)

    best_acc = 0

    for nepoch in range(epoch_num):
        # Train/Eval state must be setted correctly
        # before `set_fake_quantize`
        model.train()
        # --------------------------------------------#
        #   qat模型训练和普通浮点模型训练的不同之处!
        # --------------------------------------------#
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.QAT)

        train_one_epoch(
            model,
            nn.CrossEntropyLoss(),
            optimizer,
            scheduler,
            train_data_loader,
            device,
        )

        model.eval()
        # --------------------------------------------#
        #   qat模型训练和普通浮点模型训练的不同之处!
        # --------------------------------------------#
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.VALIDATION)

        top1, top5 = evaluate(
            model,
            eval_data_loader,
            device,
        )
        print(
            "{} Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
                stage.capitalize(), nepoch, top1.avg, top5.avg
            )
        )

        if top1.avg > best_acc:
            best_acc = top1.avg

            torch.save(
                model.state_dict(),
                os.path.join(model_path, "{}-checkpoint.ckpt".format(stage)),
            )   # 可用于保存 float-checkpoint.ckpt 和 qat-checkpoint.ckpt

    # ----------------------------------------------#
    #   当传入epoch=1,用于qat eval
    # ----------------------------------------------#
    if nepoch == 0:
        model.eval()
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.VALIDATION)

        top1, top5 = evaluate(
            model,
            eval_data_loader,
            device,
        )
        print(
            "{} eval only: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
                stage.capitalize(), top1.avg, top5.avg
            )
        )   # stage.capitalize()表示将字符串首字母大写

    print("Best Acc@1 {:.3f}".format(best_acc))

    return model

4.7 模型校准部分的代码解读——calib_model

float模型训练完成后,需要进行参数校准,得到calib_model,如果calib_model精度满足要求,qat训练就不需要了,即使calib_model精度不行,calib_model_state_dict(校准后的权重)对qat训练收敛也非常有帮助。

def calibrate(
    data_path,
    model_path,
    calib_batch_size,
    eval_batch_size,
    device,
    num_examples=float("inf"),  # float("inf")表示无穷大,主要用于控制使用多少数据进行校准,默认使用所有数据集
    march=March.BAYES,
    quant_method="fx",
):
    calib_model = get_model("calib", model_path, device, march, quant_method)
    # Please note that calibration need the model in eval mode
    # to make BatchNorm act properly.
    calib_model.eval()  # 即使下面用的是train数据集,这儿也是eval
    # set CALIBRATION state will make FakeQuantize in training mode.
    set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)

    train_data_loader, eval_data_loader = prepare_data_loaders(
        data_path, calib_batch_size, eval_batch_size
    )

    with torch.no_grad():
        cnt = 0
        for image, target in train_data_loader:
            image, target = image.to(device), target.to(device)
            calib_model(image)
            print(".", end="", flush=True)
            cnt += image.size(0)
            if cnt >= num_examples:     # 主要用于控制使用多少数据进行校准,默认使用所有数据集
                break
        print()

    # Must set eval mode again before validation, because
    # set CALIBRATION state will make FakeQuantize in training mode.
    calib_model.eval()  
    set_fake_quantize(calib_model, FakeQuantState.VALIDATION)

    top1, top5 = evaluate(
        calib_model,
        eval_data_loader,
        device,
    )
    print(
        "Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

    torch.save(
        calib_model.state_dict(),
        os.path.join(model_path, "calib-checkpoint.ckpt"),
    )

    return calib_model

4.8 量化模型看精度代码解读——quantized_model

定点模型/quantized模型/量化模型 eval推理一下看看精度

# 定点模型/quantized模型/量化模型 eval推理一下看看精度
def int_infer(
    data_path,
    model_path,
    eval_batch_size,
    device,
    march=March.BAYES,
    quant_method="fx",
):
    # 定点模型/quantized模型/量化模型
    quantized_model = get_model(
        "int_infer", model_path, device, march, quant_method
    )

    _, eval_data_loader = prepare_data_loaders(
        data_path, eval_batch_size, eval_batch_size
    )

    top1, top5 = evaluate(
        quantized_model,
        eval_data_loader,
        device,
    )
    print(
        "Quantized: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

    return quantized_model

4.9 编译生成上板模型——script_model/model.hbm

编译生成上板模型model.hbm,同时针对script_model预估模型性能,并进行可视化

def compile(
    data_path,
    model_path,
    compile_opt=0,
    march=March.BAYES,
    quant_method="fx",
):
    # It is recommended to do compile on cpu, because associated interfaces
    # do not fully support cuda.
    device = torch.device("cpu")

    # 定点模型
    quantized_model = get_model(
        "int_infer", model_path, device, march, quant_method
    )

    # prepare_data_loaders(data_path: str, train_batch_size: int, eval_batch_size: int)
    _, eval_data_loader = prepare_data_loaders(data_path, 1, 1)

    # We can generate random input data (in proper shape) for
    # tracing and compiling and so on.
    # Use real data in `perf_model` will get more accurate perf result.
    example_input = next(iter(eval_data_loader))[0]     # Tensor

    # ------------------------------------------------------------------#
    #   torch.jit.trace是PyTorch中的一个静态图转换工具,
    #   用于将一个PyTorch模型转换成一个可以序列化的Torch脚本(TorchScript),
    #   以便在不需要Python解释器的环境中使用模型。
    #   model并不一定需要是quantized_model,普通的也可以,这里是QAT场景
    # ------------------------------------------------------------------#
    script_model = torch.jit.trace(quantized_model.cpu(), example_input)    # 单纯为了更保险,这儿再次加上.cpu()
    # 这个.pt结尾,就和手册中术语约定对上了:文档中的 pt 模型指 torchscript 模型
    torch.jit.save(script_model, os.path.join(model_path, "int_model.pt"))  

    check_model(script_model, [example_input], advice=1)

    compile_model(
        script_model,
        [example_input],
        hbm=os.path.join(model_path, "model.hbm"),
        input_source="pyramid",     # 上板时输入的数据来源,通常有ddr/resizer/pyramid,多输入时配置为字符串列表
        opt=compile_opt,
    )

    # hbdk预估模型性能,生成html文件,里面提供一些性能评测信息
    perf_model(
        script_model,
        [example_input],
        out_dir=os.path.join(model_path, "perf_out"),
        input_source="pyramid",
        opt=compile_opt,
        layer_details=True,     # html中会提供逐层算子耗时
    )

    # 可视化torchscript模型,也就是hbdk眼中的模型,会考虑到layout的变换、硬件对齐、算子融合、算子等效替换等情况
    visualize_model(
        script_model,
        [example_input],
        save_path=os.path.join(model_path, "model.svg"),
        show=False,
    )

    return script_model

5. 建议or吐槽

免责声明:纯纯吐槽,如有雷同,请勿当真!

  • 提供用户手册、提供上手示例,很棒!只是说好的快速上手示例,能麻烦大佬们写的基础一点嘛~

  • 一定要善于看源码,里面有函数的作用和使用方法的介绍,很有用!可惜我用vscode在docker里总是无法跳转,馋哭了,其实可以有个笨方法,如下图

  • 初次上手的例子,建议和我们说一个最标准的流程就好了,像float_model到底选用origin_float_model更好还是FxQATReadyModel更好?calib这一步到底要不要?qat_model到底加载float_state_dict更合适还是calib_state_dict更合适?这些问题在我初次看代码时产生了一些疑惑~

  • X3的OE包里,能否像J5 OE包里一样提供plugin_basic的例子?要不是J5 OE包也对外释放了,都学不到这种好东西,偏心了啊!

  • J5 OE包里提供的plugin_basic例子,能否把fx和eager拆开成两个py文件?放到一起,刚开始学的时候总搞混…(当然,也可能是我水平问题)

  • 用户手册中把快速上手部分全部可执行代码放出来,感觉还挺好的,适合我这种小白,当然,在OE包里还有一份全面的代码,感觉在手册里告诉我它在OE包里的位置,这样也可以接受。其实我想说:手册中更建议多放点需要跟着操作的步骤,或者理论介绍,或者代码多点注释,不是很理解为啥把全部log日志都贴出来了(4.2.3 快速上手)!输出日志部分,放点开头、结尾、关键部分说明意思就行,想看全部的话,我自己会去跑跑试一下的,难道手册有最低字数限制?

  • 想让尾部conv以高精度int32输出,竟然配置的是default_qat_out_8bit_fake_quant_qconfig,大大问号脸?明明是out_8bit啊!后来咨询技术支持,原来这里的8bit是weight的量化方式为8bit。感觉这个命名有点容易造成误解,不知道能否修改为qat_out_int32_weight_8bit_fake_quant_qconfig?(反正都已经很长了…)

  • OE包中看着提供了很多例子,但例子之间又有很多共用的代码,造成非常多的嵌套,我就参考其中一个,还得下载整个OE包,不知道能否拆开例子,放到github或者gitee上,想参考哪个我就下载哪个多好!

  • 能否给点从浮点训练 到 量化转换编译 到 上板部署(python/c++) 到 可视化 的全流程示例仓库,本来生态就不如英伟达,支持国产总得让我们用起来很顺溜才好吧!建议搞点全流程例子给我们!(理直气不壮)

都看到这儿了,如果对您有帮助的话,麻烦给点个赞呀~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/31444.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

七、docker-compose方式运行Jenkins,更新Jenkins版本,添加npm node环境

docker-compose方式运行Jenkins,更新Jenkins版本,添加npm node环境 一、docker-compose方式安装运行Jenkins 中发现Jenkins版本有点老,没有node环境,本节来说下更新jenkins 及添加构建前端的node环境。 1. 准备好docker-compose…

三种方法将Word文档转换为PDF文件格式

如何将Word文档转换为PDF文件格式呢?大家在传输文件时,很多人喜欢使用PDF文件格式,因为它非常稳定,不会出现格式混乱的问题。但有些人可能不知道如何进行转换,今天我将介绍三种转换方法,让我们一起来学习一…

mysql 删表引出的问题

背景 将测试环境的表同步到另外一个数据库服务器中,但有些表里面数据巨大,(其实不同步该表的数据就行,当时没想太多),几千万的数据!! 步骤 1. 既然已经把数据同步过来的话&#x…

环境配置 | Git的安装及配置[图文详情]

Git是一个开源的分布式版本控制系统,可以有效、高速地处理从小到大的项目版本管理。下面介绍了基础概念及详细的用图文形式介绍一下git安装过程. 目录 1.Git基础概念 2.Git的下载及安装 3.常见的git命令 Git高级技巧 Git与团队协作 1.Git基础概念 仓库&#…

认识异常

目录 异常的概念与体系结构 异常的概念 异常的体系结构 异常的分类 异常的处理 防御式编程 1. LBYL: 2. EAFP: 异常的抛出 异常的捕获 异常声明throws try-catch捕获并处理 关于异常的处理方式 finally 异常的处理流程 自定义异常类 异常的概念与体系结构 异常…

基于SpringBoot+kaptcha的验证码生成

教程 1.添加 Kaptcha 依赖 在 pom.xml 文件中添加 Kaptcha 依赖&#xff1a; <dependency><groupId>com.github.penggle</groupId><artifactId>kaptcha</artifactId><version>2.3.2</version> </dependency> <!--或者 都…

平凯星辰重磅支持 2023 开放原子全球开源峰会,开源数据库分论坛成功召开

2023 年 6 月 11 日至 13 日&#xff0c;以“开源赋能&#xff0c;普惠未来”为主题的 2023 开放原子全球开源峰会开幕式暨高峰论坛在北京成功举办。企业级开源分布式数据库厂商平凯星辰联合创始人兼 CTO 黄东旭受邀出席峰会参与开源论道圆桌&#xff0c;担任开源数据库分论坛出…

第一章 数据可视化简介(复习)

第一章 数据可视化简介 什么是可视化 定义&#xff1a;通过可视表达增强人们完成某些 任务的效率 The American Heritage Dictionary&#xff1a; The act or process of interpreting in visual terms or of putting into visible form&#xff08;用可视形式进行解释的 动作…

VUE L ∠脚手架 配置代理 ⑩⑧

目录 文章有误请指正&#xff0c;如果觉得对你有用&#xff0c;请点三连一波&#xff0c;蟹蟹支持✨ V u e j s Vuejs Vuejs初识 V u e C L I VueCLI VueCLI C L I CLI CLI V u e Vue Vue配置代理 C L I CLI CLI配置方法一 C L I CLI CLI配置方法二 C L I CLI CLI V u …

EL标签-给JSP减负

https://blog.csdn.net/weixin_42259823/article/details/85945149 安装使用 1. 通过命令行创建maven项目 2. 安装jstl包 <dependency><groupId>jstl</groupId><artifactId>jstl</artifactId><version>1.2</version> </depen…

【工程项目管理】工程项目管理实践报告

前言&#xff1a; 1.大学课程的大作业&#xff0c;觉得存起来也没什么用就干脆发出来了。。。 2.很可能有不严谨之处&#xff0c;各位看官如若发现欢迎指出~ 创作者文章管理系统 1 实践环节作业1&#xff1a;选题及任务分解WBS &#xff08;1&#xff09;选题 a.项目名称&a…

【论文阅读】Adap-t: Adaptively Modulating Embedding Magnitude for Recommendation

【论文阅读】Adap-&#x1d70f;: Adaptively Modulating Embedding Magnitude for Recommendation 文章目录 【论文阅读】Adap-&#x1d70f;: Adaptively Modulating Embedding Magnitude for Recommendation1. 来源2. 介绍3. 模型解读3.1 准备工作3.1.1 任务说明3.1.2 基于嵌…

「已解决」已有Umi Antd 环境下安装 formily v2 依赖报错问题

背景 在一个项目中想引入 formily v2 试一下这个针对复杂表单的解决方案&#xff0c;结果发现安装后报错&#xff0c;目前已有的第三方库大致为 “ant-design/icons”: “^5.0.1”, “ant-design/pro-components”: “^2.4.4”, “umijs/max”: “^4.0.68”, “ahooks”: “^3…

textract OCR的安装使用

安装 pip install textract使用 在 Python 中&#xff0c;textract 是一个用于提取文本和信息的库。它提供了一个函数 textract.process()&#xff0c;用于处理不同类型的文档并提取文本内容。下面是 textract.process() 函数的各个参数的介绍&#xff1a; filename&#xf…

第3章 运输层

1​、在 ISO/OSI 参考模型中&#xff0c;对于运输层描述正确的有&#xff08; &#xff09; A. 为传输数据选择数据链路层所提供的最合适的服务B. 为系统之间提供面向连接的数据传输服务C. 可以提供端到端的差错恢复和流量控制&#xff0c;实现可靠的数据传输D. 提供路由选择…

HarmonyOS学习路之开发篇—多媒体开发(相机开发 一)

HarmonyOS相机模块支持相机业务的开发&#xff0c;开发者可以通过已开放的接口实现相机硬件的访问、操作和新功能开发&#xff0c;最常见的操作如&#xff1a;预览、拍照、连拍和录像等。 基本概念 相机静态能力 用于描述相机的固有能力的一系列参数&#xff0c;比如朝向、支持…

20分钟做一套采购审批系统

1、设计输入模板 excel画表格界面 # 公式代表新建时以默认值代替 2、设置单元格为签名控件 双击单元格后&#xff0c;会默认显示当前用户的信息,用于签名 3、设置要合计的数据 生成的合计公式会默认放到下一行 4、设置单元格的ID与标题&#xff0c;在添加或者删除行或者列时&am…

GEE:为每个对象(斑块/超像素)添加属性

作者:CSDN @ _养乐多_ 本文将介绍为每个对象(斑块/超像素)添加属性的代码。并举例将最近距离作为属性添加到每个对象(斑块/超像素)特征中。 结果如下图所示, 文章目录 一、代码二、代码链接一、代码 这段代码的目的是对动态世界土地覆盖图像进行分析,并提取出其中的目…

第九章 ShuffleNetv1网络详解

系列文章目录 第一章 AlexNet网络详解 第二章 VGG网络详解 第三章 GoogLeNet网络详解 第四章 ResNet网络详解 第五章 ResNeXt网络详解 第六章 MobileNetv1网络详解 第七章 MobileNetv2网络详解 第八章 MobileNetv3网络详解 第九章 ShuffleNetv1网络详解 第十章…

React之state详解

目录 执行过程 异步 React18与自动批处理 setState 推荐用法 ()>{return }&#xff0c;this.state. 生命周期 数据没改变时​不渲染 shouldComponentUpdate PureComponent自动&#xff08;推荐&#xff09; 你真的理解setState吗&#xff1f; - 掘金 组件的私有…