diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index ac7a45e4bf0..338cd9bb278 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -142,6 +142,7 @@ cc_library( ":utils", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/utils:transitive_fanin", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 80d01341d6f..effaa2a23f1 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/transitive_fanin.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { @@ -208,63 +209,5 @@ std::vector ComputeTransitiveFanin( return result; } -std::vector ComputeTransitiveFanin( - const GraphDef& graph, const std::vector& terminal_nodes, - bool* ill_formed) { - *ill_formed = false; - std::unordered_map name_to_node; - std::unordered_map name_to_send; - for (const auto& node : graph.node()) { - name_to_node[node.name()] = &node; - if (node.op() == "_Send") { - const auto& attr = node.attr(); - name_to_send[attr.at("tensor_name").s()] = &node; - } - } - - std::vector queue; - for (const string& root : terminal_nodes) { - const NodeDef* node = name_to_node[NodeName(root)]; - if (!node) { - *ill_formed = true; - VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root; - return {}; - } - queue.push_back(node); - } - - std::vector result; - std::unordered_set visited; - - while (!queue.empty()) { - const NodeDef* node = queue.back(); - queue.pop_back(); - if (!visited.insert(node).second) { - // The node has already been visited. - continue; - } - result.push_back(node); - for (const string& input : node->input()) { - const NodeDef* in = name_to_node[NodeName(input)]; - if (!in) { - VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input; - *ill_formed = true; - return {}; - } - queue.push_back(in); - } - if (node->op() == "_Recv") { - const auto& attr = node->attr(); - const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; - if (send) { - queue.push_back(send); - } - // Subgraph after partitioning may have either _Send or _Recv, not both. - // So, we do not set ill_formed for missing _Send. - } - } - return result; -} - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 1422ccfb3e7..1f85d679478 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -379,6 +379,7 @@ cc_library( "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:transitive_fanin", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", ], diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index cf9ce6fa32c..7cb96a21c5b 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/transitive_fanin.h" namespace tensorflow { namespace grappler { @@ -411,29 +412,6 @@ Status SplitIdentityNInputs(GraphDef* graph, return Status::OK(); } -Status SetTransitiveFaninGraph(const GraphDef& input_graph, - GraphDef* output_graph, - const std::vector& terminal_nodes) { - // Determines transitive fanin nodes from terminal nodes and add them to the - // output graph. - bool ill_formed = false; - std::vector keep = - ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed); - if (ill_formed) { - // Some graph edges are invalid, or some of the feeds/fetch don't exist: - // let's be conservative and preserve the graph as is. - return errors::InvalidArgument("Invalid input graph."); - } - // Try to keep the nodes ordered somewhat topologically since this helps - // further optimizations perform better. - output_graph->mutable_node()->Reserve(keep.size()); - for (int i = keep.size() - 1; i >= 0; --i) { - *output_graph->add_node() = *keep[i]; - } - - return Status::OK(); -} - Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { const std::unordered_set nodes_to_preserve = item.NodesToPreserve(); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index fef002b2788..fe07e769ef2 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -379,3 +379,31 @@ tf_cc_test( "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "transitive_fanin", + srcs = ["transitive_fanin.cc"], + hdrs = ["transitive_fanin.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:utils", + ], +) + +tf_cc_test( + name = "transitive_fanin_test", + srcs = ["transitive_fanin_test.cc"], + deps = [ + ":transitive_fanin", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:utils", + ], +) diff --git a/tensorflow/core/grappler/utils/transitive_fanin.cc b/tensorflow/core/grappler/utils/transitive_fanin.cc new file mode 100644 index 00000000000..dffba729be4 --- /dev/null +++ b/tensorflow/core/grappler/utils/transitive_fanin.cc @@ -0,0 +1,109 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/transitive_fanin.h" + +#include +#include + +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { + +std::vector ComputeTransitiveFanin( + const GraphDef& graph, const std::vector& terminal_nodes, + bool* ill_formed) { + *ill_formed = false; + std::unordered_map name_to_node; + std::unordered_map name_to_send; + for (const auto& node : graph.node()) { + name_to_node[node.name()] = &node; + if (node.op() == "_Send") { + const auto& attr = node.attr(); + name_to_send[attr.at("tensor_name").s()] = &node; + } + } + + std::vector queue; + for (const string& root : terminal_nodes) { + const NodeDef* node = name_to_node[NodeName(root)]; + if (!node) { + *ill_formed = true; + VLOG(2) << "ComputeTransitiveFanin: problem with root node: " << root; + return {}; + } + queue.push_back(node); + } + + std::vector result; + std::unordered_set visited; + + while (!queue.empty()) { + const NodeDef* node = queue.back(); + queue.pop_back(); + if (!visited.insert(node).second) { + // The node has already been visited. + continue; + } + result.push_back(node); + for (const string& input : node->input()) { + const NodeDef* in = name_to_node[NodeName(input)]; + if (!in) { + VLOG(2) << "ComputeTransitiveFanin: problem with node: " << input; + *ill_formed = true; + return {}; + } + queue.push_back(in); + } + if (node->op() == "_Recv") { + const auto& attr = node->attr(); + const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; + if (send) { + queue.push_back(send); + } + // Subgraph after partitioning may have either _Send or _Recv, not both. + // So, we do not set ill_formed for missing _Send. + } + } + return result; +} + +Status SetTransitiveFaninGraph(const GraphDef& input_graph, + GraphDef* output_graph, + const std::vector& terminal_nodes) { + // Determines transitive fanin nodes from terminal nodes and add them to the + // output graph. + bool ill_formed = false; + std::vector keep = + ComputeTransitiveFanin(input_graph, terminal_nodes, &ill_formed); + if (ill_formed) { + // Some graph edges are invalid, or some of the feeds/fetch don't exist: + // let's be conservative and preserve the graph as is. + return errors::InvalidArgument("Invalid input graph."); + } + // Try to keep the nodes ordered somewhat topologically since this helps + // further optimizations perform better. + output_graph->mutable_node()->Reserve(keep.size()); + for (int i = keep.size() - 1; i >= 0; --i) { + *output_graph->add_node() = *keep[i]; + } + + return Status::OK(); +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/transitive_fanin.h b/tensorflow/core/grappler/utils/transitive_fanin.h new file mode 100644 index 00000000000..1f89af4b69b --- /dev/null +++ b/tensorflow/core/grappler/utils/transitive_fanin.h @@ -0,0 +1,44 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { + +// Computes the transitive fanin of the graph based on reachability from the +// specified terminal nodes. ill_formed will be set to true if the graph is +// deemed structurally invalid. Returns the set of nodes comprising the +// transitive fanin. +std::vector ComputeTransitiveFanin( + const GraphDef& graph, const std::vector& terminal_nodes, + bool* ill_formed); + +// Creates output_graph from input_graph using the transitive fanin from the +// specified terminal nodes. Returns error if the input_graph is deemed +// structurally invalid. +Status SetTransitiveFaninGraph(const GraphDef& input_graph, + GraphDef* output_graph, + const std::vector& terminal_nodes); + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_TRANSITIVE_FANIN_H_ diff --git a/tensorflow/core/grappler/utils/transitive_fanin_test.cc b/tensorflow/core/grappler/utils/transitive_fanin_test.cc new file mode 100644 index 00000000000..94d98b93078 --- /dev/null +++ b/tensorflow/core/grappler/utils/transitive_fanin_test.cc @@ -0,0 +1,139 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/utils/transitive_fanin.h" + +#include + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TransitiveFaninTest : public ::testing::Test { + protected: + struct NodeConfig { + NodeConfig(string name, std::vector inputs) + : name(std::move(name)), inputs(std::move(inputs)) {} + NodeConfig(string name, string op, std::vector inputs) + : name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {} + + string name; + string op; + std::vector inputs; + }; + + static GraphDef CreateGraph(const std::vector& nodes) { + GraphDef graph; + + for (const NodeConfig& node : nodes) { + NodeDef node_def; + node_def.set_name(node.name); + node_def.set_op(node.op); + for (const string& input : node.inputs) { + node_def.add_input(input); + } + *graph.add_node() = std::move(node_def); + } + + return graph; + } +}; + +TEST_F(TransitiveFaninTest, NoPruning) { + GraphDef graph = CreateGraph({ + {"1", {"2"}}, // + {"2", {"3"}}, // + {"3", {"4"}}, // + {"4", {}} // + }); + + GraphDef output_graph; + const std::vector terminal_nodes = {"1"}; + TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes)); + NodeMap node_map(&output_graph); + ASSERT_TRUE(node_map.NodeExists("1")); + ASSERT_TRUE(node_map.NodeExists("2")); + ASSERT_TRUE(node_map.NodeExists("3")); + ASSERT_TRUE(node_map.NodeExists("4")); +} + +TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromSingleTerminalNode) { + GraphDef graph = CreateGraph({ + {"1", {"2"}}, // + {"2", {"3"}}, // + {"3", {"4"}}, // + {"4", {}}, // + {"5", {"1"}} // + }); + + GraphDef output_graph; + const std::vector terminal_nodes = {"1"}; + TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes)); + NodeMap node_map(&output_graph); + ASSERT_TRUE(node_map.NodeExists("1")); + ASSERT_TRUE(node_map.NodeExists("2")); + ASSERT_TRUE(node_map.NodeExists("3")); + ASSERT_TRUE(node_map.NodeExists("4")); + ASSERT_FALSE(node_map.NodeExists("5")); +} + +TEST_F(TransitiveFaninTest, PruneNodesUnreachableFromMultipleTerminalNodes) { + GraphDef graph = CreateGraph({ + {"1", {"2"}}, // + {"2", {"3"}}, // + {"3", {"4"}}, // + {"4", {}}, // + {"5", {"2"}}, // + {"6", {"1"}} // + }); + + GraphDef output_graph; + const std::vector terminal_nodes = {"1", "5"}; + TF_EXPECT_OK(SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes)); + NodeMap node_map(&output_graph); + ASSERT_TRUE(node_map.NodeExists("1")); + ASSERT_TRUE(node_map.NodeExists("2")); + ASSERT_TRUE(node_map.NodeExists("3")); + ASSERT_TRUE(node_map.NodeExists("4")); + ASSERT_TRUE(node_map.NodeExists("5")); + ASSERT_FALSE(node_map.NodeExists("6")); +} + +TEST_F(TransitiveFaninTest, InvalidGraph) { + GraphDef graph = CreateGraph({ + {"1", {"2"}}, // + {"2", {"3"}}, // + {"3", {"4"}}, // + {"4", {}}, // + {"5", {"6"}}, // + {"7", {"8"}} // + }); + + GraphDef output_graph; + const std::vector terminal_nodes = {"1", "5"}; + auto s = SetTransitiveFaninGraph(graph, &output_graph, terminal_nodes); + EXPECT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), "Invalid input graph."); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow