专栏算法工具链0基础学习地平线J6 QAT量化感知训练

0基础学习地平线J6 QAT量化感知训练

Jade-self2024-09-02
555
0

1. 背景

首先感谢地平线工具链用户手册和官方提供的示例,给了我很大的帮助,特别是代码注释写了很多的知识点,超赞!要是注释能再详细点,就是超超赞了!下面开始正文。

最近想着学QAT(量化感知训练)玩玩,大体看了一下地平线的用户手册,不说精度调优之类比较复杂的,光一个QAT上手,就感觉对我这种小白不是很友好,捣鼓了好久,感觉在用户手册中很多基础概念都没写,不同模块之间的关联性也没有详细地介绍,后来发现在用户手册 量化感知训练(QAT) 简介章节 ,有这么一句话:

Description

懂了,没用过Pytorch的QAT,直接看手册学起来有点费劲才是正常滴!

那针对只使用过Pytorch在服务器上训练过一些分类、检测模型,没接触过QAT的小白,又不想读PyTorch官方文档,只想简单入个门,怎么办嘞?欢迎看看这篇文章,提供实操代码和运行步骤,如果文章对你有点作用的话,麻烦收藏+点个赞再走~

该文章参考自J6 OE3.0.17中对应的示例以及用户手册

2. 基础理论知识

深度学习量化通常是指以int类型的数据代替浮点float类型的数据进行计算和存储,从而减小模型大小,降低带宽需求,理论上,INT8 量化,与常规的 FP32 模型相比,模型大小减少 4 倍,内存带宽需求减少 4 倍。
量化可以分为PTQ与QAT,
  • PTQ:Post-training Quantization,训练后量化,指浮点模型训练完成后,基于一些校准数据,直接通过工具自动进行模型量化的过程,相比QAT,PTQ更简单一些,这篇文章不介绍PTQ。

  • QAT:Quantization aware training,量化感知训练,指浮点模型训练完成后,在模型中插入伪量化节点再进行量化训练的过程,大体过程如下图所示,相比PTQ,QAT精度更有保障一些,这篇文章介绍QAT。

Description
小白:图中伪量化节点FakeQuantize node是什么?有什么作用?
大黑:从命名看,就是假装量化呗,模拟将数据从float类型量化为int类型,主要作用于网络的权重和激活(节点输出,不是relu这种激活函数的意思)。在QAT中,通过使用伪量化节点,可以在训练期间优化模型以适应后续的真实量化操作,从而提高量化模型的准确性和性能。一旦模型训练完成后,伪量化节点将被替换为真实的量化操作,以生成最终的量化模型。
小白:插入伪量化节点后需要Retraining/Funetuning?感觉很浪费资源的样子...
大黑:通常再多训 1/10 浮点阶段训练的轮数就好了,比如浮点阶段训练了100epoch,QAT训个10epoch就好,为了精度,浪费就浪费点,小问题!
小白:从上面这个图看,感觉QAT还挺简单的,其实目前我就只会用pytorch搭一个卷积网络,然后去训练,那我要经历哪些阶段才能得到最终上板部署的模型呢?
大黑:整个过程会涉及到以下几个模型:
Description

在每个阶段,还有一些需要注意的地方,比如...

小白:停停停,先别急,这里面新名词有点多,先帮我捋捋。float_model和我直接用pytorch搭建的有什么不同吗?calib是什么?qat_model/qat.bc/quantized.bc这三者还不是一个意思?板端部署hbm模型我知道,就是可以在板子上推理的模型对吧?
大黑:这一连串问题问的挺好,我下面逐个简单解释一下。
  • float_model和我直接用pytorch搭建的有什么不同吗?
    这里float_model浮点模型,其实就是在pytorch搭建的常规网络输入位置插入QuantStub节点、输出位置插入DeQuantstub节点,用于区分量化和不量化的边界。
    在PyTorch中,QuantStub/DequantStub 是一种用于量化的辅助工具,用于标记量化过程中需要量化/反量化的层或操作,前期浮点训练时可以当它不存在,在定点量化时会自动被替换为对应的量化操作。从普遍而又常规意义上说,想让模型某部分量化,每个输入分支都要插入QuantStub,每个输出分支都要拆入DeQuantStub,别再追问为什么了,问就是甲鱼的臀部——“规定”。
  • calib是什么?
    calib是校准calibration的缩写,主要作用是确定量化参数,我们知道,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然可以不做 Calibration直接进行qat量化训练,但一般来说,校准对量化训练有益无害,所以推荐大家将此步骤作为必选项。
  • qat_model/qat.bc/quantized.bc这三者还不是一个意思?
    确实不是一个意思。
    qat_model:一种插入了伪量化节点的伪量化模型,还是torch模型,简单理解为:qat_model是为了量化训练而存在的模型,里面还“流淌”着浮点的参数,通过伪量化节点在模拟量化而已。
    qat.bc:相比于qat_model,多了一步查表算子定点化的操作,精度与qat_model可能会存在微小的差异。
    quantized.bc:模型中浮点算子转换成定点算子,浮点参数全部转换为定点参数,这种转换后的模型称之为quantized_model /定点模型 / 量化模型。
  • 板端部署hbm模型我知道,就是可以在板子上推理的模型对吧?
    非常对。
