diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc index 20f40d06402..119b44d6a4a 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes.cc @@ -37,6 +37,9 @@ Status RemoveNodes(const GraphDef& input_graph_def, "remove_nodes expects at least one 'op'" "argument, e.g. remove_nodes(op=Identity)"); } + int32 max_inputs; + TF_RETURN_IF_ERROR( + context.GetOneInt32Parameter("max_inputs", 1, &max_inputs)); // Make sure we don't get rid of any nodes used as graph inputs or outputs. std::set required_nodes; @@ -50,39 +53,48 @@ Status RemoveNodes(const GraphDef& input_graph_def, std::vector ops_to_remove = context.params.at("op"); GraphDef current_graph_def = input_graph_def; for (const string& op : ops_to_remove) { - // Keep looking for nodes to remove until there are no more changes. - bool any_nodes_removed; - do { - any_nodes_removed = false; - std::map inputs_to_rename; - GraphDef replaced_graph_def; - TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( - current_graph_def, {op, {{"*"}}}, - [&inputs_to_rename, &required_nodes, &any_nodes_removed]( - const NodeMatch& match, const std::set& input_nodes, - const std::set& output_nodes, - std::vector* new_nodes) { - const NodeDef& replace_node = match.node; - // If this node is needed in the inputs or outputs don't replace it. - if (required_nodes.count(replace_node.name())) { - LOG(INFO) << "Skipping replacement for " << replace_node.name(); - CopyOriginalMatch(match, new_nodes); + for (int num_inputs = 1; num_inputs <= max_inputs; ++num_inputs) { + // Look for a variable number of inputs. + OpTypePattern pattern = {op}; + pattern.inputs.resize(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + pattern.inputs[i] = {"*"}; + } + // Keep looking for nodes to remove until there are no more changes. + bool any_nodes_removed; + do { + any_nodes_removed = false; + std::map inputs_to_rename; + GraphDef replaced_graph_def; + TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( + current_graph_def, pattern, + [&inputs_to_rename, &required_nodes, &any_nodes_removed]( + const NodeMatch& match, const std::set& input_nodes, + const std::set& output_nodes, + std::vector* new_nodes) { + const NodeDef& replace_node = match.node; + // If this node is needed in the inputs or outputs don't replace + // it. + if (required_nodes.count(replace_node.name())) { + LOG(INFO) << "Skipping replacement for " << replace_node.name(); + CopyOriginalMatch(match, new_nodes); + return Status::OK(); + } + const NodeDef& input_node = match.inputs[0].node; + inputs_to_rename[replace_node.name()] = input_node.name(); + inputs_to_rename["^" + replace_node.name()] = + "^" + input_node.name(); + new_nodes->push_back(input_node); + any_nodes_removed = true; return Status::OK(); - } - const NodeDef& input_node = match.inputs[0].node; - inputs_to_rename[replace_node.name()] = input_node.name(); - inputs_to_rename["^" + replace_node.name()] = - "^" + input_node.name(); - new_nodes->push_back(input_node); - any_nodes_removed = true; - return Status::OK(); - }, - {true}, &replaced_graph_def)); - // Make sure all references to removed nodes now point to their inputs. - TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename, - std::unordered_set(), - ¤t_graph_def)); - } while (any_nodes_removed); + }, + {true}, &replaced_graph_def)); + // Make sure all references to removed nodes now point to their inputs. + TF_RETURN_IF_ERROR( + RenameNodeInputs(replaced_graph_def, inputs_to_rename, + std::unordered_set(), ¤t_graph_def)); + } while (any_nodes_removed); + } } *output_graph_def = current_graph_def; diff --git a/tensorflow/tools/graph_transforms/remove_nodes_test.cc b/tensorflow/tools/graph_transforms/remove_nodes_test.cc index e87ea1daa6f..d8d85a3b471 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes_test.cc @@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test { EXPECT_EQ(0, node_lookup.count("identity_node2")); EXPECT_EQ(0, node_lookup.count("identity_node3")); } + + void TestRemoveMultipleInputs() { + GraphDef graph_def; + + NodeDef* const_node1 = graph_def.add_node(); + const_node1->set_name("const_node1"); + const_node1->set_op("Const"); + + NodeDef* const_node2 = graph_def.add_node(); + const_node2->set_name("const_node2"); + const_node2->set_op("Const"); + + NodeDef* const_node3 = graph_def.add_node(); + const_node3->set_name("const_node3"); + const_node3->set_op("Const"); + + NodeDef* const_node4 = graph_def.add_node(); + const_node4->set_name("const_node4"); + const_node4->set_op("Const"); + + NodeDef* fake_quant_node = graph_def.add_node(); + fake_quant_node->set_name("fake_quant_node"); + fake_quant_node->set_op("FakeQuantWithMinMaxVars"); + fake_quant_node->add_input("const_node1"); + fake_quant_node->add_input("const_node2"); + fake_quant_node->add_input("const_node3"); + + NodeDef* add_node = graph_def.add_node(); + add_node->set_name("add_node"); + add_node->set_op("Add"); + add_node->add_input("fake_quant_node"); + add_node->add_input("const_node4"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"add_node"}; + context.params.insert(std::pair>( + {"op", {string("FakeQuantWithMinMaxVars")}})); + context.params.insert( + std::pair>({"max_inputs", {string("3")}})); + TF_ASSERT_OK(RemoveNodes(graph_def, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + ASSERT_EQ(1, node_lookup.count("const_node1")); + ASSERT_EQ(1, node_lookup.count("const_node4")); + ASSERT_EQ(0, node_lookup.count("fake_quant_node")); + ASSERT_EQ(1, node_lookup.count("add_node")); + EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0)); + EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1)); + } }; TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); } @@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); } TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); } +TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) { + TestRemoveMultipleInputs(); +} + } // namespace graph_transforms } // namespace tensorflow