diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index fd3894553b9..91c159d7937 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -18,9 +18,12 @@ cc_library( hdrs = ["scc.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:utils", ], ) @@ -28,6 +31,9 @@ cc_test( name = "scc_test", size = "small", srcs = ["scc_test.cc"], + data = [ + "//tensorflow/core/grappler/costs:graph_properties_testdata", + ], deps = [ ":scc", "//tensorflow/core:lib_proto_parsing", diff --git a/tensorflow/core/grappler/utils/scc.cc b/tensorflow/core/grappler/utils/scc.cc index 6568e99aa3f..f2a6507d94a 100644 --- a/tensorflow/core/grappler/utils/scc.cc +++ b/tensorflow/core/grappler/utils/scc.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { @@ -163,13 +164,83 @@ void StronglyConnectedComponents( DCHECK(component.second < *num_components); counts_per_component[component.second]++; } + bool has_single_element_component = false; for (auto& component : *components) { if (counts_per_component[component.second] == 1) { component.second = -1; (*num_components)--; + has_single_element_component = true; } } - (*num_components) += 1; + if (has_single_element_component) { + (*num_components) += 1; + } +} + +int IdentifyLoops(const GraphDef& graph, + std::unordered_map>* loops) { + int num_components = 0; + std::unordered_map components; + StronglyConnectedComponents(graph, &components, &num_components); + if (num_components <= 1) { + if (!components.empty() && components.begin()->second == -1) { + return 0; + } + } + + std::unordered_map> component_ids; + for (const auto it : components) { + int id = it.second; + if (id < 0) { + continue; + } + component_ids[id].push_back(it.first); + } + + int loop_id = 0; + for (const auto& component : component_ids) { + const std::vector& component_nodes = component.second; + std::vector> next_iter_nodes; + GraphDef subgraph; + + for (const auto& component_node : component_nodes) { + NodeDef* node = subgraph.add_node(); + *node = *component_node; + if (IsNextIteration(*node)) { + CHECK_EQ(1, node->input_size()); + next_iter_nodes.emplace_back(node, node->input(0)); + } + } + if (next_iter_nodes.size() == 1) { + for (const auto& component_node : component_nodes) { + (*loops)[component_node].push_back(loop_id); + } + ++loop_id; + } else { + for (int i = 0; i < next_iter_nodes.size(); ++i) { + for (int j = 0; j < next_iter_nodes.size(); ++j) { + next_iter_nodes[j].first->clear_input(); + if (i == j) { + *next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second; + } + } + int num_components = 0; + std::unordered_map components; + StronglyConnectedComponents(subgraph, &components, &num_components); + CHECK_EQ(1, num_components); + for (const auto it : components) { + int id = it.second; + if (id < 0) { + continue; + } + (*loops)[it.first].push_back(loop_id); + } + ++loop_id; + } + } + } + + return loop_id; } } // namespace grappler diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h index 8b0577763d6..4e46169971a 100644 --- a/tensorflow/core/grappler/utils/scc.h +++ b/tensorflow/core/grappler/utils/scc.h @@ -18,6 +18,8 @@ limitations under the License. #include #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/grappler/inputs/utils.h" +#include "tensorflow/core/lib/io/path.h" namespace tensorflow { namespace grappler { @@ -32,6 +34,12 @@ void StronglyConnectedComponents( const GraphDef& graph, std::unordered_map* components, int* num_ids); +// Returns the number of individual loops present in the graph, and populate the +// 'loops' argument with the collection of loops (denoted by their loop ids) a +// node is part of. Loops ids are arbitrary. +int IdentifyLoops(const GraphDef& graph, + std::unordered_map>* loops); + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/scc_test.cc b/tensorflow/core/grappler/utils/scc_test.cc index 3185cbe2326..b5fa76ef8bf 100644 --- a/tensorflow/core/grappler/utils/scc_test.cc +++ b/tensorflow/core/grappler/utils/scc_test.cc @@ -406,5 +406,28 @@ versions { } } +TEST_F(SCCTest, NestedLoops) { + GrapplerItem item; + string filename = io::JoinPath( + testing::TensorFlowSrcRoot(), + "core/grappler/costs/graph_properties_testdata/nested_loop.pbtxt"); + TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); + + for (const auto& node : item.graph.node()) { + std::cout << node.DebugString() << std::endl; + } + + std::unordered_map> loops; + int num_loops = IdentifyLoops(item.graph, &loops); + EXPECT_EQ(4, num_loops); + for (const auto& node_info : loops) { + std::cout << node_info.first->name() << " ["; + for (int i : node_info.second) { + std::cout << " " << i; + } + std::cout << "]" << std::endl; + } +} + } // namespace grappler } // namespace tensorflow