Generalize assumptions in IdentifyLoops and StronglyConnectedComponents.
PiperOrigin-RevId: 203066657
This commit is contained in:
parent
f2a44bc813
commit
e9a7b5e0e6
@ -142,9 +142,13 @@ void StronglyConnectedComponents(
|
|||||||
|
|
||||||
// Create a list of top-level parents (add them to object queue)
|
// Create a list of top-level parents (add them to object queue)
|
||||||
// Also create a mapping from nodes to their children.
|
// Also create a mapping from nodes to their children.
|
||||||
|
// Inputs might not be present if called on a subgraph.
|
||||||
for (const NodeDef& node : graph.node()) {
|
for (const NodeDef& node : graph.node()) {
|
||||||
for (const string& input : node.input()) {
|
for (const string& input : node.input()) {
|
||||||
name_to_data[NodeName(input)]->children.push_back(node_to_data[&node]);
|
auto it = name_to_data.find(NodeName(input));
|
||||||
|
if (it != name_to_data.end()) {
|
||||||
|
it->second->children.push_back(node_to_data[&node]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,10 +206,12 @@ int IdentifyLoops(const GraphDef& graph,
|
|||||||
const std::vector<const NodeDef*>& component_nodes = component.second;
|
const std::vector<const NodeDef*>& component_nodes = component.second;
|
||||||
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
|
std::vector<std::pair<NodeDef*, string>> next_iter_nodes;
|
||||||
GraphDef subgraph;
|
GraphDef subgraph;
|
||||||
|
std::unordered_map<const NodeDef*, const NodeDef*> subgraph_mapping;
|
||||||
|
|
||||||
for (const auto& component_node : component_nodes) {
|
for (const auto& component_node : component_nodes) {
|
||||||
NodeDef* node = subgraph.add_node();
|
NodeDef* node = subgraph.add_node();
|
||||||
*node = *component_node;
|
*node = *component_node;
|
||||||
|
subgraph_mapping[node] = component_node;
|
||||||
if (IsNextIteration(*node)) {
|
if (IsNextIteration(*node)) {
|
||||||
CHECK_EQ(1, node->input_size());
|
CHECK_EQ(1, node->input_size());
|
||||||
next_iter_nodes.emplace_back(node, node->input(0));
|
next_iter_nodes.emplace_back(node, node->input(0));
|
||||||
@ -227,13 +233,13 @@ int IdentifyLoops(const GraphDef& graph,
|
|||||||
int num_components = 0;
|
int num_components = 0;
|
||||||
std::unordered_map<const NodeDef*, int> components;
|
std::unordered_map<const NodeDef*, int> components;
|
||||||
StronglyConnectedComponents(subgraph, &components, &num_components);
|
StronglyConnectedComponents(subgraph, &components, &num_components);
|
||||||
CHECK_EQ(1, num_components);
|
CHECK_GE(num_components, 1);
|
||||||
for (const auto it : components) {
|
for (const auto it : components) {
|
||||||
int id = it.second;
|
int id = it.second;
|
||||||
if (id < 0) {
|
if (id < 0) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
(*loops)[it.first].push_back(loop_id);
|
(*loops)[subgraph_mapping[it.first]].push_back(loop_id);
|
||||||
}
|
}
|
||||||
++loop_id;
|
++loop_id;
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user