Enable removal of nodes with multiple inputs
PiperOrigin-RevId: 157068219
This commit is contained in:
parent
4a09e96797
commit
70fc6abad7
@ -37,6 +37,9 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
|||||||
"remove_nodes expects at least one 'op'"
|
"remove_nodes expects at least one 'op'"
|
||||||
"argument, e.g. remove_nodes(op=Identity)");
|
"argument, e.g. remove_nodes(op=Identity)");
|
||||||
}
|
}
|
||||||
|
int32 max_inputs;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
context.GetOneInt32Parameter("max_inputs", 1, &max_inputs));
|
||||||
|
|
||||||
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
|
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
|
||||||
std::set<string> required_nodes;
|
std::set<string> required_nodes;
|
||||||
@ -50,6 +53,13 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
|||||||
std::vector<string> ops_to_remove = context.params.at("op");
|
std::vector<string> ops_to_remove = context.params.at("op");
|
||||||
GraphDef current_graph_def = input_graph_def;
|
GraphDef current_graph_def = input_graph_def;
|
||||||
for (const string& op : ops_to_remove) {
|
for (const string& op : ops_to_remove) {
|
||||||
|
for (int num_inputs = 1; num_inputs <= max_inputs; ++num_inputs) {
|
||||||
|
// Look for a variable number of inputs.
|
||||||
|
OpTypePattern pattern = {op};
|
||||||
|
pattern.inputs.resize(num_inputs);
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
pattern.inputs[i] = {"*"};
|
||||||
|
}
|
||||||
// Keep looking for nodes to remove until there are no more changes.
|
// Keep looking for nodes to remove until there are no more changes.
|
||||||
bool any_nodes_removed;
|
bool any_nodes_removed;
|
||||||
do {
|
do {
|
||||||
@ -57,13 +67,14 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
|||||||
std::map<string, string> inputs_to_rename;
|
std::map<string, string> inputs_to_rename;
|
||||||
GraphDef replaced_graph_def;
|
GraphDef replaced_graph_def;
|
||||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||||
current_graph_def, {op, {{"*"}}},
|
current_graph_def, pattern,
|
||||||
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
|
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
|
||||||
const NodeMatch& match, const std::set<string>& input_nodes,
|
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||||
const std::set<string>& output_nodes,
|
const std::set<string>& output_nodes,
|
||||||
std::vector<NodeDef>* new_nodes) {
|
std::vector<NodeDef>* new_nodes) {
|
||||||
const NodeDef& replace_node = match.node;
|
const NodeDef& replace_node = match.node;
|
||||||
// If this node is needed in the inputs or outputs don't replace it.
|
// If this node is needed in the inputs or outputs don't replace
|
||||||
|
// it.
|
||||||
if (required_nodes.count(replace_node.name())) {
|
if (required_nodes.count(replace_node.name())) {
|
||||||
LOG(INFO) << "Skipping replacement for " << replace_node.name();
|
LOG(INFO) << "Skipping replacement for " << replace_node.name();
|
||||||
CopyOriginalMatch(match, new_nodes);
|
CopyOriginalMatch(match, new_nodes);
|
||||||
@ -79,11 +90,12 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
|||||||
},
|
},
|
||||||
{true}, &replaced_graph_def));
|
{true}, &replaced_graph_def));
|
||||||
// Make sure all references to removed nodes now point to their inputs.
|
// Make sure all references to removed nodes now point to their inputs.
|
||||||
TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
TF_RETURN_IF_ERROR(
|
||||||
std::unordered_set<string>(),
|
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||||
¤t_graph_def));
|
std::unordered_set<string>(), ¤t_graph_def));
|
||||||
} while (any_nodes_removed);
|
} while (any_nodes_removed);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
*output_graph_def = current_graph_def;
|
*output_graph_def = current_graph_def;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test {
|
|||||||
EXPECT_EQ(0, node_lookup.count("identity_node2"));
|
EXPECT_EQ(0, node_lookup.count("identity_node2"));
|
||||||
EXPECT_EQ(0, node_lookup.count("identity_node3"));
|
EXPECT_EQ(0, node_lookup.count("identity_node3"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TestRemoveMultipleInputs() {
|
||||||
|
GraphDef graph_def;
|
||||||
|
|
||||||
|
NodeDef* const_node1 = graph_def.add_node();
|
||||||
|
const_node1->set_name("const_node1");
|
||||||
|
const_node1->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node2 = graph_def.add_node();
|
||||||
|
const_node2->set_name("const_node2");
|
||||||
|
const_node2->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node3 = graph_def.add_node();
|
||||||
|
const_node3->set_name("const_node3");
|
||||||
|
const_node3->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* const_node4 = graph_def.add_node();
|
||||||
|
const_node4->set_name("const_node4");
|
||||||
|
const_node4->set_op("Const");
|
||||||
|
|
||||||
|
NodeDef* fake_quant_node = graph_def.add_node();
|
||||||
|
fake_quant_node->set_name("fake_quant_node");
|
||||||
|
fake_quant_node->set_op("FakeQuantWithMinMaxVars");
|
||||||
|
fake_quant_node->add_input("const_node1");
|
||||||
|
fake_quant_node->add_input("const_node2");
|
||||||
|
fake_quant_node->add_input("const_node3");
|
||||||
|
|
||||||
|
NodeDef* add_node = graph_def.add_node();
|
||||||
|
add_node->set_name("add_node");
|
||||||
|
add_node->set_op("Add");
|
||||||
|
add_node->add_input("fake_quant_node");
|
||||||
|
add_node->add_input("const_node4");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"add_node"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"op", {string("FakeQuantWithMinMaxVars")}}));
|
||||||
|
context.params.insert(
|
||||||
|
std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
|
||||||
|
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
ASSERT_EQ(1, node_lookup.count("const_node1"));
|
||||||
|
ASSERT_EQ(1, node_lookup.count("const_node4"));
|
||||||
|
ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
|
||||||
|
ASSERT_EQ(1, node_lookup.count("add_node"));
|
||||||
|
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
|
||||||
|
EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
|
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
|
||||||
@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
|
|||||||
|
|
||||||
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
|
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
|
||||||
|
|
||||||
|
TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
|
||||||
|
TestRemoveMultipleInputs();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace graph_transforms
|
} // namespace graph_transforms
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user