Generalize assumptions in IdentifyLoops and StronglyConnectedComponents.

PiperOrigin-RevId: 203066657
This commit is contained in:
A. Unique TensorFlower 2018-07-02 22:28:45 -07:00 committed by TensorFlower Gardener
parent f2a44bc813
commit e9a7b5e0e6

View File

@ -142,9 +142,13 @@ void StronglyConnectedComponents(
// Create a list of top-level parents (add them to object queue)
// Also create a mapping from nodes to their children.
// Inputs might not be present if called on a subgraph.
for (const NodeDef& node : graph.node()) {
for (const string& input : node.input()) {
name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]);
auto it = name_to_data.find(NodeName(input));
if (it != name_to_data.end()) {
it->second->children.push_back(node_to_data[&node]);
}
}
}
@ -202,10 +206,12 @@ int IdentifyLoops(const GraphDef& graph,
const std::vector<const NodeDef*>& component_nodes = component.second;
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
GraphDef subgraph;
std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
for (const auto& component_node : component_nodes) {
NodeDef* node = subgraph.add_node();
*node = *component_node;
subgraph_mapping[node] = component_node;
if (IsNextIteration(*node)) {
CHECK_EQ(1, node->input_size());
next_iter_nodes.emplace_back(node, node->input(0));
@ -227,13 +233,13 @@ int IdentifyLoops(const GraphDef& graph,
int num_components = 0;
std::unordered_map<const NodeDef*, int> components;
StronglyConnectedComponents(subgraph, &components, &num_components);
CHECK_EQ(1, num_components);
CHECK_GE(num_components, 1);
for (const auto it : components) {
int id = it.second;
if (id < 0) {
continue;
}
(*loops)[it.first].push_back(loop_id);
(*loops)[subgraph_mapping[it.first]].push_back(loop_id);
}
++loop_id;
}