From 634c902f4aec948cc47e76c8806137cf3534e340 Mon Sep 17 00:00:00 2001 From: Andrew Selle Date: Thu, 23 Apr 2020 12:46:28 -0700 Subject: [PATCH] Improve testing of tf.bool in tf.unstack op. Previously randn()'s results were cast to np.bool resuling in True being the only value often. Also, check both the 0th and the -1th axis. PiperOrigin-RevId: 308109624 Change-Id: I2b1b8bc6e6111874ad6f33c634d4d8a85c444882 --- .../python/kernel_tests/unstack_op_test.py | 84 +++++++++++++------ 1 file changed, 58 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py index 7a15888686e..13611b278bc 100644 --- a/tensorflow/python/kernel_tests/unstack_op_test.py +++ b/tensorflow/python/kernel_tests/unstack_op_test.py @@ -39,22 +39,49 @@ def np_split_squeeze(array, axis): class UnstackOpTest(test.TestCase): + def randn(self, shape, dtype): + data = np.random.randn(*shape) + if dtype == np.bool: + return data < 0 # Naive casting yields True with P(1)! + else: + return data.astype(dtype) + + def unstackReference(self, data, axis): + """Use numpy primitives to implement unstack equivalent.""" + result = [] + rank = len(data.shape) + axis = axis + rank if axis < 0 else axis + for k in range(data.shape[axis]): + axis = rank + axis if axis < 0 else axis + # Slice in axis dimension of k'th slice. + # e.g. if rank=4 k=2, axis=2 then equivalent of data[:,:,2,:] + # Give error with loop context + slice_spec = tuple( + slice(None) if i != axis else k for i in range(rank)) + result.append(data.__getitem__(slice_spec)) + return result + def testSimple(self): np.random.seed(7) for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): - for dtype in [ - np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, - np.int64 - ]: - data = np.random.randn(*shape).astype(dtype) - # Convert data to a single tensorflow tensor - x = constant_op.constant(data) - # Unstack into a list of tensors - cs = array_ops.unstack(x, num=shape[0]) - self.assertEqual(type(cs), list) - self.assertEqual(len(cs), shape[0]) - cs = [self.evaluate(c) for c in cs] - self.assertAllEqual(cs, data) + rank = len(shape) + for axis in range(-rank, rank): + for dtype in [ + np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, + np.int64 + ]: + data = self.randn(shape, dtype) + # Convert data to a single tensorflow tensor + x = constant_op.constant(data) + + # Unstack into a list of tensors + ref = self.unstackReference(data, axis) + cs = array_ops.unstack(x, axis=axis) + self.assertEqual(type(cs), list) + self.assertEqual(len(cs), shape[axis]) + for k, c in enumerate(cs): + with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype): + self.assertAllEqual(ref[k], self.evaluate(c)) def testSimpleGpu(self): if not test_util.is_gpu_available(): @@ -63,19 +90,24 @@ class UnstackOpTest(test.TestCase): np.random.seed(7) with test_util.force_gpu(): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): - for dtype in [ - np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, - np.int64 - ]: - data = np.random.randn(*shape).astype(dtype) - # Convert data to a single tensorflow tensor - x = constant_op.constant(data) - # Unstack into a list of tensors - cs = array_ops.unstack(x, num=shape[0]) - self.assertEqual(type(cs), list) - self.assertEqual(len(cs), shape[0]) - cs = [self.evaluate(c) for c in cs] - self.assertAllEqual(cs, data) + rank = len(shape) + for axis in range(-rank, rank): + for dtype in [ + np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32, + np.int64 + ]: + data = self.randn(shape, dtype) + # Convert data to a single tensorflow tensor + x = constant_op.constant(data) + # Unstack into a list of tensors + ref = self.unstackReference(data, axis) + cs = array_ops.unstack(x, axis=axis) + self.assertEqual(type(cs), list) + self.assertEqual(len(cs), shape[axis]) + for k, c in enumerate(cs): + # Give error with loop context + with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype): + self.assertAllEqual(ref[k], self.evaluate(c)) @test_util.run_deprecated_v1 def testGradientsAxis0(self):