[tf.data] Migrating remaining experimental API tests to use TF combinations and performing various minor test cleanup.

PiperOrigin-RevId: 288533288
Change-Id: Iba6a980cd08fa0aba9e9703711b1dcdfbc3cb734
This commit is contained in:
Jiri Simsa 2020-01-07 11:10:42 -08:00 committed by TensorFlower Gardener
parent 1d41edaee6
commit febe171d32
6 changed files with 71 additions and 39 deletions

View File

@ -17,17 +17,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class AssertNextTest(test_base.DatasetTestBase):
class AssertNextTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Map"])).map(lambda x: x)
@ -36,6 +38,7 @@ class AssertNextTest(test_base.DatasetTestBase):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testAssertNextInvalid(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Whoops"])).map(lambda x: x)
@ -49,6 +52,7 @@ class AssertNextTest(test_base.DatasetTestBase):
"Asserted Whoops transformation at offset 0 but encountered "
"Map transformation instead."))
@combinations.generate(test_base.default_test_combinations())
def testAssertNextShort(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
testing.assert_next(["Map", "Whoops"])).map(lambda x: x)

View File

@ -17,21 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.experimental.cardinality()`."""
@parameterized.named_parameters(
# pylint: disable=g-long-lambda
def _test_combinations():
# pylint: disable=g-long-lambda
cases = [
("Batch1",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
("Batch2",
@ -151,9 +150,24 @@ class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
5), dataset_ops.Dataset.range(3).filter(lambda _: True))),
cardinality.UNKNOWN),
# pylint: enable=g-long-lambda
)
def testNumElements(self, dataset_fn, expected_result):
]
def reduce_fn(x, y):
name, dataset_fn, expected_result = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn),
expected_result=expected_result)
return functools.reduce(reduce_fn, cases, [])
class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.experimental.cardinality()`."""
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
_test_combinations()))
def testCardinality(self, dataset_fn, expected_result):
with self.cached_session() as sess:
self.assertEqual(
sess.run(cardinality.cardinality(dataset_fn())), expected_result)

View File

@ -22,14 +22,14 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testAutotuneOption(self):
dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset.map(lambda x: x).apply(

View File

@ -17,16 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework import combinations
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class NonSerializableTest(test_base.DatasetTestBase):
class NonSerializableTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testNonSerializable(self):
dataset = dataset_ops.Dataset.from_tensors(0)
dataset = dataset.apply(testing.assert_next(["FiniteSkip"]))
@ -41,6 +43,7 @@ class NonSerializableTest(test_base.DatasetTestBase):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testNonSerializableAsDirectInput(self):
"""Tests that non-serializable dataset can be OptimizeDataset's input."""
dataset = dataset_ops.Dataset.from_tensors(0)

View File

@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import warnings
from absl.testing import parameterized
@ -30,23 +31,17 @@ from tensorflow.python.data.experimental.ops import testing
from tensorflow.python.data.experimental.ops import threadpool
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
def _generate_captured_refvar_test_cases():
"""Generates testcases.
Returns:
A list of tuples of (testcase_name, make_dataset_fn). make_dataset_fn takes
a tf.Variable as input and creates a test dataset that uses that variable.
"""
def _captured_refvar_test_combinations():
def make_map_dataset(var):
return dataset_ops.Dataset.from_tensors(0).map(lambda x: x + var)
@ -88,7 +83,7 @@ def _generate_captured_refvar_test_cases():
scan_ops.scan(
0, lambda old_state, elem: (old_state + 1, elem + old_state + var)))
return [
cases = [
# Core datasets
("Map", make_map_dataset),
("FlatMap", make_flat_map_dataset),
@ -100,10 +95,17 @@ def _generate_captured_refvar_test_cases():
("Scan", make_scan_dataset)
]
def reduce_fn(x, y):
name, dataset_fn = y
return x + combinations.combine(
dataset_fn=combinations.NamedObject(name, dataset_fn))
return functools.reduce(reduce_fn, cases, [])
@test_util.run_all_in_graph_and_eager_modes
class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range(
10).map(lambda _: random_ops.random_uniform([])).batch(10)
@ -113,8 +115,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset)
self.evaluate(get_next())
@test_util.run_v1_only("b/123902160")
def testSkipEagerOptimizationLargeInputFromTensor(self):
# TODO(b/123902160)
@combinations.generate(test_base.graph_only_combinations())
def testOptimizationLargeInputFromTensor(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
dataset = dataset_ops.Dataset.from_tensors(input_t)
options = dataset_ops.Options()
@ -128,8 +131,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
self.evaluate(get_next)
@test_util.run_v1_only("b/123902160")
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
# TODO(b/123902160)
@combinations.generate(test_base.graph_only_combinations())
def testOptimizationLargeInputFromTensorSlices(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
options = dataset_ops.Options()
@ -143,6 +147,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
self.evaluate(get_next)
@combinations.generate(test_base.default_test_combinations())
def testOptimizationNestedDataset(self):
def flat_map_fn(_):
@ -160,6 +165,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[0])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationNestedDatasetWithModifiedRetval(self):
def flat_map_fn(_):
@ -179,6 +185,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset = dataset.with_options(options)
self.assertDatasetProduces(dataset, expected_output=[[0]])
@combinations.generate(test_base.default_test_combinations())
def testOptimizationThreadPoolDataset(self):
dataset = dataset_ops.Dataset.range(10).batch(10)
@ -195,9 +202,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output=[list(range(10))],
requires_initialization=True)
@parameterized.named_parameters(_generate_captured_refvar_test_cases())
@test_util.run_v1_only("RefVariables are not supported in eager mode.")
def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn):
# Reference variables are not supported in eager mode.
@combinations.generate(
combinations.times(test_base.graph_only_combinations(),
_captured_refvar_test_combinations()))
def testOptimizationWithCapturedRefVar(self, dataset_fn):
"""Tests that default optimizations are disabled with ref variables."""
variable = variable_scope.get_variable(
"v", initializer=0, use_resource=False)
@ -241,6 +250,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
except errors.OutOfRangeError:
break
@combinations.generate(test_base.default_test_combinations())
def testOptimizationEnabledByDefault(self):
"""Tests that some optimizations are applied to datasets by default."""
options = dataset_ops.Options()
@ -252,6 +262,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(
set(options._graph_rewrites()), set(expected_optimizations))
@combinations.generate(test_base.default_test_combinations())
def testOptimizationDisableDefault(self):
"""Tests that we can disable all graph optimizations enabled by default.
@ -269,6 +280,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(
set(options._graph_rewrites()), set(expected_optimizations))
@combinations.generate(test_base.default_test_combinations())
def testAutotuningDefaults(self):
options = dataset_ops.Options()
@ -279,6 +291,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
optimization_options._AutotuneAlgorithm.HILL_CLIMB)
self.assertEqual(cpu_budget, 0)
@combinations.generate(test_base.default_test_combinations())
def testAutotuningBufferSizes(self):
options = dataset_ops.Options()
options.experimental_optimization.autotune_buffers = True

View File

@ -44,10 +44,8 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase):
for i in range(100):
self.assertEqual(i, self.evaluate(get_next()))
# TODO(b/123901304)
@combinations.generate(
combinations.combine(tf_api_version=[1], mode=["graph"]))
def testSkipEagerGPU(self):
@combinations.generate(test_base.graph_only_combinations())
def testGPU(self):
ds = dataset_ops.Dataset.range(100)
ds_variant = ds._variant_tensor # pylint: disable=protected-access
wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)