[tf.data] Disable application of options in reduce
until optimizing an input pipeline in tf.function is supported.
PiperOrigin-RevId: 269892020
This commit is contained in:
parent
5ff1b3dc30
commit
b45e94d159
@ -26,6 +26,7 @@ from tensorflow.python.data.experimental.ops import optimization
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -38,15 +39,16 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSum(self):
|
||||
for i in range(10):
|
||||
ds = dataset_ops.Dataset.range(1, i + 1)
|
||||
result = ds.reduce(np.int64(0), lambda x, y: x + y)
|
||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSumTuple(self):
|
||||
|
||||
def reduce_fn(state, value):
|
||||
@ -59,6 +61,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
|
||||
self.assertEqual(((i + 1) * i), self.evaluate(result))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSumAndCount(self):
|
||||
|
||||
def reduce_fn(state, value):
|
||||
@ -74,8 +77,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(((i + 1) * i) // 2, s)
|
||||
self.assertEqual(i, c)
|
||||
|
||||
@test_util.run_v1_only("graph-mode specific test")
|
||||
def testSkipEagerSquareUsingPlaceholder(self):
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testSquareUsingPlaceholder(self):
|
||||
delta = array_ops.placeholder(dtype=dtypes.int64)
|
||||
|
||||
def reduce_fn(state, _):
|
||||
@ -88,6 +91,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
square = sess.run(result, feed_dict={delta: i})
|
||||
self.assertEqual(i * i, square)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSparse(self):
|
||||
|
||||
def reduce_fn(_, value):
|
||||
@ -104,6 +108,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = ds.reduce(make_sparse_fn(0), reduce_fn)
|
||||
self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testNested(self):
|
||||
|
||||
def reduce_fn(state, value):
|
||||
@ -128,6 +133,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||
self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testDatasetSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
@ -150,6 +156,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
@ -169,6 +176,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testAutomaticControlDependencies(self):
|
||||
counter_var = variables.Variable(1)
|
||||
|
||||
@ -193,6 +201,7 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 4)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testStateOnGPU(self):
|
||||
if not test_util.is_gpu_available():
|
||||
self.skipTest("No GPUs available.")
|
||||
@ -208,8 +217,8 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
result = ds.reduce(state, reduce_fn)
|
||||
self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
|
||||
|
||||
@test_util.run_v1_only("graph-mode specific test")
|
||||
def testSkipEagerCancellation(self):
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testCancellation(self):
|
||||
ds = dataset_ops.Dataset.from_tensors(1).repeat()
|
||||
result = ds.reduce(0, lambda x, y: x + y)
|
||||
with self.cached_session() as sess:
|
||||
@ -221,12 +230,15 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
sess.close()
|
||||
thread.join()
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testInvalidFunction(self):
|
||||
ds = dataset_ops.Dataset.range(5)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(ds.reduce(0, lambda _, __: ()))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testOptions(self):
|
||||
self.skipTest("b/141256846")
|
||||
dataset = dataset_ops.Dataset.range(5)
|
||||
dataset = dataset.apply(optimization.assert_next(["MapAndBatch"]))
|
||||
dataset = dataset.map(lambda x: x).batch(5)
|
||||
|
@ -25,12 +25,14 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -292,6 +294,21 @@ class ShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
self.assertEqual(first_epoch != second_epoch, seed is None)
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=2, mode="eager"))
|
||||
def testShuffleV2InFunction(self):
|
||||
self.skipTest("b/141256846")
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
@function.defun
|
||||
def consume():
|
||||
ds = dataset_ops.Dataset.range(10)
|
||||
ds = ds.shuffle(1)
|
||||
for _ in ds:
|
||||
counter_var.assign(counter_var + 1)
|
||||
|
||||
consume()
|
||||
self.assertAllEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -1609,7 +1609,10 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
reduce_func = wrapped_func.function
|
||||
reduce_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
dataset = self._apply_options()
|
||||
# TODO(b/141256846): Apply options once optimizing stateful input pipelines
|
||||
# in tf.functions is supported.
|
||||
# dataset = self._apply_options()
|
||||
dataset = self
|
||||
|
||||
# pylint: disable=protected-access
|
||||
return structure.from_compatible_tensor_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user