小白:这些模型是如何生成的?通过图中那几个函数?是地平线封装好的,直接用?
大黑:是的。

3. 文件准备与程序运行

参考OE/samples/ai_toolchain/horizon_model_train_sample/plugin_basic下的示例,主要使用fx_mode.py脚本。注意:这儿的fx_mode不是传统意义上pytorch中量化方式,而是地平线自己开发的jit-strip方式,它具有使用简单、debug方便等优点,对我们基础用户而言,别学太多,all in就行,反正不会的可以找地平线技术支持。

  • 文件目录

代码运行,建议在地平线提供的docker里运行,当然,如果大家自己会配置本地环境的话,也可以不用docker。

  • 运行过程

运行完全程,所有产出物文件如下图:
Description

跑起来很简单,下面再和大家一起看看代码层面的情况。

4. 代码详解

该章节参考地平线用户手册以及自己的理解进行介绍,主要是添加了一些中文注释。

4.1 导入必要依赖

之所以写这一节,主要是希望大家可以从注释中,简单了解各个函数的作用,像torch、os等对于我而言,特别基础的导入省略没写,全部的依赖可以看提供的代码。其中,horizon_plugin_pytorch是地平线基于 PyTorch 开发的 的量化训练工具,可以理解成numpy这种库,里面有很多用于量化训练的的依赖,我们直接用就好了。

  • common.py

4.2 主函数

看了第2节理论知识部分,主函数部分的代码就是严格执行那几个阶段stage(详见第2节),很easy,关于内部细节,在后面几个小节挨个介绍。

4.3 获取不同阶段模型get_model_fx

Example input

在代码运行时,有个输入参数stage必须配置,表示拿到哪个model去整后面的事,当stage参数传入("float", "calib", "qat", "int_infer")中某一个时,会通过如下函数去获取,具体实现过程解读可见代码注释。

4.4 构建float_model

从torchvision.models中继承MobileNetV2,微调一下,以支持量化相关操作。模型改造必要的操作有:

  • 在模型所有输入分支前插入 QuantStub

  • 在模型所有输出分支后插入 DequantStub

这部分具体实现过程解读可见代码注释。

关于如何加载预训练权重部分的代码在函数load_pretrain里,详细内容可以看Python文件,这里不再呈现。

4.5 定义常规模型训练与验证的函数

具体实现,看py代码就行,很常规。

4.7 模型校的代码解读—calib_model

float模型训练完成后,需要进行参数校准,得到calib_model,如果calib_model精度满足要求,qat训练就不需要了,即使calib_model精度不行,calib_model_state_dict(校准后的权重)对qat训练收敛也非常有帮助。

4.8 定点模型评测精度 代码解读—quantized_model

定点模型/quantized模型/量化模型 eval推理一下看看精度

bc/hbir模型如何推理呢?使用HbirModule即可

4.9 编译生成上板模型—model.hbm

编译生成上板模型model.hbm,同时针对hbm模型预估性能

5. 后续计划

之前写过一篇XJ3/J5 0基础学习地平线QAT量化感知训练,在CSDN上的反馈还行,为了方便直接入手J6的用户,特意写了这篇文章。关于后续,应该会基于J6再学习一些工具链的知识,例如:

  • prepare阶段的qconfig_setter是重中之重,本文并未介绍,后续再看

  • 这个模型比较简单,针对forward中有动态循环的怎么处理?不需要处理?后续再看

  • 精度调优过程是什么?精度调优工具怎么用?后续再看

算法工具链
社区征文官方教程技术深度解析
+3
评论0
0/1000