Add graph transformer to inline one level of PartitionedCall operations.
PiperOrigin-RevId: 299305422 Change-Id: I4502daffd998daff54fc9af5fd3fcb6c62dec234
This commit is contained in:
parent
6b3556f20d
commit
df9f9fb7f3
tensorflow/tools/graph_transforms
@ -98,6 +98,7 @@ cc_library(
|
||||
"fold_old_batch_norms.cc",
|
||||
"freeze_requantization_ranges.cc",
|
||||
"fuse_convolutions.cc",
|
||||
"inline_partitionedcall.cc",
|
||||
"insert_logging.cc",
|
||||
"obfuscate_names.cc",
|
||||
"quantize_nodes.cc",
|
||||
@ -122,6 +123,9 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":transform_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/core/util/tensor_bundle",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -175,6 +179,7 @@ tf_cc_test(
|
||||
"fold_old_batch_norms_test.cc",
|
||||
"freeze_requantization_ranges_test.cc",
|
||||
"fuse_convolutions_test.cc",
|
||||
"inline_partitionedcall_test.cc",
|
||||
"insert_logging_test.cc",
|
||||
"obfuscate_names_test.cc",
|
||||
"quantize_nodes_test.cc",
|
||||
|
151
tensorflow/tools/graph_transforms/inline_partitionedcall.cc
Normal file
151
tensorflow/tools/graph_transforms/inline_partitionedcall.cc
Normal file
@ -0,0 +1,151 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
constexpr char kPartitionedCallOpName[] = "PartitionedCall";
|
||||
constexpr char kFunctionAttrName[] = "f";
|
||||
|
||||
namespace {
|
||||
absl::optional<FunctionDef> GetFunctionByNameFromLibrary(
|
||||
const GraphDef& graph, absl::string_view function_name) {
|
||||
for (const auto& fct : graph.library().function()) {
|
||||
if (fct.signature().name() == function_name) {
|
||||
return fct;
|
||||
}
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::string NormalizeNodeDefInput(const std::string& input_name) {
|
||||
std::vector<std::string> name_parts =
|
||||
absl::StrSplit(input_name, absl::ByChar(':'));
|
||||
if (name_parts.size() > 2) {
|
||||
return absl::StrCat(name_parts[0], ":", name_parts.back());
|
||||
}
|
||||
return input_name;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status InlinePartitionedCall(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
output_graph_def->Clear();
|
||||
absl::flat_hash_map<std::string, std::string> remap_input;
|
||||
|
||||
for (const NodeDef& node : input_graph_def.node()) {
|
||||
if (node.op() == kPartitionedCallOpName) {
|
||||
if (node.attr().count(kFunctionAttrName) == 0) {
|
||||
return Status(
|
||||
error::Code::NOT_FOUND,
|
||||
"Node " + node.name() + " has no attribute: " + kFunctionAttrName);
|
||||
}
|
||||
|
||||
if (!node.attr().at(kFunctionAttrName).has_func()) {
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"Cannot figure out function name");
|
||||
}
|
||||
const std::string function_name =
|
||||
node.attr().at(kFunctionAttrName).func().name();
|
||||
absl::optional<FunctionDef> function =
|
||||
GetFunctionByNameFromLibrary(input_graph_def, function_name);
|
||||
if (!function.has_value()) {
|
||||
return Status(error::Code::NOT_FOUND,
|
||||
"function " + function_name + " Not found");
|
||||
}
|
||||
|
||||
const std::string prefix = node.name();
|
||||
|
||||
const int kOutputArgumentCount =
|
||||
function->signature().output_arg().size();
|
||||
for (int k = 0; k < kOutputArgumentCount; ++k) {
|
||||
const std::string function_arg_output_name =
|
||||
function->ret().at(function->signature().output_arg()[k].name());
|
||||
remap_input.insert_or_assign(
|
||||
CanonicalInputName(absl::StrCat(node.name(), ":", k)),
|
||||
absl::StrCat(prefix, "/",
|
||||
NormalizeNodeDefInput(function_arg_output_name)));
|
||||
}
|
||||
|
||||
const int kInputArgumentCount = function->signature().input_arg().size();
|
||||
if (node.input().size() != kInputArgumentCount) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Called function " + function_name +
|
||||
" has invalid input signature.");
|
||||
}
|
||||
absl::flat_hash_map<std::string, std::string> input_argument_map;
|
||||
for (int k = 0; k < kInputArgumentCount; ++k) {
|
||||
const std::string canonical_name =
|
||||
CanonicalInputName(function->signature().input_arg()[k].name());
|
||||
input_argument_map.insert_or_assign(canonical_name, node.input()[k]);
|
||||
}
|
||||
|
||||
for (const NodeDef& function_node : function->node_def()) {
|
||||
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||
*new_node = function_node;
|
||||
new_node->set_name(absl::StrCat(prefix, "/", function_node.name()));
|
||||
absl::c_transform(
|
||||
*new_node->mutable_input(), new_node->mutable_input()->begin(),
|
||||
[prefix, input_argument_map](const std::string& input_name) {
|
||||
const std::string canonical_input_name =
|
||||
CanonicalInputName(input_name);
|
||||
if (input_argument_map.find(canonical_input_name) !=
|
||||
input_argument_map.end()) {
|
||||
return input_argument_map.at(canonical_input_name);
|
||||
}
|
||||
return absl::StrCat(prefix, "/",
|
||||
NormalizeNodeDefInput(input_name));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
NodeDef* new_node = output_graph_def->mutable_node()->Add();
|
||||
*new_node = node;
|
||||
}
|
||||
}
|
||||
|
||||
// Remap PartitionCall outputs to correct nodes.
|
||||
for (NodeDef& node : *output_graph_def->mutable_node()) {
|
||||
absl::c_transform(
|
||||
*node.mutable_input(), node.mutable_input()->begin(),
|
||||
[remap_input](const std::string& input_name) {
|
||||
const std::string canonical_input_name =
|
||||
CanonicalInputName(input_name);
|
||||
if (remap_input.find(canonical_input_name) != remap_input.end()) {
|
||||
return remap_input.at(canonical_input_name);
|
||||
}
|
||||
return input_name;
|
||||
});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("inline_partitionedcall", InlinePartitionedCall);
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
136
tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc
Normal file
136
tensorflow/tools/graph_transforms/inline_partitionedcall_test.cc
Normal file
@ -0,0 +1,136 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
|
||||
constexpr char kGraphDefWithPartitionedCall[] =
|
||||
"node {\n"
|
||||
" name: \"y\"\n"
|
||||
" op: \"Placeholder\"\n"
|
||||
"}\n"
|
||||
"node {\n"
|
||||
" name: \"sub/y\"\n"
|
||||
" op: \"Const\"\n"
|
||||
"}\n"
|
||||
"node {\n"
|
||||
" name: \"PartitionedCall\"\n"
|
||||
" op: \"PartitionedCall\"\n"
|
||||
" input: \"y\"\n"
|
||||
" input: \"sub/y\"\n"
|
||||
" attr {\n"
|
||||
" key: \"f\"\n"
|
||||
" value {\n"
|
||||
" func {\n"
|
||||
" name: \"__inference_simple_add_14\"\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n"
|
||||
"node {\n"
|
||||
" name: \"add/y\"\n"
|
||||
" op: \"Const\"\n"
|
||||
"}\n"
|
||||
"node {\n"
|
||||
" name: \"add\"\n"
|
||||
" op: \"AddV2\"\n"
|
||||
" input: \"PartitionedCall\"\n"
|
||||
" input: \"add/y\"\n"
|
||||
"}\n"
|
||||
"node {\n"
|
||||
" name: \"Identity\"\n"
|
||||
" op: \"Identity\"\n"
|
||||
" input: \"add\"\n"
|
||||
"}\n"
|
||||
"library {\n"
|
||||
" function {\n"
|
||||
" signature {\n"
|
||||
" name: \"__inference_simple_add_14\"\n"
|
||||
" input_arg {\n"
|
||||
" name: \"x\"\n"
|
||||
" type: DT_FLOAT\n"
|
||||
" }\n"
|
||||
" input_arg {\n"
|
||||
" name: \"y\"\n"
|
||||
" type: DT_FLOAT\n"
|
||||
" }\n"
|
||||
" output_arg {\n"
|
||||
" name: \"identity\"\n"
|
||||
" type: DT_FLOAT\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
" node_def {\n"
|
||||
" name: \"mul\"\n"
|
||||
" op: \"Mul\"\n"
|
||||
" input: \"x\"\n"
|
||||
" input: \"y\"\n"
|
||||
" }\n"
|
||||
" node_def {\n"
|
||||
" name: \"add/y\"\n"
|
||||
" op: \"Const\"\n"
|
||||
" }\n"
|
||||
" node_def {\n"
|
||||
" name: \"add\"\n"
|
||||
" op: \"AddV2\"\n"
|
||||
" input: \"mul:z:0\"\n"
|
||||
" input: \"add/y:output:0\"\n"
|
||||
" }\n"
|
||||
" node_def {\n"
|
||||
" name: \"Identity\"\n"
|
||||
" op: \"Identity\"\n"
|
||||
" input: \"add:z:0\"\n"
|
||||
" }\n"
|
||||
" ret {\n"
|
||||
" key: \"identity\"\n"
|
||||
" value: \"Identity:output:0\"\n"
|
||||
" }\n"
|
||||
" }\n"
|
||||
"}\n";
|
||||
|
||||
// Declare here, so we don't need a public header.
|
||||
Status InlinePartitionedCall(const GraphDef& input_graph_def,
|
||||
const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def);
|
||||
|
||||
TEST(InlinePartitionedCallTest, Inlining) {
|
||||
GraphDef in_graph;
|
||||
EXPECT_TRUE(::tensorflow::protobuf::TextFormat::ParseFromString(
|
||||
kGraphDefWithPartitionedCall, &in_graph));
|
||||
|
||||
GraphDef result;
|
||||
TransformFuncContext context;
|
||||
context.input_names = {"y"};
|
||||
context.output_names = {"Identity"};
|
||||
TF_ASSERT_OK(InlinePartitionedCall(in_graph, context, &result));
|
||||
|
||||
EXPECT_TRUE(std::none_of(
|
||||
result.node().cbegin(), result.node().cend(),
|
||||
[](const NodeDef& node) { return node.op() == "PartitionedCall"; }));
|
||||
EXPECT_EQ(9, result.node().size());
|
||||
TF_EXPECT_OK(IsGraphValid(result));
|
||||
}
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user