
1. ScatterND简介
1.1 理论
ScatterND 是一个常见的操作符,用于更新一个张量的部分值,将某些值散布到一个更大的张量中。
ScatterND 操作从一个输入张量 data 中选取值,然后根据给定的索引 indices 将这些值写入到一个输出张量 output 的指定位置。简单来说,它将数据分散(scatter)到一个预定义的位置。常用于更新稀疏矩阵中的部分值、图像处理中使用索引更新像素或特征图中的特定区域等。
1.2 示例
输入张量 data:包含要分散的值。
索引张量 indices:指定要更新的目标位置。indices 的每一行表示一个多维索引,指出 output 中要更新的具体位置。
更新 output:ScatterND 将 data 中的值按照 indices 中指定的位置更新 output。output 的初始状态通常是全零或其他初始值。
假设:
output 是要更新的张量,初始值为 zeros(shape) 或其他默认值。
indices 是一个 k x n 形状的张量,k 表示索引数目,n 表示每个索引的维度。
- data 是一个形状为 k x ... 的张量,其中 ... 表示 data 的形状可以有更多的维度。
ScatterND 的操作是:
1.3 代码
输出结果:
output 开始是一个 3x3 的零矩阵。
indices 指定了 data 中的每个值应写入 output 中的位置。
data 的值 [1, 2, 3] 被分别写入 output 中的 [0, 1], [1, 2], [2, 0] 位置。
介绍完ScatterND是干什么的,下面看一些python代码中没写ScatterND,但导出的onnx中有ScatterND的场景,以及如何修改代码。
2. 实操场景1
会产出ScatterND onnx代码可见:
对应的onnx可见:

可以发现,onnx中是有ScatterND的,造成ScatterND的代码是哪儿呢?很明显,就一行代码:x[:, :2] = x[:, :2].sigmoid(),slice 之后 inplace 进行运算,导致ScatterND的产生。想去掉ScatterND,等价替换这样的操作即可,参考代码如下。

3. 实操场景2
会产出ScatterND onnx代码可见:

可以发现,onnx中是有ScatterND的,造成ScatterND的代码是哪儿呢?很明显,就两行代码:
也是slice 之后 inplace 进行运算,导致ScatterND的产生。想去掉ScatterND,等价替换这样的操作即可,参考代码如下。

4. 实操场景3
如下代码会引入scatterND
修改与验证代码如下:

5. 总结
从上述介绍看,大部分场景,onnx中ScatterND 是由 slice 之后 inplace 进行运算导致,想要不产生ScatterND,等价替换相关操作即可。


