Include Ops that are used via PartitionedCalls to MetaGraphDef.MetaInfoDef.stripped_op_list
PiperOrigin-RevId: 322706036 Change-Id: I3f307d07a9d38aeca34f7c550d857c76aed37005
This commit is contained in:
parent
fffa1e6548
commit
0554618036
@ -161,12 +161,17 @@ def ops_used_by_graph_def(graph_def):
|
|||||||
functions_to_process.append(name_to_function[op])
|
functions_to_process.append(name_to_function[op])
|
||||||
used_ops.add(op)
|
used_ops.add(op)
|
||||||
|
|
||||||
for node in graph_def.node:
|
def process_node(node):
|
||||||
mark_op_as_used(node.op)
|
mark_op_as_used(node.op)
|
||||||
|
if node.op in ["PartitionedCall", "StatefulPartitionedCall"]:
|
||||||
|
mark_op_as_used(node.attr["f"].func.name)
|
||||||
|
|
||||||
|
for node in graph_def.node:
|
||||||
|
process_node(node)
|
||||||
while functions_to_process:
|
while functions_to_process:
|
||||||
fun = functions_to_process.pop()
|
fun = functions_to_process.pop()
|
||||||
for node in fun.node_def:
|
for node in fun.node_def:
|
||||||
mark_op_as_used(node.op)
|
process_node(node)
|
||||||
|
|
||||||
return [op for op in used_ops if op not in name_to_function]
|
return [op for op in used_ops if op not in name_to_function]
|
||||||
|
|
||||||
|
@ -161,6 +161,29 @@ class SimpleMetaGraphTest(test.TestCase):
|
|||||||
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
||||||
self.assertEqual(["Const"], [op.name for op in op_list.op])
|
self.assertEqual(["Const"], [op.name for op in op_list.op])
|
||||||
|
|
||||||
|
def testStrippedOpListPartitionedCalls(self):
|
||||||
|
# Function A calls B via StatefulPartitionedCall.
|
||||||
|
graph = graph_pb2.GraphDef()
|
||||||
|
a = graph.library.function.add()
|
||||||
|
b = graph.library.function.add()
|
||||||
|
a.signature.name = "A"
|
||||||
|
b.signature.name = "B"
|
||||||
|
node_in_a = a.node_def.add()
|
||||||
|
node_in_a.op = "StatefulPartitionedCall"
|
||||||
|
node_in_a.attr["f"].func.name = "B"
|
||||||
|
b.node_def.add().op = "Const"
|
||||||
|
b.node_def.add().op = "A"
|
||||||
|
|
||||||
|
# Use A in the graph via PartitionedCall.
|
||||||
|
node = graph.node.add()
|
||||||
|
node.op = "PartitionedCall"
|
||||||
|
node.attr["f"].func.name = "A"
|
||||||
|
|
||||||
|
op_list = meta_graph.stripped_op_list_for_graph(graph)
|
||||||
|
self.assertSameElements(
|
||||||
|
["Const", "PartitionedCall", "StatefulPartitionedCall"],
|
||||||
|
[op.name for op in op_list.op])
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testDefaultAttrStripping(self):
|
def testDefaultAttrStripping(self):
|
||||||
"""Verifies that default attributes are stripped from a graph def."""
|
"""Verifies that default attributes are stripped from a graph def."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user