Add a graph transformer to rename node.
PiperOrigin-RevId: 298577065 Change-Id: Id689e6f60cee1ab4922c1b8e9f6d2988a58e8028
This commit is contained in:
parent
5b1e6d3ceb
commit
27e215e2ed
@ -107,6 +107,7 @@ cc_library(
|
|||||||
"remove_device.cc",
|
"remove_device.cc",
|
||||||
"remove_nodes.cc",
|
"remove_nodes.cc",
|
||||||
"rename_attribute.cc",
|
"rename_attribute.cc",
|
||||||
|
"rename_node.cc",
|
||||||
"rename_op.cc",
|
"rename_op.cc",
|
||||||
"round_weights.cc",
|
"round_weights.cc",
|
||||||
"set_device.cc",
|
"set_device.cc",
|
||||||
@ -182,6 +183,7 @@ tf_cc_test(
|
|||||||
"remove_device_test.cc",
|
"remove_device_test.cc",
|
||||||
"remove_nodes_test.cc",
|
"remove_nodes_test.cc",
|
||||||
"rename_attribute_test.cc",
|
"rename_attribute_test.cc",
|
||||||
|
"rename_node_test.cc",
|
||||||
"rename_op_test.cc",
|
"rename_op_test.cc",
|
||||||
"round_weights_test.cc",
|
"round_weights_test.cc",
|
||||||
"set_device_test.cc",
|
"set_device_test.cc",
|
||||||
|
70
tensorflow/tools/graph_transforms/rename_node.cc
Normal file
70
tensorflow/tools/graph_transforms/rename_node.cc
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
Status RenameNode(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def) {
|
||||||
|
if (!context.params.count("old_node_name") ||
|
||||||
|
(context.params.at("old_node_name").size() != 1) ||
|
||||||
|
!context.params.count("new_node_name") ||
|
||||||
|
(context.params.at("new_node_name").size() != 1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"rename_node expects exactly one 'old_node_name' and one "
|
||||||
|
"'new_node_name' argument, e.g. "
|
||||||
|
"rename_node(old_attribute_name=super/deep/output, "
|
||||||
|
"new_attribute_name=output)");
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::string old_node_name = context.params.at("old_node_name")[0];
|
||||||
|
const std::string new_node_name = context.params.at("new_node_name")[0];
|
||||||
|
|
||||||
|
output_graph_def->Clear();
|
||||||
|
for (const NodeDef& input_node : input_graph_def.node()) {
|
||||||
|
NodeDef* node = output_graph_def->mutable_node()->Add();
|
||||||
|
*node = input_node;
|
||||||
|
if (node->name() == new_node_name) {
|
||||||
|
return Status(error::Code::INVALID_ARGUMENT,
|
||||||
|
"A node is alreading using " + new_node_name + "as name.");
|
||||||
|
}
|
||||||
|
if (node->name() == old_node_name) {
|
||||||
|
node->set_name(new_node_name);
|
||||||
|
}
|
||||||
|
for (std::string& input_name : *node->mutable_input()) {
|
||||||
|
std::string prefix;
|
||||||
|
std::string input_node_name;
|
||||||
|
std::string suffix;
|
||||||
|
NodeNamePartsFromInput(input_name, &prefix, &input_node_name, &suffix);
|
||||||
|
if (input_node_name == old_node_name) {
|
||||||
|
std::string new_input_name = prefix + new_node_name + suffix;
|
||||||
|
input_name = new_input_name;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_GRAPH_TRANSFORM("rename_node", RenameNode);
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
96
tensorflow/tools/graph_transforms/rename_node_test.cc
Normal file
96
tensorflow/tools/graph_transforms/rename_node_test.cc
Normal file
@ -0,0 +1,96 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace graph_transforms {
|
||||||
|
|
||||||
|
// Declare here, so we don't need a public header.
|
||||||
|
Status RenameNode(const GraphDef& input_graph_def,
|
||||||
|
const TransformFuncContext& context,
|
||||||
|
GraphDef* output_graph_def);
|
||||||
|
|
||||||
|
TEST(RenameNodeTest, Rename) {
|
||||||
|
GraphDef in_graph;
|
||||||
|
NodeDef* node = in_graph.add_node();
|
||||||
|
node->set_name("input");
|
||||||
|
node->set_op("Placeholder");
|
||||||
|
|
||||||
|
NodeDef* node_splitter = in_graph.add_node();
|
||||||
|
node_splitter->set_name("splitter");
|
||||||
|
node_splitter->set_op("Split");
|
||||||
|
|
||||||
|
NodeDef* node_adder = in_graph.add_node();
|
||||||
|
node_adder->set_op("Add");
|
||||||
|
node_adder->set_name("adder");
|
||||||
|
node_adder->add_input("splitter");
|
||||||
|
node_adder->add_input("splitter:1");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"adder"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"old_node_name", {std::string("splitter")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"new_node_name", {string("demux")}}));
|
||||||
|
TF_ASSERT_OK(RenameNode(in_graph, context, &result));
|
||||||
|
|
||||||
|
std::map<string, const NodeDef*> node_lookup;
|
||||||
|
MapNamesToNodes(result, &node_lookup);
|
||||||
|
EXPECT_EQ(1, node_lookup.count("demux"));
|
||||||
|
EXPECT_EQ(1, node_lookup.count("adder"));
|
||||||
|
EXPECT_EQ(2, node_lookup["adder"]->input().size());
|
||||||
|
EXPECT_EQ("demux", node_lookup["adder"]->input()[0]);
|
||||||
|
EXPECT_EQ("demux:1", node_lookup["adder"]->input()[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RenameNodeTest, FailWhenNameAlreadyExists) {
|
||||||
|
GraphDef in_graph;
|
||||||
|
NodeDef* node = in_graph.add_node();
|
||||||
|
node->set_name("input");
|
||||||
|
node->set_op("Placeholder");
|
||||||
|
|
||||||
|
NodeDef* node_splitter = in_graph.add_node();
|
||||||
|
node_splitter->set_name("splitter");
|
||||||
|
node_splitter->set_op("Split");
|
||||||
|
|
||||||
|
NodeDef* node_adder = in_graph.add_node();
|
||||||
|
node_adder->set_op("Add");
|
||||||
|
node_adder->set_name("adder");
|
||||||
|
node_adder->add_input("splitter");
|
||||||
|
node_adder->add_input("splitter:1");
|
||||||
|
|
||||||
|
GraphDef result;
|
||||||
|
TransformFuncContext context;
|
||||||
|
context.input_names = {};
|
||||||
|
context.output_names = {"adder"};
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"old_node_name", {std::string("splitter")}}));
|
||||||
|
context.params.insert(std::pair<string, std::vector<string>>(
|
||||||
|
{"new_node_name", {string("adder")}}));
|
||||||
|
EXPECT_FALSE(RenameNode(in_graph, context, &result).ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph_transforms
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user