Address review comments.

- Make GetFrameBasedTopologicalOrder a member function of DeadnessAnalysisImpl.
- Use string_view when appropriate.
- Change force_pessimistic to enable_optimistic. And make it a non-default
  parameter.
- Rename num_enters/num_exits to num_enters_for_frame/num_exits_for_frame
  in GetFrameBasedTopologicalOrder().
This commit is contained in:
Trent Lo 2019-05-21 17:04:21 -07:00
parent c58c4aa634
commit a298e285fe
3 changed files with 126 additions and 125 deletions

View File

@ -735,7 +735,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
explicit DeadnessAnalysisImpl(const Graph* graph) explicit DeadnessAnalysisImpl(const Graph* graph)
: graph_(*graph), vlog_(VLOG_IS_ON(2)) {} : graph_(*graph), vlog_(VLOG_IS_ON(2)) {}
Status Populate(bool force_pessimistic = false); Status Populate(bool enable_optimistic);
Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode, Status PopulateFrame(absl::Span<Node* const> topo, bool use_optimistic_mode,
bool* success); bool* success);
StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor( StatusOr<DeadnessAnalysis::DeadnessPredicate> GetPredicateFor(
@ -786,12 +786,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status HandleNode(Node* n, std::vector<bool>* should_revisit, Status HandleNode(Node* n, std::vector<bool>* should_revisit,
bool use_optimistic_mode = false); bool use_optimistic_mode = false);
Status GetFrameBasedTopologicalOrder(std::vector<Node*>* order);
const Graph& graph_; const Graph& graph_;
absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_; absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_; PredicateFactory predicate_factory_;
std::vector<ControlFlowInfo> control_flow_info_; std::vector<ControlFlowInfo> control_flow_info_;
bool vlog_; bool vlog_;
absl::flat_hash_map<string, Node*> frame_to_merge_node_; absl::flat_hash_map<absl::string_view, Node*> frame_to_merge_node_;
}; };
TensorId InputEdgeToTensorId(const Edge* e) { TensorId InputEdgeToTensorId(const Edge* e) {
@ -975,109 +977,6 @@ Status GetRootFrame(const Node* n, absl::Span<const ControlFlowInfo> cfi_infos,
*frame = cfi_iter->frame_name; *frame = cfi_iter->frame_name;
return Status::OK(); return Status::OK();
} }
// Compute a special topological order for the Graph, where nodes having the
// same root frame are placed adjacent to each other. The traversal uses a
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
// many inputs of each node are ready; a node is ready to be scheduled if all
// of its inputs are ready.
// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
Status GetFrameBasedTopologicalOrder(const Graph* g,
absl::Span<const ControlFlowInfo> cf_infos,
std::vector<Node*>* order) {
absl::flat_hash_map<absl::string_view, size_t> num_enters;
absl::flat_hash_map<absl::string_view, size_t> num_exits;
std::vector<size_t> num_ready_inputs(g->num_node_ids(), 0);
Node* src_node = g->source_node();
for (const auto* node : g->op_nodes()) {
const ControlFlowInfo& cf = cf_infos[node->id()];
bool is_root_level = cf.parent_frame == src_node;
if (IsEnter(node) && is_root_level) {
// Since we care only the root-level frame, full frame names are the same
// as frame names.
++num_enters[cf.frame_name];
} else if (IsExit(node) && is_root_level) {
++num_exits[cf.frame_name];
}
// Edge NextIteration->Merge is counted before starting the traveral to
// break the backedges.
if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) {
++num_ready_inputs[node->id()];
}
}
}
}
std::deque<Node*> ready;
ready.push_back(src_node);
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
// the ready enters/exits before they are moved to the `ready` queue for
// controlling the start and end of a processing frame.
absl::flat_hash_map<string, std::vector<Node*>> ready_enters_per_frame;
// Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough.
std::vector<Node*> ready_exits;
while (!ready.empty()) {
Node* curr_node = ready.front();
ready.pop_front();
VLOG(4) << "Visiting " << curr_node->name();
order->push_back(curr_node);
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
if (IsNextIteration(curr_node) && IsMerge(out)) {
// Edge NextIteration->Merge has been counted.
continue;
}
++num_ready_inputs[out->id()];
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
bool is_root_level = cf_infos[out_id].parent_frame == src_node;
string frame_name = cf_infos[out_id].frame_name;
if (IsEnter(out) && is_root_level) {
ready_enters_per_frame[frame_name].push_back(out);
} else if (IsExit(out) && is_root_level) {
ready_exits.push_back(out);
} else {
ready.push_back(out);
}
}
if (ready.empty()) {
// Try moving nodes from ready_enters_per_frame and read_exits to `ready`.
if (!ready_exits.empty()) {
// If there are nodes in ready_exits we must process them before
// processing ready_enters_per_frame to make sure all nodes in the
// currently processing frame are visited before starting processing
// other frames.
string frame_name = cf_infos[ready_exits.front()->id()].frame_name;
CHECK_EQ(ready_exits.size(), num_exits[frame_name]);
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
ready_exits.clear();
} else {
// Otherwise, try moving nodes from ready_enters to `ready`.
for (auto iter = ready_enters_per_frame.begin();
iter != ready_enters_per_frame.end(); ++iter) {
string frame_name = iter->first;
const std::vector<Node*>& ready_enters = iter->second;
if (ready_enters.size() == num_enters[frame_name]) {
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
ready_enters_per_frame.erase(iter);
break;
}
}
}
}
}
CHECK(ready_enters_per_frame.empty() && ready_exits.empty());
return Status::OK();
}
} // namespace } // namespace
Status DeadnessAnalysisImpl::HandleMerge(Node* n, Status DeadnessAnalysisImpl::HandleMerge(Node* n,
@ -1107,7 +1006,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n,
// frame as the representative Merge node. It is just convenient and // frame as the representative Merge node. It is just convenient and
// does not affect the result after pattern-matching into the // does not affect the result after pattern-matching into the
// AndRecurrence form. // AndRecurrence form.
string frame_name = control_flow_info_[n->id()].frame_name; absl::string_view frame_name = control_flow_info_[n->id()].frame_name;
auto insert_result = frame_to_merge_node_.insert({frame_name, n}); auto insert_result = frame_to_merge_node_.insert({frame_name, n});
Node* representative = insert_result.first->second; Node* representative = insert_result.first->second;
TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate(
@ -1219,6 +1118,109 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
return Status::OK(); return Status::OK();
} }
// Compute a special topological order for the Graph, where nodes having the
// same root frame are placed adjacent to each other. The traversal uses a
// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how
// many inputs of each node are ready; a node is ready to be scheduled if all
// of its inputs are ready.
// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details.
Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder(
std::vector<Node*>* order) {
absl::flat_hash_map<absl::string_view, size_t> num_enters_for_frame;
absl::flat_hash_map<absl::string_view, size_t> num_exits_for_frame;
std::vector<size_t> num_ready_inputs(graph_.num_node_ids(), 0);
Node* src_node = graph_.source_node();
for (const auto* node : graph_.op_nodes()) {
const ControlFlowInfo& cf = control_flow_info_[node->id()];
bool is_root_level = cf.parent_frame == src_node;
if (IsEnter(node) && is_root_level) {
// Since we care only the root-level frame, full frame names are the same
// as frame names.
++num_enters_for_frame[cf.frame_name];
} else if (IsExit(node) && is_root_level) {
++num_exits_for_frame[cf.frame_name];
}
// Edge NextIteration->Merge is counted before starting the traveral to
// break the backedges.
if (IsMerge(node)) {
for (const Edge* e : node->in_edges()) {
if (IsNextIteration(e->src())) {
++num_ready_inputs[node->id()];
}
}
}
}
std::deque<Node*> ready;
ready.push_back(src_node);
// ready_enters_per_frame and ready_exits serve as a staging area to buffer
// the ready enters/exits before they are moved to the `ready` queue for
// controlling the start and end of a processing frame.
absl::flat_hash_map<string, std::vector<Node*>> ready_enters_per_frame;
// Exit nodes shall all be from the same frame, as we process a frame at a
// time. So, one vector is enough.
std::vector<Node*> ready_exits;
while (!ready.empty()) {
Node* curr_node = ready.front();
ready.pop_front();
VLOG(4) << "Visiting " << curr_node->name();
order->push_back(curr_node);
for (const Edge* out_edge : curr_node->out_edges()) {
Node* out = out_edge->dst();
int out_id = out->id();
if (IsNextIteration(curr_node) && IsMerge(out)) {
// Edge NextIteration->Merge has been counted.
continue;
}
++num_ready_inputs[out->id()];
if (!out->IsOp()) continue; // Skip Sink/Source nodes.
if (num_ready_inputs[out->id()] != out->in_edges().size()) continue;
bool is_root_level = control_flow_info_[out_id].parent_frame == src_node;
absl::string_view frame_name = control_flow_info_[out_id].frame_name;
if (IsEnter(out) && is_root_level) {
ready_enters_per_frame[frame_name].push_back(out);
} else if (IsExit(out) && is_root_level) {
ready_exits.push_back(out);
} else {
ready.push_back(out);
}
}
if (ready.empty()) {
// Try moving nodes from ready_enters_per_frame and read_exits to `ready`.
if (!ready_exits.empty()) {
// If there are nodes in ready_exits we must process them before
// processing ready_enters_per_frame to make sure all nodes in the
// currently processing frame are visited before starting processing
// other frames.
absl::string_view frame_name =
control_flow_info_[ready_exits.front()->id()].frame_name;
CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]);
ready.insert(ready.end(), ready_exits.begin(), ready_exits.end());
ready_exits.clear();
} else {
// Otherwise, try moving nodes from ready_enters to `ready`.
for (auto iter = ready_enters_per_frame.begin();
iter != ready_enters_per_frame.end(); ++iter) {
absl::string_view frame_name = iter->first;
const std::vector<Node*>& ready_enters = iter->second;
if (ready_enters.size() == num_enters_for_frame[frame_name]) {
ready.insert(ready.end(), ready_enters.begin(), ready_enters.end());
ready_enters_per_frame.erase(iter);
break;
}
}
}
}
}
CHECK(ready_enters_per_frame.empty() && ready_exits.empty());
return Status::OK();
}
// We populate the nodes along a special topological order where nodes having // We populate the nodes along a special topological order where nodes having
// the same root frame are placed adjacent to each other. This grouping enables // the same root frame are placed adjacent to each other. This grouping enables
// processing the graph per root frame at a time and guarantees that when a root // processing the graph per root frame at a time and guarantees that when a root
@ -1231,7 +1233,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n,
// (root) frame. Note that we don't separate while loops belonging to the same // (root) frame. Note that we don't separate while loops belonging to the same
// nested while, as there is no clean cut for separating them in the topological // nested while, as there is no clean cut for separating them in the topological
// order. // order.
Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) {
std::vector<string> unreachable_nodes; std::vector<string> unreachable_nodes;
// Compute the loop structure of the graph. // Compute the loop structure of the graph.
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -1251,8 +1253,7 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
} }
std::vector<Node*> topo; std::vector<Node*> topo;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo));
GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &topo));
size_t frame_start = 0; size_t frame_start = 0;
while (frame_start < topo.size()) { while (frame_start < topo.size()) {
@ -1277,7 +1278,7 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) {
// First, try the optimistic mode. // First, try the optimistic mode.
bool success = false; bool success = false;
if (!force_pessimistic && !cur_frame_name.empty()) { if (enable_optimistic && !cur_frame_name.empty()) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success)); PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success));
} }
@ -1349,7 +1350,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
// predicates. // predicates.
if (use_optimistic_mode) { if (use_optimistic_mode) {
bool is_converged = true; bool is_converged = true;
absl::flat_hash_map<string, Predicate*> frame_to_pred; absl::flat_hash_map<absl::string_view, Predicate*> frame_to_pred;
for (Node* n : topo) { for (Node* n : topo) {
if (!n->IsMerge()) { if (!n->IsMerge()) {
continue; continue;
@ -1361,7 +1362,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span<Node* const> topo,
continue; continue;
} }
Node* merge = n; Node* merge = n;
string frame_name = control_flow_info_[merge->id()].frame_name; absl::string_view frame_name = control_flow_info_[merge->id()].frame_name;
auto it = predicate_map_.find(TensorId(merge->name(), 0)); auto it = predicate_map_.find(TensorId(merge->name(), 0));
Predicate* merge_pred = it->second; Predicate* merge_pred = it->second;
if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) { if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) {
@ -1441,7 +1442,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) { const Graph& graph, std::unique_ptr<DeadnessAnalysis>* result) {
std::unique_ptr<DeadnessAnalysisImpl> analysis( std::unique_ptr<DeadnessAnalysisImpl> analysis(
new DeadnessAnalysisImpl(&graph)); new DeadnessAnalysisImpl(&graph));
TF_RETURN_IF_ERROR(analysis->Populate()); TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true));
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
analysis->Print(); analysis->Print();
@ -1463,9 +1464,9 @@ DeadnessAnalysisImpl::PredicateMapAsString() const {
namespace deadness_analysis_internal { namespace deadness_analysis_internal {
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool force_pessimistic) { bool enable_optimistic) {
DeadnessAnalysisImpl impl(&graph); DeadnessAnalysisImpl impl(&graph);
TF_RETURN_IF_ERROR(impl.Populate(force_pessimistic)); TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic));
*out_predicate_map = impl.PredicateMapAsString(); *out_predicate_map = impl.PredicateMapAsString();
return Status::OK(); return Status::OK();
} }

View File

@ -26,7 +26,7 @@ namespace deadness_analysis_internal {
// testing purposes only. // testing purposes only.
using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>; using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map,
bool force_pessimistic = false); bool enable_optimistic = true);
} // namespace deadness_analysis_internal } // namespace deadness_analysis_internal
} // namespace tensorflow } // namespace tensorflow

