【J5】
量化后使用perf_model发现每一个torch.cat算子都会附加产生大量的hz_cat_1_rescale_1这样的算子,影响模型性能,请问应该如何规避呢?torch.cat的多个输入使用的都是相同的qconfig配置(均为默认default_qat_8bit_fake_quant_qconfig)
【J5】

请问要如何实现呢,我这样修改后,最后转hbm阶段执行torch.jit.trace会报错:AssertionError: input scale must be the same as op’s:
class xxxx(nn.Module):
super().__init__()
self.quant_k = QuantStub() # for cat scale
self.quant_v = QuantStub() # for cat scale
def forward(self):