[tf.data] Disable application of options in reduce
until optimizing an input pipeline in tf.function is supported.
PiperOrigin-RevId: 269892020
This commit is contained in:
parent
5ff1b3dc30
commit
b45e94d159
@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import optimization
|
|||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
@ -38,15 +39,16 @@ from tensorflow.python.ops import variables
|
|||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
|
||||||
class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSum(self):
|
def testSum(self):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||||
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
||||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSumTuple(self):
|
def testSumTuple(self):
|
||||||
|
|
||||||
def reduce_fn(state, value):
|
def reduce_fn(state, value):
|
||||||
@ -59,6 +61,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
|
result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
|
||||||
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSumAndCount(self):
|
def testSumAndCount(self):
|
||||||
|
|
||||||
def reduce_fn(state, value):
|
def reduce_fn(state, value):
|
||||||
@ -74,8 +77,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(((i + 1) * i) // 2, s)
|
self.assertEqual(((i + 1) * i) // 2, s)
|
||||||
self.assertEqual(i, c)
|
self.assertEqual(i, c)
|
||||||
|
|
||||||
@test_util.run_v1_only("graph-mode specific test")
|
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||||
def testSkipEagerSquareUsingPlaceholder(self):
|
def testSquareUsingPlaceholder(self):
|
||||||
delta = array_ops.placeholder(dtype=dtypes.int64)
|
delta = array_ops.placeholder(dtype=dtypes.int64)
|
||||||
|
|
||||||
def reduce_fn(state, _):
|
def reduce_fn(state, _):
|
||||||
@ -88,6 +91,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
square = sess.run(result, feed_dict={delta: i})
|
square = sess.run(result, feed_dict={delta: i})
|
||||||
self.assertEqual(i * i, square)
|
self.assertEqual(i * i, square)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSparse(self):
|
def testSparse(self):
|
||||||
|
|
||||||
def reduce_fn(_, value):
|
def reduce_fn(_, value):
|
||||||
@ -104,6 +108,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||||
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
|
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testNested(self):
|
def testNested(self):
|
||||||
|
|
||||||
def reduce_fn(state, value):
|
def reduce_fn(state, value):
|
||||||
@ -128,6 +133,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||||
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
|
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testDatasetSideEffect(self):
|
def testDatasetSideEffect(self):
|
||||||
counter_var = variables.Variable(0)
|
counter_var = variables.Variable(0)
|
||||||
|
|
||||||
@ -150,6 +156,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||||
self.assertEqual(self.evaluate(counter_var), 10)
|
self.assertEqual(self.evaluate(counter_var), 10)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testSideEffect(self):
|
def testSideEffect(self):
|
||||||
counter_var = variables.Variable(0)
|
counter_var = variables.Variable(0)
|
||||||
|
|
||||||
@ -169,6 +176,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||||
self.assertEqual(self.evaluate(counter_var), 10)
|
self.assertEqual(self.evaluate(counter_var), 10)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testAutomaticControlDependencies(self):
|
def testAutomaticControlDependencies(self):
|
||||||
counter_var = variables.Variable(1)
|
counter_var = variables.Variable(1)
|
||||||
|
|
||||||
@ -193,6 +201,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||||
self.assertEqual(self.evaluate(counter_var), 4)
|
self.assertEqual(self.evaluate(counter_var), 4)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testStateOnGPU(self):
|
def testStateOnGPU(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
self.skipTest("No GPUs available.")
|
self.skipTest("No GPUs available.")
|
||||||
@ -208,8 +217,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
result = ds.reduce(state, reduce_fn)
|
result = ds.reduce(state, reduce_fn)
|
||||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||||
|
|
||||||
@test_util.run_v1_only("graph-mode specific test")
|
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||||
def testSkipEagerCancellation(self):
|
def testCancellation(self):
|
||||||
ds = dataset_ops.Dataset.from_tensors(1).repeat()
|
ds = dataset_ops.Dataset.from_tensors(1).repeat()
|
||||||
result = ds.reduce(0, lambda x, y: x + y)
|
result = ds.reduce(0, lambda x, y: x + y)
|
||||||
with self.cached_session() as sess:
|
with self.cached_session() as sess:
|
||||||
@ -221,12 +230,15 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
sess.close()
|
sess.close()
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testInvalidFunction(self):
|
def testInvalidFunction(self):
|
||||||
ds = dataset_ops.Dataset.range(5)
|
ds = dataset_ops.Dataset.range(5)
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
self.evaluate(ds.reduce(0, lambda _, __: ()))
|
self.evaluate(ds.reduce(0, lambda _, __: ()))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testOptions(self):
|
def testOptions(self):
|
||||||
|
self.skipTest("b/141256846")
|
||||||
dataset = dataset_ops.Dataset.range(5)
|
dataset = dataset_ops.Dataset.range(5)
|
||||||
dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
|
dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
|
||||||
dataset = dataset.map(lambda x: x).batch(5)
|
dataset = dataset.map(lambda x: x).batch(5)
|
||||||
|
@ -25,12 +25,14 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
from tensorflow.python.framework import random_seed
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -292,6 +294,21 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(first_epoch != second_epoch, seed is None)
|
self.assertEqual(first_epoch != second_epoch, seed is None)
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||||
|
def testShuffleV2InFunction(self):
|
||||||
|
self.skipTest("b/141256846")
|
||||||
|
counter_var = variables.Variable(0)
|
||||||
|
|
||||||
|
@function.defun
|
||||||
|
def consume():
|
||||||
|
ds = dataset_ops.Dataset.range(10)
|
||||||
|
ds = ds.shuffle(1)
|
||||||
|
for _ in ds:
|
||||||
|
counter_var.assign(counter_var + 1)
|
||||||
|
|
||||||
|
consume()
|
||||||
|
self.assertAllEqual(self.evaluate(counter_var), 10)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -1609,7 +1609,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
|||||||
reduce_func = wrapped_func.function
|
reduce_func = wrapped_func.function
|
||||||
reduce_func.add_to_graph(ops.get_default_graph())
|
reduce_func.add_to_graph(ops.get_default_graph())
|
||||||
|
|
||||||
dataset = self._apply_options()
|
# TODO(b/141256846): Apply options once optimizing stateful input pipelines
|
||||||
|
# in tf.functions is supported.
|
||||||
|
# dataset = self._apply_options()
|
||||||
|
dataset = self
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return structure.from_compatible_tensor_list(
|
return structure.from_compatible_tensor_list(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user