diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_next_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_next_test.py index c246122c92b..37d0f1586a4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/assert_next_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/assert_next_test.py @@ -17,17 +17,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class AssertNextTest(test_base.DatasetTestBase): +class AssertNextTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testAssertNext(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( testing.assert_next(["Map"])).map(lambda x: x) @@ -36,6 +38,7 @@ class AssertNextTest(test_base.DatasetTestBase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) + @combinations.generate(test_base.default_test_combinations()) def testAssertNextInvalid(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( testing.assert_next(["Whoops"])).map(lambda x: x) @@ -49,6 +52,7 @@ class AssertNextTest(test_base.DatasetTestBase): "Asserted Whoops transformation at offset 0 but encountered " "Map transformation instead.")) + @combinations.generate(test_base.default_test_combinations()) def testAssertNextShort(self): dataset = dataset_ops.Dataset.from_tensors(0).apply( testing.assert_next(["Map", "Whoops"])).map(lambda x: x) diff --git a/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py index 993b511d5e3..904027a0de4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py @@ -17,21 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools + from absl.testing import parameterized from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase): - """Tests for `tf.data.experimental.cardinality()`.""" - - @parameterized.named_parameters( - # pylint: disable=g-long-lambda +def _test_combinations(): + # pylint: disable=g-long-lambda + cases = [ ("Batch1", lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2), ("Batch2", @@ -151,9 +150,24 @@ class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase): ("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range( 5), dataset_ops.Dataset.range(3).filter(lambda _: True))), cardinality.UNKNOWN), - # pylint: enable=g-long-lambda - ) - def testNumElements(self, dataset_fn, expected_result): + ] + + def reduce_fn(x, y): + name, dataset_fn, expected_result = y + return x + combinations.combine( + dataset_fn=combinations.NamedObject(name, dataset_fn), + expected_result=expected_result) + + return functools.reduce(reduce_fn, cases, []) + + +class CardinalityTest(test_base.DatasetTestBase, parameterized.TestCase): + """Tests for `tf.data.experimental.cardinality()`.""" + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + _test_combinations())) + def testCardinality(self, dataset_fn, expected_result): with self.cached_session() as sess: self.assertEqual( sess.run(cardinality.cardinality(dataset_fn())), expected_result) diff --git a/tensorflow/python/data/experimental/kernel_tests/model_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/model_dataset_test.py index 511990d6d27..634cf1aa2e8 100644 --- a/tensorflow/python/data/experimental/kernel_tests/model_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/model_dataset_test.py @@ -22,14 +22,14 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class ModelDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testAutotuneOption(self): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.map(lambda x: x).apply( diff --git a/tensorflow/python/data/experimental/kernel_tests/non_serializable_test.py b/tensorflow/python/data/experimental/kernel_tests/non_serializable_test.py index 7b07853384b..24b60ad9b35 100644 --- a/tensorflow/python/data/experimental/kernel_tests/non_serializable_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/non_serializable_test.py @@ -17,16 +17,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import test_util +from tensorflow.python.framework import combinations from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes -class NonSerializableTest(test_base.DatasetTestBase): +class NonSerializableTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testNonSerializable(self): dataset = dataset_ops.Dataset.from_tensors(0) dataset = dataset.apply(testing.assert_next(["FiniteSkip"])) @@ -41,6 +43,7 @@ class NonSerializableTest(test_base.DatasetTestBase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) + @combinations.generate(test_base.default_test_combinations()) def testNonSerializableAsDirectInput(self): """Tests that non-serializable dataset can be OptimizeDataset's input.""" dataset = dataset_ops.Dataset.from_tensors(0) diff --git a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py index 90c269a6825..59e41528ea4 100644 --- a/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/optimize_dataset_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import functools import warnings from absl.testing import parameterized @@ -30,23 +31,17 @@ from tensorflow.python.data.experimental.ops import testing from tensorflow.python.data.experimental.ops import threadpool from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -def _generate_captured_refvar_test_cases(): - """Generates testcases. - - Returns: - A list of tuples of (testcase_name, make_dataset_fn). make_dataset_fn takes - a tf.Variable as input and creates a test dataset that uses that variable. - """ +def _captured_refvar_test_combinations(): def make_map_dataset(var): return dataset_ops.Dataset.from_tensors(0).map(lambda x: x + var) @@ -88,7 +83,7 @@ def _generate_captured_refvar_test_cases(): scan_ops.scan( 0, lambda old_state, elem: (old_state + 1, elem + old_state + var))) - return [ + cases = [ # Core datasets ("Map", make_map_dataset), ("FlatMap", make_flat_map_dataset), @@ -100,10 +95,17 @@ def _generate_captured_refvar_test_cases(): ("Scan", make_scan_dataset) ] + def reduce_fn(x, y): + name, dataset_fn = y + return x + combinations.combine( + dataset_fn=combinations.NamedObject(name, dataset_fn)) + + return functools.reduce(reduce_fn, cases, []) + -@test_util.run_all_in_graph_and_eager_modes class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testOptimizationStatefulFunction(self): dataset = dataset_ops.Dataset.range( 10).map(lambda _: random_ops.random_uniform([])).batch(10) @@ -113,8 +115,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): get_next = self.getNext(dataset) self.evaluate(get_next()) - @test_util.run_v1_only("b/123902160") - def testSkipEagerOptimizationLargeInputFromTensor(self): + # TODO(b/123902160) + @combinations.generate(test_base.graph_only_combinations()) + def testOptimizationLargeInputFromTensor(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None)) dataset = dataset_ops.Dataset.from_tensors(input_t) options = dataset_ops.Options() @@ -128,8 +131,9 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)}) self.evaluate(get_next) - @test_util.run_v1_only("b/123902160") - def testSkipEagerOptimizationLargeInputFromTensorSlices(self): + # TODO(b/123902160) + @combinations.generate(test_base.graph_only_combinations()) + def testOptimizationLargeInputFromTensorSlices(self): input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None)) dataset = dataset_ops.Dataset.from_tensor_slices(input_t) options = dataset_ops.Options() @@ -143,6 +147,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)}) self.evaluate(get_next) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDataset(self): def flat_map_fn(_): @@ -160,6 +165,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[0]) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationNestedDatasetWithModifiedRetval(self): def flat_map_fn(_): @@ -179,6 +185,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): dataset = dataset.with_options(options) self.assertDatasetProduces(dataset, expected_output=[[0]]) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationThreadPoolDataset(self): dataset = dataset_ops.Dataset.range(10).batch(10) @@ -195,9 +202,11 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output=[list(range(10))], requires_initialization=True) - @parameterized.named_parameters(_generate_captured_refvar_test_cases()) - @test_util.run_v1_only("RefVariables are not supported in eager mode.") - def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn): + # Reference variables are not supported in eager mode. + @combinations.generate( + combinations.times(test_base.graph_only_combinations(), + _captured_refvar_test_combinations())) + def testOptimizationWithCapturedRefVar(self, dataset_fn): """Tests that default optimizations are disabled with ref variables.""" variable = variable_scope.get_variable( "v", initializer=0, use_resource=False) @@ -241,6 +250,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): except errors.OutOfRangeError: break + @combinations.generate(test_base.default_test_combinations()) def testOptimizationEnabledByDefault(self): """Tests that some optimizations are applied to datasets by default.""" options = dataset_ops.Options() @@ -252,6 +262,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual( set(options._graph_rewrites()), set(expected_optimizations)) + @combinations.generate(test_base.default_test_combinations()) def testOptimizationDisableDefault(self): """Tests that we can disable all graph optimizations enabled by default. @@ -269,6 +280,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual( set(options._graph_rewrites()), set(expected_optimizations)) + @combinations.generate(test_base.default_test_combinations()) def testAutotuningDefaults(self): options = dataset_ops.Options() @@ -279,6 +291,7 @@ class OptimizeDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): optimization_options._AutotuneAlgorithm.HILL_CLIMB) self.assertEqual(cpu_budget, 0) + @combinations.generate(test_base.default_test_combinations()) def testAutotuningBufferSizes(self): options = dataset_ops.Options() options.experimental_optimization.autotune_buffers = True diff --git a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py index 3fd252ab3ac..44c351ef2d2 100644 --- a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py @@ -44,10 +44,8 @@ class WrapDatasetVariantTest(test_base.DatasetTestBase, parameterized.TestCase): for i in range(100): self.assertEqual(i, self.evaluate(get_next())) - # TODO(b/123901304) - @combinations.generate( - combinations.combine(tf_api_version=[1], mode=["graph"])) - def testSkipEagerGPU(self): + @combinations.generate(test_base.graph_only_combinations()) + def testGPU(self): ds = dataset_ops.Dataset.range(100) ds_variant = ds._variant_tensor # pylint: disable=protected-access wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)