背景
目前工具链的客户可能存在模型是fp16训练的(即模型中所有的数据都是fp16的),目前地平线工具链中的PTQ暂时不支持fp16模型(后期可能会新增支持fp16数据流,但是短期内没有相应开发规划)。
地平线工具链中的PTQ支持fp32数据流的模型,且fp32可以存下fp16数据
基于以上情况,如果客户遇到模型是fp16训练的,我们当前可以使用以下脚本将fp16模型转为fp32模型,然后再来进行PTQ,即可顺利完成PTQ的流程。
代码
1. 安装依赖
2. 运行fp16转fp32的代码
3. requirements.txt代码如下
4. convert.py代码如下
from onnx import helper as h
from onnx import checker as ch
from onnx import TensorProto, GraphProto
from onnx import numpy_helper as nph
from onnx import checker as ch
from onnx import TensorProto, GraphProto
from onnx import numpy_helper as nph
import numpy as np
from collections import OrderedDict
from collections import OrderedDict
from logger import log
import typer
import typer
def make_param_dictionary(initializer):
params = OrderedDict()
for data in initializer:
params[data.name] = data
return params
params = OrderedDict()
for data in initializer:
params[data.name] = data
return params
def convert_params_to_fp32(params_dict):
converted_params = []
for param in params_dict:
data = params_dict[param]
if data.data_type == TensorProto.FLOAT16:
data_cvt = nph.to_array(data).astype(np.float32)
data = nph.from_array(data_cvt, data.name)
converted_params += [data]
return converted_params
converted_params = []
for param in params_dict:
data = params_dict[param]
if data.data_type == TensorProto.FLOAT16:
data_cvt = nph.to_array(data).astype(np.float32)
data = nph.from_array(data_cvt, data.name)
converted_params += [data]
return converted_params
def convert_constant_nodes_to_fp32(nodes):
"""
convert_constant_nodes_to_fp32 Convert Constant nodes to FLOAT32. If a constant node has data type FLOAT16, a new version of the
node is created with FLOAT32 data type and stored.
"""
convert_constant_nodes_to_fp32 Convert Constant nodes to FLOAT32. If a constant node has data type FLOAT16, a new version of the
node is created with FLOAT32 data type and stored.
def convert_model_to_fp32(model_path: str, out_path: str):
"""
convert_model_to_fp32 Converts ONNX model with FLOAT16 params to FLOAT32 params.\n
"""
convert_model_to_fp32 Converts ONNX model with FLOAT16 params to FLOAT32 params.\n
if name == "__main__":
typer.run(convert_model_to_fp32)
typer.run(convert_model_to_fp32)
示例
以下给了一个图示,如果您顺利运行上述脚本,模型中所有的都会变成fp32数据流
顺利转换的效果图(左为fp16模型,右为fp32模型)
顺利转换的效果图(左为fp16模型,右为fp32模型)
