diff --git a/tensorflow/python/data/kernel_tests/reduce_test.py b/tensorflow/python/data/kernel_tests/reduce_test.py index 06f565d6d77..d1cb5a9c594 100644 --- a/tensorflow/python/data/kernel_tests/reduce_test.py +++ b/tensorflow/python/data/kernel_tests/reduce_test.py @@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import optimization from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.eager import function +from tensorflow.python.framework import combinations from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -38,15 +39,16 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test -@test_util.run_all_in_graph_and_eager_modes class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): + @combinations.generate(test_base.default_test_combinations()) def testSum(self): for i in range(10): ds = dataset_ops.Dataset.range(1, i + 1) result = ds.reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) + @combinations.generate(test_base.default_test_combinations()) def testSumTuple(self): def reduce_fn(state, value): @@ -59,6 +61,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn) self.assertEqual(((i + 1) * i), self.evaluate(result)) + @combinations.generate(test_base.default_test_combinations()) def testSumAndCount(self): def reduce_fn(state, value): @@ -74,8 +77,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(((i + 1) * i) // 2, s) self.assertEqual(i, c) - @test_util.run_v1_only("graph-mode specific test") - def testSkipEagerSquareUsingPlaceholder(self): + @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) + def testSquareUsingPlaceholder(self): delta = array_ops.placeholder(dtype=dtypes.int64) def reduce_fn(state, _): @@ -88,6 +91,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): square = sess.run(result, feed_dict={delta: i}) self.assertEqual(i * i, square) + @combinations.generate(test_base.default_test_combinations()) def testSparse(self): def reduce_fn(_, value): @@ -104,6 +108,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): result = ds.reduce(make_sparse_fn(0), reduce_fn) self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result)) + @combinations.generate(test_base.default_test_combinations()) def testNested(self): def reduce_fn(state, value): @@ -128,6 +133,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(((i + 1) * i) // 2, result["dense"]) self.assertValuesEqual(make_sparse_fn(i), result["sparse"]) + @combinations.generate(test_base.default_test_combinations()) def testDatasetSideEffect(self): counter_var = variables.Variable(0) @@ -150,6 +156,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 10) + @combinations.generate(test_base.default_test_combinations()) def testSideEffect(self): counter_var = variables.Variable(0) @@ -169,6 +176,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 10) + @combinations.generate(test_base.default_test_combinations()) def testAutomaticControlDependencies(self): counter_var = variables.Variable(1) @@ -193,6 +201,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(counter_var), 4) + @combinations.generate(test_base.default_test_combinations()) def testStateOnGPU(self): if not test_util.is_gpu_available(): self.skipTest("No GPUs available.") @@ -208,8 +217,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): result = ds.reduce(state, reduce_fn) self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) - @test_util.run_v1_only("graph-mode specific test") - def testSkipEagerCancellation(self): + @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) + def testCancellation(self): ds = dataset_ops.Dataset.from_tensors(1).repeat() result = ds.reduce(0, lambda x, y: x + y) with self.cached_session() as sess: @@ -221,12 +230,15 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): sess.close() thread.join() + @combinations.generate(test_base.default_test_combinations()) def testInvalidFunction(self): ds = dataset_ops.Dataset.range(5) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(ds.reduce(0, lambda _, __: ())) + @combinations.generate(test_base.default_test_combinations()) def testOptions(self): + self.skipTest("b/141256846") dataset = dataset_ops.Dataset.range(5) dataset = dataset.apply(optimization.assert_next(["MapAndBatch"])) dataset = dataset.map(lambda x: x).batch(5) diff --git a/tensorflow/python/data/kernel_tests/shuffle_test.py b/tensorflow/python/data/kernel_tests/shuffle_test.py index d1846e4eaeb..9f1c3ed8161 100644 --- a/tensorflow/python/data/kernel_tests/shuffle_test.py +++ b/tensorflow/python/data/kernel_tests/shuffle_test.py @@ -25,12 +25,14 @@ import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import function from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -292,6 +294,21 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(first_epoch != second_epoch, seed is None) + @combinations.generate(combinations.combine(tf_api_version=2, mode="eager")) + def testShuffleV2InFunction(self): + self.skipTest("b/141256846") + counter_var = variables.Variable(0) + + @function.defun + def consume(): + ds = dataset_ops.Dataset.range(10) + ds = ds.shuffle(1) + for _ in ds: + counter_var.assign(counter_var + 1) + + consume() + self.assertAllEqual(self.evaluate(counter_var), 10) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 110f81d2dce..264c3a7ca3b 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -1609,7 +1609,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): reduce_func = wrapped_func.function reduce_func.add_to_graph(ops.get_default_graph()) - dataset = self._apply_options() + # TODO(b/141256846): Apply options once optimizing stateful input pipelines + # in tf.functions is supported. + # dataset = self._apply_options() + dataset = self # pylint: disable=protected-access return structure.from_compatible_tensor_list(