Include Ops that are used via PartitionedCalls to MetaGraphDef.MetaInfoDef.stripped_op_list

PiperOrigin-RevId: 322706036
Change-Id: I3f307d07a9d38aeca34f7c550d857c76aed37005
This commit is contained in:
A. Unique TensorFlower 2020-07-22 19:32:01 -07:00 committed by TensorFlower Gardener
parent fffa1e6548
commit 0554618036
2 changed files with 30 additions and 2 deletions
tensorflow/python/framework

View File

@ -161,12 +161,17 @@ def ops_used_by_graph_def(graph_def):
functions_to_process.append(name_to_function[op])
used_ops.add(op)
for node in graph_def.node:
def process_node(node):
mark_op_as_used(node.op)
if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
mark_op_as_used(node.attr["f"].func.name)
for node in graph_def.node:
process_node(node)
while functions_to_process:
fun = functions_to_process.pop()
for node in fun.node_def:
mark_op_as_used(node.op)
process_node(node)
return [op for op in used_ops if op not in name_to_function]

View File

@ -161,6 +161,29 @@ class SimpleMetaGraphTest(test.TestCase):
op_list = meta_graph.stripped_op_list_for_graph(graph)
self.assertEqual(["Const"], [op.name for op in op_list.op])
def testStrippedOpListPartitionedCalls(self):
# Function A calls B via StatefulPartitionedCall.
graph = graph_pb2.GraphDef()
a = graph.library.function.add()
b = graph.library.function.add()
a.signature.name = "A"
b.signature.name = "B"
node_in_a = a.node_def.add()
node_in_a.op = "StatefulPartitionedCall"
node_in_a.attr["f"].func.name = "B"
b.node_def.add().op = "Const"
b.node_def.add().op = "A"
# Use A in the graph via PartitionedCall.
node = graph.node.add()
node.op = "PartitionedCall"
node.attr["f"].func.name = "A"
op_list = meta_graph.stripped_op_list_for_graph(graph)
self.assertSameElements(
["Const", "PartitionedCall", "StatefulPartitionedCall"],
[op.name for op in op_list.op])
@test_util.run_deprecated_v1
def testDefaultAttrStripping(self):
"""Verifies that default attributes are stripped from a graph def."""