专栏算法工具链Onnx中ScatterND的产生与去除

Onnx中ScatterND的产生与去除

Jade-self2024-08-20
250
0
前言:使用地平线工具链进行模型部署,当ONNX模型中存在ScatterND算子时,使用hb_compile编译模型后,会发现ScatterNd运行在CPU上,查看算子支持与约束列表,目前确实不支持ScatterND算子。
Description
疑问:ScatterND是必须的吗?看自己python代码里没有用ScatterND,为什么导出成onnx时出现了呢?能否去掉ONNX中的ScatterND呢?可以去掉的话,应该怎么去掉呢?

1. ScatterND简介

1.1 理论

ScatterND 是一个常见的操作符,用于更新一个张量的部分值,将某些值散布到一个更大的张量中。

ScatterND 操作从一个输入张量 data 中选取值,然后根据给定的索引 indices 将这些值写入到一个输出张量 output 的指定位置。简单来说,它将数据分散(scatter)到一个预定义的位置。常用于更新稀疏矩阵中的部分值、图像处理中使用索引更新像素或特征图中的特定区域等。

1.2 示例

  1. 输入张量 data:包含要分散的值。

  2. 索引张量 indices:指定要更新的目标位置。indices 的每一行表示一个多维索引,指出 output 中要更新的具体位置。

  3. 更新 output:ScatterND 将 data 中的值按照 indices 中指定的位置更新 output。output 的初始状态通常是全零或其他初始值。

假设:

  • output 是要更新的张量,初始值为 zeros(shape) 或其他默认值。

  • indices 是一个 k x n 形状的张量,k 表示索引数目,n 表示每个索引的维度。

  • data 是一个形状为 k x ... 的张量,其中 ... 表示 data 的形状可以有更多的维度。
    ScatterND 的操作是:

1.3 代码

输出结果:

  • output 开始是一个 3x3 的零矩阵。

  • indices 指定了 data 中的每个值应写入 output 中的位置。

  • data 的值 [1, 2, 3] 被分别写入 output 中的 [0, 1], [1, 2], [2, 0] 位置。

介绍完ScatterND是干什么的,下面看一些python代码中没写ScatterND,但导出的onnx中有ScatterND的场景,以及如何修改代码。

2. 实操场景1

会产出ScatterND onnx代码可见:

对应的onnx可见:

Description

可以发现,onnx中是有ScatterND的,造成ScatterND的代码是哪儿呢?很明显,就一行代码:x[:, :2] = x[:, :2].sigmoid(),slice 之后 inplace 进行运算,导致ScatterND的产生。想去掉ScatterND,等价替换这样的操作即可,参考代码如下。

此时,可视化结果如下:
Description

3. 实操场景2

会产出ScatterND onnx代码可见:

对应的onnx可见:
Description

可以发现,onnx中是有ScatterND的,造成ScatterND的代码是哪儿呢?很明显,就两行代码:

也是slice 之后 inplace 进行运算,导致ScatterND的产生。想去掉ScatterND,等价替换这样的操作即可,参考代码如下。

此时,可视化结果如下:
Description

4. 实操场景3

如下代码会引入scatterND

修改与验证代码如下:

方案1与方案2思想一致:
Description

5. 总结

从上述介绍看,大部分场景,onnx中ScatterND 是由 slice 之后 inplace 进行运算导致,想要不产生ScatterND,等价替换相关操作即可。

算法工具链
社区征文杂谈技术深度解析
+4
评论0
0/1000