专栏算法工具链argmax()节点在量化模型会自己增加一个cast()节点,且argmax()节点量化损失大

argmax()节点在量化模型会自己增加一个cast()节点,且argmax()节点量化损失大

已解决
driverli2025-04-14
78
4
  1. 芯片型号:J3

  2. 天工开物开发包 OpenExplorer 版本:horizon_xj3_open_explorer_v1.16.6

  3. 问题定位:模型转换

  4. 问题具体描述:argmax()节点紧接输出时,量化模型会自己增加一个cast()节点,且argmax()节点量化损失大

此图是原始onnx模型的输出节点
Description
此图是quantized_model增加的cast节点
Description
argmax()节点量化精度如下
Description
算法工具链
征程3
评论1
0/1000
  • Huanghui
    Lv.5

    可以按照下面这样修改一下模型的代码尝试一下

    2025-04-14
    0
    3
    • driverli回复Huanghui:
      更改后没有增加cast节点了,但是argmax()节点依然量化损失大
      2025-04-15
      0
    • driverli回复Huanghui:

      你好 这个问题麻烦再看下,谢谢~

      2025-04-17
      0
    • Huanghui回复driverli:

      增加cast转换节点可能是默认需要提高量化精度,所以转换成int64, 通过torch.argmax可以指定量化数据类型。量化损失较大是因为argmax() 对输入张量的微小变化敏感,量化误差可能导致完全不同的输出索引。 你可以尝试 对argmax输入层使用 FP16,其他层用 INT8 。

      2025-04-30
      0