目录
- 1. 前言
- 2. 使用说明
- 3. Calibration 后的 QAT 训练建议
- 4. 数据集和Transorm设置建议
- 5. 在horizon_train_sample中使用Calibration
- 6. 常见问题
1. 前言
在量化训练中,一个重要的步骤是确定量化参数 scale ,合理的 scale 能够显著提升量化训练效果并加快收敛速度。
Calibration 是通过用浮点模型在训练集上跑少数 batch 的数据(只跑 forward 过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_value 和 max_value,然后可以用这些 min_value 和 max_value 去计算得到 scale。
推荐大家在量化训练之前先使用Calibration,一方面是因为calibration时间较短,部分模型仅进行calibration即可满足精度要求,可以免去费时的QAT过程,另一方面通过calibration初始化参数之后也可以加速模型的训练。
calibration不支持 train() 模式和 eval() 模式行为不一致的Module(如 dropout)。
Calibration 是通过用浮点模型在训练集上跑少数 batch 的数据(只跑 forward 过程,没有backward),统计这些数据的分布直方图,通过一定方法去计算出 min_value 和 max_value,然后可以用这些 min_value 和 max_value 去计算得到 scale。
推荐大家在量化训练之前先使用Calibration,一方面是因为calibration时间较短,部分模型仅进行calibration即可满足精度要求,可以免去费时的QAT过程,另一方面通过calibration初始化参数之后也可以加速模型的训练。
calibration不支持 train() 模式和 eval() 模式行为不一致的Module(如 dropout)。
2. 使用说明
以下说明仅适用于Horizon Plugin Pytorch ≥ v1.2.2,相较于此前的版本新的 Calibration(后文称为Calibrationv2) 支持更多的方法,用法更灵活,支持直接将Calibration模型转换编译成部署模型,因此若Calibration已满足精度要求,可跳过qat的过程。 |
2.1 使用方式
# 1. 加载浮点模型
float_model = load_float_model(pretrain=True)
# 2. 设置 BPU 架构
set_march(march) # J5:march = March.BAYES; XJ3:march = March.BERNOULLI2
float_model = load_float_model(pretrain=True)
# 2. 设置 BPU 架构
set_march(march) # J5:march = March.BAYES; XJ3:march = March.BERNOULLI2
# 3. 准备 calib_model
calib_model = prepare_qat_fx(
copy.deepcopy(float_model), # 为不影响 float_model 的后续使用,建议进行 deepcopy
{
"": default_calib_8bit_fake_quant_qconfig,
"module_name": {
# 在模型的输出层为 Conv 或 Linear 时,可以使用 out_qconfig# 配置为高精度输出
# plugin ≤ v1.6.2 需使用 default_calib_out_8bit_fake_quant_qconfig ,但该参数后续将弃用
"last_layer_name": default_calib_8bit_weight_32bit_out_fake_quant_qconfig
}
}
)
# 4. 数据校准(通过前向推理初始化scale,无需训练)
# 注意此处对模型状态的控制,模型需要处于 eval 状态以使 Bn 的行为符合要求
calib_model.eval()
# 关于set_fake_quantize的使用介绍,非常建议您查看后文常见问题 3,以避免踩坑
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calibrate(calib_model, calib_dataloader, device) # 从训练集中选适量数据进行模型推理
# 5. 校准精度验证
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
val(calib_model,val_dataloader,device)
# 6. 定点模型转换
quantized_model = convert_fx(calib_model.eval(), inplace=False)
# 7. 模型检查
example_input = torch.ones(size=(neval_batches, 3, 28, 28), device="cpu")
quantized_model = quantized_model.cpu()
traced_model = torch.jit.trace(quantized_model, example_input)
check_model(quantized_model, example_input)
# 8. 编译模型
compile_model(traced_model, example_input, opt=3, hbm="model.hbm")
calib_model = prepare_qat_fx(
copy.deepcopy(float_model), # 为不影响 float_model 的后续使用,建议进行 deepcopy
{
"": default_calib_8bit_fake_quant_qconfig,
"module_name": {
# 在模型的输出层为 Conv 或 Linear 时,可以使用 out_qconfig# 配置为高精度输出
# plugin ≤ v1.6.2 需使用 default_calib_out_8bit_fake_quant_qconfig ,但该参数后续将弃用
"last_layer_name": default_calib_8bit_weight_32bit_out_fake_quant_qconfig
}
}
)
# 4. 数据校准(通过前向推理初始化scale,无需训练)
# 注意此处对模型状态的控制,模型需要处于 eval 状态以使 Bn 的行为符合要求
calib_model.eval()
# 关于set_fake_quantize的使用介绍,非常建议您查看后文常见问题 3,以避免踩坑
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)
calibrate(calib_model, calib_dataloader, device) # 从训练集中选适量数据进行模型推理
# 5. 校准精度验证
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
val(calib_model,val_dataloader,device)
# 6. 定点模型转换
quantized_model = convert_fx(calib_model.eval(), inplace=False)
# 7. 模型检查
example_input = torch.ones(size=(neval_batches, 3, 28, 28), device="cpu")
quantized_model = quantized_model.cpu()
traced_model = torch.jit.trace(quantized_model, example_input)
check_model(quantized_model, example_input)
# 8. 编译模型
compile_model(traced_model, example_input, opt=3, hbm="model.hbm")
2.2 调参建议
Calibrationv2支持min_max、mix、kl、mse、percentile这几种校准方法,每种方法的介绍可参考用户手册 calibration指南 。以下是基础调参建议和一些注意事项:
先配置min_max,调整batch size、average_constant得到最佳精度。
- 如果没有达到预期,固定前一步得到的最佳batch size、average_constant,尝试mix、mse、kl、percentile校准。
a. mse的stride调整只会带来计算速度的提高,调整得太大会影响精度表现;
b. kl 算法的update_interval需要小于step;
c. percentile需要先针对整个模型确定最佳的percentile(一般bin默认2048对绝大多数模型都够用了,可以不调整),然后再找出量化敏感层,依据输入输出统计量单独调整这些层的percentile。
以 percentile 为例,配置方式如下所示:
calib_qconfig = get_default_qconfig(
activation_fake_quant = "fake_quant",
weight_fake_quant = "fake_quant",
activation_observer = "percentile",
weight_observer = "min_max",
activation_qkwargs={
"percentile": 99.9
},
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0
}
)
activation_fake_quant = "fake_quant",
weight_fake_quant = "fake_quant",
activation_observer = "percentile",
weight_observer = "min_max",
activation_qkwargs={
"percentile": 99.9
},
weight_qkwargs={
"qscheme": torch.per_channel_symmetric,
"ch_axis": 0
}
)
3. Calibration 后的 QAT 训练建议
量化相关参数 | 推荐配置 | 高级配置 |
averaging_constant(qconfig_params) | 1. 不用calibration时,使用默认即可 2. 使用calibration时建议固定激活的scale weight averaging_constant=1.0 activation averaging_constant=0.0 | 1. calibration的精度和浮点差距较大时:activation averaging_constant建议保持默认值。 2. 对于部分特殊任务,固定activation的scale可能会导致精度变差,可根据实际情况调整。 3. weight averaging_constant一般不需要设置成0.0,实际情况可以在(0,1.0]之间调整。 |
4. 数据集和Transorm设置建议
做 Calibration 的数据集(dataset)不能是测试集(建议使用训练集)。一般来说,在数据干净的情况下,calibration 数据越多越好(同时建议尽量调大batchsize),但因为边际效应的存在,当数据量大到一定程度后,对精度的提升将非常有限。如果训练集较小,可以全部用来 calibration,如果训练集较大,可以结合 calibration 耗时挑选大小合适的子集(尽量覆盖所有典型场景),建议至少进行 10 - 100 个 step 的校准。 因为随机性和噪声数据的存在,实际效果不是和图片数量完全呈正相关,也需要根据实际情况调整。
数据可以做水平翻转这类 augmentation,不要做马赛克这种 augmentation。推荐使用 infer 阶段的前处理 + 训练数据进行校准。
数据可以做水平翻转这类 augmentation,不要做马赛克这种 augmentation。推荐使用 infer 阶段的前处理 + 训练数据进行校准。
5. 在horizon_train_sample中使用Calibration
calibration操作可直接在config中进行定义,具体请参考J5用户手册-FCOS-EfficientNetB0的config构造详细说明,需要设置calibration_data_loader、calibration_batch_processor以及calibration_trainer。训练时通过指定--stage即可启用calibration,可参考J5用户手册-FCOS检测模型训练:
后续使用calibration模型继续进行qat的话需要修改一下qat_qconfig以及checkpoint_path:
自 horizon_train_sample 1.3.2 版本后使用 CalibrationV2 替换了Calibration v1。 需要注意的是,calibrationv1 的精度评测可以通过predict.py --stage qat(train过程中默认每个epoch/训练结束会进行一次validation),但只能加载每个epoch/calib训练结束后保存的ckpt,因为中间step保存的ckpt是calibration_model的,qat_model加载这些权重会报错。calibrationv2不再存在这个问题,因此建议大家优先考虑升级版本。 |
6. 常见问题
1. 为何calibrate之后打印模型发现scale还是等于1?
答:推测可能是因为set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)后又设置了model.eval(),导致scale未正常更新,只可在set_fake_quantize之前执行model.eval()。
2. 为何prepare_qat之后测试calib_model精度很低甚至为0?
答:prepare_qat之后的模型处于初始状态,scale=1。
3. 为何calibrate、qat训练以及精度评测之前都要使用set_fake_quantize接口?
答:感兴趣的话大家可以到python path下找到horizon_plugin_pytorch/quantization/fake_quantize.py,查看set_fake_quantize的实现,以下为截取的关键片段:
答:推测可能是因为set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)后又设置了model.eval(),导致scale未正常更新,只可在set_fake_quantize之前执行model.eval()。
2. 为何prepare_qat之后测试calib_model精度很低甚至为0?
答:prepare_qat之后的模型处于初始状态,scale=1。
3. 为何calibrate、qat训练以及精度评测之前都要使用set_fake_quantize接口?
答:感兴趣的话大家可以到python path下找到horizon_plugin_pytorch/quantization/fake_quantize.py,查看set_fake_quantize的实现,以下为截取的关键片段:
从这段代码中我们可以得知:
\qquada. 在QAT之前要求模型设置的是train()的状态,CALIBRATION和VALIDATION则要求模型是eval()状态,这主要是为了使bn、dropout等处于正确的状态(训练的时候bn会更新,评测的时候bn不更新)。
\qquadb. CALIBRATION时会disable_fake_quant,并设置fake_quant状态为train(),即不进行伪量化操作,仅观测算子输入输出统计量,更新scale;QAT时会观测统计量并进行伪量化操作;VALIDATION时不会观测统计量,仅进行伪量化操作。
因此如下常见误操作会导致一些异常现象:
\qquada. 数据校准之前模型设置为train()的状态,且未使用set_fake_quantize,等于是在跑QAT训练;
\qquadb. 数据校准之前模型设置为eval()的状态,且未使用set_fake_quantize,会导致scale一直处于初始状态,全为1;
\qquadc. 数据校准之前模型设置为eval()的状态,且正确使用了set_fake_quantize,但是在这之后又设置了一遍model.eval(),这将导致fake_quant未处于训练状态,scale一直处于初始状态,全为1;
\qquada. 在QAT之前要求模型设置的是train()的状态,CALIBRATION和VALIDATION则要求模型是eval()状态,这主要是为了使bn、dropout等处于正确的状态(训练的时候bn会更新,评测的时候bn不更新)。
\qquadb. CALIBRATION时会disable_fake_quant,并设置fake_quant状态为train(),即不进行伪量化操作,仅观测算子输入输出统计量,更新scale;QAT时会观测统计量并进行伪量化操作;VALIDATION时不会观测统计量,仅进行伪量化操作。
因此如下常见误操作会导致一些异常现象:
\qquada. 数据校准之前模型设置为train()的状态,且未使用set_fake_quantize,等于是在跑QAT训练;
\qquadb. 数据校准之前模型设置为eval()的状态,且未使用set_fake_quantize,会导致scale一直处于初始状态,全为1;
\qquadc. 数据校准之前模型设置为eval()的状态,且正确使用了set_fake_quantize,但是在这之后又设置了一遍model.eval(),这将导致fake_quant未处于训练状态,scale一直处于初始状态,全为1;

