Have grappler preserve nodes with attribute _grappler_do_not_remove.
This is useful for debugging. You can prevent grappler from optimizing nodes away with: with ops.Graph().as_default() as g: with g._attr_scope({"_grappler_do_not_remove": tf.attr_value_pb2.AttrValue(b=True)}): ... # Create ops here. PiperOrigin-RevId: 241386287
This commit is contained in:
parent
a89be16e7a
commit
865004e8aa
@ -115,14 +115,23 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
||||
}
|
||||
}
|
||||
|
||||
// Tensorflow functions do not prune stateful or dataset-output ops from
|
||||
// the function body (see PruneFunctionBody in common_runtime/function.cc).
|
||||
absl::optional<FunctionLibraryDefinition> fn_library;
|
||||
if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
|
||||
FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (IsStateful(node, &fn_library) || IsDataset(node)) {
|
||||
result.insert(node.name());
|
||||
}
|
||||
fn_library.emplace(OpRegistry::Global(), graph.library());
|
||||
}
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
// Tensorflow functions do not prune stateful or dataset-output ops from
|
||||
// the function body (see PruneFunctionBody in common_runtime/function.cc).
|
||||
if (!optimization_options_.allow_pruning_stateful_and_dataset_ops &&
|
||||
(IsStateful(node, &*fn_library) || IsDataset(node))) {
|
||||
result.insert(node.name());
|
||||
}
|
||||
|
||||
// Do not remove ops with attribute _grappler_do_not_remove. This is useful
|
||||
// for debugging.
|
||||
auto iter = node.attr().find("_grappler_do_not_remove");
|
||||
if (iter != node.attr().end() && iter->second.b()) {
|
||||
result.insert(node.name());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -65,6 +66,9 @@ class PyWrapOptimizeGraphTest(test.TestCase):
|
||||
1.0) # Must be preserved since it's in the collection 'variables'.
|
||||
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
|
||||
ops.add_to_collection('a2', a2) # Explicitly add to collection.
|
||||
with g._attr_scope(
|
||||
{'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
|
||||
a3 = constant_op.constant(0, name='keep2')
|
||||
b = constant_op.constant(1, shape=[100, 10])
|
||||
c = constant_op.constant(0, shape=[10, 30])
|
||||
d = math_ops.matmul(b, c)
|
||||
@ -80,7 +84,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
|
||||
# Check that the nodes referenced in various collections have been preserved
|
||||
optimized_graph_nodes = [node.name for node in optimized_graph.node]
|
||||
expected_nodes = [
|
||||
d.op.name, a1.op.name, a2.op.name, 'Variable/initial_value',
|
||||
d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
|
||||
'Variable/Assign'
|
||||
]
|
||||
self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
|
||||
|
Loading…
x
Reference in New Issue
Block a user