[tf.data] Consolidating testing utility.

PiperOrigin-RevId: 256419683
This commit is contained in:
Jiri Simsa 2019-07-03 12:39:20 -07:00 committed by TensorFlower Gardener
parent 82b4538721
commit 697c5cbfb4
7 changed files with 24 additions and 50 deletions

View File

@ -222,7 +222,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
self.assertValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -257,7 +257,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
values=expected_values,
dense_shape=[5, i * 3 + 5 - 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
self.assertValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -285,7 +285,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
self.assertValuesEqual(actual, expected)
# Slide: 2nd batch.
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
@ -295,7 +295,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
dense_shape=[3, 4, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
self.assertValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)

View File

@ -286,7 +286,7 @@ class MapDefunTest(test_base.DatasetTestBase):
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
self.assertValuesEqual(expected, actual)
def testMapDefunWithVariantTensorAsCaptured(self):
@ -307,7 +307,7 @@ class MapDefunTest(test_base.DatasetTestBase):
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
self.assertValuesEqual(expected, actual)
def testMapDefunWithStrTensor(self):
@ -328,7 +328,7 @@ class MapDefunTest(test_base.DatasetTestBase):
values=[1, 2, 1, 2],
dense_shape=[2, 3, 4])
actual = self.evaluate(deserialized)
self.assertSparseValuesEqual(expected, actual)
self.assertValuesEqual(expected, actual)
if __name__ == "__main__":

View File

@ -57,11 +57,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
for k, v in sorted(dict_tensors.items()):
expected_v = expected_tensors[k]
if sparse_tensor.is_sparse(v):
self.assertSparseValuesEqual(expected_v, v)
else:
# One output for standard Tensor.
self.assertAllEqual(expected_v, v)
self.assertValuesEqual(expected_v, v)
def _test(self,
input_tensor,

View File

@ -27,7 +27,6 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
@ -165,10 +164,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
results = self.evaluate(get_next())
for component, result_component in zip(
(list(zip(*components[:3]))[i] + expected[i]), results):
if sparse_tensor.is_sparse(component):
self.assertSparseValuesEqual(component, result_component)
else:
self.assertAllEqual(component, result_component)
self.assertValuesEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
@ -255,12 +251,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
results = self.evaluate(get_next())
for component, result_component in zip(
(list(zip(*components[:3]))[i] + expected[i]), results):
if sparse_tensor.is_sparse(component):
self.assertSparseValuesEqual(component, result_component)
elif ragged_tensor.is_ragged(component):
self.assertAllEqual(component, result_component)
else:
self.assertAllEqual(component, result_component)
self.assertValuesEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())

View File

@ -101,7 +101,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(10):
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
result = ds.reduce(make_sparse_fn(0), reduce_fn)
self.assertSparseValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
def testNested(self):
@ -125,7 +125,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(map_fn(0), reduce_fn)
result = self.evaluate(result)
self.assertEqual(((i + 1) * i) // 2, result["dense"])
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
def testDatasetSideEffect(self):
counter_var = variables.Variable(0)

View File

@ -47,11 +47,14 @@ class DatasetTestBase(test.TestCase):
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
self.evaluate(op)
def assertSparseValuesEqual(self, a, b):
"""Asserts that two SparseTensors/SparseTensorValues are equal."""
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)
def assertValuesEqual(self, expected, actual):
"""Asserts that two values are equal."""
if sparse_tensor.is_sparse(expected):
self.assertAllEqual(expected.indices, actual.indices)
self.assertAllEqual(expected.values, actual.values)
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
else:
self.assertAllEqual(expected, actual)
def getNext(self, dataset, requires_initialization=False, shared_name=None):
"""Returns a callable that returns the next element of the dataset.
@ -107,16 +110,7 @@ class DatasetTestBase(test.TestCase):
nest.assert_same_structure(result_values[i], expected_values[i])
for result_value, expected_value in zip(
nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
if sparse_tensor.is_sparse(result_value):
self.assertSparseValuesEqual(result_value, expected_value)
elif ragged_tensor.is_ragged(result_value):
self.assertAllEqual(result_value, expected_value)
else:
self.assertAllEqual(
result_value,
expected_value,
msg=("Result value: {}. Expected value: {}"
.format(result_value, expected_value)))
self.assertValuesEqual(expected_value, result_value)
def assertDatasetProduces(self,
dataset,
@ -208,10 +202,8 @@ class DatasetTestBase(test.TestCase):
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
if sparse_tensor.is_sparse(op1[i]):
self.assertSparseValuesEqual(op1[i], op2[i])
elif ragged_tensor.is_ragged(op1[i]):
self.assertAllEqual(op1[i], op2[i])
if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
self.assertValuesEqual(op1[i], op2[i])
elif flattened_types[i] == dtypes.string:
self.assertAllEqual(op1[i], op2[i])
else:

View File

@ -703,12 +703,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
for expected, actual in zip(
nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
if sparse_tensor.is_sparse(expected):
self.assertSparseValuesEqual(expected, actual)
elif ragged_tensor.is_ragged(expected):
self.assertAllEqual(expected, actual)
else:
self.assertAllEqual(expected, actual)
self.assertValuesEqual(expected, actual)
# pylint: enable=g-long-lambda