本文提纲:
fx和eager两种量化训练方式介绍
量化训练的流程介绍:以mmdet的yolov3为例
常用的精度调优debug工具介绍
案例分析:模型精度调优经验分享
第一部分:fx和eager两种量化训练方式介绍
首先介绍一下量化训练的原理。

上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用float32的形式,模型的内存占用和数据搬运都会很占资源。如果用int8替换float32,内存的搬运效率能提高75%,充分展示了量化的有效性。由于两个int8相乘会超出int8的表示范围,为了防止溢出,累加器使用int32类型的,累加后的结果会再次requantized到int8;
量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练PTQ和量化感知训练QAT。

量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些op前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:
1.在训练时,用以统计流经该op的数据的最大最小值,便于在部署量化模型时对节点进行量化
2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。

上图是模型学习量化损失的示意图, 正常的量化流程是quantize->mul(int)->dequantize,而伪量化是对原先的float先quantize到int,再dequantize到float,这个步骤用于模拟量化过程中round操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第4个黑点表示一个浮点数,quantize后映射到253,dequantize后取到了第5个黑点,这就引起了误差。
地平线基于PyTorch开发的horizon_plugin_pytorch量化训练工具,同时支持Eager和fx两种模式。

第二部分:量化训练的流程介绍:以mmdet的yolov3为例
QAT流程介绍
准备好浮点模型,加载训好的浮点权重
设置BPU架构
算子融合(eager模式需要,fx可省略)
设置量化配置
整个model使用默认的qconfig
模型的输出,配置高精度输出
det模型head输出的loss损失函数的qconfig设置为None
将浮点模型转换为qat模型(示例使用eager模式)
开始qat训练
可以复用浮点的train_detector,替换model即可
qat模型转定点(需要load训练好的qat模型权重)
deploy_model 和 example_input准备
Trace模型构建静态graph,进行编译
eval()使bn、dropout等处于正确的状态
编译只能在cpu上做
- check_model用于检查算子是否能全部跑在bpu上,建议提前检查
如果qat精度不达标,如何插入calibration?
CALIBRATION模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新scale
QAT模式:观测统计量并进行伪量化操作。
VALIDATION模式:不会观测统计量,仅进行伪量化操作。
以下常见误操作会导致一些异常现象:
- calibration 之前模型设置为train()的状态,且未使用set_fake_quantize,等于是在跑QAT训练;
- calibration 之前模型设置为eval()的状态,且未使用set_fake_quantize,会导致scale一直处于初始状态,全为1,calib不起作用。
- calibration 之前模型设置为eval()的状态,且正确使用了set_fake_quantize,但是在这之后又设置了一遍model.eval(),这将导致fake_quant未处于训练状态,scale一直处于初始状态,全为1;
对mobilenet_v2模型做qat训练的设置
量化节点设置
关键代码:
from horizon_plugin_pytorch.quantization import QuantStub
self.quant = QuantStub(scale=1/128) # 一般pyramid输入的Quant层,需要手动设置scale=1/128def fuse_modules(self):
x = self.quant(x)
算子融合
7.5.5. 算子融合 — Horizon Open Explorer

举个例子:mmcv/cnn/bricks/conv_module.py
eager方案麻烦的是,基本每个模块都要手动去设置算子融合
反量化节点设置
mmdetection-master/mmdet/models/dense_heads/yolo_head.py
关键代码:
self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来
self.dequant.append(DeQuantStub())
def fuse_modules(self):
pred_map = self.dequant[i](self.convs_pred[i](x))
第三部分:常用的精度调优debug工具介绍

第四部分:模型精度调优分享
模型精度调优时常遇到的问题:
1. calib模型的精度和float对齐,quantized模型的精度损失较大
正常情况下,calib/qat模型的精度和quantized模型的精度损失很小(1%), 如果偏差过大,可能是calib/qat的流程不对。
原因:calib模型伪量化节点的状态不正确,导致calib阶段,测试的是float模型的精度,而quantized阶段,测试的是calib模型的精度,所以精度损失本质上还是量化精度的损失。
如何避免:
正确设置calib训练和评测时的伪量化节点状态。
让客户在calib的基础上,做qat, 评测qat模型的精度。(客户的数据量大,qat时间太长,一直没有选择qat,导致这个问题被暴露出来了)
如何设置正确的calib 伪量化节点的状态?(fx 和 eager都是一样的)
http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html

注意:16行的train在评测时,也要设置FakeQuantState.VALIDATION,不然scale不生效,评测的指标也不对
常见问题:
- 数据校准之前模型设置为train()的状态,且未使用set_fake_quantize,等于caib阶段是在跑QAT训练;
校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是float模型;
总结2: 如果做calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果
2. 当量化精度损失超过大,如何调优?
使用 model_profiler() 这个集成接口,生成压缩包。
- 检查是否配置高精度输出、是否存在未融合的算子、是否共享op、是否算子分布过大int8兜不住?
注意:使用debug集成接口时,要保证浮点模型训练到位,并传入真实数据
3.多任务模型的精度调优建议
qat调优策略和常规模型一样,ptq+qat
如果只有一个head精度有损失,可以固定其他部分,单独使用这个head的数据做calib
4.calib和qat流程的正确衔接
calib:
qat:
5.检查conv高精度输出
方式1:查看 qconfig_info.txt,重点关注 DeQuantStub附近的conv是不是float32输出
qconfig_info.txt

方式2:打印qat_model的最后一层,查看该层是否有 (activation_post_process): FakeQuantize
高精度的conv:
int8的conv
6.检查共享op
打开qconfig_info.txt,后面标有(n)的就是共享的

特殊情况:layernorm在QAT阶段是多个小量化算子拼接而成,module的重复调用,也会产生大量op共享的问题
解决办法: 将 layernorm 替换为 batchnorm,测试了float精度,没有下降。


7.检查未融合的算子
打开qconfig_info.txt,全局搜BatchNorm2d 和 ReLU,如果前面有conv,那就是没做算子融合
可以融合的算子:
conv+bn
conv+relu
conv+add
conv+bn+relu
conv+bn+add
conv+bn+relu+add

8.检查数据分布特别大的算子
打开float模型的统计量分布,一般是 model0_statistic.txt
有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的
拖到第二个表,看前几行是那些op
可以看到很多conv的分布很异常,使用的是int8量化

解决办法:
检查这些conv后面是否有bn,添加bn后,数据能收敛一些
如果结构上已经加了bn,数据分布还大,可以配置int16量化
int16调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig
中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()

