Move SetTransitiveFaninGraph and ComputeTransitiveFaninGraph as grappler utils and reuse within model_pruner optimizer. This change is NFC
PiperOrigin-RevId: 275066975 Change-Id: I430022b469f4b0938c283bf6deae428d80e7ba5b
This commit is contained in:
parent
671808b4b5
commit
7d8ee44de4
@ -142,6 +142,7 @@ cc_library(
|
||||
":utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -208,63 +209,5 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
bool* ill_formed) {
|
||||
*ill_formed = false;
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
std::unordered_map<string, const NodeDef*> name_to_send;
|
||||
for (const auto& node : graph.node()) {
|
||||
name_to_node[node.name()] = &node;
|
||||
if (node.op() == "_Send") {
|
||||
const auto& attr = node.attr();
|
||||
name_to_send[attr.at("tensor_name").s()] = &node;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> queue;
|
||||
for (const string& root : terminal_nodes) {
|
||||
const NodeDef* node = name_to_node[NodeName(root)];
|
||||
if (!node) {
|
||||
*ill_formed = true;
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
|
||||
return {};
|
||||
}
|
||||
queue.push_back(node);
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> result;
|
||||
std::unordered_set<const NodeDef*> visited;
|
||||
|
||||
while (!queue.empty()) {
|
||||
const NodeDef* node = queue.back();
|
||||
queue.pop_back();
|
||||
if (!visited.insert(node).second) {
|
||||
// The node has already been visited.
|
||||
continue;
|
||||
}
|
||||
result.push_back(node);
|
||||
for (const string& input : node->input()) {
|
||||
const NodeDef* in = name_to_node[NodeName(input)];
|
||||
if (!in) {
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
|
||||
*ill_formed = true;
|
||||
return {};
|
||||
}
|
||||
queue.push_back(in);
|
||||
}
|
||||
if (node->op() == "_Recv") {
|
||||
const auto& attr = node->attr();
|
||||
const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
|
||||
if (send) {
|
||||
queue.push_back(send);
|
||||
}
|
||||
// Subgraph after partitioning may have either _Send or _Recv, not both.
|
||||
// So, we do not set ill_formed for missing _Send.
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
@ -379,6 +379,7 @@ cc_library(
|
||||
"//tensorflow/core/grappler:mutable_graph_view",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/mutable_graph_view.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -411,29 +412,6 @@ Status SplitIdentityNInputs(GraphDef* graph,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
|
||||
GraphDef* output_graph,
|
||||
const std::vector<string>& terminal_nodes) {
|
||||
// Determines transitive fanin nodes from terminal nodes and add them to the
|
||||
// output graph.
|
||||
bool ill_formed = false;
|
||||
std::vector<const NodeDef*> keep =
|
||||
ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
|
||||
if (ill_formed) {
|
||||
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
|
||||
// let's be conservative and preserve the graph as is.
|
||||
return errors::InvalidArgument("Invalid input graph.");
|
||||
}
|
||||
// Try to keep the nodes ordered somewhat topologically since this helps
|
||||
// further optimizations perform better.
|
||||
output_graph->mutable_node()->Reserve(keep.size());
|
||||
for (int i = keep.size() - 1; i >= 0; --i) {
|
||||
*output_graph->add_node() = *keep[i];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();
|
||||
|
@ -379,3 +379,31 @@ tf_cc_test(
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transitive_fanin",
|
||||
srcs = ["transitive_fanin.cc"],
|
||||
hdrs = ["transitive_fanin.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "transitive_fanin_test",
|
||||
srcs = ["transitive_fanin_test.cc"],
|
||||
deps = [
|
||||
":transitive_fanin",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
109
tensorflow/core/grappler/utils/transitive_fanin.cc
Normal file
109
tensorflow/core/grappler/utils/transitive_fanin.cc
Normal file
@ -0,0 +1,109 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
bool* ill_formed) {
|
||||
*ill_formed = false;
|
||||
std::unordered_map<string, const NodeDef*> name_to_node;
|
||||
std::unordered_map<string, const NodeDef*> name_to_send;
|
||||
for (const auto& node : graph.node()) {
|
||||
name_to_node[node.name()] = &node;
|
||||
if (node.op() == "_Send") {
|
||||
const auto& attr = node.attr();
|
||||
name_to_send[attr.at("tensor_name").s()] = &node;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> queue;
|
||||
for (const string& root : terminal_nodes) {
|
||||
const NodeDef* node = name_to_node[NodeName(root)];
|
||||
if (!node) {
|
||||
*ill_formed = true;
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
|
||||
return {};
|
||||
}
|
||||
queue.push_back(node);
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> result;
|
||||
std::unordered_set<const NodeDef*> visited;
|
||||
|
||||
while (!queue.empty()) {
|
||||
const NodeDef* node = queue.back();
|
||||
queue.pop_back();
|
||||
if (!visited.insert(node).second) {
|
||||
// The node has already been visited.
|
||||
continue;
|
||||
}
|
||||
result.push_back(node);
|
||||
for (const string& input : node->input()) {
|
||||
const NodeDef* in = name_to_node[NodeName(input)];
|
||||
if (!in) {
|
||||
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
|
||||
*ill_formed = true;
|
||||
return {};
|
||||
}
|
||||
queue.push_back(in);
|
||||
}
|
||||
if (node->op() == "_Recv") {
|
||||
const auto& attr = node->attr();
|
||||
const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
|
||||
if (send) {
|
||||
queue.push_back(send);
|
||||
}
|
||||
// Subgraph after partitioning may have either _Send or _Recv, not both.
|
||||
// So, we do not set ill_formed for missing _Send.
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
|
||||
GraphDef* output_graph,
|
||||
const std::vector<string>& terminal_nodes) {
|
||||
// Determines transitive fanin nodes from terminal nodes and add them to the
|
||||
// output graph.
|
||||
bool ill_formed = false;
|
||||
std::vector<const NodeDef*> keep =
|
||||
ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
|
||||
if (ill_formed) {
|
||||
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
|
||||
// let's be conservative and preserve the graph as is.
|
||||
return errors::InvalidArgument("Invalid input graph.");
|
||||
}
|
||||
// Try to keep the nodes ordered somewhat topologically since this helps
|
||||
// further optimizations perform better.
|
||||
output_graph->mutable_node()->Reserve(keep.size());
|
||||
for (int i = keep.size() - 1; i >= 0; --i) {
|
||||
*output_graph->add_node() = *keep[i];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
44
tensorflow/core/grappler/utils/transitive_fanin.h
Normal file
44
tensorflow/core/grappler/utils/transitive_fanin.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Computes the transitive fanin of the graph based on reachability from the
|
||||
// specified terminal nodes. ill_formed will be set to true if the graph is
|
||||
// deemed structurally invalid. Returns the set of nodes comprising the
|
||||
// transitive fanin.
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes,
|
||||
bool* ill_formed);
|
||||
|
||||
// Creates output_graph from input_graph using the transitive fanin from the
|
||||
// specified terminal nodes. Returns error if the input_graph is deemed
|
||||
// structurally invalid.
|
||||
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
|
||||
GraphDef* output_graph,
|
||||
const std::vector<string>& terminal_nodes);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_
|
139
tensorflow/core/grappler/utils/transitive_fanin_test.cc
Normal file
139
tensorflow/core/grappler/utils/transitive_fanin_test.cc
Normal file
@ -0,0 +1,139 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class TransitiveFaninTest : public ::testing::Test {
|
||||
protected:
|
||||
struct NodeConfig {
|
||||
NodeConfig(string name, std::vector<string> inputs)
|
||||
: name(std::move(name)), inputs(std::move(inputs)) {}
|
||||
NodeConfig(string name, string op, std::vector<string> inputs)
|
||||
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
|
||||
|
||||
string name;
|
||||
string op;
|
||||
std::vector<string> inputs;
|
||||
};
|
||||
|
||||
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
|
||||
GraphDef graph;
|
||||
|
||||
for (const NodeConfig& node : nodes) {
|
||||
NodeDef node_def;
|
||||
node_def.set_name(node.name);
|
||||
node_def.set_op(node.op);
|
||||
for (const string& input : node.inputs) {
|
||||
node_def.add_input(input);
|
||||
}
|
||||
*graph.add_node() = std::move(node_def);
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(TransitiveFaninTest, NoPruning) {
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {"2"}}, //
|
||||
{"2", {"3"}}, //
|
||||
{"3", {"4"}}, //
|
||||
{"4", {}} //
|
||||
});
|
||||
|
||||
GraphDef output_graph;
|
||||
const std::vector<string> terminal_nodes = {"1"};
|
||||
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
|
||||
NodeMap node_map(&output_graph);
|
||||
ASSERT_TRUE(node_map.NodeExists("1"));
|
||||
ASSERT_TRUE(node_map.NodeExists("2"));
|
||||
ASSERT_TRUE(node_map.NodeExists("3"));
|
||||
ASSERT_TRUE(node_map.NodeExists("4"));
|
||||
}
|
||||
|
||||
TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromSingleTerminalNode) {
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {"2"}}, //
|
||||
{"2", {"3"}}, //
|
||||
{"3", {"4"}}, //
|
||||
{"4", {}}, //
|
||||
{"5", {"1"}} //
|
||||
});
|
||||
|
||||
GraphDef output_graph;
|
||||
const std::vector<string> terminal_nodes = {"1"};
|
||||
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
|
||||
NodeMap node_map(&output_graph);
|
||||
ASSERT_TRUE(node_map.NodeExists("1"));
|
||||
ASSERT_TRUE(node_map.NodeExists("2"));
|
||||
ASSERT_TRUE(node_map.NodeExists("3"));
|
||||
ASSERT_TRUE(node_map.NodeExists("4"));
|
||||
ASSERT_FALSE(node_map.NodeExists("5"));
|
||||
}
|
||||
|
||||
TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromMultipleTerminalNodes) {
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {"2"}}, //
|
||||
{"2", {"3"}}, //
|
||||
{"3", {"4"}}, //
|
||||
{"4", {}}, //
|
||||
{"5", {"2"}}, //
|
||||
{"6", {"1"}} //
|
||||
});
|
||||
|
||||
GraphDef output_graph;
|
||||
const std::vector<string> terminal_nodes = {"1", "5"};
|
||||
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
|
||||
NodeMap node_map(&output_graph);
|
||||
ASSERT_TRUE(node_map.NodeExists("1"));
|
||||
ASSERT_TRUE(node_map.NodeExists("2"));
|
||||
ASSERT_TRUE(node_map.NodeExists("3"));
|
||||
ASSERT_TRUE(node_map.NodeExists("4"));
|
||||
ASSERT_TRUE(node_map.NodeExists("5"));
|
||||
ASSERT_FALSE(node_map.NodeExists("6"));
|
||||
}
|
||||
|
||||
TEST_F(TransitiveFaninTest, InvalidGraph) {
|
||||
GraphDef graph = CreateGraph({
|
||||
{"1", {"2"}}, //
|
||||
{"2", {"3"}}, //
|
||||
{"3", {"4"}}, //
|
||||
{"4", {}}, //
|
||||
{"5", {"6"}}, //
|
||||
{"7", {"8"}} //
|
||||
});
|
||||
|
||||
GraphDef output_graph;
|
||||
const std::vector<string> terminal_nodes = {"1", "5"};
|
||||
auto s = SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(), "Invalid input graph.");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user