型号J6M
在网络中设置 nn.Div() 计算精度为int8不起作用,始终计算精度为int16
self.mul_div = horizon.nn.Div() # 查看文档 该算子会拆分为 b30.lut和hbir.mul 都是支持int8输入和输出的
# 使用默认的 default_qat_qconfig_setter 量化配置 和下方的自定义配置
div_i8 = horizon.quantization.QConfig(
input=horizon.quantization.FakeQuantize.with_args(
observer=horizon.quantization.MinMaxObserver,
quant_min=horizon.quantization.qinfo("qint8").min,
quant_max=horizon.quantization.qinfo("qint8").max,
dtype="qint8",
),
output=horizon.quantization.FakeQuantize.with_args(
observer=horizon.quantization.MinMaxObserver,
quant_min=horizon.quantization.qinfo("qint8").min,
quant_max=horizon.quantization.qinfo("qint8").max,
dtype="qint8",
),
weight=horizon.quantization.FakeQuantize.with_args(
observer=horizon.quantization.MinMaxObserver,
quant_min=horizon.quantization.qinfo("qint8").min,
quant_max=horizon.quantization.qinfo("qint8").max,
dtype="qint8",
)
)
"view_transformer.mul_div.reciprocal": div_i8,
"view_transformer.mul_div.mul": div_i8,
"view_transformer.mul_div": div_i8,
# 以上两种配置方式,都不能改变其self.mul_div 中间计算精度,都是为int16
view_transformer.mul_div.reciprocal | <class 'horizon_plugin_pytorch.nn.qat.segment_lut.segmentlut'=""> | ['qint8'] | ['qint16'] | -1 | MinMaxObserver
view_transformer.mul_div.mul[mul] |<class 'horizon_plugin_pytorch.nn.qat.functional_modules.floatfunctional'=""> | ['qint8', 'qint16'] | ['qint8'] | -1 | MinMaxObserver
下面为onnx结构图

当右侧输入为int16时,hbir_cast_type层会消失,当输入为int8时会自动添加上,已知通过PTQ转换的bc模型,该部分全为int8精度。
现在我的需求是:1.通过什么办法可以让中间计算全部转换为int8精度(PTQ方式得到的bc模型,中间计算全为int8精度)
2.在通过qconfig指定 mul_div.reciprocal 输入输出都为int8的情况下,为什么不起作用,输出还是为int16 并自动添加hbir_cast_type层



