Enable removal of nodes with multiple inputs

PiperOrigin-RevId: 157068219
This commit is contained in:
Pete Warden 2017-05-24 19:05:37 -07:00 committed by TensorFlower Gardener
parent 4a09e96797
commit 70fc6abad7
2 changed files with 100 additions and 32 deletions

View File

@ -37,6 +37,9 @@ Status RemoveNodes(const GraphDef& input_graph_def,
"remove_nodes expects at least one 'op'"
"argument, e.g. remove_nodes(op=Identity)");
}
int32 max_inputs;
TF_RETURN_IF_ERROR(
context.GetOneInt32Parameter("max_inputs", 1, &max_inputs));
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
std::set<string> required_nodes;
@ -50,6 +53,13 @@ Status RemoveNodes(const GraphDef& input_graph_def,
std::vector<string> ops_to_remove = context.params.at("op");
GraphDef current_graph_def = input_graph_def;
for (const string& op : ops_to_remove) {
for (int num_inputs = 1; num_inputs <= max_inputs; ++num_inputs) {
// Look for a variable number of inputs.
OpTypePattern pattern = {op};
pattern.inputs.resize(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
pattern.inputs[i] = {"*"};
}
// Keep looking for nodes to remove until there are no more changes.
bool any_nodes_removed;
do {
@ -57,13 +67,14 @@ Status RemoveNodes(const GraphDef& input_graph_def,
std::map<string, string> inputs_to_rename;
GraphDef replaced_graph_def;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, {op, {{"*"}}},
current_graph_def, pattern,
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
const NodeMatch& match, const std::set<string>& input_nodes,
const std::set<string>& output_nodes,
std::vector<NodeDef>* new_nodes) {
const NodeDef& replace_node = match.node;
// If this node is needed in the inputs or outputs don't replace it.
// If this node is needed in the inputs or outputs don't replace
// it.
if (required_nodes.count(replace_node.name())) {
LOG(INFO) << "Skipping replacement for " << replace_node.name();
CopyOriginalMatch(match, new_nodes);
@ -79,11 +90,12 @@ Status RemoveNodes(const GraphDef& input_graph_def,
},
{true}, &replaced_graph_def));
// Make sure all references to removed nodes now point to their inputs.
TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
std::unordered_set<string>(),
&current_graph_def));
TF_RETURN_IF_ERROR(
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
std::unordered_set<string>(), &current_graph_def));
} while (any_nodes_removed);
}
}
*output_graph_def = current_graph_def;
return Status::OK();

View File

@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test {
EXPECT_EQ(0, node_lookup.count("identity_node2"));
EXPECT_EQ(0, node_lookup.count("identity_node3"));
}
void TestRemoveMultipleInputs() {
GraphDef graph_def;
NodeDef* const_node1 = graph_def.add_node();
const_node1->set_name("const_node1");
const_node1->set_op("Const");
NodeDef* const_node2 = graph_def.add_node();
const_node2->set_name("const_node2");
const_node2->set_op("Const");
NodeDef* const_node3 = graph_def.add_node();
const_node3->set_name("const_node3");
const_node3->set_op("Const");
NodeDef* const_node4 = graph_def.add_node();
const_node4->set_name("const_node4");
const_node4->set_op("Const");
NodeDef* fake_quant_node = graph_def.add_node();
fake_quant_node->set_name("fake_quant_node");
fake_quant_node->set_op("FakeQuantWithMinMaxVars");
fake_quant_node->add_input("const_node1");
fake_quant_node->add_input("const_node2");
fake_quant_node->add_input("const_node3");
NodeDef* add_node = graph_def.add_node();
add_node->set_name("add_node");
add_node->set_op("Add");
add_node->add_input("fake_quant_node");
add_node->add_input("const_node4");
GraphDef result;
TransformFuncContext context;
context.input_names = {};
context.output_names = {"add_node"};
context.params.insert(std::pair<string, std::vector<string>>(
{"op", {string("FakeQuantWithMinMaxVars")}}));
context.params.insert(
std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
std::map<string, const NodeDef*> node_lookup;
MapNamesToNodes(result, &node_lookup);
ASSERT_EQ(1, node_lookup.count("const_node1"));
ASSERT_EQ(1, node_lookup.count("const_node4"));
ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
ASSERT_EQ(1, node_lookup.count("add_node"));
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
}
};
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
TestRemoveMultipleInputs();
}
} // namespace graph_transforms
} // namespace tensorflow