[tf.data] Disable application of options in reduce until optimizing an input pipeline in tf.function is supported.

PiperOrigin-RevId: 269892020
This commit is contained in:
Jiri Simsa 2019-09-18 14:33:32 -07:00 committed by TensorFlower Gardener
parent 5ff1b3dc30
commit b45e94d159
3 changed files with 38 additions and 6 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
@ -38,15 +39,16 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSum(self): def testSum(self):
for i in range(10): for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1) ds = dataset_ops.Dataset.range(1, i + 1)
result = ds.reduce(np.int64(0), lambda x, y: x + y) result = ds.reduce(np.int64(0), lambda x, y: x + y)
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testSumTuple(self): def testSumTuple(self):
def reduce_fn(state, value): def reduce_fn(state, value):
@ -59,6 +61,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn) result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
self.assertEqual(((i + 1) * i), self.evaluate(result)) self.assertEqual(((i + 1) * i), self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testSumAndCount(self): def testSumAndCount(self):
def reduce_fn(state, value): def reduce_fn(state, value):
@ -74,8 +77,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(((i + 1) * i) // 2, s) self.assertEqual(((i + 1) * i) // 2, s)
self.assertEqual(i, c) self.assertEqual(i, c)
@test_util.run_v1_only("graph-mode specific test") @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testSkipEagerSquareUsingPlaceholder(self): def testSquareUsingPlaceholder(self):
delta = array_ops.placeholder(dtype=dtypes.int64) delta = array_ops.placeholder(dtype=dtypes.int64)
def reduce_fn(state, _): def reduce_fn(state, _):
@ -88,6 +91,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
square = sess.run(result, feed_dict={delta: i}) square = sess.run(result, feed_dict={delta: i})
self.assertEqual(i * i, square) self.assertEqual(i * i, square)
@combinations.generate(test_base.default_test_combinations())
def testSparse(self): def testSparse(self):
def reduce_fn(_, value): def reduce_fn(_, value):
@ -104,6 +108,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(make_sparse_fn(0), reduce_fn) result = ds.reduce(make_sparse_fn(0), reduce_fn)
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result)) self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testNested(self): def testNested(self):
def reduce_fn(state, value): def reduce_fn(state, value):
@ -128,6 +133,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(((i + 1) * i) // 2, result["dense"]) self.assertEqual(((i + 1) * i) // 2, result["dense"])
self.assertValuesEqual(make_sparse_fn(i), result["sparse"]) self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
@combinations.generate(test_base.default_test_combinations())
def testDatasetSideEffect(self): def testDatasetSideEffect(self):
counter_var = variables.Variable(0) counter_var = variables.Variable(0)
@ -150,6 +156,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10) self.assertEqual(self.evaluate(counter_var), 10)
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self): def testSideEffect(self):
counter_var = variables.Variable(0) counter_var = variables.Variable(0)
@ -169,6 +176,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10) self.assertEqual(self.evaluate(counter_var), 10)
@combinations.generate(test_base.default_test_combinations())
def testAutomaticControlDependencies(self): def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1) counter_var = variables.Variable(1)
@ -193,6 +201,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello") self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 4) self.assertEqual(self.evaluate(counter_var), 4)
@combinations.generate(test_base.default_test_combinations())
def testStateOnGPU(self): def testStateOnGPU(self):
if not test_util.is_gpu_available(): if not test_util.is_gpu_available():
self.skipTest("No GPUs available.") self.skipTest("No GPUs available.")
@ -208,8 +217,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(state, reduce_fn) result = ds.reduce(state, reduce_fn)
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
@test_util.run_v1_only("graph-mode specific test") @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testSkipEagerCancellation(self): def testCancellation(self):
ds = dataset_ops.Dataset.from_tensors(1).repeat() ds = dataset_ops.Dataset.from_tensors(1).repeat()
result = ds.reduce(0, lambda x, y: x + y) result = ds.reduce(0, lambda x, y: x + y)
with self.cached_session() as sess: with self.cached_session() as sess:
@ -221,12 +230,15 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.close() sess.close()
thread.join() thread.join()
@combinations.generate(test_base.default_test_combinations())
def testInvalidFunction(self): def testInvalidFunction(self):
ds = dataset_ops.Dataset.range(5) ds = dataset_ops.Dataset.range(5)
with self.assertRaises(errors.InvalidArgumentError): with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(ds.reduce(0, lambda _, __: ())) self.evaluate(ds.reduce(0, lambda _, __: ()))
@combinations.generate(test_base.default_test_combinations())
def testOptions(self): def testOptions(self):
self.skipTest("b/141256846")
dataset = dataset_ops.Dataset.range(5) dataset = dataset_ops.Dataset.range(5)
dataset = dataset.apply(optimization.assert_next(["MapAndBatch"])) dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
dataset = dataset.map(lambda x: x).batch(5) dataset = dataset.map(lambda x: x).batch(5)

View File

@ -25,12 +25,14 @@ import numpy as np
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -292,6 +294,21 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(first_epoch != second_epoch, seed is None) self.assertEqual(first_epoch != second_epoch, seed is None)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testShuffleV2InFunction(self):
self.skipTest("b/141256846")
counter_var = variables.Variable(0)
@function.defun
def consume():
ds = dataset_ops.Dataset.range(10)
ds = ds.shuffle(1)
for _ in ds:
counter_var.assign(counter_var + 1)
consume()
self.assertAllEqual(self.evaluate(counter_var), 10)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -1609,7 +1609,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
reduce_func = wrapped_func.function reduce_func = wrapped_func.function
reduce_func.add_to_graph(ops.get_default_graph()) reduce_func.add_to_graph(ops.get_default_graph())
dataset = self._apply_options() # TODO(b/141256846): Apply options once optimizing stateful input pipelines
# in tf.functions is supported.
# dataset = self._apply_options()
dataset = self
# pylint: disable=protected-access # pylint: disable=protected-access
return structure.from_compatible_tensor_list( return structure.from_compatible_tensor_list(