Enable removal of nodes with multiple inputs
PiperOrigin-RevId: 157068219
This commit is contained in:
parent
4a09e96797
commit
70fc6abad7
@ -37,6 +37,9 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
||||
"remove_nodes expects at least one 'op'"
|
||||
"argument, e.g. remove_nodes(op=Identity)");
|
||||
}
|
||||
int32 max_inputs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context.GetOneInt32Parameter("max_inputs", 1, &max_inputs));
|
||||
|
||||
// Make sure we don't get rid of any nodes used as graph inputs or outputs.
|
||||
std::set<string> required_nodes;
|
||||
@ -50,39 +53,48 @@ Status RemoveNodes(const GraphDef& input_graph_def,
|
||||
std::vector<string> ops_to_remove = context.params.at("op");
|
||||
GraphDef current_graph_def = input_graph_def;
|
||||
for (const string& op : ops_to_remove) {
|
||||
// Keep looking for nodes to remove until there are no more changes.
|
||||
bool any_nodes_removed;
|
||||
do {
|
||||
any_nodes_removed = false;
|
||||
std::map<string, string> inputs_to_rename;
|
||||
GraphDef replaced_graph_def;
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
current_graph_def, {op, {{"*"}}},
|
||||
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
|
||||
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
const std::set<string>& output_nodes,
|
||||
std::vector<NodeDef>* new_nodes) {
|
||||
const NodeDef& replace_node = match.node;
|
||||
// If this node is needed in the inputs or outputs don't replace it.
|
||||
if (required_nodes.count(replace_node.name())) {
|
||||
LOG(INFO) << "Skipping replacement for " << replace_node.name();
|
||||
CopyOriginalMatch(match, new_nodes);
|
||||
for (int num_inputs = 1; num_inputs <= max_inputs; ++num_inputs) {
|
||||
// Look for a variable number of inputs.
|
||||
OpTypePattern pattern = {op};
|
||||
pattern.inputs.resize(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
pattern.inputs[i] = {"*"};
|
||||
}
|
||||
// Keep looking for nodes to remove until there are no more changes.
|
||||
bool any_nodes_removed;
|
||||
do {
|
||||
any_nodes_removed = false;
|
||||
std::map<string, string> inputs_to_rename;
|
||||
GraphDef replaced_graph_def;
|
||||
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
|
||||
current_graph_def, pattern,
|
||||
[&inputs_to_rename, &required_nodes, &any_nodes_removed](
|
||||
const NodeMatch& match, const std::set<string>& input_nodes,
|
||||
const std::set<string>& output_nodes,
|
||||
std::vector<NodeDef>* new_nodes) {
|
||||
const NodeDef& replace_node = match.node;
|
||||
// If this node is needed in the inputs or outputs don't replace
|
||||
// it.
|
||||
if (required_nodes.count(replace_node.name())) {
|
||||
LOG(INFO) << "Skipping replacement for " << replace_node.name();
|
||||
CopyOriginalMatch(match, new_nodes);
|
||||
return Status::OK();
|
||||
}
|
||||
const NodeDef& input_node = match.inputs[0].node;
|
||||
inputs_to_rename[replace_node.name()] = input_node.name();
|
||||
inputs_to_rename["^" + replace_node.name()] =
|
||||
"^" + input_node.name();
|
||||
new_nodes->push_back(input_node);
|
||||
any_nodes_removed = true;
|
||||
return Status::OK();
|
||||
}
|
||||
const NodeDef& input_node = match.inputs[0].node;
|
||||
inputs_to_rename[replace_node.name()] = input_node.name();
|
||||
inputs_to_rename["^" + replace_node.name()] =
|
||||
"^" + input_node.name();
|
||||
new_nodes->push_back(input_node);
|
||||
any_nodes_removed = true;
|
||||
return Status::OK();
|
||||
},
|
||||
{true}, &replaced_graph_def));
|
||||
// Make sure all references to removed nodes now point to their inputs.
|
||||
TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||
std::unordered_set<string>(),
|
||||
¤t_graph_def));
|
||||
} while (any_nodes_removed);
|
||||
},
|
||||
{true}, &replaced_graph_def));
|
||||
// Make sure all references to removed nodes now point to their inputs.
|
||||
TF_RETURN_IF_ERROR(
|
||||
RenameNodeInputs(replaced_graph_def, inputs_to_rename,
|
||||
std::unordered_set<string>(), ¤t_graph_def));
|
||||
} while (any_nodes_removed);
|
||||
}
|
||||
}
|
||||
|
||||
*output_graph_def = current_graph_def;
|
||||
|
@ -210,6 +210,58 @@ class RemoveNodesTest : public ::testing::Test {
|
||||
EXPECT_EQ(0, node_lookup.count("identity_node2"));
|
||||
EXPECT_EQ(0, node_lookup.count("identity_node3"));
|
||||
}
|
||||
|
||||
void TestRemoveMultipleInputs() {
|
||||
GraphDef graph_def;
|
||||
|
||||
NodeDef* const_node1 = graph_def.add_node();
|
||||
const_node1->set_name("const_node1");
|
||||
const_node1->set_op("Const");
|
||||
|
||||
NodeDef* const_node2 = graph_def.add_node();
|
||||
const_node2->set_name("const_node2");
|
||||
const_node2->set_op("Const");
|
||||
|
||||
NodeDef* const_node3 = graph_def.add_node();
|
||||
const_node3->set_name("const_node3");
|
||||
const_node3->set_op("Const");
|
||||
|
||||
NodeDef* const_node4 = graph_def.add_node();
|
||||
const_node4->set_name("const_node4");
|
||||
const_node4->set_op("Const");
|
||||
|
||||
NodeDef* fake_quant_node = graph_def.add_node();
|
||||
fake_quant_node->set_name("fake_quant_node");
|
||||
fake_quant_node->set_op("FakeQuantWithMinMaxVars");
|
||||
fake_quant_node->add_input("const_node1");
|
||||
fake_quant_node->add_input("const_node2");
|
||||
fake_quant_node->add_input("const_node3");
|
||||
|
||||
NodeDef* add_node = graph_def.add_node();
|
||||
add_node->set_name("add_node");
|
||||
add_node->set_op("Add");
|
||||
add_node->add_input("fake_quant_node");
|
||||
add_node->add_input("const_node4");
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {};
|
||||
context.output_names = {"add_node"};
|
||||
context.params.insert(std::pair<string, std::vector<string>>(
|
||||
{"op", {string("FakeQuantWithMinMaxVars")}}));
|
||||
context.params.insert(
|
||||
std::pair<string, std::vector<string>>({"max_inputs", {string("3")}}));
|
||||
TF_ASSERT_OK(RemoveNodes(graph_def, context, &result));
|
||||
|
||||
std::map<string, const NodeDef*> node_lookup;
|
||||
MapNamesToNodes(result, &node_lookup);
|
||||
ASSERT_EQ(1, node_lookup.count("const_node1"));
|
||||
ASSERT_EQ(1, node_lookup.count("const_node4"));
|
||||
ASSERT_EQ(0, node_lookup.count("fake_quant_node"));
|
||||
ASSERT_EQ(1, node_lookup.count("add_node"));
|
||||
EXPECT_EQ("const_node1", node_lookup.at("add_node")->input(0));
|
||||
EXPECT_EQ("const_node4", node_lookup.at("add_node")->input(1));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(RemoveNodesTest, TestRemoveNodes) { TestRemoveNodes(); }
|
||||
@ -218,5 +270,9 @@ TEST_F(RemoveNodesTest, TestRemoveOutputNodes) { TestRemoveOutputNodes(); }
|
||||
|
||||
TEST_F(RemoveNodesTest, TestRemoveChainedNodes) { TestRemoveChainedNodes(); }
|
||||
|
||||
TEST_F(RemoveNodesTest, TestRemoveMultipleInputs) {
|
||||
TestRemoveMultipleInputs();
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user