Include Ops that are used via PartitionedCalls to MetaGraphDef.MetaInfoDef.stripped_op_list
PiperOrigin-RevId: 322706036 Change-Id: I3f307d07a9d38aeca34f7c550d857c76aed37005
This commit is contained in:
parent
fffa1e6548
commit
0554618036
tensorflow/python/framework
@ -161,12 +161,17 @@ def ops_used_by_graph_def(graph_def):
|
||||
functions_to_process.append(name_to_function[op])
|
||||
used_ops.add(op)
|
||||
|
||||
for node in graph_def.node:
|
||||
def process_node(node):
|
||||
mark_op_as_used(node.op)
|
||||
if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
|
||||
mark_op_as_used(node.attr["f"].func.name)
|
||||
|
||||
for node in graph_def.node:
|
||||
process_node(node)
|
||||
while functions_to_process:
|
||||
fun = functions_to_process.pop()
|
||||
for node in fun.node_def:
|
||||
mark_op_as_used(node.op)
|
||||
process_node(node)
|
||||
|
||||
return [op for op in used_ops if op not in name_to_function]
|
||||
|
||||
|
@ -161,6 +161,29 @@ class SimpleMetaGraphTest(test.TestCase):
|
||||
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
||||
self.assertEqual(["Const"], [op.name for op in op_list.op])
|
||||
|
||||
def testStrippedOpListPartitionedCalls(self):
|
||||
# Function A calls B via StatefulPartitionedCall.
|
||||
graph = graph_pb2.GraphDef()
|
||||
a = graph.library.function.add()
|
||||
b = graph.library.function.add()
|
||||
a.signature.name = "A"
|
||||
b.signature.name = "B"
|
||||
node_in_a = a.node_def.add()
|
||||
node_in_a.op = "StatefulPartitionedCall"
|
||||
node_in_a.attr["f"].func.name = "B"
|
||||
b.node_def.add().op = "Const"
|
||||
b.node_def.add().op = "A"
|
||||
|
||||
# Use A in the graph via PartitionedCall.
|
||||
node = graph.node.add()
|
||||
node.op = "PartitionedCall"
|
||||
node.attr["f"].func.name = "A"
|
||||
|
||||
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
||||
self.assertSameElements(
|
||||
["Const", "PartitionedCall", "StatefulPartitionedCall"],
|
||||
[op.name for op in op_list.op])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDefaultAttrStripping(self):
|
||||
"""Verifies that default attributes are stripped from a graph def."""
|
||||
|
Loading…
Reference in New Issue
Block a user