Update gather_op_test and unique_op_test to use subTest for easier debugging.
PiperOrigin-RevId: 311595699 Change-Id: I1a8cf8b5b314aada4aeeece2603e975bc8a4ff42
This commit is contained in:
parent
66769844a5
commit
6db3caf99b
@ -62,14 +62,15 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
data = np.array([0, 1, 2, 3, 7, 5])
|
data = np.array([0, 1, 2, 3, 7, 5])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in _TEST_TYPES:
|
||||||
for indices in 4, [1, 2, 2, 4, 5]:
|
for indices in 4, [1, 2, 2, 4, 5]:
|
||||||
params_np = self._buildParams(data, dtype)
|
with self.subTest(dtype=dtype, indices=indices):
|
||||||
params = constant_op.constant(params_np)
|
params_np = self._buildParams(data, dtype)
|
||||||
indices_tf = constant_op.constant(indices)
|
params = constant_op.constant(params_np)
|
||||||
gather_t = array_ops.gather(params, indices_tf)
|
indices_tf = constant_op.constant(indices)
|
||||||
gather_val = self.evaluate(gather_t)
|
gather_t = array_ops.gather(params, indices_tf)
|
||||||
np_val = params_np[indices]
|
gather_val = self.evaluate(gather_t)
|
||||||
self.assertAllEqual(np_val, gather_val)
|
np_val = params_np[indices]
|
||||||
self.assertEqual(np_val.shape, gather_t.get_shape())
|
self.assertAllEqual(np_val, gather_val)
|
||||||
|
self.assertEqual(np_val.shape, gather_t.get_shape())
|
||||||
|
|
||||||
def testScalar2D(self):
|
def testScalar2D(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
@ -77,14 +78,15 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
[9, 10, 11], [12, 13, 14]])
|
[9, 10, 11], [12, 13, 14]])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in _TEST_TYPES:
|
||||||
for axis in range(data.ndim):
|
for axis in range(data.ndim):
|
||||||
params_np = self._buildParams(data, dtype)
|
with self.subTest(dtype=dtype, axis=axis):
|
||||||
params = constant_op.constant(params_np)
|
params_np = self._buildParams(data, dtype)
|
||||||
indices = constant_op.constant(2)
|
params = constant_op.constant(params_np)
|
||||||
gather_t = array_ops.gather(params, indices, axis=axis)
|
indices = constant_op.constant(2)
|
||||||
gather_val = self.evaluate(gather_t)
|
gather_t = array_ops.gather(params, indices, axis=axis)
|
||||||
self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val)
|
gather_val = self.evaluate(gather_t)
|
||||||
expected_shape = data.shape[:axis] + data.shape[axis + 1:]
|
self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val)
|
||||||
self.assertEqual(expected_shape, gather_t.get_shape())
|
expected_shape = data.shape[:axis] + data.shape[axis + 1:]
|
||||||
|
self.assertEqual(expected_shape, gather_t.get_shape())
|
||||||
|
|
||||||
def testSimpleTwoD32(self):
|
def testSimpleTwoD32(self):
|
||||||
with self.session(use_gpu=True):
|
with self.session(use_gpu=True):
|
||||||
@ -92,16 +94,17 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
[9, 10, 11], [12, 13, 14]])
|
[9, 10, 11], [12, 13, 14]])
|
||||||
for dtype in _TEST_TYPES:
|
for dtype in _TEST_TYPES:
|
||||||
for axis in range(data.ndim):
|
for axis in range(data.ndim):
|
||||||
params_np = self._buildParams(data, dtype)
|
with self.subTest(dtype=dtype, axis=axis):
|
||||||
params = constant_op.constant(params_np)
|
params_np = self._buildParams(data, dtype)
|
||||||
# The indices must be in bounds for any axis.
|
params = constant_op.constant(params_np)
|
||||||
indices = constant_op.constant([0, 1, 0, 2])
|
# The indices must be in bounds for any axis.
|
||||||
gather_t = array_ops.gather(params, indices, axis=axis)
|
indices = constant_op.constant([0, 1, 0, 2])
|
||||||
gather_val = self.evaluate(gather_t)
|
gather_t = array_ops.gather(params, indices, axis=axis)
|
||||||
self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis),
|
gather_val = self.evaluate(gather_t)
|
||||||
gather_val)
|
self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis),
|
||||||
expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
|
gather_val)
|
||||||
self.assertEqual(expected_shape, gather_t.get_shape())
|
expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
|
||||||
|
self.assertEqual(expected_shape, gather_t.get_shape())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testHigherRank(self):
|
def testHigherRank(self):
|
||||||
@ -112,58 +115,60 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
for axis in range(len(shape)):
|
for axis in range(len(shape)):
|
||||||
params = self._buildParams(np.random.randn(*shape), dtype)
|
params = self._buildParams(np.random.randn(*shape), dtype)
|
||||||
indices = np.random.randint(shape[axis], size=indices_shape)
|
indices = np.random.randint(shape[axis], size=indices_shape)
|
||||||
with self.cached_session(use_gpu=True) as sess:
|
with self.subTest(indices_shape=indices_shape, dtype=dtype, axis=axis,
|
||||||
tf_params = constant_op.constant(params)
|
indices=indices):
|
||||||
tf_indices = constant_op.constant(indices)
|
with self.cached_session(use_gpu=True) as sess:
|
||||||
# Check that both positive and negative indices for axis work.
|
tf_params = constant_op.constant(params)
|
||||||
tf_axis = constant_op.constant(axis)
|
tf_indices = constant_op.constant(indices)
|
||||||
tf_negative_axis = constant_op.constant(-len(shape) + axis)
|
# Check that both positive and negative indices for axis work.
|
||||||
gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
|
tf_axis = constant_op.constant(axis)
|
||||||
gather_negative_axis = array_ops.gather(
|
tf_negative_axis = constant_op.constant(-len(shape) + axis)
|
||||||
tf_params, tf_indices, axis=tf_negative_axis)
|
gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
|
||||||
gather_value, gather_negative_axis_value = sess.run(
|
gather_negative_axis = array_ops.gather(
|
||||||
[gather, gather_negative_axis])
|
tf_params, tf_indices, axis=tf_negative_axis)
|
||||||
gather_np = np.take(params, indices, axis)
|
gather_value, gather_negative_axis_value = sess.run(
|
||||||
self.assertAllEqual(gather_np, gather_value)
|
[gather, gather_negative_axis])
|
||||||
self.assertAllEqual(gather_np, gather_negative_axis_value)
|
gather_np = np.take(params, indices, axis)
|
||||||
expected_shape = (params.shape[:axis] + indices.shape +
|
self.assertAllEqual(gather_np, gather_value)
|
||||||
params.shape[axis + 1:])
|
self.assertAllEqual(gather_np, gather_negative_axis_value)
|
||||||
self.assertEqual(expected_shape, gather.shape)
|
expected_shape = (params.shape[:axis] + indices.shape +
|
||||||
self.assertEqual(expected_shape, gather_negative_axis.shape)
|
params.shape[axis + 1:])
|
||||||
|
self.assertEqual(expected_shape, gather.shape)
|
||||||
|
self.assertEqual(expected_shape, gather_negative_axis.shape)
|
||||||
|
|
||||||
# Test gradients
|
# Test gradients
|
||||||
gather_grad = np.random.randn(
|
gather_grad = np.random.randn(
|
||||||
*gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
|
*gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
|
||||||
if dtype.is_complex:
|
if dtype.is_complex:
|
||||||
gather_grad -= 1j * gather_grad
|
gather_grad -= 1j * gather_grad
|
||||||
params_grad, indices_grad, axis_grad = gradients_impl.gradients(
|
params_grad, indices_grad, axis_grad = gradients_impl.gradients(
|
||||||
gather, [tf_params, tf_indices, tf_axis], gather_grad)
|
gather, [tf_params, tf_indices, tf_axis], gather_grad)
|
||||||
self.assertEqual(indices_grad, None)
|
self.assertEqual(indices_grad, None)
|
||||||
self.assertEqual(axis_grad, None)
|
self.assertEqual(axis_grad, None)
|
||||||
if dtype.is_integer:
|
if dtype.is_integer:
|
||||||
self.assertEqual(params_grad, None)
|
self.assertEqual(params_grad, None)
|
||||||
continue
|
continue
|
||||||
# For axis 0, we are able to create an efficient IndexedSlices for
|
# For axis 0, we are able to create an efficient IndexedSlices for
|
||||||
# the gradient.
|
# the gradient.
|
||||||
if axis == 0:
|
if axis == 0:
|
||||||
self.assertEqual(type(params_grad), ops.IndexedSlices)
|
self.assertEqual(type(params_grad), ops.IndexedSlices)
|
||||||
params_grad = ops.convert_to_tensor(params_grad)
|
params_grad = ops.convert_to_tensor(params_grad)
|
||||||
correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
|
correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
|
||||||
outer_dims = axis
|
outer_dims = axis
|
||||||
inner_dims = len(shape) - axis - 1
|
inner_dims = len(shape) - axis - 1
|
||||||
gather_grad = gather_grad.reshape(
|
gather_grad = gather_grad.reshape(
|
||||||
shape[:axis] + (indices.size,) + shape[axis + 1:])
|
shape[:axis] + (indices.size,) + shape[axis + 1:])
|
||||||
for source_index, dest_index in enumerate(indices.flat):
|
for source_index, dest_index in enumerate(indices.flat):
|
||||||
dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
|
dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
|
||||||
(slice(None),) * inner_dims)
|
|
||||||
source_slice = ((slice(None),) * outer_dims + (source_index,) +
|
|
||||||
(slice(None),) * inner_dims)
|
(slice(None),) * inner_dims)
|
||||||
correct_params_grad[dest_slice] += gather_grad[source_slice]
|
source_slice = ((slice(None),) * outer_dims + (source_index,) +
|
||||||
self.assertAllClose(
|
(slice(None),) * inner_dims)
|
||||||
correct_params_grad,
|
correct_params_grad[dest_slice] += gather_grad[source_slice]
|
||||||
self.evaluate(params_grad),
|
self.assertAllClose(
|
||||||
atol=2e-6,
|
correct_params_grad,
|
||||||
rtol=2e-6)
|
self.evaluate(params_grad),
|
||||||
|
atol=2e-6,
|
||||||
|
rtol=2e-6)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testString(self):
|
def testString(self):
|
||||||
@ -177,12 +182,14 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testUInt32AndUInt64(self):
|
def testUInt32AndUInt64(self):
|
||||||
for unsigned_type in (dtypes.uint32, dtypes.uint64):
|
for unsigned_type in (dtypes.uint32, dtypes.uint64):
|
||||||
params = self._buildParams(
|
with self.subTest(unsigned_type=unsigned_type):
|
||||||
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
|
params = self._buildParams(
|
||||||
with self.cached_session():
|
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
|
||||||
self.assertAllEqual([7, 8, 9],
|
with self.cached_session():
|
||||||
array_ops.gather(params, 1, axis=0).eval())
|
self.assertAllEqual([7, 8, 9],
|
||||||
self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
|
array_ops.gather(params, 1, axis=0).eval())
|
||||||
|
self.assertAllEqual([1, 7],
|
||||||
|
array_ops.gather(params, 0, axis=1).eval())
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testUnknownIndices(self):
|
def testUnknownIndices(self):
|
||||||
@ -237,14 +244,15 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
indices = 0
|
indices = 0
|
||||||
for bad_axis in (1, 2, -2):
|
for bad_axis in (1, 2, -2):
|
||||||
# Shape inference can validate axis for known params rank.
|
# Shape inference can validate axis for known params rank.
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.subTest(bad_axis=bad_axis):
|
||||||
ValueError, "Shape must be at least rank . but is rank 1"):
|
with self.assertRaisesWithPredicateMatch(
|
||||||
array_ops.gather(params, indices, axis=bad_axis)
|
ValueError, "Shape must be at least rank . but is rank 1"):
|
||||||
# If params rank is unknown, an op error occurs.
|
array_ops.gather(params, indices, axis=bad_axis)
|
||||||
with self.assertRaisesOpError(
|
# If params rank is unknown, an op error occurs.
|
||||||
r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis):
|
with self.assertRaisesOpError(
|
||||||
array_ops.gather(params_ph, indices, axis=bad_axis).eval(
|
r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis):
|
||||||
feed_dict={params_ph: params})
|
array_ops.gather(params_ph, indices, axis=bad_axis).eval(
|
||||||
|
feed_dict={params_ph: params})
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testEmptySlices(self):
|
def testEmptySlices(self):
|
||||||
@ -252,20 +260,21 @@ class GatherTest(test.TestCase, parameterized.TestCase):
|
|||||||
for dtype in _TEST_TYPES:
|
for dtype in _TEST_TYPES:
|
||||||
for itype in np.int32, np.int64:
|
for itype in np.int32, np.int64:
|
||||||
# Leading axis gather.
|
# Leading axis gather.
|
||||||
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
|
with self.subTest(dtype=dtype, itype=itype):
|
||||||
indices = np.array([3, 4], dtype=itype)
|
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
|
||||||
gather = array_ops.gather(params, indices, axis=0)
|
indices = np.array([3, 4], dtype=itype)
|
||||||
self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
|
gather = array_ops.gather(params, indices, axis=0)
|
||||||
|
self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
|
||||||
|
|
||||||
# Middle axis gather.
|
# Middle axis gather.
|
||||||
params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
|
params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
|
||||||
gather = array_ops.gather(params, indices, axis=1)
|
gather = array_ops.gather(params, indices, axis=1)
|
||||||
self.assertAllEqual(gather.eval(), np.zeros((0, 2, 0)))
|
self.assertAllEqual(gather.eval(), np.zeros((0, 2, 0)))
|
||||||
|
|
||||||
# Trailing axis gather.
|
# Trailing axis gather.
|
||||||
params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
|
params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
|
||||||
gather = array_ops.gather(params, indices, axis=2)
|
gather = array_ops.gather(params, indices, axis=2)
|
||||||
self.assertAllEqual(gather.eval(), np.zeros((0, 0, 2)))
|
self.assertAllEqual(gather.eval(), np.zeros((0, 0, 2)))
|
||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
# batch_dims=0 (equivalent to tf.gather)
|
# batch_dims=0 (equivalent to tf.gather)
|
||||||
|
@ -61,17 +61,18 @@ class UniqueTest(test.TestCase):
|
|||||||
|
|
||||||
def testInt32Axis(self):
|
def testInt32Axis(self):
|
||||||
for dtype in [np.int32, np.int64]:
|
for dtype in [np.int32, np.int64]:
|
||||||
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
with self.subTest(dtype=dtype):
|
||||||
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
|
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
||||||
self.assertEqual(y0.shape.rank, 2)
|
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
|
||||||
tf_y0, tf_idx0 = self.evaluate([y0, idx0])
|
self.assertEqual(y0.shape.rank, 2)
|
||||||
y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
|
tf_y0, tf_idx0 = self.evaluate([y0, idx0])
|
||||||
self.assertEqual(y1.shape.rank, 2)
|
y1, idx1 = gen_array_ops.unique_v2(x, axis=np.array([1], dtype))
|
||||||
tf_y1, tf_idx1 = self.evaluate([y1, idx1])
|
self.assertEqual(y1.shape.rank, 2)
|
||||||
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
tf_y1, tf_idx1 = self.evaluate([y1, idx1])
|
||||||
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
||||||
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
||||||
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
||||||
|
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
||||||
|
|
||||||
def testInt32V2(self):
|
def testInt32V2(self):
|
||||||
# This test is only temporary, once V2 is used
|
# This test is only temporary, once V2 is used
|
||||||
@ -144,26 +145,28 @@ class UniqueWithCountsTest(test.TestCase):
|
|||||||
for i in range(len(x)):
|
for i in range(len(x)):
|
||||||
self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
|
self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
|
||||||
for value, count in zip(tf_y, tf_count):
|
for value, count in zip(tf_y, tf_count):
|
||||||
v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
|
with self.subTest(value=value, count=count):
|
||||||
self.assertEqual(count, sum(v))
|
v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
|
||||||
|
self.assertEqual(count, sum(v))
|
||||||
|
|
||||||
def testInt32Axis(self):
|
def testInt32Axis(self):
|
||||||
for dtype in [np.int32, np.int64]:
|
for dtype in [np.int32, np.int64]:
|
||||||
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
with self.subTest(dtype=dtype):
|
||||||
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
|
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
|
||||||
x, axis=np.array([0], dtype))
|
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
|
||||||
self.assertEqual(y0.shape.rank, 2)
|
x, axis=np.array([0], dtype))
|
||||||
tf_y0, tf_idx0, tf_count0 = self.evaluate([y0, idx0, count0])
|
self.assertEqual(y0.shape.rank, 2)
|
||||||
y1, idx1, count1 = gen_array_ops.unique_with_counts_v2(
|
tf_y0, tf_idx0, tf_count0 = self.evaluate([y0, idx0, count0])
|
||||||
x, axis=np.array([1], dtype))
|
y1, idx1, count1 = gen_array_ops.unique_with_counts_v2(
|
||||||
self.assertEqual(y1.shape.rank, 2)
|
x, axis=np.array([1], dtype))
|
||||||
tf_y1, tf_idx1, tf_count1 = self.evaluate([y1, idx1, count1])
|
self.assertEqual(y1.shape.rank, 2)
|
||||||
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
tf_y1, tf_idx1, tf_count1 = self.evaluate([y1, idx1, count1])
|
||||||
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
|
||||||
self.assertAllEqual(tf_count0, np.array([2, 1]))
|
self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
|
||||||
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
self.assertAllEqual(tf_count0, np.array([2, 1]))
|
||||||
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
|
||||||
self.assertAllEqual(tf_count1, np.array([1, 2]))
|
self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
|
||||||
|
self.assertAllEqual(tf_count1, np.array([1, 2]))
|
||||||
|
|
||||||
def testInt32V2(self):
|
def testInt32V2(self):
|
||||||
# This test is only temporary, once V2 is used
|
# This test is only temporary, once V2 is used
|
||||||
|
Loading…
x
Reference in New Issue
Block a user