[tf.data] Disable application of options in reduce until optimizing an input pipeline in tf.function is supported.

PiperOrigin-RevId: 269892020
This commit is contained in:
Jiri Simsa 2019-09-18 14:33:32 -07:00 committed by TensorFlower Gardener
parent 5ff1b3dc30
commit b45e94d159
3 changed files with 38 additions and 6 deletions

View File

@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -38,15 +39,16 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSum(self):
for i in range(10):
ds = dataset_ops.Dataset.range(1, i + 1)
result = ds.reduce(np.int64(0), lambda x, y: x + y)
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testSumTuple(self):
def reduce_fn(state, value):
@ -59,6 +61,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
self.assertEqual(((i + 1) * i), self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testSumAndCount(self):
def reduce_fn(state, value):
@ -74,8 +77,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(((i + 1) * i) // 2, s)
self.assertEqual(i, c)
@test_util.run_v1_only("graph-mode specific test")
def testSkipEagerSquareUsingPlaceholder(self):
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testSquareUsingPlaceholder(self):
delta = array_ops.placeholder(dtype=dtypes.int64)
def reduce_fn(state, _):
@ -88,6 +91,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
square = sess.run(result, feed_dict={delta: i})
self.assertEqual(i * i, square)
@combinations.generate(test_base.default_test_combinations())
def testSparse(self):
def reduce_fn(_, value):
@ -104,6 +108,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(make_sparse_fn(0), reduce_fn)
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
@combinations.generate(test_base.default_test_combinations())
def testNested(self):
def reduce_fn(state, value):
@ -128,6 +133,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(((i + 1) * i) // 2, result["dense"])
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
@combinations.generate(test_base.default_test_combinations())
def testDatasetSideEffect(self):
counter_var = variables.Variable(0)
@ -150,6 +156,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10)
@combinations.generate(test_base.default_test_combinations())
def testSideEffect(self):
counter_var = variables.Variable(0)
@ -169,6 +176,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10)
@combinations.generate(test_base.default_test_combinations())
def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1)
@ -193,6 +201,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 4)
@combinations.generate(test_base.default_test_combinations())
def testStateOnGPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPUs available.")
@ -208,8 +217,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
result = ds.reduce(state, reduce_fn)
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
@test_util.run_v1_only("graph-mode specific test")
def testSkipEagerCancellation(self):
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testCancellation(self):
ds = dataset_ops.Dataset.from_tensors(1).repeat()
result = ds.reduce(0, lambda x, y: x + y)
with self.cached_session() as sess:
@ -221,12 +230,15 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
sess.close()
thread.join()
@combinations.generate(test_base.default_test_combinations())
def testInvalidFunction(self):
ds = dataset_ops.Dataset.range(5)
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(ds.reduce(0, lambda _, __: ()))
@combinations.generate(test_base.default_test_combinations())
def testOptions(self):
self.skipTest("b/141256846")
dataset = dataset_ops.Dataset.range(5)
dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
dataset = dataset.map(lambda x: x).batch(5)

View File

@ -25,12 +25,14 @@ import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -292,6 +294,21 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(first_epoch != second_epoch, seed is None)
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
def testShuffleV2InFunction(self):
self.skipTest("b/141256846")
counter_var = variables.Variable(0)
@function.defun
def consume():
ds = dataset_ops.Dataset.range(10)
ds = ds.shuffle(1)
for _ in ds:
counter_var.assign(counter_var + 1)
consume()
self.assertAllEqual(self.evaluate(counter_var), 10)
if __name__ == "__main__":
test.main()

View File

@ -1609,7 +1609,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
reduce_func = wrapped_func.function
reduce_func.add_to_graph(ops.get_default_graph())
dataset = self._apply_options()
# TODO(b/141256846): Apply options once optimizing stateful input pipelines
# in tf.functions is supported.
# dataset = self._apply_options()
dataset = self
# pylint: disable=protected-access
return structure.from_compatible_tensor_list(