[tf.data] Marking dataset ops that consume a dataset without an iterator as stateful to make sure they are not prune from the graph in case their output is not used.

This is a conservative approach to guarantee that any side-effects of the op are carried out.

This CL also reverts a previous (incomplete) solution to the same problem.

PiperOrigin-RevId: 233663631
This commit is contained in:
Jiri Simsa 2019-02-12 13:22:55 -08:00 committed by TensorFlower Gardener
parent d7b3e49283
commit 5702b86c96
8 changed files with 148 additions and 8 deletions

View File

@ -818,7 +818,7 @@ void PruneFunctionBody(Graph* g) {
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
// still needed. It would be preferable to prune entire loops and/or
// conditionals if they are not used in the graph.
if (n->IsControlFlow() || n->IsDataset() ||
if (n->IsControlFlow() ||
(n->op_def().is_stateful() && n->type_string() != kArgOp)) {
nodes.insert(n);
}

View File

@ -85,10 +85,6 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
{"CollectiveBcastSend", NC_COLLECTIVE},
{"CollectiveBcastRecv", NC_COLLECTIVE},
{"FakeParam", NC_FAKE_PARAM},
{"IteratorGetNext", NC_DATASET},
{"IteratorGetNextSync", NC_DATASET},
{"DatasetToSingleElement", NC_DATASET},
{"ReduceDataset", NC_DATASET},
});
#undef REF_CLASS

View File

@ -174,8 +174,6 @@ class Node {
bool IsMetadata() const { return class_ == NC_METADATA; }
bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
bool IsDataset() const { return class_ == NC_DATASET; }
template <typename T>
void AddAttr(const string& name, const T& val) {
SetAttrValue(val, AddAttrHelper(name));
@ -256,7 +254,6 @@ class Node {
NC_SCOPED_ALLOCATOR,
NC_COLLECTIVE,
NC_FAKE_PARAM,
NC_DATASET,
NC_OTHER // Not a special kind of node
};

View File

@ -554,13 +554,22 @@ REGISTER_OP("IteratorGetNextSync")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
// implement a mechanism to determine whether `dataset` has a side-effect
// and use it to decide whether to use a stateless or stateful version of this
// op.
REGISTER_OP("DatasetToSingleElement")
.Input("dataset: variant")
.Output("components: output_types")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetIsStateful()
.SetShapeFn(IteratorGetNextShapeFn);
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
// implement a mechanism to determine whether `dataset` has a side-effect
// and use it to decide whether to use a stateless or stateful version of this
// op.
REGISTER_OP("ReduceDataset")
.Input("input_dataset: variant")
.Input("initial_state: Tstate")
@ -572,6 +581,7 @@ REGISTER_OP("ReduceDataset")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("use_inter_op_parallelism: bool = true")
.SetIsStateful()
.SetShapeFn(IteratorGetNextShapeFn);
REGISTER_OP("IteratorToStringHandle")
@ -652,6 +662,8 @@ REGISTER_OP("ModelDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
// TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
// is stateful.
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
.Input("captured_inputs: Tcaptured")

View File

@ -76,10 +76,15 @@ REGISTER_OP("ExperimentalDatasetCardinality")
.Output("cardinality: int64")
.SetShapeFn(shape_inference::ScalarShape);
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
// implement a mechanism to determine whether `dataset` has a side-effect
// and use it to decide whether to use a stateless or stateful version of this
// op.
REGISTER_OP("ExperimentalDatasetToTFRecord")
.Input("input_dataset: variant")
.Input("filename: string")
.Input("compression_type: string")
.SetIsStateful()
.SetShapeFn(shape_inference::NoOutputs);
REGISTER_OP("ExperimentalDenseToSparseBatchDataset")

View File

@ -22,10 +22,12 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import get_single_element
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -71,6 +73,52 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertDatasetProduces(
dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
def testSideEffect(self):
counter_var = variables.Variable(0)
def increment_fn(x):
counter_var.assign_add(1)
return x
def dataset_fn():
return dataset_ops.Dataset.range(1).map(increment_fn)
@function.defun
def fn():
_ = get_single_element.get_single_element(dataset_fn())
return "hello"
self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 1)
def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1)
def increment_fn(x):
counter_var.assign(counter_var + 1)
return x
def multiply_fn(x):
counter_var.assign(counter_var * 2)
return x
def dataset1_fn():
return dataset_ops.Dataset.range(1).map(increment_fn)
def dataset2_fn():
return dataset_ops.Dataset.range(1).map(multiply_fn)
@function.defun
def fn():
_ = get_single_element.get_single_element(dataset1_fn())
_ = get_single_element.get_single_element(dataset2_fn())
return "hello"
self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 4)
if __name__ == "__main__":
test.main()

View File

@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.ops import writers
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import function
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.lib.io import tf_record
@ -94,6 +95,20 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
with self.assertRaises(TypeError):
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
def testSideEffect(self):
def writer_fn():
input_dataset = readers.TFRecordDataset(self._createFile())
return writers.TFRecordWriter(self._outputFilename()).write(input_dataset)
@function.defun
def fn():
_ = writer_fn()
return "hello"
self.assertEqual(self.evaluate(fn()), b"hello")
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
self.assertAllEqual(self._record(i), r)
if __name__ == "__main__":
test.main()

View File

@ -22,12 +22,14 @@ import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -123,6 +125,71 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(((i + 1) * i) // 2, result["dense"])
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
def testDatasetSideEffect(self):
counter_var = variables.Variable(0)
def increment_fn(x):
counter_var.assign_add(1)
return x
def dataset_fn():
return dataset_ops.Dataset.range(10).map(increment_fn)
def reduce_fn(state, value):
return state + value
@function.defun
def fn():
_ = dataset_fn().reduce(np.int64(0), reduce_fn)
return "hello"
self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10)
def testSideEffect(self):
counter_var = variables.Variable(0)
def dataset_fn():
return dataset_ops.Dataset.range(10)
def reduce_fn(state, value):
counter_var.assign_add(1)
return state + value
@function.defun
def fn():
_ = dataset_fn().reduce(np.int64(0), reduce_fn)
return "hello"
self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 10)
def testAutomaticControlDependencies(self):
counter_var = variables.Variable(1)
def dataset_fn():
return dataset_ops.Dataset.range(1)
def reduce1_fn(state, value):
counter_var.assign(counter_var + 1)
return state + value
def reduce2_fn(state, value):
counter_var.assign(counter_var * 2)
return state + value
@function.defun
def fn():
_ = dataset_fn().reduce(np.int64(0), reduce1_fn)
_ = dataset_fn().reduce(np.int64(0), reduce2_fn)
return "hello"
self.evaluate(counter_var.initializer)
self.assertEqual(self.evaluate(fn()), b"hello")
self.assertEqual(self.evaluate(counter_var), 4)
if __name__ == "__main__":
test.main()