View File

@ -639,7 +639,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/false)); /*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -653,7 +653,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/true)); /*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -683,7 +683,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/false)); /*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"{#true,&,*iv0/cond:0}<frame>"); "{#true,&,*iv0/cond:0}<frame>");
@ -691,7 +691,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/true)); /*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)],
"div0/iv:0"); "div0/iv:0");
@ -746,7 +746,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/false)); /*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>"); "{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -755,7 +755,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
"{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/" "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
"cond:0}<inner_loop;outer_loop>"); "cond:0}<inner_loop;outer_loop>");
// force_pessimistic = true or not should produce the same results because // enable_optimistic = true or not should produce the same results because
// of fallback. However, note that the order of iv_inner/cond:0 and // of fallback. However, note that the order of iv_inner/cond:0 and
// iv_inner/iv:0 is different because the optimistic approach does not // iv_inner/iv:0 is different because the optimistic approach does not
// create predicates for all merges and it can change the predicate id and // create predicates for all merges and it can change the predicate id and
@ -772,7 +772,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/true)); /*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>"); "{#true,&,*iv_outer/cond:0}<outer_loop>");
@ -925,7 +925,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/false)); /*enable_optimistic=*/true));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");
@ -942,7 +942,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) {
{ {
PredicateMapTy predicate_map; PredicateMapTy predicate_map;
TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map,
/*force_pessimistic=*/true)); /*enable_optimistic=*/false));
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>"); "{#true,&,*iv0/cond:0}<loop>");