Have grappler preserve nodes with attribute _grappler_do_not_remove.

This is useful for debugging. You can prevent grappler from optimizing nodes away with:

    with ops.Graph().as_default() as g:
        with g._attr_scope({"_grappler_do_not_remove":
                            tf.attr_value_pb2.AttrValue(b=True)}):
            ... # Create ops here.

PiperOrigin-RevId: 241386287
This commit is contained in:
Reed Wanderman-Milne 2019-04-01 13:21:03 -07:00 committed by TensorFlower Gardener
parent a89be16e7a
commit 865004e8aa
2 changed files with 21 additions and 8 deletions

View File

@ -115,14 +115,23 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
}
}
// Tensorflow functions do not prune stateful or dataset-output ops from
// the function body (see PruneFunctionBody in common_runtime/function.cc).
absl::optional<FunctionLibraryDefinition> fn_library;
if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
for (const NodeDef& node : graph.node()) {
if (IsStateful(node, &fn_library) || IsDataset(node)) {
result.insert(node.name());
}
fn_library.emplace(OpRegistry::Global(), graph.library());
}
for (const NodeDef& node : graph.node()) {
// Tensorflow functions do not prune stateful or dataset-output ops from
// the function body (see PruneFunctionBody in common_runtime/function.cc).
if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
(IsStateful(node, &*fn_library) || IsDataset(node))) {
result.insert(node.name());
}
// Do not remove ops with attribute _grappler_do_not_remove. This is useful
// for debugging.
auto iter = node.attr().find("_grappler_do_not_remove");
if (iter != node.attr().end() && iter->second.b()) {
result.insert(node.name());
}
}

View File

@ -17,6 +17,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -65,6 +66,9 @@ class PyWrapOptimizeGraphTest(test.TestCase):
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
with g._attr_scope(
{'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
a3 = constant_op.constant(0, name='keep2')
b = constant_op.constant(1, shape=[100, 10])
c = constant_op.constant(0, shape=[10, 30])
d = math_ops.matmul(b, c)
@ -80,7 +84,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
# Check that the nodes referenced in various collections have been preserved
optimized_graph_nodes = [node.name for node in optimized_graph.node]
expected_nodes = [
d.op.name, a1.op.name, a2.op.name, 'Variable/initial_value',
d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
'Variable/Assign'
]
self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))