Updating documentation for tf.scatter_nd_update to reflect the actual implementation, which is in case of duplicates we pick the first one. Also added tests for both CPU and GPU to assert this behavior.

PiperOrigin-RevId: 334021744
Change-Id: Ib045d04a9631b6244327d0acac6ad44ed4b5c65e
This commit is contained in:
Rohan Jain 2020-09-27 09:07:06 -07:00 committed by TensorFlower Gardener
parent 3b382e8427
commit 7679280cbd
3 changed files with 35 additions and 7 deletions

View File

@ -33,12 +33,14 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
for the existing tensor cannot be re-used, a copy is made and updated.
If `indices` contains duplicates, then their updates are accumulated (summed).
If `indices` contains duplicates, then we pick the last update for the index.
**WARNING**: The order in which updates are applied is nondeterministic, so the
output will be nondeterministic if `indices` contains duplicates -- because
of some numerical approximation issues, numbers summed in different order
may yield different results.
If an out of bound index is found, an error is returned.
**WARNING**: There are some GPU specific semantics for this operation.
- If an out of bound index is found, the index is ignored.
- The order in which updates are applied is nondeterministic, so the output
will be nondeterministic if `indices` contains duplicates.
`indices` is an integer tensor containing indices into a new tensor of shape
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
@ -98,7 +100,5 @@ In Python, this scatter operation would look like this:
[1 1 1 1]
[1 1 1 1]]]
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, the index is ignored.
END
}

View File

@ -76,6 +76,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",

View File

@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
@ -909,6 +910,32 @@ class ScatterNdTensorTest(test.TestCase):
"hello", "there", "hello", "there", "there", "hello", "hello", "12"
]))
@test_util.run_in_graph_and_eager_modes
def testUpdateRepeatedIndices1D(self):
if test_util.is_gpu_available():
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
a = array_ops.zeros([10, 1])
b = array_ops.tensor_scatter_update(a, [[5], [5]], [[4], [8]])
self.assertAllEqual(
b,
constant_op.constant([[0.], [0.], [0.], [0.], [0.], [8.], [0.], [0.],
[0.], [0.]]))
@test_util.run_in_graph_and_eager_modes
def testUpdateRepeatedIndices2D(self):
if test_util.is_gpu_available():
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
a = array_ops.zeros([10, 10])
b = array_ops.tensor_scatter_update(
a, [[5], [6], [6]],
[math_ops.range(10),
math_ops.range(11, 21),
math_ops.range(10, 20)])
self.assertAllEqual(
b[6],
constant_op.constant(
[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]))
if __name__ == "__main__":
test.main()