整个部署流程涉及到网络结构的数据流包括:float train训练、float eval推理、calib、qat训练、qat eval推理、导出部署、qat eval回灌推理,建议使用一套forward,不要存在多forward的情况(特别是prepare时 算子替换的hook注册在forward上,换成xxx_forward_xxx不保证能成功)。否则,即使每个流程都跑通了,也容易有各种forward不对齐的问题。常见的训练与部署链路不一致情况:
train时forward中多了loss,eval+导出部署时不需要loss
train时forward中多了前后处理,建议快速挪到forward外面去
train时forward中多了辅助分支,eval+导出部署时不需要辅助分支
train时forward中某结构for运行n遍(时序),eval+导出部署时该结构只运行一遍
train时forwad中多loss/前处理
- 在模型输入前插入 QuantStub。
- 在模型输出后插入 DequantStub。
- PyTorch 的 nn.Module 有一个布尔属性 self.training
- 在调用 model.train() 后为 True(训练模式)。
- 在调用 model.eval() 后为 False(推理模式)。
- assert target is not None:训练时必须有标签 target。
- x = self.dequant(x):前面模型的 量化输出,进loss前需要先 dequant,把量化的输出转成浮点。
else: return torch.argmax(x, dim=1).to(torch.int16)
推理阶段不需要计算损失,而是返回预测类别。
- torch.argmax(x, dim=1) 取出概率最大的位置(类别 ID),默认为int64输出


train时forward中多辅助分支
train时forward中多了辅助分支,eval+导出部署时不需要辅助分支。针对这种情况,需要注意:
仅对部署上板的部分 插入伪量化节点,辅助分支采用float进行qat训练。
QAT 训练为全局训练,若辅助分支量化,会导致训练难度增加,若辅助分支处的数据分布与其他分支差异较大还会加大精度风险。
辅助分支处quant/dequant的处理需要注意,示例如下:
train时forward中for 某结构 n遍(时序)
当train与eval存在某些结构运行次数不一样时,可能会遇到图不一致的问题。
先介绍几种常见错误写法
在进行qat训练时报错
修改如下:
此时依旧报错:
relu的generated_modules次数发生变化,考虑是relu的写法问题,修改为:
最终版本:
报错如下:可以看到是在qat推理时,在add地方又报类似图不一致的错了。
参考如上修改即可不再报错。
常见替换为FloatFunctional的方式如下
