芯片信号:j6
大佬们,请教一个问题,可能有一些不成熟的地方,请见谅。
如题,我在对一个transformer的模型进行qat的时候,发现时序那部分改起来颇为复杂,flaot的forward中有一个buffer_length为10的时序循环逻辑,目前我用 导出onnx的forward (不包含时序循环逻辑)成功的生成了calib模型,并且在板端也能够成功跑通,目前calibration之后的精度掉点在10以内,进而想进行qat,但是又绕不开时序部分,所以想向大佬们求助一下。
如题,我在对一个transformer的模型进行qat的时候,发现时序那部分改起来颇为复杂,flaot的forward中有一个buffer_length为10的时序循环逻辑,目前我用 导出onnx的forward (不包含时序循环逻辑)成功的生成了calib模型,并且在板端也能够成功跑通,目前calibration之后的精度掉点在10以内,进而想进行qat,但是又绕不开时序部分,所以想向大佬们求助一下。
直接在float模型的forward中加入quant和dequant节点,加载之前的pretrained pth参数,然后做finetuning,然后再将finetuning后的权重信息进行calib,是否能提高精度?
致以崇高的敬意


