Updating documentation for tf.scatter_nd_update to reflect the actual implementation, which is in case of duplicates we pick the first one. Also added tests for both CPU and GPU to assert this behavior.
PiperOrigin-RevId: 334021744 Change-Id: Ib045d04a9631b6244327d0acac6ad44ed4b5c65e
This commit is contained in:
parent
3b382e8427
commit
7679280cbd
@ -33,12 +33,14 @@ This operation is very similar to `tf.scatter_nd`, except that the updates are
|
||||
scattered onto an existing tensor (as opposed to a zero-tensor). If the memory
|
||||
for the existing tensor cannot be re-used, a copy is made and updated.
|
||||
|
||||
If `indices` contains duplicates, then their updates are accumulated (summed).
|
||||
If `indices` contains duplicates, then we pick the last update for the index.
|
||||
|
||||
**WARNING**: The order in which updates are applied is nondeterministic, so the
|
||||
output will be nondeterministic if `indices` contains duplicates -- because
|
||||
of some numerical approximation issues, numbers summed in different order
|
||||
may yield different results.
|
||||
If an out of bound index is found, an error is returned.
|
||||
|
||||
**WARNING**: There are some GPU specific semantics for this operation.
|
||||
- If an out of bound index is found, the index is ignored.
|
||||
- The order in which updates are applied is nondeterministic, so the output
|
||||
will be nondeterministic if `indices` contains duplicates.
|
||||
|
||||
`indices` is an integer tensor containing indices into a new tensor of shape
|
||||
`shape`. The last dimension of `indices` can be at most the rank of `shape`:
|
||||
@ -98,7 +100,5 @@ In Python, this scatter operation would look like this:
|
||||
[1 1 1 1]
|
||||
[1 1 1 1]]]
|
||||
|
||||
Note that on CPU, if an out of bound index is found, an error is returned.
|
||||
On GPU, if an out of bound index is found, the index is ignored.
|
||||
END
|
||||
}
|
||||
|
@ -76,6 +76,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:resource_variable_ops",
|
||||
"//tensorflow/python:state_ops",
|
||||
"//tensorflow/python:variables",
|
||||
|
@ -32,6 +32,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker_v2
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -909,6 +910,32 @@ class ScatterNdTensorTest(test.TestCase):
|
||||
"hello", "there", "hello", "there", "there", "hello", "hello", "12"
|
||||
]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testUpdateRepeatedIndices1D(self):
|
||||
if test_util.is_gpu_available():
|
||||
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
|
||||
a = array_ops.zeros([10, 1])
|
||||
b = array_ops.tensor_scatter_update(a, [[5], [5]], [[4], [8]])
|
||||
self.assertAllEqual(
|
||||
b,
|
||||
constant_op.constant([[0.], [0.], [0.], [0.], [0.], [8.], [0.], [0.],
|
||||
[0.], [0.]]))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testUpdateRepeatedIndices2D(self):
|
||||
if test_util.is_gpu_available():
|
||||
self.skipTest("Duplicate indices scatter is non-deterministic on GPU")
|
||||
a = array_ops.zeros([10, 10])
|
||||
b = array_ops.tensor_scatter_update(
|
||||
a, [[5], [6], [6]],
|
||||
[math_ops.range(10),
|
||||
math_ops.range(11, 21),
|
||||
math_ops.range(10, 20)])
|
||||
self.assertAllEqual(
|
||||
b[6],
|
||||
constant_op.constant(
|
||||
[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user