From 3f7afc54bdc08fe923d3cbd8ab4afc39c7d54f4f Mon Sep 17 00:00:00 2001 From: Trent Lo Date: Fri, 17 May 2019 23:34:00 -0700 Subject: [PATCH] Address review comments. --- tensorflow/compiler/jit/deadness_analysis.cc | 95 ++++++++++--------- .../compiler/jit/deadness_analysis_test.cc | 16 ++-- 2 files changed, 59 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index b98e479ca07..97dc2f979fa 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" #include "tensorflow/compiler/jit/deadness_analysis_internal.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -648,9 +649,8 @@ Predicate* PredicateFactory::MakeAndOrImpl( std::vector to_add; for (Predicate* op : simplified_ops) { if (op->kind() == Predicate::Kind::kAndRecurrence) { - AndRecurrencePredicate* and_rec = - static_cast(op); - if (negated_ops.count(and_rec->step())) { + auto* and_rec = static_cast(op); + if (negated_ops.contains(and_rec->step())) { // Remove and_rec and ~X and insert S. Note that checking the // existence of ~X through negated_ops is sufficient since it makes // sure the predicate is in the input operands. It does not need to @@ -663,7 +663,7 @@ Predicate* PredicateFactory::MakeAndOrImpl( } auto it = simplified_ops.begin(); while (it != simplified_ops.end()) { - if (to_remove.count(*it)) { + if (to_remove.contains(*it)) { it = simplified_ops.erase(it); } else { ++it; @@ -736,8 +736,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} Status Populate(bool force_pessimistic = false); - Status PopulateFrame(absl::Span tpo, bool use_optimistic_mode, - bool* is_success); + Status PopulateFrame(absl::Span topo, bool use_optimistic_mode, + bool* success); StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; @@ -955,11 +955,13 @@ Status GetFullFrame(const Node* n, absl::Span cfi_infos, return Status::OK(); } +// If the node is inside some frames, get the name of the outermost non-empty +// frame. Otherwise, get an empty frame name. Status GetRootFrame(const Node* n, absl::Span cfi_infos, - const Node* src_node, string* frame) { + absl::string_view frame) { int depth = 0; const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; - while (cfi_iter->parent_frame != src_node) { + while (!cfi_iter->parent_frame->IsSource()) { n = cfi_iter->parent_frame; cfi_iter = &cfi_infos[n->id()]; @@ -970,21 +972,23 @@ Status GetRootFrame(const Node* n, absl::Span cfi_infos, } } - *frame = cfi_iter->frame_name; + frame = cfi_iter->frame_name; return Status::OK(); } // Compute a special topological order for the Graph, where nodes having the -// same root frame are placed adjacent to each other. +// same root frame are placed adjacent to each other. The traversal is a +// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how +// many inputs of each node are ready; a node is ready to be scheduled if all +// of its inputs are ready. +// For details, see https://en.wikipedia.org/wiki/Topological_sorting Status GetFrameBasedTopologicalOrder(const Graph* g, absl::Span cf_infos, std::vector* order) { - absl::flat_hash_map num_enters; - absl::flat_hash_map num_exits; + absl::flat_hash_map num_enters; + absl::flat_hash_map num_exits; std::vector num_ready_inputs(g->num_node_ids(), 0); Node* src_node = g->source_node(); - std::deque ready; - ready.push_back(src_node); for (const auto* node : g->op_nodes()) { const ControlFlowInfo& cf = cf_infos[node->id()]; bool is_root_level = cf.parent_frame == src_node; @@ -995,6 +999,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g, } else if (IsExit(node) && is_root_level) { ++num_exits[cf.frame_name]; } + // Edge NextIteration->Merge is counted before starting the traveral to + // break the backedges. if (IsMerge(node)) { for (const Edge* e : node->in_edges()) { if (IsNextIteration(e->src())) { @@ -1004,6 +1010,8 @@ Status GetFrameBasedTopologicalOrder(const Graph* g, } } + std::deque ready; + ready.push_back(src_node); absl::flat_hash_map> staging_enter_vecs; // Exit nodes shall all be from the same frame, as we process a frame at a // time. So, one vector is enough. @@ -1044,7 +1052,7 @@ Status GetFrameBasedTopologicalOrder(const Graph* g, // make sure all nodes in the currently processing frame are visited // before starting processing other frames. string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name; - CHECK(staging_exit_vec.size() == num_exits[frame_name]); + CHECK_EQ(staging_exit_vec.size(), num_exits[frame_name]); ready.insert(ready.end(), staging_exit_vec.begin(), staging_exit_vec.end()); staging_exit_vec.clear(); @@ -1098,8 +1106,8 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // does not affect the result after pattern-matching into the // AndRecurrence form. string frame_name = control_flow_info_[n->id()].frame_name; - auto ret = frame_to_merge_node_.insert({frame_name, n}); - Node* representative = ret.first->second; + auto insert_result = frame_to_merge_node_.insert({frame_name, n}); + Node* representative = insert_result.first->second; TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( representative, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); @@ -1240,52 +1248,51 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { absl::StrJoin(unreachable_nodes, ", ")); } - std::vector tpo; + std::vector topo; TF_RETURN_IF_ERROR( - GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo)); + GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &topo)); size_t frame_start = 0; - for (size_t i = 0; i < tpo.size(); ++i) { + for (size_t i = 0; i < topo.size(); ++i) { // Collect nodes until we see a node who has a different root frame. - if (i != tpo.size() - 1) { - string i_frame_name, next_frame_name; - TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_, - graph_.source_node(), &i_frame_name)); - TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_, - graph_.source_node(), &next_frame_name)); + if (i != topo.size() - 1) { + absl::string_view i_frame_name, next_frame_name; + TF_RETURN_IF_ERROR(GetRootFrame(topo[i], control_flow_info_, + i_frame_name)); + TF_RETURN_IF_ERROR(GetRootFrame(topo[i + 1], control_flow_info_, + next_frame_name)); if (i_frame_name == next_frame_name) { continue; } } - string frame_name = control_flow_info_[tpo[i]->id()].frame_name; - absl::Span sub_tpo(tpo.data() + frame_start, i - frame_start + 1); + string frame_name = control_flow_info_[topo[i]->id()].frame_name; + absl::Span sub_topo(topo.data() + frame_start, + /*length=*/i - frame_start + 1); frame_start = i + 1; // First, try the optimistic mode. - bool is_success = false; + bool success = false; if (!force_pessimistic && !frame_name.empty()) { TF_RETURN_IF_ERROR( - PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success)); + PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success)); } - if (!is_success) { + if (!success) { // The optimistic mode does not converge. Let's fall back to the // pessimistic mode. TF_RETURN_IF_ERROR( - PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr)); - } - if (VLOG_IS_ON(2)) { - VLOG(2) << "Done populating frame " << frame_name << " using the " - << (is_success ? "optimistic" : "pessimistic") << " mode."; + PopulateFrame(sub_topo, /*use_optimistic_mode=*/false, nullptr)); } + VLOG(2) << "Done populating frame " << frame_name << " using the " + << (success ? "optimistic" : "pessimistic") << " mode."; } return Status::OK(); } -Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, +Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, bool use_optimistic_mode, - bool* is_success) { + bool* success) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // @@ -1303,7 +1310,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, // delta should not change in the second iteration. std::vector should_revisit; should_revisit.resize(graph_.num_node_ids()); - for (Node* n : tpo) { + for (Node* n : topo) { VLOG(4) << "Visiting " << n->name(); TF_RETURN_IF_ERROR( HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode)); @@ -1318,7 +1325,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, } } - for (Node* n : tpo) { + for (Node* n : topo) { // The nodes added to should_revisit in the previous loop need to be // revisited now. Reprocesing these initial nodes may add *their* consumers // to should_revisit, and these newly added nodes will also be processed by @@ -1339,7 +1346,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, if (use_optimistic_mode) { bool is_converged = true; absl::flat_hash_map frame_to_pred; - for (Node* n : tpo) { + for (Node* n : topo) { if (!n->IsMerge()) { continue; } @@ -1382,7 +1389,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, // Clear the assigned predicates if the optimistic mode does not converge. if (!is_converged) { - for (Node* n : tpo) { + for (Node* n : topo) { for (int oid = 0; oid < n->num_outputs(); ++oid) { predicate_map_.erase(TensorId(n->name(), oid)); } @@ -1390,8 +1397,8 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, } } - if (is_success != nullptr) { - *is_success = is_converged; + if (success != nullptr) { + *success = is_converged; } } diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 8ce61a828f6..abfce7ce2aa 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -639,7 +639,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ false)); + /*force_pessimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -653,7 +653,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ true)); + /*force_pessimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -683,7 +683,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ false)); + /*force_pessimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -691,7 +691,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ true)); + /*force_pessimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "div0/iv:0"); @@ -746,7 +746,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ false)); + /*force_pessimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -772,7 +772,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ true)); + /*force_pessimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -925,7 +925,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ false)); + /*force_pessimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -942,7 +942,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic*/ true)); + /*force_pessimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}");