专栏算法工具链J6 PTQ精度调优辅助代码,总有你用得上的

J6 PTQ精度调优辅助代码,总有你用得上的

DR_KAN2025-05-26
123
0

1 截取onnx模型片段

在模型编译的时候,往往会出现各种各样的报错,您可能会受限于公司要求,无法把完整的onnx模型发送给地平线做分析,此时可以考虑截取onnx模型,找到可复现的片段,再将该片段提供给地平线的技术支持人员(这种方式通常是可以被公司允许的)。我们可以直接利用onnx的python api去完成这件事,onnx.utils.extract_model的具体使用方式如下:

onnx.utils.extract_model 是 onnx 官方库提供的一个 从 ONNX 模型中提取子模型(子图) 的实用函数。 它的常见用途包括:
  • 提取模型中某一部分用于调试或加速推理;

  • 截取特定中间层的输出;

  • 制作用于 calibration 的简化子模型(例如量化前的子图)。

基本用法

from onnx import utils
utils.extract_model(
input_path, # 原始模型文件路径(.onnx)

output_path, # 提取后的子模型保存路径(.onnx)

input_names, # 子模型的输入节点名列表(字符串

output_names # 子模型的输出节点名列表(字符串)

)

ONNX 会自动计算以这些输入到输出为范围的所有依赖节点,然后构造出一个合法的新 ONNX 子模型。

举例来说,我们可以运行这样的代码截取onnx模型片段:

请注意,这个脚本不一定能一次直接运行成功,有些节点不支持extract,如果失败,可以尝试配置模型节点的name或者节点输入输出的name,或者尝试使用其他节点或输入输出。

2 打印模型的所有op type

当我们在做比较细致的精度调优时,会想把某一类算子全部配置int16或者float32,而如何快速知道模型有哪些算子类别呢?首先我们要知道,PTQ在校准的时候,回忆optimized onnx为处理对象,因此我们只要知道optimized onnx模型有哪些算子类别就可以了。用户使用PTQ时,可以使用我们提供好的功python api完成这个事情,举例如下:

yaml配置算子时,以optimized模型的node name为准,这里打印op type最好也选optimized.onnx,但该功能可能会有部分op type漏打印的现象,可以检查hb_compile之后的log查漏补缺。

 

3 手动计算量化前后相似度

虽然模型编译的时候,日志里会提供相似度的情况,但显示的并不够完整。如果我们想知道所有输入数据或者某个特定输入数据的相似度情况,或者不同阶段模型的相似度情况,就需要手写代码去做模型推理。这里我提供这样一份参考代码,用来读取文件夹里所有数据,并挨个打印相似度,满足我们的相似度计算需求,从而更好地了解模型的精度情况。

主函数:float_vs_quant(i)

这个函数是整个对比流程的核心:

读取输入数据

input_depth = np.fromfile(..., dtype=np.float32).reshape(1,1280,720)
xyzrgb_tensor = np.fromfile(..., dtype=np.float32).reshape(1,6,640,360)
  • 分别读取第 i 条样本的:
    • 深度图:维度是 (1, 1280, 720)
    • RGB + XYZ 张量:维度是 (1, 6, 640, 360)
  • 若文件读取失败(如不存在),直接返回 0(跳过该样本)。

推理浮点模型

sess = HBRuntime("./model_output/grasp_original_float_model.onnx")
...
pred_6d_grasp_float = output[0]
  • 使用 HBRuntime 加载浮点版模型 .onnx。
  • 输入数据为 input_depth 和 xyzrgb_tensor。
  • 获取输出(预测的 6D 抓取向量)。

推理量化模型

sess = HBRuntime("./model_output/grasp_quantized_model.bc")
...
pred_6d_grasp_quant = output[0]
  • 加载量化后的模型(通常是 .bc 格式,地平线编译后模型)。
  • 同样输入相同的数据。

  • 获取量化模型的推理结果。

余弦相似度对比

return cosine_similarity(pred_6d_grasp_float.reshape(-1), pred_6d_grasp_quant.reshape(-1))

  • 将浮点和量化预测结果展平成一维向量。

  • 计算并返回它们的余弦相似度。

主执行入口

if name == "__main__":for i in range(0, 100):
cos_sim = float_vs_quant(i)
...
  • 对编号为 0 到 99 的输入样本,逐个调用 float_vs_quant。
  • 跳过无效样本(返回值为 0)。

  • 打印样本编号及对应的相似度。

输出示例

float
[0.12 0.45 0.87 ...]
quant
[0.11 0.44 0.85 ...]
3 0.99875
...
  • 展示浮点和量化模型的推理输出;

  • 输出每个有效样本的相似度分数(越接近 1 表示差异越小)。

 

4 精度debug保存终端打印日志

通常来说,我们可以使用精度DEBUG功能,去查看哪些算子的量化风险高,从而为其设置更高的量化精度。但有时,我们的模型特别大,算子特别多,如果直接在vscode终端执行精度debug命令,打印的算子信息很可能不够完整,会被截断,因此可以使用下面介绍方法将精度debug日志完整地保存到本地文件里。

首先,我们先在debug.py脚本里写好我们要执行的命令,比如:

方法1、前台运行时保存,这个方法会持续占用该终端,直到程序结束。

方法2、后台运行时保存,这是更为推荐的方法,这样我们在这个终端还可以同时做其他事,比如同时运行其他node_type的静的debug功能。

在程序运行时,目录下就会生成对应的log日志,我们可以随时查看精度debug的进展,非常方便!

 

 

 

算法工具链
征程6社区征文官方教程技术深度解析
评论0
0/1000