Fixing run_v1_decorator for gather_op_test.py. Also moving it to an array_ops folder.

PiperOrigin-RevId: 324302642
Change-Id: I85d54f334537a25a4c7f6d6eaeb17721cce39e25
This commit is contained in:
Rohan Jain 2020-07-31 16:05:20 -07:00 committed by TensorFlower Gardener
parent cea9f19ebf
commit 5594fde93b
4 changed files with 156 additions and 103 deletions

View File

@ -78,10 +78,11 @@ class GatherOp : public OpKernel {
}
}
int64 min_params_dim = axis < 0 ? -axis : axis + 1;
OP_REQUIRES(
c, axis >= -params.dims() && axis < params.dims(),
errors::InvalidArgument("Expected axis in the range [", -params.dims(),
", ", params.dims(), "), but got ", axis));
c, params.dims() >= min_params_dim,
errors::InvalidArgument("Shape must be at least rank ", min_params_dim,
" but is rank ", params.dims()));
if (axis < 0) {
axis = params.dims() + axis;

View File

@ -2035,20 +2035,6 @@ cuda_py_test(
],
)
cuda_py_test(
name = "gather_op_test",
size = "medium",
srcs = ["gather_op_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
cuda_py_test(
name = "gradient_correctness_test",
size = "small",

View File

@ -46,3 +46,17 @@ cuda_py_test(
"//third_party/py/numpy",
],
)
cuda_py_test(
name = "gather_op_test",
size = "medium",
srcs = ["gather_op_test.py"],
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -107,18 +107,20 @@ class GatherTest(test.TestCase, parameterized.TestCase):
expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
self.assertEqual(expected_shape, gather_t.get_shape())
@test_util.run_deprecated_v1
def testHigherRank(self):
# We check that scalar and empty indices shapes work as well
shape = (2, 1, 3, 2)
for indices_shape in (), (0,), (2, 0), (2, 3):
for dtype in _TEST_TYPES:
for axis in range(len(shape)):
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
with self.subTest(indices_shape=indices_shape, dtype=dtype, axis=axis,
indices=indices):
with self.cached_session(use_gpu=True) as sess:
with ops.Graph().as_default():
# We check that scalar and empty indices shapes work as well
shape = (2, 1, 3, 2)
for indices_shape in (), (0,), (2, 0), (2, 3):
for dtype in _TEST_TYPES:
for axis in range(len(shape)):
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
with self.subTest(
indices_shape=indices_shape,
dtype=dtype,
axis=axis,
indices=indices):
tf_params = constant_op.constant(params)
tf_indices = constant_op.constant(indices)
# Check that both positive and negative indices for axis work.
@ -127,7 +129,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
gather_negative_axis = array_ops.gather(
tf_params, tf_indices, axis=tf_negative_axis)
gather_value, gather_negative_axis_value = sess.run(
gather_value, gather_negative_axis_value = self.evaluate(
[gather, gather_negative_axis])
gather_np = np.take(params, indices, axis)
self.assertAllEqual(gather_np, gather_value)
@ -144,10 +146,10 @@ class GatherTest(test.TestCase, parameterized.TestCase):
gather_grad -= 1j * gather_grad
params_grad, indices_grad, axis_grad = gradients_impl.gradients(
gather, [tf_params, tf_indices, tf_axis], gather_grad)
self.assertEqual(indices_grad, None)
self.assertEqual(axis_grad, None)
self.assertIsNone(indices_grad)
self.assertIsNone(axis_grad)
if dtype.is_integer:
self.assertEqual(params_grad, None)
self.assertIsNone(params_grad)
continue
# For axis 0, we are able to create an efficient IndexedSlices for
# the gradient.
@ -171,47 +173,113 @@ class GatherTest(test.TestCase, parameterized.TestCase):
atol=2e-6,
rtol=2e-6)
@test_util.run_deprecated_v1
def testHigherRankGradientTape(self):
# We check that scalar and empty indices shapes work as well
shape = (2, 1, 3, 2)
for indices_shape in (), (0,), (2, 0), (2, 3):
for dtype in _TEST_TYPES:
for axis in range(len(shape)):
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
with self.subTest(
indices_shape=indices_shape,
dtype=dtype,
axis=axis,
indices=indices):
with backprop.GradientTape() as tape:
tf_params = constant_op.constant(params)
tf_indices = constant_op.constant(indices)
# Check that both positive and negative indices for axis work.
tf_axis = constant_op.constant(axis)
tape.watch(tf_params)
tape.watch(tf_indices)
tape.watch(tf_axis)
tf_negative_axis = constant_op.constant(-len(shape) + axis)
gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
gather_negative_axis = array_ops.gather(
tf_params, tf_indices, axis=tf_negative_axis)
gather_value, gather_negative_axis_value = self.evaluate(
[gather, gather_negative_axis])
gather_np = np.take(params, indices, axis)
self.assertAllEqual(gather_np, gather_value)
self.assertAllEqual(gather_np, gather_negative_axis_value)
expected_shape = (
params.shape[:axis] + indices.shape + params.shape[axis + 1:])
self.assertEqual(expected_shape, gather.shape)
self.assertEqual(expected_shape, gather_negative_axis.shape)
# Test gradients
gather_grad = np.random.randn(
*gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
gather_grad -= 1j * gather_grad
params_grad, indices_grad, axis_grad = tape.gradient(
gather, [tf_params, tf_indices, tf_axis], gather_grad)
self.assertIsNone(indices_grad)
self.assertIsNone(axis_grad)
if dtype.is_integer:
self.assertIsNone(params_grad)
continue
# For axis 0, we are able to create an efficient IndexedSlices for
# the gradient.
if axis == 0:
self.assertEqual(type(params_grad), ops.IndexedSlices)
params_grad = ops.convert_to_tensor(params_grad)
correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
outer_dims = axis
inner_dims = len(shape) - axis - 1
gather_grad = gather_grad.reshape(shape[:axis] + (indices.size,) +
shape[axis + 1:])
for source_index, dest_index in enumerate(indices.flat):
dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
(slice(None),) * inner_dims)
source_slice = ((slice(None),) * outer_dims + (source_index,) +
(slice(None),) * inner_dims)
correct_params_grad[dest_slice] += gather_grad[source_slice]
self.assertAllClose(
correct_params_grad,
self.evaluate(params_grad),
atol=2e-6,
rtol=2e-6)
def testString(self):
params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
with self.cached_session():
self.assertAllEqual([b"qwer", b"uiop"],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([b"asdf", b"qwer"],
array_ops.gather(params, 0, axis=1).eval())
self.assertAllEqual([b"qwer", b"uiop"], array_ops.gather(params, 1, axis=0))
self.assertAllEqual([b"asdf", b"qwer"], array_ops.gather(params, 0, axis=1))
@test_util.run_deprecated_v1
def testUInt32AndUInt64(self):
for unsigned_type in (dtypes.uint32, dtypes.uint64):
with self.subTest(unsigned_type=unsigned_type):
params = self._buildParams(
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
with self.cached_session():
self.assertAllEqual([7, 8, 9],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([1, 7],
array_ops.gather(params, 0, axis=1).eval())
self.assertAllEqual([7, 8, 9], array_ops.gather(params, 1, axis=0))
self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1))
@test_util.run_deprecated_v1
def testUnknownIndices(self):
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices)
self.assertEqual(None, gather_t.get_shape())
# This test is purely a test for placeholder inputs which is only applicable
# in graph mode.
with ops.Graph().as_default():
params = constant_op.constant([[0, 1, 2]])
indices = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices)
self.assertEqual(None, gather_t.get_shape())
@test_util.run_deprecated_v1
def testUnknownAxis(self):
params = constant_op.constant([[0, 1, 2]])
indices = constant_op.constant([[0, 0], [0, 0]])
axis = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices, axis=axis)
# Rank 2 params with rank 2 indices results in a rank 3 shape.
self.assertEqual([None, None, None], gather_t.shape.as_list())
# This test is purely a test for placeholder inputs which is only applicable
# in graph mode.
with ops.Graph().as_default():
params = constant_op.constant([[0, 1, 2]])
indices = constant_op.constant([[0, 0], [0, 0]])
axis = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices, axis=axis)
# Rank 2 params with rank 2 indices results in a rank 3 shape.
self.assertEqual([None, None, None], gather_t.shape.as_list())
# If indices is also unknown the result rank is unknown.
indices = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices, axis=axis)
self.assertEqual(None, gather_t.shape)
# If indices is also unknown the result rank is unknown.
indices = array_ops.placeholder(dtypes.int32)
gather_t = array_ops.gather(params, indices, axis=axis)
self.assertEqual(None, gather_t.shape)
def testBadIndicesType(self):
with self.assertRaisesRegex(
@ -243,45 +311,36 @@ class GatherTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
array_ops.gather(params, [[7]], axis=1).eval()
@test_util.run_deprecated_v1
def testBadAxis(self):
with self.session(use_gpu=True):
params = [0, 1, 2]
params_ph = array_ops.placeholder(dtypes.int32)
indices = 0
for bad_axis in (1, 2, -2):
# Shape inference can validate axis for known params rank.
with self.subTest(bad_axis=bad_axis):
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape must be at least rank . but is rank 1"):
array_ops.gather(params, indices, axis=bad_axis)
# If params rank is unknown, an op error occurs.
with self.assertRaisesOpError(
r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis):
array_ops.gather(params_ph, indices, axis=bad_axis).eval(
feed_dict={params_ph: params})
params = [0, 1, 2]
indices = 0
for bad_axis in (1, 2, -2):
# Shape inference can validate axis for known params rank.
with self.subTest(bad_axis=bad_axis):
with self.assertRaisesRegex(
(ValueError, errors.InvalidArgumentError),
"Shape must be at least rank .* but is rank 1"):
array_ops.gather(params, indices, axis=bad_axis)
@test_util.run_deprecated_v1
def testEmptySlices(self):
with self.session(use_gpu=True):
for dtype in _TEST_TYPES:
for itype in np.int32, np.int64:
# Leading axis gather.
with self.subTest(dtype=dtype, itype=itype):
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
indices = np.array([3, 4], dtype=itype)
gather = array_ops.gather(params, indices, axis=0)
self.assertAllEqual(gather, np.zeros((2, 0, 0)))
for dtype in _TEST_TYPES:
for itype in np.int32, np.int64:
# Leading axis gather.
with self.subTest(dtype=dtype, itype=itype):
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
indices = np.array([3, 4], dtype=itype)
gather = array_ops.gather(params, indices, axis=0)
self.assertAllEqual(gather, np.zeros((2, 0, 0)))
# Middle axis gather.
params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
gather = array_ops.gather(params, indices, axis=1)
self.assertAllEqual(gather, np.zeros((0, 2, 0)))
# Middle axis gather.
params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
gather = array_ops.gather(params, indices, axis=1)
self.assertAllEqual(gather, np.zeros((0, 2, 0)))
# Trailing axis gather.
params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
gather = array_ops.gather(params, indices, axis=2)
self.assertAllEqual(gather, np.zeros((0, 0, 2)))
# Trailing axis gather.
params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
gather = array_ops.gather(params, indices, axis=2)
self.assertAllEqual(gather, np.zeros((0, 0, 2)))
@parameterized.parameters([
# batch_dims=0 (equivalent to tf.gather)
@ -385,20 +444,13 @@ class GatherTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected, result)
# Test the gradients shape.
if context.executing_eagerly():
with backprop.GradientTape() as tape:
zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
tape.watch(zeros)
values = zeros * 2 + zeros
result = array_ops.gather(
values, indices, axis=axis, batch_dims=batch_dims)
gradients = tape.gradient(result, zeros)
else:
with backprop.GradientTape() as tape:
zeros = array_ops.zeros_like(params, dtype=dtypes.float32)
tape.watch(zeros)
values = zeros * 2 + zeros
result = array_ops.gather(
values, indices, axis=axis, batch_dims=batch_dims)
gradients = gradients_impl.gradients(result, [zeros])[0]
gradients = tape.gradient(result, zeros)
self.assertAllEqual(array_ops.shape(params), array_ops.shape(gradients))