Update gather_op_test and unique_op_test to use subTest for easier debugging.

PiperOrigin-RevId: 311595699
Change-Id: I1a8cf8b5b314aada4aeeece2603e975bc8a4ff42
This commit is contained in:
Andrew Selle 2020-05-14 13:28:33 -07:00 committed by TensorFlower Gardener
parent 66769844a5
commit 6db3caf99b
2 changed files with 142 additions and 130 deletions

View File

@ -62,6 +62,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
data = np.array([0, 1, 2, 3, 7, 5])
for dtype in _TEST_TYPES:
for indices in 4, [1, 2, 2, 4, 5]:
with self.subTest(dtype=dtype, indices=indices):
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
indices_tf = constant_op.constant(indices)
@ -77,6 +78,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
[9, 10, 11], [12, 13, 14]])
for dtype in _TEST_TYPES:
for axis in range(data.ndim):
with self.subTest(dtype=dtype, axis=axis):
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
indices = constant_op.constant(2)
@ -92,6 +94,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
[9, 10, 11], [12, 13, 14]])
for dtype in _TEST_TYPES:
for axis in range(data.ndim):
with self.subTest(dtype=dtype, axis=axis):
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
# The indices must be in bounds for any axis.
@ -112,6 +115,8 @@ class GatherTest(test.TestCase, parameterized.TestCase):
for axis in range(len(shape)):
params = self._buildParams(np.random.randn(*shape), dtype)
indices = np.random.randint(shape[axis], size=indices_shape)
with self.subTest(indices_shape=indices_shape, dtype=dtype, axis=axis,
indices=indices):
with self.cached_session(use_gpu=True) as sess:
tf_params = constant_op.constant(params)
tf_indices = constant_op.constant(indices)
@ -177,12 +182,14 @@ class GatherTest(test.TestCase, parameterized.TestCase):
@test_util.run_deprecated_v1
def testUInt32AndUInt64(self):
for unsigned_type in (dtypes.uint32, dtypes.uint64):
with self.subTest(unsigned_type=unsigned_type):
params = self._buildParams(
np.array([[1, 2, 3], [7, 8, 9]]), unsigned_type)
with self.cached_session():
self.assertAllEqual([7, 8, 9],
array_ops.gather(params, 1, axis=0).eval())
self.assertAllEqual([1, 7], array_ops.gather(params, 0, axis=1).eval())
self.assertAllEqual([1, 7],
array_ops.gather(params, 0, axis=1).eval())
@test_util.run_deprecated_v1
def testUnknownIndices(self):
@ -237,6 +244,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
indices = 0
for bad_axis in (1, 2, -2):
# Shape inference can validate axis for known params rank.
with self.subTest(bad_axis=bad_axis):
with self.assertRaisesWithPredicateMatch(
ValueError, "Shape must be at least rank . but is rank 1"):
array_ops.gather(params, indices, axis=bad_axis)
@ -252,6 +260,7 @@ class GatherTest(test.TestCase, parameterized.TestCase):
for dtype in _TEST_TYPES:
for itype in np.int32, np.int64:
# Leading axis gather.
with self.subTest(dtype=dtype, itype=itype):
params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
indices = np.array([3, 4], dtype=itype)
gather = array_ops.gather(params, indices, axis=0)

View File

@ -61,6 +61,7 @@ class UniqueTest(test.TestCase):
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
with self.subTest(dtype=dtype):
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
y0, idx0 = gen_array_ops.unique_v2(x, axis=np.array([0], dtype))
self.assertEqual(y0.shape.rank, 2)
@ -144,11 +145,13 @@ class UniqueWithCountsTest(test.TestCase):
for i in range(len(x)):
self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
for value, count in zip(tf_y, tf_count):
with self.subTest(value=value, count=count):
v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
self.assertEqual(count, sum(v))
def testInt32Axis(self):
for dtype in [np.int32, np.int64]:
with self.subTest(dtype=dtype):
x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
y0, idx0, count0 = gen_array_ops.unique_with_counts_v2(
x, axis=np.array([0], dtype))