* Use TensorScatterAdd to correctly handle repeated indices * Support Complex typed operands Also, enabled complex typed operands for ZerosLikeOp. PiperOrigin-RevId: 338347731 Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1
227 lines
8.4 KiB
Python
227 lines
8.4 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tensorflow.ops.tf.scatter_nd."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import functools
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.compiler.tests import xla_test
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.platform import test
|
|
|
|
|
|
def _AsType(v, vtype):
|
|
return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
|
|
|
|
|
|
def _FlatInnerDims(tensor, ndims=2):
|
|
shape = list(tensor.shape)
|
|
return tensor.reshape(
|
|
[functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] +
|
|
shape[-ndims + 1:])
|
|
|
|
|
|
def _FlatOuterDims(tensor, ndims=2):
|
|
shape = list(tensor.shape)
|
|
return tensor.reshape(
|
|
shape[:ndims - 1] +
|
|
[functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)])
|
|
|
|
|
|
def _NumpyScatterNd(ref, indices, updates, op):
|
|
ixdim = indices.shape[-1]
|
|
num_updates = indices.size // ixdim
|
|
total_nd = len(ref.shape)
|
|
slice_size = 1
|
|
for i in range(ixdim, total_nd):
|
|
slice_size *= ref.shape[i]
|
|
flat_indices = _FlatInnerDims(indices)
|
|
flat_updates = updates.reshape((num_updates, slice_size))
|
|
output_flat = _FlatOuterDims(ref, ixdim + 1)
|
|
for ix_updates, ix_output in enumerate(flat_indices):
|
|
ix_output = tuple(ix_output)
|
|
output_flat[ix_output] = op(output_flat[ix_output],
|
|
flat_updates[ix_updates])
|
|
return output_flat.reshape(ref.shape)
|
|
|
|
|
|
def _NumpyUpdate(indices, updates, shape):
|
|
ref = np.zeros(shape, dtype=updates.dtype)
|
|
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
|
|
|
|
|
|
class ScatterNdTest(xla_test.XLATestCase):
|
|
|
|
def _VariableRankTest(self,
|
|
np_scatter,
|
|
tf_scatter,
|
|
vtype,
|
|
itype,
|
|
repeat_indices=False):
|
|
np.random.seed(8)
|
|
ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
|
|
indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
|
|
for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
|
|
num_updates = indices_shape[0]
|
|
ixdim = indices_shape[-1]
|
|
|
|
indexable_area_shape = ()
|
|
for i in range(ixdim):
|
|
indexable_area_shape += (ref_shape[i],)
|
|
all_indices = [
|
|
list(coord)
|
|
for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype))
|
|
]
|
|
np.random.shuffle(all_indices)
|
|
indices = np.array(all_indices[:num_updates])
|
|
|
|
if num_updates > 1 and repeat_indices:
|
|
indices = indices[:num_updates // 2]
|
|
for _ in range(num_updates - num_updates // 2):
|
|
indices = np.append(
|
|
indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
|
|
np.random.shuffle(indices)
|
|
indices = _AsType(indices[:num_updates], itype)
|
|
|
|
updates_shape = (num_updates,)
|
|
for i in range(ixdim, len(ref_shape)):
|
|
updates_shape += (ref_shape[i],)
|
|
updates = _AsType(np.random.randn(*(updates_shape)), vtype)
|
|
|
|
# Scatter via numpy
|
|
np_out = np_scatter(indices, updates, ref_shape)
|
|
# Scatter via tensorflow
|
|
tf_out = tf_scatter(indices, updates, ref_shape)
|
|
|
|
self.assertAllClose(np_out, tf_out)
|
|
|
|
def _VariableRankTests(self, np_scatter, tf_scatter):
|
|
for vtype in self.numeric_types:
|
|
for itype in set([np.int32, np.int64]).intersection(set(self.int_types)):
|
|
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
|
|
|
|
def _runScatterNd(self, indices, updates, shape):
|
|
with self.session():
|
|
updates_placeholder = array_ops.placeholder(updates.dtype)
|
|
indices_placeholder = array_ops.placeholder(indices.dtype)
|
|
with self.test_scope():
|
|
output = array_ops.scatter_nd(indices_placeholder, updates_placeholder,
|
|
shape)
|
|
feed_dict = {updates_placeholder: updates, indices_placeholder: indices}
|
|
return output.eval(feed_dict=feed_dict)
|
|
|
|
def testSimple(self):
|
|
indices = np.array([[4], [3], [1], [7]], dtype=np.int32)
|
|
updates = np.array([9, 10, 11, 12], dtype=np.float32)
|
|
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32)
|
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8]))
|
|
|
|
def testRepeatedIndices(self):
|
|
indices = np.array([[0], [1], [0], [1]], dtype=np.int32)
|
|
updates = np.array([9, 10, 11, 12], dtype=np.float32)
|
|
expected = np.array([20, 22], dtype=np.int32)
|
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2]))
|
|
|
|
def testSimple2(self):
|
|
indices = np.array([[1, 0], [1, 1]], dtype=np.int32)
|
|
updates = np.array([11., 12.], dtype=np.float32)
|
|
expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32)
|
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
|
|
|
|
def testSimple3(self):
|
|
indices = np.array([[1]], dtype=np.int32)
|
|
updates = np.array([[11., 12.]], dtype=np.float32)
|
|
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
|
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
|
|
|
|
def testVariableRankUpdate(self):
|
|
self._VariableRankTests(_NumpyUpdate, self._runScatterNd)
|
|
|
|
def testExtraIndicesDimensions(self):
|
|
indices = np.zeros([1, 1, 2], np.int32)
|
|
updates = np.zeros([1, 1], np.int32)
|
|
expected = np.zeros([2, 2], dtype=np.int32)
|
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
|
|
|
|
@test_util.disable_mlir_bridge("Error messages differ")
|
|
def testRank3InvalidShape1(self):
|
|
indices = np.zeros([3, 2, 2], np.int32)
|
|
updates = np.zeros([2, 2, 2], np.int32)
|
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
|
"Must have updates.shape"):
|
|
self._runScatterNd(indices, updates, [2, 2, 2])
|
|
|
|
@test_util.disable_mlir_bridge("Error messages differ")
|
|
def testRank3InvalidShape2(self):
|
|
indices = np.zeros([2, 2, 1], np.int32)
|
|
updates = np.zeros([2, 2], np.int32)
|
|
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
|
"Must have updates.shape"):
|
|
self._runScatterNd(indices, updates, [2, 2, 2])
|
|
|
|
def testScatterOutOfRange(self):
|
|
updates = np.array([-3, -4, -5]).astype(np.float32)
|
|
|
|
# Indices all in range, no problem.
|
|
indices = np.array([[2], [0], [5]], dtype=np.int32)
|
|
self._runScatterNd(indices, updates, [6])
|
|
|
|
# Indices out of range should not fail. It produces implementation-defined
|
|
# output.
|
|
indices = np.array([[-1], [0], [5]], dtype=np.int32)
|
|
self._runScatterNd(indices, updates, [6])
|
|
indices = np.array([[2], [0], [6]], dtype=np.int32)
|
|
self._runScatterNd(indices, updates, [6])
|
|
|
|
|
|
class ScatterNdTensorTest(xla_test.XLATestCase):
|
|
|
|
def _runScatter(self, op):
|
|
indices_np = np.array([[4], [3], [1], [7]], dtype=np.int32)
|
|
updates_np = np.array([9, 10, 11, 12], dtype=np.float32)
|
|
with self.session() as sess, self.test_scope():
|
|
indices = array_ops.placeholder(indices_np.dtype, shape=indices_np.shape)
|
|
updates = array_ops.placeholder(updates_np.dtype, shape=updates_np.shape)
|
|
t = array_ops.ones([8], dtype=np.float32)
|
|
|
|
out = op(t, indices, updates)
|
|
return sess.run(out, feed_dict={indices: indices_np, updates: updates_np})
|
|
|
|
def testAdd(self):
|
|
self.assertAllEqual(
|
|
self._runScatter(array_ops.tensor_scatter_add),
|
|
np.array([1, 12, 1, 11, 10, 1, 1, 13], dtype=np.float32))
|
|
|
|
def testSub(self):
|
|
self.assertAllEqual(
|
|
self._runScatter(array_ops.tensor_scatter_sub),
|
|
np.array([1, -10, 1, -9, -8, 1, 1, -11], dtype=np.float32))
|
|
|
|
def testUpdate(self):
|
|
self.assertAllEqual(
|
|
self._runScatter(array_ops.tensor_scatter_update),
|
|
np.array([1, 11, 1, 10, 9, 1, 1, 12], dtype=np.float32))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test.main()
|