专栏算法工具链【倾力推荐】量化训练及精度调优经验分享

【倾力推荐】量化训练及精度调优经验分享

kuku2024-04-12
303
0

本文提纲:

  1. fx和eager两种量化训练方式介绍

  2. 量化训练的流程介绍:以mmdet的yolov3为例

  3. 常用的精度调优debug工具介绍

  4. 案例分析:模型精度调优经验分享

第一部分:fx和eager两种量化训练方式介绍

首先介绍一下量化训练的原理。

上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用float32的形式,模型的内存占用和数据搬运都会很占资源。如果用int8替换float32,内存的搬运效率能提高75%,充分展示了量化的有效性。由于两个int8相乘会超出int8的表示范围,为了防止溢出,累加器使用int32类型的,累加后的结果会再次requantized到int8;

量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练PTQ和量化感知训练QAT。

量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些op前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:

1.在训练时,用以统计流经该op的数据的最大最小值,便于在部署量化模型时对节点进行量化

2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。

上图是模型学习量化损失的示意图, 正常的量化流程是quantize->mul(int)->dequantize,而伪量化是对原先的float先quantize到int,再dequantize到float,这个步骤用于模拟量化过程中round操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第4个黑点表示一个浮点数,quantize后映射到253,dequantize后取到了第5个黑点,这就引起了误差。

地平线基于PyTorch开发的horizon_plugin_pytorch量化训练工具,同时支持Eager和fx两种模式。

eager模式的使用方式建议参考用户手册-4.2量化感知训练章节(4.2.2. 快速上手中有完整的快速上手示例,各使用阶段注意事项建议参考4.2.3. 使用指南)。fx模式的相关API介绍请参考用户手册-4.2.3.4.2. 主要接口参数说明章节

第二部分:量化训练的流程介绍:以mmdet的yolov3为例

QAT流程介绍

准备好浮点模型,加载训好的浮点权重

设置BPU架构

算子融合(eager模式需要,fx可省略)

设置量化配置

  • 整个model使用默认的qconfig

  • 模型的输出,配置高精度输出

  • det模型head输出的loss损失函数的qconfig设置为None

将浮点模型转换为qat模型(示例使用eager模式)

开始qat训练

  1.   可以复用浮点的train_detector,替换model即可

qat模型转定点(需要load训练好的qat模型权重)

deploy_model 和 example_input准备

Trace模型构建静态graph,进行编译

  • eval()使bn、dropout等处于正确的状态

  • 编译只能在cpu上做

  • check_model用于检查算子是否能全部跑在bpu上,建议提前检查

如果qat精度不达标,如何插入calibration?

伪量化节点(fake quantize)的三种状态:
  • CALIBRATION模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新scale

  • QAT模式:观测统计量并进行伪量化操作。

  • VALIDATION模式:不会观测统计量,仅进行伪量化操作。

以下常见误操作会导致一些异常现象:

  1. calibration 之前模型设置为train()的状态,且未使用set_fake_quantize,等于是在跑QAT训练;
  2. calibration 之前模型设置为eval()的状态,且未使用set_fake_quantize,会导致scale一直处于初始状态,全为1,calib不起作用。
  3. calibration 之前模型设置为eval()的状态,且正确使用了set_fake_quantize,但是在这之后又设置了一遍model.eval(),这将导致fake_quant未处于训练状态,scale一直处于初始状态,全为1;

对mobilenet_v2模型做qat训练的设置

量化节点设置

关键代码:

from horizon_plugin_pytorch.quantization import QuantStub

self.quant = QuantStub(scale=1/128) # 一般pyramid输入的Quant层,需要手动设置scale=1/128def fuse_modules(self):

x = self.quant(x)

算子融合

7.5.5. 算子融合 — Horizon Open Explorer

举个例子:mmcv/cnn/bricks/conv_module.py

eager方案麻烦的是,基本每个模块都要手动去设置算子融合

反量化节点设置

mmdetection-master/mmdet/models/dense_heads/yolo_head.py

关键代码:

self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来

self.dequant.append(DeQuantStub())

def fuse_modules(self):

pred_map = self.dequant[i](self.convs_pred[i](x))

第三部分:常用的精度调优debug工具介绍

第四部分:模型精度调优分享

模型精度调优时常遇到的问题:

1. calib模型的精度和float对齐,quantized模型的精度损失较大

正常情况下,calib/qat模型的精度和quantized模型的精度损失很小(1%), 如果偏差过大,可能是calib/qat的流程不对。

原因:calib模型伪量化节点的状态不正确,导致calib阶段,测试的是float模型的精度,而quantized阶段,测试的是calib模型的精度,所以精度损失本质上还是量化精度的损失。

如何避免:

  • 正确设置calib训练和评测时的伪量化节点状态。

  • 让客户在calib的基础上,做qat, 评测qat模型的精度。(客户的数据量大,qat时间太长,一直没有选择qat,导致这个问题被暴露出来了)

如何设置正确的calib 伪量化节点的状态?(fx 和 eager都是一样的)

http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html

注意:16行的train在评测时,也要设置FakeQuantState.VALIDATION,不然scale不生效,评测的指标也不对

常见问题:

  1. 数据校准之前模型设置为train()的状态,且未使用set_fake_quantize,等于caib阶段是在跑QAT训练;
  2. 校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是float模型;

总结2: 如果做calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果

2. 当量化精度损失超过大,如何调优?

  1. 使用 model_profiler() 这个集成接口,生成压缩包。

  2. 检查是否配置高精度输出、是否存在未融合的算子、是否共享op、是否算子分布过大int8兜不住?
  • 注意:使用debug集成接口时,要保证浮点模型训练到位,并传入真实数据

3.多任务模型的精度调优建议

  1. qat调优策略和常规模型一样,ptq+qat

  2. 如果只有一个head精度有损失,可以固定其他部分,单独使用这个head的数据做calib

4.calib和qat流程的正确衔接

calib:

qat:

5.检查conv高精度输出

方式1:查看 qconfig_info.txt,重点关注 DeQuantStub附近的conv是不是float32输出

qconfig_info.txt

方式2:打印qat_model的最后一层,查看该层是否有 (activation_post_process): FakeQuantize

高精度的conv:

int8的conv

6.检查共享op

打开qconfig_info.txt,后面标有(n)的就是共享的

特殊情况:layernorm在QAT阶段是多个小量化算子拼接而成,module的重复调用,也会产生大量op共享的问题

解决办法: 将 layernorm 替换为 batchnorm,测试了float精度,没有下降。

7.检查未融合的算子

打开qconfig_info.txt,全局搜BatchNorm2d 和 ReLU,如果前面有conv,那就是没做算子融合

可以融合的算子:

  • conv+bn

  • conv+relu

  • conv+add

  • conv+bn+relu

  • conv+bn+add

  • conv+bn+relu+add

8.检查数据分布特别大的算子

打开float模型的统计量分布,一般是 model0_statistic.txt

有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的

拖到第二个表,看前几行是那些op

可以看到很多conv的分布很异常,使用的是int8量化

解决办法:

  1. 检查这些conv后面是否有bn,添加bn后,数据能收敛一些

  2. 如果结构上已经加了bn,数据分布还大,可以配置int16量化

  • int16调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig

  • 中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()

算法工具链
社区征文技术深度解析征程5杂谈
评论0
0/1000