diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 72aee7b2806..6e0736479b4 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -818,7 +818,7 @@ void PruneFunctionBody(Graph* g) { // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is // still needed. It would be preferable to prune entire loops and/or // conditionals if they are not used in the graph. - if (n->IsControlFlow() || n->IsDataset() || + if (n->IsControlFlow() || (n->op_def().is_stateful() && n->type_string() != kArgOp)) { nodes.insert(n); } diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 00d3549312a..3ea222c13c5 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -85,10 +85,6 @@ const std::unordered_map& Node::kNodeClassTable = {"CollectiveBcastSend", NC_COLLECTIVE}, {"CollectiveBcastRecv", NC_COLLECTIVE}, {"FakeParam", NC_FAKE_PARAM}, - {"IteratorGetNext", NC_DATASET}, - {"IteratorGetNextSync", NC_DATASET}, - {"DatasetToSingleElement", NC_DATASET}, - {"ReduceDataset", NC_DATASET}, }); #undef REF_CLASS diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index f65e4b921ef..289a3d2a230 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -174,8 +174,6 @@ class Node { bool IsMetadata() const { return class_ == NC_METADATA; } bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; } - bool IsDataset() const { return class_ == NC_DATASET; } - template void AddAttr(const string& name, const T& val) { SetAttrValue(val, AddAttrHelper(name)); @@ -256,7 +254,6 @@ class Node { NC_SCOPED_ALLOCATOR, NC_COLLECTIVE, NC_FAKE_PARAM, - NC_DATASET, NC_OTHER // Not a special kind of node }; diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 872a6da915b..cc7ce542579 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -554,13 +554,22 @@ REGISTER_OP("IteratorGetNextSync") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(IteratorGetNextShapeFn); +// TODO(b/124308596): Instead of conservatively marking this op as stateful, +// implement a mechanism to determine whether `dataset` has a side-effect +// and use it to decide whether to use a stateless or stateful version of this +// op. REGISTER_OP("DatasetToSingleElement") .Input("dataset: variant") .Output("components: output_types") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() .SetShapeFn(IteratorGetNextShapeFn); +// TODO(b/124308596): Instead of conservatively marking this op as stateful, +// implement a mechanism to determine whether `dataset` has a side-effect +// and use it to decide whether to use a stateless or stateful version of this +// op. REGISTER_OP("ReduceDataset") .Input("input_dataset: variant") .Input("initial_state: Tstate") @@ -572,6 +581,7 @@ REGISTER_OP("ReduceDataset") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("use_inter_op_parallelism: bool = true") + .SetIsStateful() .SetShapeFn(IteratorGetNextShapeFn); REGISTER_OP("IteratorToStringHandle") @@ -652,6 +662,8 @@ REGISTER_OP("ModelDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +// TODO(b/124308749): Add a stateful version of MapDefun and use it when `f` +// is stateful. REGISTER_OP("MapDefun") .Input("arguments: Targuments") .Input("captured_inputs: Tcaptured") diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 95230af5798..7b9d95a38d1 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -76,10 +76,15 @@ REGISTER_OP("ExperimentalDatasetCardinality") .Output("cardinality: int64") .SetShapeFn(shape_inference::ScalarShape); +// TODO(b/124308596): Instead of conservatively marking this op as stateful, +// implement a mechanism to determine whether `dataset` has a side-effect +// and use it to decide whether to use a stateless or stateful version of this +// op. REGISTER_OP("ExperimentalDatasetToTFRecord") .Input("input_dataset: variant") .Input("filename: string") .Input("compression_type: string") + .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs); REGISTER_OP("ExperimentalDenseToSparseBatchDataset") diff --git a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py index 3e2cf779a3f..f65740c5651 100644 --- a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py @@ -22,10 +22,12 @@ from absl.testing import parameterized from tensorflow.python.data.experimental.ops import get_single_element from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import function from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -71,6 +73,52 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertDatasetProduces( dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) + def testSideEffect(self): + counter_var = variables.Variable(0) + + def increment_fn(x): + counter_var.assign_add(1) + return x + + def dataset_fn(): + return dataset_ops.Dataset.range(1).map(increment_fn) + + @function.defun + def fn(): + _ = get_single_element.get_single_element(dataset_fn()) + return "hello" + + self.evaluate(counter_var.initializer) + self.assertEqual(self.evaluate(fn()), b"hello") + self.assertEqual(self.evaluate(counter_var), 1) + + def testAutomaticControlDependencies(self): + counter_var = variables.Variable(1) + + def increment_fn(x): + counter_var.assign(counter_var + 1) + return x + + def multiply_fn(x): + counter_var.assign(counter_var * 2) + return x + + def dataset1_fn(): + return dataset_ops.Dataset.range(1).map(increment_fn) + + def dataset2_fn(): + return dataset_ops.Dataset.range(1).map(multiply_fn) + + @function.defun + def fn(): + _ = get_single_element.get_single_element(dataset1_fn()) + _ = get_single_element.get_single_element(dataset2_fn()) + return "hello" + + self.evaluate(counter_var.initializer) + self.assertEqual(self.evaluate(fn()), b"hello") + self.assertEqual(self.evaluate(counter_var), 4) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py index 14a4241ec2e..783b2e6e22a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py @@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.ops import writers from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import readers +from tensorflow.python.eager import function from tensorflow.python.framework import test_util from tensorflow.python.lib.io import python_io from tensorflow.python.lib.io import tf_record @@ -94,6 +95,20 @@ class TFRecordWriterTest(test_base.DatasetTestBase): with self.assertRaises(TypeError): writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset) + def testSideEffect(self): + def writer_fn(): + input_dataset = readers.TFRecordDataset(self._createFile()) + return writers.TFRecordWriter(self._outputFilename()).write(input_dataset) + + @function.defun + def fn(): + _ = writer_fn() + return "hello" + + self.assertEqual(self.evaluate(fn()), b"hello") + for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())): + self.assertAllEqual(self._record(i), r) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/reduce_test.py b/tensorflow/python/data/kernel_tests/reduce_test.py index 93acc1565fd..846d9a6cef9 100644 --- a/tensorflow/python/data/kernel_tests/reduce_test.py +++ b/tensorflow/python/data/kernel_tests/reduce_test.py @@ -22,12 +22,14 @@ import numpy as np from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -123,6 +125,71 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertEqual(((i + 1) * i) // 2, result["dense"]) self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"]) + def testDatasetSideEffect(self): + counter_var = variables.Variable(0) + + def increment_fn(x): + counter_var.assign_add(1) + return x + + def dataset_fn(): + return dataset_ops.Dataset.range(10).map(increment_fn) + + def reduce_fn(state, value): + return state + value + + @function.defun + def fn(): + _ = dataset_fn().reduce(np.int64(0), reduce_fn) + return "hello" + + self.evaluate(counter_var.initializer) + self.assertEqual(self.evaluate(fn()), b"hello") + self.assertEqual(self.evaluate(counter_var), 10) + + def testSideEffect(self): + counter_var = variables.Variable(0) + + def dataset_fn(): + return dataset_ops.Dataset.range(10) + + def reduce_fn(state, value): + counter_var.assign_add(1) + return state + value + + @function.defun + def fn(): + _ = dataset_fn().reduce(np.int64(0), reduce_fn) + return "hello" + + self.evaluate(counter_var.initializer) + self.assertEqual(self.evaluate(fn()), b"hello") + self.assertEqual(self.evaluate(counter_var), 10) + + def testAutomaticControlDependencies(self): + counter_var = variables.Variable(1) + + def dataset_fn(): + return dataset_ops.Dataset.range(1) + + def reduce1_fn(state, value): + counter_var.assign(counter_var + 1) + return state + value + + def reduce2_fn(state, value): + counter_var.assign(counter_var * 2) + return state + value + + @function.defun + def fn(): + _ = dataset_fn().reduce(np.int64(0), reduce1_fn) + _ = dataset_fn().reduce(np.int64(0), reduce2_fn) + return "hello" + + self.evaluate(counter_var.initializer) + self.assertEqual(self.evaluate(fn()), b"hello") + self.assertEqual(self.evaluate(counter_var), 4) + if __name__ == "__main__": test.main()