[tf.data] Consolidating testing utility.
PiperOrigin-RevId: 256419683
This commit is contained in:
parent
82b4538721
commit
697c5cbfb4
@ -222,7 +222,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
|
||||
dense_shape=[5, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
self.assertValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -257,7 +257,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
values=expected_values,
|
||||
dense_shape=[5, i * 3 + 5 - 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
self.assertValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
@ -285,7 +285,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
values=[0, 1, 2, 3, 2, 3, 4, 5, 4, 5, 6, 7],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
self.assertValuesEqual(actual, expected)
|
||||
# Slide: 2nd batch.
|
||||
actual = sess.run(get_next)
|
||||
expected = sparse_tensor.SparseTensorValue(
|
||||
@ -295,7 +295,7 @@ class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
values=[2, 3, 4, 5, 4, 5, 6, 7, 6, 7, 8, 9],
|
||||
dense_shape=[3, 4, 1])
|
||||
self.assertTrue(sparse_tensor.is_sparse(actual))
|
||||
self.assertSparseValuesEqual(actual, expected)
|
||||
self.assertValuesEqual(actual, expected)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
|
@ -286,7 +286,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
def testMapDefunWithVariantTensorAsCaptured(self):
|
||||
|
||||
@ -307,7 +307,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
def testMapDefunWithStrTensor(self):
|
||||
|
||||
@ -328,7 +328,7 @@ class MapDefunTest(test_base.DatasetTestBase):
|
||||
values=[1, 2, 1, 2],
|
||||
dense_shape=[2, 3, 4])
|
||||
actual = self.evaluate(deserialized)
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -57,11 +57,7 @@ class ParseExampleDatasetTest(test_base.DatasetTestBase):
|
||||
|
||||
for k, v in sorted(dict_tensors.items()):
|
||||
expected_v = expected_tensors[k]
|
||||
if sparse_tensor.is_sparse(v):
|
||||
self.assertSparseValuesEqual(expected_v, v)
|
||||
else:
|
||||
# One output for standard Tensor.
|
||||
self.assertAllEqual(expected_v, v)
|
||||
self.assertValuesEqual(expected_v, v)
|
||||
|
||||
def _test(self,
|
||||
input_tensor,
|
||||
|
@ -27,7 +27,6 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -165,10 +164,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
|
||||
results = self.evaluate(get_next())
|
||||
for component, result_component in zip(
|
||||
(list(zip(*components[:3]))[i] + expected[i]), results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
else:
|
||||
self.assertAllEqual(component, result_component)
|
||||
self.assertValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
@ -255,12 +251,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
|
||||
results = self.evaluate(get_next())
|
||||
for component, result_component in zip(
|
||||
(list(zip(*components[:3]))[i] + expected[i]), results):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
elif ragged_tensor.is_ragged(component):
|
||||
self.assertAllEqual(component, result_component)
|
||||
else:
|
||||
self.assertAllEqual(component, result_component)
|
||||
self.assertValuesEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
@ -101,7 +101,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
for i in range(10):
|
||||
ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
|
||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
|
||||
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
|
||||
|
||||
def testNested(self):
|
||||
|
||||
@ -125,7 +125,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = ds.reduce(map_fn(0), reduce_fn)
|
||||
result = self.evaluate(result)
|
||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
|
||||
def testDatasetSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
@ -47,11 +47,14 @@ class DatasetTestBase(test.TestCase):
|
||||
with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
|
||||
self.evaluate(op)
|
||||
|
||||
def assertSparseValuesEqual(self, a, b):
|
||||
"""Asserts that two SparseTensors/SparseTensorValues are equal."""
|
||||
self.assertAllEqual(a.indices, b.indices)
|
||||
self.assertAllEqual(a.values, b.values)
|
||||
self.assertAllEqual(a.dense_shape, b.dense_shape)
|
||||
def assertValuesEqual(self, expected, actual):
|
||||
"""Asserts that two values are equal."""
|
||||
if sparse_tensor.is_sparse(expected):
|
||||
self.assertAllEqual(expected.indices, actual.indices)
|
||||
self.assertAllEqual(expected.values, actual.values)
|
||||
self.assertAllEqual(expected.dense_shape, actual.dense_shape)
|
||||
else:
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
def getNext(self, dataset, requires_initialization=False, shared_name=None):
|
||||
"""Returns a callable that returns the next element of the dataset.
|
||||
@ -107,16 +110,7 @@ class DatasetTestBase(test.TestCase):
|
||||
nest.assert_same_structure(result_values[i], expected_values[i])
|
||||
for result_value, expected_value in zip(
|
||||
nest.flatten(result_values[i]), nest.flatten(expected_values[i])):
|
||||
if sparse_tensor.is_sparse(result_value):
|
||||
self.assertSparseValuesEqual(result_value, expected_value)
|
||||
elif ragged_tensor.is_ragged(result_value):
|
||||
self.assertAllEqual(result_value, expected_value)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
result_value,
|
||||
expected_value,
|
||||
msg=("Result value: {}. Expected value: {}"
|
||||
.format(result_value, expected_value)))
|
||||
self.assertValuesEqual(expected_value, result_value)
|
||||
|
||||
def assertDatasetProduces(self,
|
||||
dataset,
|
||||
@ -208,10 +202,8 @@ class DatasetTestBase(test.TestCase):
|
||||
op2 = nest.flatten(op2)
|
||||
assert len(op1) == len(op2)
|
||||
for i in range(len(op1)):
|
||||
if sparse_tensor.is_sparse(op1[i]):
|
||||
self.assertSparseValuesEqual(op1[i], op2[i])
|
||||
elif ragged_tensor.is_ragged(op1[i]):
|
||||
self.assertAllEqual(op1[i], op2[i])
|
||||
if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]):
|
||||
self.assertValuesEqual(op1[i], op2[i])
|
||||
elif flattened_types[i] == dtypes.string:
|
||||
self.assertAllEqual(op1[i], op2[i])
|
||||
else:
|
||||
|
@ -703,12 +703,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
for expected, actual in zip(
|
||||
nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
|
||||
if sparse_tensor.is_sparse(expected):
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
elif ragged_tensor.is_ragged(expected):
|
||||
self.assertAllEqual(expected, actual)
|
||||
else:
|
||||
self.assertAllEqual(expected, actual)
|
||||
self.assertValuesEqual(expected, actual)
|
||||
|
||||
# pylint: enable=g-long-lambda
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user