From 055461803630ad1a4461f2ffd6e488e9a3effbc1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 22 Jul 2020 19:32:01 -0700 Subject: [PATCH] Include Ops that are used via PartitionedCalls to MetaGraphDef.MetaInfoDef.stripped_op_list PiperOrigin-RevId: 322706036 Change-Id: I3f307d07a9d38aeca34f7c550d857c76aed37005 --- tensorflow/python/framework/meta_graph.py | 9 ++++++-- .../python/framework/meta_graph_test.py | 23 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 327b476c576..dbc2a894d65 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -161,12 +161,17 @@ def ops_used_by_graph_def(graph_def): functions_to_process.append(name_to_function[op]) used_ops.add(op) - for node in graph_def.node: + def process_node(node): mark_op_as_used(node.op) + if node.op in ["PartitionedCall", "StatefulPartitionedCall"]: + mark_op_as_used(node.attr["f"].func.name) + + for node in graph_def.node: + process_node(node) while functions_to_process: fun = functions_to_process.pop() for node in fun.node_def: - mark_op_as_used(node.op) + process_node(node) return [op for op in used_ops if op not in name_to_function] diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index ae44fbce0f0..36acd81fe26 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -161,6 +161,29 @@ class SimpleMetaGraphTest(test.TestCase): op_list = meta_graph.stripped_op_list_for_graph(graph) self.assertEqual(["Const"], [op.name for op in op_list.op]) + def testStrippedOpListPartitionedCalls(self): + # Function A calls B via StatefulPartitionedCall. + graph = graph_pb2.GraphDef() + a = graph.library.function.add() + b = graph.library.function.add() + a.signature.name = "A" + b.signature.name = "B" + node_in_a = a.node_def.add() + node_in_a.op = "StatefulPartitionedCall" + node_in_a.attr["f"].func.name = "B" + b.node_def.add().op = "Const" + b.node_def.add().op = "A" + + # Use A in the graph via PartitionedCall. + node = graph.node.add() + node.op = "PartitionedCall" + node.attr["f"].func.name = "A" + + op_list = meta_graph.stripped_op_list_for_graph(graph) + self.assertSameElements( + ["Const", "PartitionedCall", "StatefulPartitionedCall"], + [op.name for op in op_list.op]) + @test_util.run_deprecated_v1 def testDefaultAttrStripping(self): """Verifies that default attributes are stripped from a graph def."""