diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 198859598eb..19c17079031 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -107,6 +107,7 @@ cc_library( "remove_device.cc", "remove_nodes.cc", "rename_attribute.cc", + "rename_node.cc", "rename_op.cc", "round_weights.cc", "set_device.cc", @@ -182,6 +183,7 @@ tf_cc_test( "remove_device_test.cc", "remove_nodes_test.cc", "rename_attribute_test.cc", + "rename_node_test.cc", "rename_op_test.cc", "round_weights_test.cc", "set_device_test.cc", diff --git a/tensorflow/tools/graph_transforms/rename_node.cc b/tensorflow/tools/graph_transforms/rename_node.cc new file mode 100644 index 00000000000..bd40e842577 --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_node.cc @@ -0,0 +1,70 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +Status RenameNode(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + if (!context.params.count("old_node_name") || + (context.params.at("old_node_name").size() != 1) || + !context.params.count("new_node_name") || + (context.params.at("new_node_name").size() != 1)) { + return errors::InvalidArgument( + "rename_node expects exactly one 'old_node_name' and one " + "'new_node_name' argument, e.g. " + "rename_node(old_attribute_name=super/deep/output, " + "new_attribute_name=output)"); + } + + const std::string old_node_name = context.params.at("old_node_name")[0]; + const std::string new_node_name = context.params.at("new_node_name")[0]; + + output_graph_def->Clear(); + for (const NodeDef& input_node : input_graph_def.node()) { + NodeDef* node = output_graph_def->mutable_node()->Add(); + *node = input_node; + if (node->name() == new_node_name) { + return Status(error::Code::INVALID_ARGUMENT, + "A node is alreading using " + new_node_name + "as name."); + } + if (node->name() == old_node_name) { + node->set_name(new_node_name); + } + for (std::string& input_name : *node->mutable_input()) { + std::string prefix; + std::string input_node_name; + std::string suffix; + NodeNamePartsFromInput(input_name, &prefix, &input_node_name, &suffix); + if (input_node_name == old_node_name) { + std::string new_input_name = prefix + new_node_name + suffix; + input_name = new_input_name; + } + } + } + + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("rename_node", RenameNode); +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/rename_node_test.cc b/tensorflow/tools/graph_transforms/rename_node_test.cc new file mode 100644 index 00000000000..574272b8cca --- /dev/null +++ b/tensorflow/tools/graph_transforms/rename_node_test.cc @@ -0,0 +1,96 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Declare here, so we don't need a public header. +Status RenameNode(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def); + +TEST(RenameNodeTest, Rename) { + GraphDef in_graph; + NodeDef* node = in_graph.add_node(); + node->set_name("input"); + node->set_op("Placeholder"); + + NodeDef* node_splitter = in_graph.add_node(); + node_splitter->set_name("splitter"); + node_splitter->set_op("Split"); + + NodeDef* node_adder = in_graph.add_node(); + node_adder->set_op("Add"); + node_adder->set_name("adder"); + node_adder->add_input("splitter"); + node_adder->add_input("splitter:1"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"adder"}; + context.params.insert(std::pair>( + {"old_node_name", {std::string("splitter")}})); + context.params.insert(std::pair>( + {"new_node_name", {string("demux")}})); + TF_ASSERT_OK(RenameNode(in_graph, context, &result)); + + std::map node_lookup; + MapNamesToNodes(result, &node_lookup); + EXPECT_EQ(1, node_lookup.count("demux")); + EXPECT_EQ(1, node_lookup.count("adder")); + EXPECT_EQ(2, node_lookup["adder"]->input().size()); + EXPECT_EQ("demux", node_lookup["adder"]->input()[0]); + EXPECT_EQ("demux:1", node_lookup["adder"]->input()[1]); +} + +TEST(RenameNodeTest, FailWhenNameAlreadyExists) { + GraphDef in_graph; + NodeDef* node = in_graph.add_node(); + node->set_name("input"); + node->set_op("Placeholder"); + + NodeDef* node_splitter = in_graph.add_node(); + node_splitter->set_name("splitter"); + node_splitter->set_op("Split"); + + NodeDef* node_adder = in_graph.add_node(); + node_adder->set_op("Add"); + node_adder->set_name("adder"); + node_adder->add_input("splitter"); + node_adder->add_input("splitter:1"); + + GraphDef result; + TransformFuncContext context; + context.input_names = {}; + context.output_names = {"adder"}; + context.params.insert(std::pair>( + {"old_node_name", {std::string("splitter")}})); + context.params.insert(std::pair>( + {"new_node_name", {string("adder")}})); + EXPECT_FALSE(RenameNode(in_graph, context, &result).ok()); +} + +} // namespace graph_transforms +} // namespace tensorflow