Individual loop identification
PiperOrigin-RevId: 166076525
This commit is contained in:
parent
2109a2b3d9
commit
0bb9ddd426
@ -18,9 +18,12 @@ cc_library(
|
||||
hdrs = ["scc.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:utils",
|
||||
],
|
||||
)
|
||||
|
||||
@ -28,6 +31,9 @@ cc_test(
|
||||
name = "scc_test",
|
||||
size = "small",
|
||||
srcs = ["scc_test.cc"],
|
||||
data = [
|
||||
"//tensorflow/core/grappler/costs:graph_properties_testdata",
|
||||
],
|
||||
deps = [
|
||||
":scc",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -163,13 +164,83 @@ void StronglyConnectedComponents(
|
||||
DCHECK(component.second < *num_components);
|
||||
counts_per_component[component.second]++;
|
||||
}
|
||||
bool has_single_element_component = false;
|
||||
for (auto& component : *components) {
|
||||
if (counts_per_component[component.second] == 1) {
|
||||
component.second = -1;
|
||||
(*num_components)--;
|
||||
has_single_element_component = true;
|
||||
}
|
||||
}
|
||||
(*num_components) += 1;
|
||||
if (has_single_element_component) {
|
||||
(*num_components) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
int IdentifyLoops(const GraphDef& graph,
|
||||
std::unordered_map<const NodeDef*, std::vector<int>>* loops) {
|
||||
int num_components = 0;
|
||||
std::unordered_map<const NodeDef*, int> components;
|
||||
StronglyConnectedComponents(graph, &components, &num_components);
|
||||
if (num_components <= 1) {
|
||||
if (!components.empty() && components.begin()->second == -1) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::unordered_map<int, std::vector<const NodeDef*>> component_ids;
|
||||
for (const auto it : components) {
|
||||
int id = it.second;
|
||||
if (id < 0) {
|
||||
continue;
|
||||
}
|
||||
component_ids[id].push_back(it.first);
|
||||
}
|
||||
|
||||
int loop_id = 0;
|
||||
for (const auto& component : component_ids) {
|
||||
const std::vector<const NodeDef*>& component_nodes = component.second;
|
||||
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
|
||||
GraphDef subgraph;
|
||||
|
||||
for (const auto& component_node : component_nodes) {
|
||||
NodeDef* node = subgraph.add_node();
|
||||
*node = *component_node;
|
||||
if (IsNextIteration(*node)) {
|
||||
CHECK_EQ(1, node->input_size());
|
||||
next_iter_nodes.emplace_back(node, node->input(0));
|
||||
}
|
||||
}
|
||||
if (next_iter_nodes.size() == 1) {
|
||||
for (const auto& component_node : component_nodes) {
|
||||
(*loops)[component_node].push_back(loop_id);
|
||||
}
|
||||
++loop_id;
|
||||
} else {
|
||||
for (int i = 0; i < next_iter_nodes.size(); ++i) {
|
||||
for (int j = 0; j < next_iter_nodes.size(); ++j) {
|
||||
next_iter_nodes[j].first->clear_input();
|
||||
if (i == j) {
|
||||
*next_iter_nodes[j].first->add_input() = next_iter_nodes[j].second;
|
||||
}
|
||||
}
|
||||
int num_components = 0;
|
||||
std::unordered_map<const NodeDef*, int> components;
|
||||
StronglyConnectedComponents(subgraph, &components, &num_components);
|
||||
CHECK_EQ(1, num_components);
|
||||
for (const auto it : components) {
|
||||
int id = it.second;
|
||||
if (id < 0) {
|
||||
continue;
|
||||
}
|
||||
(*loops)[it.first].push_back(loop_id);
|
||||
}
|
||||
++loop_id;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return loop_id;
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
|
@ -18,6 +18,8 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/grappler/inputs/utils.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -32,6 +34,12 @@ void StronglyConnectedComponents(
|
||||
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
|
||||
int* num_ids);
|
||||
|
||||
// Returns the number of individual loops present in the graph, and populate the
|
||||
// 'loops' argument with the collection of loops (denoted by their loop ids) a
|
||||
// node is part of. Loops ids are arbitrary.
|
||||
int IdentifyLoops(const GraphDef& graph,
|
||||
std::unordered_map<const NodeDef*, std::vector<int>>* loops);
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -406,5 +406,28 @@ versions {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(SCCTest, NestedLoops) {
|
||||
GrapplerItem item;
|
||||
string filename = io::JoinPath(
|
||||
testing::TensorFlowSrcRoot(),
|
||||
"core/grappler/costs/graph_properties_testdata/nested_loop.pbtxt");
|
||||
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
|
||||
|
||||
for (const auto& node : item.graph.node()) {
|
||||
std::cout << node.DebugString() << std::endl;
|
||||
}
|
||||
|
||||
std::unordered_map<const NodeDef*, std::vector<int>> loops;
|
||||
int num_loops = IdentifyLoops(item.graph, &loops);
|
||||
EXPECT_EQ(4, num_loops);
|
||||
for (const auto& node_info : loops) {
|
||||
std::cout << node_info.first->name() << " [";
|
||||
for (int i : node_info.second) {
|
||||
std::cout << " " << i;
|
||||
}
|
||||
std::cout << "]" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user