专栏算法工具链如何去除onnx模型里的scatterND算子

如何去除onnx模型里的scatterND算子

kuku2024-08-16
222
0

目前工具链BPU不支持scatterND算子,这些算子放在cpu上会严重影响模型耗时,这篇文章会分享一些等效替换的示例,帮大家消除模型里的scatterND算子,极大提高模型性能。

ScatterND算子产生的原因

onnx里的ScatterND算子产生的原因是因为模型里存在slice之后的inplace操作导致的,需要修改代码,等价替换掉相关操作即可。

举个例子:

原始代码:

x[:, :2] = x[:, :2].softmax()

return x

修改后:

a = x[:, :2]

b = x[:, 2:]

a = a.softmax()

return torch.cat([a,b] , dim=1)

从上面的例子可以看出,去除scatterND算子的核心是需要保证被赋值的一边不包含slice的inplace的操作即可。

实战演示

上面的例子非常简单,但实际模型结构基本比较复杂,接下来会以sparse4d 模型为例,对模型里的scatterND算子进行等效替换。

在导出onnx时,打开 verbose=True,这样就能在onnx模型里定位某个算子具体对应哪一行代码。

anchor_projection模块

修改前

修改后:

SparseBox3DKeyPointsGenerator模块

修改前:

修改后:

SparseBox3DRefinementModule模块

修改前:

修改后:

InstanceBankAddTrack模块

修改前:

修改后:


性能测试

之前sparse4d模型里存在25个scatterND算子,这些算子都被放在了cpu上,严重影响性能。去除这些算子后,时延降低了20ms,很有效果,推荐大家试一下。

优化前:板端(单线程)耗时 33.237717 ms

优化后:板端(单线程)耗时 13.581193 ms


算法工具链
社区征文技术深度解析征程6杂谈
评论0
0/1000