diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index bc95c9cf72a..6916bc8a950 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -115,14 +115,23 @@ std::unordered_set GrapplerItem::NodesToPreserve() const { } } - // Tensorflow functions do not prune stateful or dataset-output ops from - // the function body (see PruneFunctionBody in common_runtime/function.cc). + absl::optional fn_library; if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) { - FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library()); - for (const NodeDef& node : graph.node()) { - if (IsStateful(node, &fn_library) || IsDataset(node)) { - result.insert(node.name()); - } + fn_library.emplace(OpRegistry::Global(), graph.library()); + } + for (const NodeDef& node : graph.node()) { + // Tensorflow functions do not prune stateful or dataset-output ops from + // the function body (see PruneFunctionBody in common_runtime/function.cc). + if (!optimization_options_.allow_pruning_stateful_and_dataset_ops && + (IsStateful(node, &*fn_library) || IsDataset(node))) { + result.insert(node.name()); + } + + // Do not remove ops with attribute _grappler_do_not_remove. This is useful + // for debugging. + auto iter = node.attr().find("_grappler_do_not_remove"); + if (iter != node.attr().end() && iter->second.b()) { + result.insert(node.name()); } } diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 8186c81378a..c5d8ec0f872 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -65,6 +66,9 @@ class PyWrapOptimizeGraphTest(test.TestCase): 1.0) # Must be preserved since it's in the collection 'variables'. a2 = constant_op.constant(0, shape=[50, 50], name='keep') ops.add_to_collection('a2', a2) # Explicitly add to collection. + with g._attr_scope( + {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}): + a3 = constant_op.constant(0, name='keep2') b = constant_op.constant(1, shape=[100, 10]) c = constant_op.constant(0, shape=[10, 30]) d = math_ops.matmul(b, c) @@ -80,7 +84,7 @@ class PyWrapOptimizeGraphTest(test.TestCase): # Check that the nodes referenced in various collections have been preserved optimized_graph_nodes = [node.name for node in optimized_graph.node] expected_nodes = [ - d.op.name, a1.op.name, a2.op.name, 'Variable/initial_value', + d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value', 'Variable/Assign' ] self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))