[tf.data] Marking dataset ops that consume a dataset without an iterator as stateful to make sure they are not prune from the graph in case their output is not used.
This is a conservative approach to guarantee that any side-effects of the op are carried out. This CL also reverts a previous (incomplete) solution to the same problem. PiperOrigin-RevId: 233663631
This commit is contained in:
parent
d7b3e49283
commit
5702b86c96
tensorflow
core
python/data
experimental/kernel_tests
kernel_tests
@ -818,7 +818,7 @@ void PruneFunctionBody(Graph* g) {
|
||||
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
|
||||
// still needed. It would be preferable to prune entire loops and/or
|
||||
// conditionals if they are not used in the graph.
|
||||
if (n->IsControlFlow() || n->IsDataset() ||
|
||||
if (n->IsControlFlow() ||
|
||||
(n->op_def().is_stateful() && n->type_string() != kArgOp)) {
|
||||
nodes.insert(n);
|
||||
}
|
||||
|
@ -85,10 +85,6 @@ const std::unordered_map<string, Node::NodeClass>& Node::kNodeClassTable =
|
||||
{"CollectiveBcastSend", NC_COLLECTIVE},
|
||||
{"CollectiveBcastRecv", NC_COLLECTIVE},
|
||||
{"FakeParam", NC_FAKE_PARAM},
|
||||
{"IteratorGetNext", NC_DATASET},
|
||||
{"IteratorGetNextSync", NC_DATASET},
|
||||
{"DatasetToSingleElement", NC_DATASET},
|
||||
{"ReduceDataset", NC_DATASET},
|
||||
});
|
||||
|
||||
#undef REF_CLASS
|
||||
|
@ -174,8 +174,6 @@ class Node {
|
||||
bool IsMetadata() const { return class_ == NC_METADATA; }
|
||||
bool IsFakeParam() const { return class_ == NC_FAKE_PARAM; }
|
||||
|
||||
bool IsDataset() const { return class_ == NC_DATASET; }
|
||||
|
||||
template <typename T>
|
||||
void AddAttr(const string& name, const T& val) {
|
||||
SetAttrValue(val, AddAttrHelper(name));
|
||||
@ -256,7 +254,6 @@ class Node {
|
||||
NC_SCOPED_ALLOCATOR,
|
||||
NC_COLLECTIVE,
|
||||
NC_FAKE_PARAM,
|
||||
NC_DATASET,
|
||||
NC_OTHER // Not a special kind of node
|
||||
};
|
||||
|
||||
|
@ -554,13 +554,22 @@ REGISTER_OP("IteratorGetNextSync")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(IteratorGetNextShapeFn);
|
||||
|
||||
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
|
||||
// implement a mechanism to determine whether `dataset` has a side-effect
|
||||
// and use it to decide whether to use a stateless or stateful version of this
|
||||
// op.
|
||||
REGISTER_OP("DatasetToSingleElement")
|
||||
.Input("dataset: variant")
|
||||
.Output("components: output_types")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(IteratorGetNextShapeFn);
|
||||
|
||||
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
|
||||
// implement a mechanism to determine whether `dataset` has a side-effect
|
||||
// and use it to decide whether to use a stateless or stateful version of this
|
||||
// op.
|
||||
REGISTER_OP("ReduceDataset")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("initial_state: Tstate")
|
||||
@ -572,6 +581,7 @@ REGISTER_OP("ReduceDataset")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.Attr("use_inter_op_parallelism: bool = true")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(IteratorGetNextShapeFn);
|
||||
|
||||
REGISTER_OP("IteratorToStringHandle")
|
||||
@ -652,6 +662,8 @@ REGISTER_OP("ModelDataset")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// TODO(b/124308749): Add a stateful version of MapDefun and use it when `f`
|
||||
// is stateful.
|
||||
REGISTER_OP("MapDefun")
|
||||
.Input("arguments: Targuments")
|
||||
.Input("captured_inputs: Tcaptured")
|
||||
|
@ -76,10 +76,15 @@ REGISTER_OP("ExperimentalDatasetCardinality")
|
||||
.Output("cardinality: int64")
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// TODO(b/124308596): Instead of conservatively marking this op as stateful,
|
||||
// implement a mechanism to determine whether `dataset` has a side-effect
|
||||
// and use it to decide whether to use a stateless or stateful version of this
|
||||
// op.
|
||||
REGISTER_OP("ExperimentalDatasetToTFRecord")
|
||||
.Input("input_dataset: variant")
|
||||
.Input("filename: string")
|
||||
.Input("compression_type: string")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("ExperimentalDenseToSparseBatchDataset")
|
||||
|
@ -22,10 +22,12 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import get_single_element
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -71,6 +73,52 @@ class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertDatasetProduces(
|
||||
dataset, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]])
|
||||
|
||||
def testSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
def increment_fn(x):
|
||||
counter_var.assign_add(1)
|
||||
return x
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(1).map(increment_fn)
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = get_single_element.get_single_element(dataset_fn())
|
||||
return "hello"
|
||||
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 1)
|
||||
|
||||
def testAutomaticControlDependencies(self):
|
||||
counter_var = variables.Variable(1)
|
||||
|
||||
def increment_fn(x):
|
||||
counter_var.assign(counter_var + 1)
|
||||
return x
|
||||
|
||||
def multiply_fn(x):
|
||||
counter_var.assign(counter_var * 2)
|
||||
return x
|
||||
|
||||
def dataset1_fn():
|
||||
return dataset_ops.Dataset.range(1).map(increment_fn)
|
||||
|
||||
def dataset2_fn():
|
||||
return dataset_ops.Dataset.range(1).map(multiply_fn)
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = get_single_element.get_single_element(dataset1_fn())
|
||||
_ = get_single_element.get_single_element(dataset2_fn())
|
||||
return "hello"
|
||||
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.data.experimental.ops import writers
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.lib.io import python_io
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
@ -94,6 +95,20 @@ class TFRecordWriterTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(TypeError):
|
||||
writers.TFRecordWriter(self._outputFilename(), "").write(input_dataset)
|
||||
|
||||
def testSideEffect(self):
|
||||
def writer_fn():
|
||||
input_dataset = readers.TFRecordDataset(self._createFile())
|
||||
return writers.TFRecordWriter(self._outputFilename()).write(input_dataset)
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = writer_fn()
|
||||
return "hello"
|
||||
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
|
||||
self.assertAllEqual(self._record(i), r)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,12 +22,14 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -123,6 +125,71 @@ class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(((i + 1) * i) // 2, result["dense"])
|
||||
self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
|
||||
|
||||
def testDatasetSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
def increment_fn(x):
|
||||
counter_var.assign_add(1)
|
||||
return x
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(10).map(increment_fn)
|
||||
|
||||
def reduce_fn(state, value):
|
||||
return state + value
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = dataset_fn().reduce(np.int64(0), reduce_fn)
|
||||
return "hello"
|
||||
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
def testSideEffect(self):
|
||||
counter_var = variables.Variable(0)
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(10)
|
||||
|
||||
def reduce_fn(state, value):
|
||||
counter_var.assign_add(1)
|
||||
return state + value
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = dataset_fn().reduce(np.int64(0), reduce_fn)
|
||||
return "hello"
|
||||
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 10)
|
||||
|
||||
def testAutomaticControlDependencies(self):
|
||||
counter_var = variables.Variable(1)
|
||||
|
||||
def dataset_fn():
|
||||
return dataset_ops.Dataset.range(1)
|
||||
|
||||
def reduce1_fn(state, value):
|
||||
counter_var.assign(counter_var + 1)
|
||||
return state + value
|
||||
|
||||
def reduce2_fn(state, value):
|
||||
counter_var.assign(counter_var * 2)
|
||||
return state + value
|
||||
|
||||
@function.defun
|
||||
def fn():
|
||||
_ = dataset_fn().reduce(np.int64(0), reduce1_fn)
|
||||
_ = dataset_fn().reduce(np.int64(0), reduce2_fn)
|
||||
return "hello"
|
||||
|
||||
self.evaluate(counter_var.initializer)
|
||||
self.assertEqual(self.evaluate(fn()), b"hello")
|
||||
self.assertEqual(self.evaluate(counter_var), 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user