Move SetTransitiveFaninGraph and ComputeTransitiveFaninGraph as grappler utils and reuse within model_pruner optimizer. This change is NFC

PiperOrigin-RevId: 275066975
Change-Id: I430022b469f4b0938c283bf6deae428d80e7ba5b
This commit is contained in:
Ashwin Murthy 2019-10-16 10:49:52 -07:00 committed by TensorFlower Gardener
parent 671808b4b5
commit 7d8ee44de4
8 changed files with 324 additions and 81 deletions

View File

@ -142,6 +142,7 @@ cc_library(
":utils",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@ -208,63 +209,5 @@ std::vector<const NodeDef*> ComputeTransitiveFanin(
return result;
}
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
bool* ill_formed) {
*ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_node;
std::unordered_map<string, const NodeDef*> name_to_send;
for (const auto& node : graph.node()) {
name_to_node[node.name()] = &node;
if (node.op() == "_Send") {
const auto& attr = node.attr();
name_to_send[attr.at("tensor_name").s()] = &node;
}
}
std::vector<const NodeDef*> queue;
for (const string& root : terminal_nodes) {
const NodeDef* node = name_to_node[NodeName(root)];
if (!node) {
*ill_formed = true;
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
return {};
}
queue.push_back(node);
}
std::vector<const NodeDef*> result;
std::unordered_set<const NodeDef*> visited;
while (!queue.empty()) {
const NodeDef* node = queue.back();
queue.pop_back();
if (!visited.insert(node).second) {
// The node has already been visited.
continue;
}
result.push_back(node);
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
if (!in) {
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
*ill_formed = true;
return {};
}
queue.push_back(in);
}
if (node->op() == "_Recv") {
const auto& attr = node->attr();
const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
if (send) {
queue.push_back(send);
}
// Subgraph after partitioning may have either _Send or _Recv, not both.
// So, we do not set ill_formed for missing _Send.
}
}
return result;
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -379,6 +379,7 @@ cc_library(
"//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/utils:transitive_fanin",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
],

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
namespace tensorflow {
namespace grappler {
@ -411,29 +412,6 @@ Status SplitIdentityNInputs(GraphDef* graph,
return Status::OK();
}
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
GraphDef* output_graph,
const std::vector<string>& terminal_nodes) {
// Determines transitive fanin nodes from terminal nodes and add them to the
// output graph.
bool ill_formed = false;
std::vector<const NodeDef*> keep =
ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
if (ill_formed) {
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
// let's be conservative and preserve the graph as is.
return errors::InvalidArgument("Invalid input graph.");
}
// Try to keep the nodes ordered somewhat topologically since this helps
// further optimizations perform better.
output_graph->mutable_node()->Reserve(keep.size());
for (int i = keep.size() - 1; i >= 0; --i) {
*output_graph->add_node() = *keep[i];
}
return Status::OK();
}
Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
const std::unordered_set<string> nodes_to_preserve = item.NodesToPreserve();

View File

@ -379,3 +379,31 @@ tf_cc_test(
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "transitive_fanin",
srcs = ["transitive_fanin.cc"],
hdrs = ["transitive_fanin.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:utils",
],
)
tf_cc_test(
name = "transitive_fanin_test",
srcs = ["transitive_fanin_test.cc"],
deps = [
":transitive_fanin",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:utils",
],
)

View File

@ -0,0 +1,109 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include <queue>
#include <vector>
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
bool* ill_formed) {
*ill_formed = false;
std::unordered_map<string, const NodeDef*> name_to_node;
std::unordered_map<string, const NodeDef*> name_to_send;
for (const auto& node : graph.node()) {
name_to_node[node.name()] = &node;
if (node.op() == "_Send") {
const auto& attr = node.attr();
name_to_send[attr.at("tensor_name").s()] = &node;
}
}
std::vector<const NodeDef*> queue;
for (const string& root : terminal_nodes) {
const NodeDef* node = name_to_node[NodeName(root)];
if (!node) {
*ill_formed = true;
VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root;
return {};
}
queue.push_back(node);
}
std::vector<const NodeDef*> result;
std::unordered_set<const NodeDef*> visited;
while (!queue.empty()) {
const NodeDef* node = queue.back();
queue.pop_back();
if (!visited.insert(node).second) {
// The node has already been visited.
continue;
}
result.push_back(node);
for (const string& input : node->input()) {
const NodeDef* in = name_to_node[NodeName(input)];
if (!in) {
VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input;
*ill_formed = true;
return {};
}
queue.push_back(in);
}
if (node->op() == "_Recv") {
const auto& attr = node->attr();
const NodeDef* send = name_to_send[attr.at("tensor_name").s()];
if (send) {
queue.push_back(send);
}
// Subgraph after partitioning may have either _Send or _Recv, not both.
// So, we do not set ill_formed for missing _Send.
}
}
return result;
}
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
GraphDef* output_graph,
const std::vector<string>& terminal_nodes) {
// Determines transitive fanin nodes from terminal nodes and add them to the
// output graph.
bool ill_formed = false;
std::vector<const NodeDef*> keep =
ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed);
if (ill_formed) {
// Some graph edges are invalid, or some of the feeds/fetch don't exist:
// let's be conservative and preserve the graph as is.
return errors::InvalidArgument("Invalid input graph.");
}
// Try to keep the nodes ordered somewhat topologically since this helps
// further optimizations perform better.
output_graph->mutable_node()->Reserve(keep.size());
for (int i = keep.size() - 1; i >= 0; --i) {
*output_graph->add_node() = *keep[i];
}
return Status::OK();
}
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace grappler {
// Computes the transitive fanin of the graph based on reachability from the
// specified terminal nodes. ill_formed will be set to true if the graph is
// deemed structurally invalid. Returns the set of nodes comprising the
// transitive fanin.
std::vector<const NodeDef*> ComputeTransitiveFanin(
const GraphDef& graph, const std::vector<string>& terminal_nodes,
bool* ill_formed);
// Creates output_graph from input_graph using the transitive fanin from the
// specified terminal nodes. Returns error if the input_graph is deemed
// structurally invalid.
Status SetTransitiveFaninGraph(const GraphDef& input_graph,
GraphDef* output_graph,
const std::vector<string>& terminal_nodes);
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_

View File

@ -0,0 +1,139 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/transitive_fanin.h"
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class TransitiveFaninTest : public ::testing::Test {
protected:
struct NodeConfig {
NodeConfig(string name, std::vector<string> inputs)
: name(std::move(name)), inputs(std::move(inputs)) {}
NodeConfig(string name, string op, std::vector<string> inputs)
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
string name;
string op;
std::vector<string> inputs;
};
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
GraphDef graph;
for (const NodeConfig& node : nodes) {
NodeDef node_def;
node_def.set_name(node.name);
node_def.set_op(node.op);
for (const string& input : node.inputs) {
node_def.add_input(input);
}
*graph.add_node() = std::move(node_def);
}
return graph;
}
};
TEST_F(TransitiveFaninTest, NoPruning) {
GraphDef graph = CreateGraph({
{"1", {"2"}}, //
{"2", {"3"}}, //
{"3", {"4"}}, //
{"4", {}} //
});
GraphDef output_graph;
const std::vector<string> terminal_nodes = {"1"};
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
NodeMap node_map(&output_graph);
ASSERT_TRUE(node_map.NodeExists("1"));
ASSERT_TRUE(node_map.NodeExists("2"));
ASSERT_TRUE(node_map.NodeExists("3"));
ASSERT_TRUE(node_map.NodeExists("4"));
}
TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromSingleTerminalNode) {
GraphDef graph = CreateGraph({
{"1", {"2"}}, //
{"2", {"3"}}, //
{"3", {"4"}}, //
{"4", {}}, //
{"5", {"1"}} //
});
GraphDef output_graph;
const std::vector<string> terminal_nodes = {"1"};
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
NodeMap node_map(&output_graph);
ASSERT_TRUE(node_map.NodeExists("1"));
ASSERT_TRUE(node_map.NodeExists("2"));
ASSERT_TRUE(node_map.NodeExists("3"));
ASSERT_TRUE(node_map.NodeExists("4"));
ASSERT_FALSE(node_map.NodeExists("5"));
}
TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromMultipleTerminalNodes) {
GraphDef graph = CreateGraph({
{"1", {"2"}}, //
{"2", {"3"}}, //
{"3", {"4"}}, //
{"4", {}}, //
{"5", {"2"}}, //
{"6", {"1"}} //
});
GraphDef output_graph;
const std::vector<string> terminal_nodes = {"1", "5"};
TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes));
NodeMap node_map(&output_graph);
ASSERT_TRUE(node_map.NodeExists("1"));
ASSERT_TRUE(node_map.NodeExists("2"));
ASSERT_TRUE(node_map.NodeExists("3"));
ASSERT_TRUE(node_map.NodeExists("4"));
ASSERT_TRUE(node_map.NodeExists("5"));
ASSERT_FALSE(node_map.NodeExists("6"));
}
TEST_F(TransitiveFaninTest, InvalidGraph) {
GraphDef graph = CreateGraph({
{"1", {"2"}}, //
{"2", {"3"}}, //
{"3", {"4"}}, //
{"4", {}}, //
{"5", {"6"}}, //
{"7", {"8"}} //
});
GraphDef output_graph;
const std::vector<string> terminal_nodes = {"1", "5"};
auto s = SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes);
EXPECT_FALSE(s.ok());
EXPECT_EQ(s.error_message(), "Invalid input graph.");
}
} // namespace
} // namespace grappler
} // namespace tensorflow