1 背景介绍
使用深度学习开源框架训练完网络模型后,在部署之前通常需要进行格式转换,地平线工具链模型转换目前支持Caffe1.0和ONNX(opset_version=10/11 且 ir_version≤7)两种。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,例如Pytorch、PaddlePaddle、TensorFlow等。本文将详细介绍如何将TensorFlow2得到的模型导出为ONNX格式。
2 实验环境
本教程的实验环境如下:
Python库 | Version |
|---|---|
tensorflow-cpu | 2.11.0 |
tensorflow-intel | 2.11.0 |
tf2onnx | 1.13.0 |
protobuf | 3.20.2 |
onnx | 1.13.0 |
onnxruntime | 1.14.0 |
3 tf2onnx工具介绍
tf2onnx可以通过命令行的方式将TensorFlow/Keras的模型转换为ONNX,该工具的主要配置参数如下:
4 代码实操
TensorFlow2与ONNX模型导出
以下代码展示了如何搭建一个简单分类模型以TensorFlow2的save-model方式保存并转换为ONNX格式。
input1 = tf.keras.layers.Input(shape=(7, 7, 3))
model = MyNet()
model.save('model')
#调用tf2onnx将上一步保存的模型导出为ONNX
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")
ONNX正确性验证
check = onnx.checker.check_model(onnx_model)
print('Check: ', check)
TensorFlow2与ONNX的一致性检查
可以使用以下代码检查导出的ONNX模型和原始的PaddlePaddle模型是否有相同的计算结果。
ort_inputs = {ort_sess.get_inputs()[0].name: input1}
ort_outs = ort_sess.run(None, ort_inputs)
tf_outs = tf_model(inputs=input1)
print(tf_outs.numpy())
np.testing.assert_allclose(tf_outs.numpy(), ort_outs[0], rtol=1e-03, atol=1e-05)
print("onnx model check finsh.")
多输入的情况
若您的模型存在多输入,则可参考下方代码以TensorFlow2的save-model方式保存并转换为ONNX格式。
input1 = tf.keras.layers.Input(shape=(7, 7, 3))
input2 = tf.keras.layers.Input(shape=(7, 7, 3))
model = MyNet()
os.system("python -m tf2onnx.convert --saved-model model --output model.onnx --opset 11")
设定输入/输出节点
有时考虑到部署难度,我们不希望TensorFlow网络结构的前后处理部分也导入进ONNX模型。此时可以使用tf2onnx工具的inputs和outputs参数,指定导出的首尾节点,这样首节点之前和尾节点之后的部分都不会导入进ONNX模型。
5 ONNX模型可视化

6 ir_version和opset_version修改
地平线工具链支持的ONNX模型需要满足 opset_version=10/11 且 ir_version≤7,当拿到的ONNX模型不满足这两个要求时,可以修改代码重新导出,或者尝试编写脚本直接修改ONNX模型的对应属性,第二种方式的示例代码如下:
model.ir_version = 6
model.opset_import[0].version = 11
onnx.save_model(model, "./model_version.onnx")
7 ONNX输入输出维度修改

可以使用如下代码进行修改:
onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx_model.graph.output[0].type.tensor_type.shape.dim[0].dim_value = 1
onnx.save(onnx_model, './model_dim.onnx')

至此,该ONNX模型已满足地平线工具链的转换条件。
