Individual loop identification

PiperOrigin-RevId: 166076525
This commit is contained in:
Benoit Steiner 2017-08-22 10:14:17 -07:00 committed by TensorFlower Gardener
parent 2109a2b3d9
commit 0bb9ddd426
4 changed files with 109 additions and 1 deletions

View File

@ -18,9 +18,12 @@ cc_library(
hdrs = ["scc.h"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:utils",
],
)
@ -28,6 +31,9 @@ cc_test(
name = "scc_test",
size = "small",
srcs = ["scc_test.cc"],
data = [
"//tensorflow/core/grappler/costs:graph_properties_testdata",
],
deps = [
":scc",
"//tensorflow/core:lib_proto_parsing",

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
@ -163,13 +164,83 @@ void StronglyConnectedComponents(
DCHECK(component.second < *num_components);
counts_per_component[component.second]++;
}
bool has_single_element_component = false;
for (auto& component : *components) {
if (counts_per_component[component.second] == 1) {
component.second = -1;
(*num_components)--;
has_single_element_component = true;
}
}
(*num_components) += 1;
if (has_single_element_component) {
(*num_components) += 1;
}
}
int IdentifyLoops(const GraphDef& graph,
std::unordered_map<const NodeDef*, std::vector<int>>* loops) {
int num_components = 0;
std::unordered_map<const NodeDef*, int> components;
StronglyConnectedComponents(graph, &components, &num_components);
if (num_components <= 1) {
if (!components.empty() && components.begin()->second == -1) {
return 0;
}
}
std::unordered_map<int, std::vector<const NodeDef*>> component_ids;
for (const auto it : components) {
int id = it.second;
if (id < 0) {
continue;
}
component_ids[id].push_back(it.first);
}
int loop_id = 0;
for (const auto& component : component_ids) {
const std::vector<const NodeDef*>& component_nodes = component.second;
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
GraphDef subgraph;
for (const auto& component_node : component_nodes) {
NodeDef* node = subgraph.add_node();
*node = *component_node;
if (IsNextIteration(*node)) {
CHECK_EQ(1, node->input_size());
next_iter_nodes.emplace_back(node, node->input(0));
}
}
if (next_iter_nodes.size() == 1) {
for (const auto& component_node : component_nodes) {
(*loops)[component_node].push_back(loop_id);
}
++loop_id;
} else {
for (int i = 0; i < next_iter_nodes.size(); ++i) {
for (int j = 0; j < next_iter_nodes.size(); ++j) {
next_iter_nodes[j].first->clear_input();
if (i == j) {
*next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second;
}
}
int num_components = 0;
std::unordered_map<const NodeDef*, int> components;
StronglyConnectedComponents(subgraph, &components, &num_components);
CHECK_EQ(1, num_components);
for (const auto it : components) {
int id = it.second;
if (id < 0) {
continue;
}
(*loops)[it.first].push_back(loop_id);
}
++loop_id;
}
}
}
return loop_id;
}
} // namespace grappler

View File

@ -18,6 +18,8 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/inputs/utils.h"
#include "tensorflow/core/lib/io/path.h"
namespace tensorflow {
namespace grappler {
@ -32,6 +34,12 @@ void StronglyConnectedComponents(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
int* num_ids);
// Returns the number of individual loops present in the graph, and populate the
// 'loops' argument with the collection of loops (denoted by their loop ids) a
// node is part of. Loops ids are arbitrary.
int IdentifyLoops(const GraphDef& graph,
std::unordered_map<const NodeDef*, std::vector<int>>* loops);
} // namespace grappler
} // namespace tensorflow

View File

@ -406,5 +406,28 @@ versions {
}
}
TEST_F(SCCTest, NestedLoops) {
GrapplerItem item;
string filename = io::JoinPath(
testing::TensorFlowSrcRoot(),
"core/grappler/costs/graph_properties_testdata/nested_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
for (const auto& node : item.graph.node()) {
std::cout << node.DebugString() << std::endl;
}
std::unordered_map<const NodeDef*, std::vector<int>> loops;
int num_loops = IdentifyLoops(item.graph, &loops);
EXPECT_EQ(4, num_loops);
for (const auto& node_info : loops) {
std::cout << node_info.first->name() << " [";
for (int i : node_info.second) {
std::cout << " " << i;
}
std::cout << "]" << std::endl;
}
}
} // namespace grappler
} // namespace tensorflow