专栏算法工具链【J6】train与eval graph不一致的部署方案

【J6】train与eval graph不一致的部署方案

Jade-self2025-10-27
38
0

整个部署流程涉及到网络结构的数据流包括:float train训练、float eval推理、calib、qat训练、qat eval推理、导出部署、qat eval回灌推理,建议使用一套forward,不要存在多forward的情况(特别是prepare时 算子替换的hook注册在forward上,换成xxx_forward_xxx不保证能成功)。否则,即使每个流程都跑通了,也容易有各种forward不对齐的问题。常见的训练与部署链路不一致情况:

  1. train时forward中多了loss,eval+导出部署时不需要loss

  2. train时forward中多了前后处理,建议快速挪到forward外面去

  3. train时forward中多了辅助分支,eval+导出部署时不需要辅助分支

  4. train时forward中某结构for运行n遍(时序),eval+导出部署时该结构只运行一遍

train时forwad中多loss/前处理

为了明确模型部署边界,需要在 非部署逻辑 与 部署逻辑 的边界处插入了 QuantStub 和 DequantStub,其中最典型的操作为:
  • 在模型输入前插入 QuantStub。
  • 在模型输出后插入 DequantStub。
具体到 train loss 时,loss本身并不需要部署,通常大家会放在forward外面,如果一定要在forward内部,可以通过if self.training 来控制训练与部署逻辑图,另外,需要在loss输入处插入 DequantStub 以保证 loss部分是不被量化的。以一个简单网络为例进行介绍:

 

if self.training:根据 训练/推理模式 执行不同的逻辑
  • PyTorch 的 nn.Module 有一个布尔属性 self.training
    • 在调用 model.train() 后为 True(训练模式)。
    • 在调用 model.eval() 后为 False(推理模式)。
  • assert target is not None:训练时必须有标签 target。
  • x = self.dequant(x):前面模型的 量化输出,进loss前需要先 dequant,把量化的输出转成浮点。
  • else: return torch.argmax(x, dim=1).to(torch.int16)

    • 推理阶段不需要计算损失,而是返回预测类别。

    • torch.argmax(x, dim=1) 取出概率最大的位置(类别 ID),默认为int64输出

 

 

 

train时forward中多辅助分支

train时forward中多了辅助分支,eval+导出部署时不需要辅助分支。针对这种情况,需要注意:

  1. 仅对部署上板的部分 插入伪量化节点,辅助分支采用float进行qat训练。

  2. QAT 训练为全局训练,若辅助分支量化,会导致训练难度增加,若辅助分支处的数据分布与其他分支差异较大还会加大精度风险。

具体到代码处理上,原则依旧是:train时需要,eval时不需要,可使用if self.training来处理。

辅助分支处quant/dequant的处理需要注意,示例如下:

 

总结:QuantStub / DeQuantStub 用于标记输入输出的量化边界。训练时的辅助分支,要 dequant 后使用fp32运行,避免影响eval(部署)的量化逻辑。

train时forward中for 某结构 n遍(时序)

当train与eval存在某些结构运行次数不一样时,可能会遇到图不一致的问题。

先介绍几种常见错误写法

在进行qat训练时报错

如果修改为:num_frames = 1 if self.training else 1,则无报错,可以确认,与train和eval时for循环中运行次数不同有关。
立刻想到:如果某块逻辑在模型中的调用次数会动态变化,则需要添加with Tracer.dynamic_block(self, "train_eval_num_frames_dynamic"):

修改如下:

此时依旧报错:

relu的generated_modules次数发生变化,考虑是relu的写法问题,修改为:

修改后不再报错

最终版本:

控制变量,看不加Tracer.dynamic_block会有什么问题

报错如下:可以看到是在qat推理时,在add地方又报类似图不一致的错了。

只有上述使用Tracer.dynamic_block方式来避免该问题吗?其实还可以将generated modules都手动替换为FloatFunctional()的形式,示例如下:

参考如上修改即可不再报错。

常见替换为FloatFunctional的方式如下

算法工具链
社区征文技术深度解析征程6杂谈
评论0
0/1000