Improve testing of tf.bool in tf.unstack op.
Previously randn()'s results were cast to np.bool resuling in True being the only value often. Also, check both the 0th and the -1th axis. PiperOrigin-RevId: 308109624 Change-Id: I2b1b8bc6e6111874ad6f33c634d4d8a85c444882
This commit is contained in:
parent
2bbab6a500
commit
634c902f4a
@ -39,22 +39,49 @@ def np_split_squeeze(array, axis):
|
||||
|
||||
class UnstackOpTest(test.TestCase):
|
||||
|
||||
def randn(self, shape, dtype):
|
||||
data = np.random.randn(*shape)
|
||||
if dtype == np.bool:
|
||||
return data < 0 # Naive casting yields True with P(1)!
|
||||
else:
|
||||
return data.astype(dtype)
|
||||
|
||||
def unstackReference(self, data, axis):
|
||||
"""Use numpy primitives to implement unstack equivalent."""
|
||||
result = []
|
||||
rank = len(data.shape)
|
||||
axis = axis + rank if axis < 0 else axis
|
||||
for k in range(data.shape[axis]):
|
||||
axis = rank + axis if axis < 0 else axis
|
||||
# Slice in axis dimension of k'th slice.
|
||||
# e.g. if rank=4 k=2, axis=2 then equivalent of data[:,:,2,:]
|
||||
# Give error with loop context
|
||||
slice_spec = tuple(
|
||||
slice(None) if i != axis else k for i in range(rank))
|
||||
result.append(data.__getitem__(slice_spec))
|
||||
return result
|
||||
|
||||
def testSimple(self):
|
||||
np.random.seed(7)
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
for dtype in [
|
||||
np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
|
||||
np.int64
|
||||
]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Convert data to a single tensorflow tensor
|
||||
x = constant_op.constant(data)
|
||||
# Unstack into a list of tensors
|
||||
cs = array_ops.unstack(x, num=shape[0])
|
||||
self.assertEqual(type(cs), list)
|
||||
self.assertEqual(len(cs), shape[0])
|
||||
cs = [self.evaluate(c) for c in cs]
|
||||
self.assertAllEqual(cs, data)
|
||||
rank = len(shape)
|
||||
for axis in range(-rank, rank):
|
||||
for dtype in [
|
||||
np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
|
||||
np.int64
|
||||
]:
|
||||
data = self.randn(shape, dtype)
|
||||
# Convert data to a single tensorflow tensor
|
||||
x = constant_op.constant(data)
|
||||
|
||||
# Unstack into a list of tensors
|
||||
ref = self.unstackReference(data, axis)
|
||||
cs = array_ops.unstack(x, axis=axis)
|
||||
self.assertEqual(type(cs), list)
|
||||
self.assertEqual(len(cs), shape[axis])
|
||||
for k, c in enumerate(cs):
|
||||
with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype):
|
||||
self.assertAllEqual(ref[k], self.evaluate(c))
|
||||
|
||||
def testSimpleGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
@ -63,19 +90,24 @@ class UnstackOpTest(test.TestCase):
|
||||
np.random.seed(7)
|
||||
with test_util.force_gpu():
|
||||
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
|
||||
for dtype in [
|
||||
np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
|
||||
np.int64
|
||||
]:
|
||||
data = np.random.randn(*shape).astype(dtype)
|
||||
# Convert data to a single tensorflow tensor
|
||||
x = constant_op.constant(data)
|
||||
# Unstack into a list of tensors
|
||||
cs = array_ops.unstack(x, num=shape[0])
|
||||
self.assertEqual(type(cs), list)
|
||||
self.assertEqual(len(cs), shape[0])
|
||||
cs = [self.evaluate(c) for c in cs]
|
||||
self.assertAllEqual(cs, data)
|
||||
rank = len(shape)
|
||||
for axis in range(-rank, rank):
|
||||
for dtype in [
|
||||
np.bool, np.float16, np.float32, np.float64, np.uint8, np.int32,
|
||||
np.int64
|
||||
]:
|
||||
data = self.randn(shape, dtype)
|
||||
# Convert data to a single tensorflow tensor
|
||||
x = constant_op.constant(data)
|
||||
# Unstack into a list of tensors
|
||||
ref = self.unstackReference(data, axis)
|
||||
cs = array_ops.unstack(x, axis=axis)
|
||||
self.assertEqual(type(cs), list)
|
||||
self.assertEqual(len(cs), shape[axis])
|
||||
for k, c in enumerate(cs):
|
||||
# Give error with loop context
|
||||
with self.subTest(shape=shape, k=k, axis=axis, dtype=dtype):
|
||||
self.assertAllEqual(ref[k], self.evaluate(c))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientsAxis0(self):
|
||||
|
Loading…
Reference in New Issue
Block a user