diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 5d66be59eb3..7f72f35d16f 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -735,7 +735,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { explicit DeadnessAnalysisImpl(const Graph* graph) : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} - Status Populate(bool force_pessimistic = false); + Status Populate(bool enable_optimistic); Status PopulateFrame(absl::Span topo, bool use_optimistic_mode, bool* success); StatusOr GetPredicateFor( @@ -786,12 +786,14 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { Status HandleNode(Node* n, std::vector* should_revisit, bool use_optimistic_mode = false); + Status GetFrameBasedTopologicalOrder(std::vector* order); + const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; std::vector control_flow_info_; bool vlog_; - absl::flat_hash_map frame_to_merge_node_; + absl::flat_hash_map frame_to_merge_node_; }; TensorId InputEdgeToTensorId(const Edge* e) { @@ -975,109 +977,6 @@ Status GetRootFrame(const Node* n, absl::Span cfi_infos, *frame = cfi_iter->frame_name; return Status::OK(); } - -// Compute a special topological order for the Graph, where nodes having the -// same root frame are placed adjacent to each other. The traversal uses a -// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how -// many inputs of each node are ready; a node is ready to be scheduled if all -// of its inputs are ready. -// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details. -Status GetFrameBasedTopologicalOrder(const Graph* g, - absl::Span cf_infos, - std::vector* order) { - absl::flat_hash_map num_enters; - absl::flat_hash_map num_exits; - std::vector num_ready_inputs(g->num_node_ids(), 0); - Node* src_node = g->source_node(); - for (const auto* node : g->op_nodes()) { - const ControlFlowInfo& cf = cf_infos[node->id()]; - bool is_root_level = cf.parent_frame == src_node; - if (IsEnter(node) && is_root_level) { - // Since we care only the root-level frame, full frame names are the same - // as frame names. - ++num_enters[cf.frame_name]; - } else if (IsExit(node) && is_root_level) { - ++num_exits[cf.frame_name]; - } - // Edge NextIteration->Merge is counted before starting the traveral to - // break the backedges. - if (IsMerge(node)) { - for (const Edge* e : node->in_edges()) { - if (IsNextIteration(e->src())) { - ++num_ready_inputs[node->id()]; - } - } - } - } - - std::deque ready; - ready.push_back(src_node); - // ready_enters_per_frame and ready_exits serve as a staging area to buffer - // the ready enters/exits before they are moved to the `ready` queue for - // controlling the start and end of a processing frame. - absl::flat_hash_map> ready_enters_per_frame; - // Exit nodes shall all be from the same frame, as we process a frame at a - // time. So, one vector is enough. - std::vector ready_exits; - while (!ready.empty()) { - Node* curr_node = ready.front(); - ready.pop_front(); - - VLOG(4) << "Visiting " << curr_node->name(); - order->push_back(curr_node); - - for (const Edge* out_edge : curr_node->out_edges()) { - Node* out = out_edge->dst(); - int out_id = out->id(); - if (IsNextIteration(curr_node) && IsMerge(out)) { - // Edge NextIteration->Merge has been counted. - continue; - } - ++num_ready_inputs[out->id()]; - if (!out->IsOp()) continue; // Skip Sink/Source nodes. - if (num_ready_inputs[out->id()] != out->in_edges().size()) continue; - - bool is_root_level = cf_infos[out_id].parent_frame == src_node; - string frame_name = cf_infos[out_id].frame_name; - if (IsEnter(out) && is_root_level) { - ready_enters_per_frame[frame_name].push_back(out); - } else if (IsExit(out) && is_root_level) { - ready_exits.push_back(out); - } else { - ready.push_back(out); - } - } - - if (ready.empty()) { - // Try moving nodes from ready_enters_per_frame and read_exits to `ready`. - if (!ready_exits.empty()) { - // If there are nodes in ready_exits we must process them before - // processing ready_enters_per_frame to make sure all nodes in the - // currently processing frame are visited before starting processing - // other frames. - string frame_name = cf_infos[ready_exits.front()->id()].frame_name; - CHECK_EQ(ready_exits.size(), num_exits[frame_name]); - ready.insert(ready.end(), ready_exits.begin(), ready_exits.end()); - ready_exits.clear(); - } else { - // Otherwise, try moving nodes from ready_enters to `ready`. - for (auto iter = ready_enters_per_frame.begin(); - iter != ready_enters_per_frame.end(); ++iter) { - string frame_name = iter->first; - const std::vector& ready_enters = iter->second; - if (ready_enters.size() == num_enters[frame_name]) { - ready.insert(ready.end(), ready_enters.begin(), ready_enters.end()); - ready_enters_per_frame.erase(iter); - break; - } - } - } - } - } - - CHECK(ready_enters_per_frame.empty() && ready_exits.empty()); - return Status::OK(); -} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, @@ -1107,7 +1006,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // frame as the representative Merge node. It is just convenient and // does not affect the result after pattern-matching into the // AndRecurrence form. - string frame_name = control_flow_info_[n->id()].frame_name; + absl::string_view frame_name = control_flow_info_[n->id()].frame_name; auto insert_result = frame_to_merge_node_.insert({frame_name, n}); Node* representative = insert_result.first->second; TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( @@ -1219,6 +1118,109 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, return Status::OK(); } +// Compute a special topological order for the Graph, where nodes having the +// same root frame are placed adjacent to each other. The traversal uses a +// variant of Kahn's algorithm. num_ready_inputs is used to keep track of how +// many inputs of each node are ready; a node is ready to be scheduled if all +// of its inputs are ready. +// Ref. to https://en.wikipedia.org/wiki/Topological_sorting for details. +Status DeadnessAnalysisImpl::GetFrameBasedTopologicalOrder( + std::vector* order) { + absl::flat_hash_map num_enters_for_frame; + absl::flat_hash_map num_exits_for_frame; + std::vector num_ready_inputs(graph_.num_node_ids(), 0); + Node* src_node = graph_.source_node(); + for (const auto* node : graph_.op_nodes()) { + const ControlFlowInfo& cf = control_flow_info_[node->id()]; + bool is_root_level = cf.parent_frame == src_node; + if (IsEnter(node) && is_root_level) { + // Since we care only the root-level frame, full frame names are the same + // as frame names. + ++num_enters_for_frame[cf.frame_name]; + } else if (IsExit(node) && is_root_level) { + ++num_exits_for_frame[cf.frame_name]; + } + // Edge NextIteration->Merge is counted before starting the traveral to + // break the backedges. + if (IsMerge(node)) { + for (const Edge* e : node->in_edges()) { + if (IsNextIteration(e->src())) { + ++num_ready_inputs[node->id()]; + } + } + } + } + + std::deque ready; + ready.push_back(src_node); + // ready_enters_per_frame and ready_exits serve as a staging area to buffer + // the ready enters/exits before they are moved to the `ready` queue for + // controlling the start and end of a processing frame. + absl::flat_hash_map> ready_enters_per_frame; + // Exit nodes shall all be from the same frame, as we process a frame at a + // time. So, one vector is enough. + std::vector ready_exits; + while (!ready.empty()) { + Node* curr_node = ready.front(); + ready.pop_front(); + + VLOG(4) << "Visiting " << curr_node->name(); + order->push_back(curr_node); + + for (const Edge* out_edge : curr_node->out_edges()) { + Node* out = out_edge->dst(); + int out_id = out->id(); + if (IsNextIteration(curr_node) && IsMerge(out)) { + // Edge NextIteration->Merge has been counted. + continue; + } + ++num_ready_inputs[out->id()]; + if (!out->IsOp()) continue; // Skip Sink/Source nodes. + if (num_ready_inputs[out->id()] != out->in_edges().size()) continue; + + bool is_root_level = control_flow_info_[out_id].parent_frame == src_node; + absl::string_view frame_name = control_flow_info_[out_id].frame_name; + if (IsEnter(out) && is_root_level) { + ready_enters_per_frame[frame_name].push_back(out); + } else if (IsExit(out) && is_root_level) { + ready_exits.push_back(out); + } else { + ready.push_back(out); + } + } + + if (ready.empty()) { + // Try moving nodes from ready_enters_per_frame and read_exits to `ready`. + if (!ready_exits.empty()) { + // If there are nodes in ready_exits we must process them before + // processing ready_enters_per_frame to make sure all nodes in the + // currently processing frame are visited before starting processing + // other frames. + absl::string_view frame_name = + control_flow_info_[ready_exits.front()->id()].frame_name; + CHECK_EQ(ready_exits.size(), num_exits_for_frame[frame_name]); + ready.insert(ready.end(), ready_exits.begin(), ready_exits.end()); + ready_exits.clear(); + } else { + // Otherwise, try moving nodes from ready_enters to `ready`. + for (auto iter = ready_enters_per_frame.begin(); + iter != ready_enters_per_frame.end(); ++iter) { + absl::string_view frame_name = iter->first; + const std::vector& ready_enters = iter->second; + if (ready_enters.size() == num_enters_for_frame[frame_name]) { + ready.insert(ready.end(), ready_enters.begin(), ready_enters.end()); + ready_enters_per_frame.erase(iter); + break; + } + } + } + } + } + + CHECK(ready_enters_per_frame.empty() && ready_exits.empty()); + return Status::OK(); +} + // We populate the nodes along a special topological order where nodes having // the same root frame are placed adjacent to each other. This grouping enables // processing the graph per root frame at a time and guarantees that when a root @@ -1231,7 +1233,7 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, // (root) frame. Note that we don't separate while loops belonging to the same // nested while, as there is no clean cut for separating them in the topological // order. -Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { +Status DeadnessAnalysisImpl::Populate(bool enable_optimistic) { std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( @@ -1251,8 +1253,7 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { } std::vector topo; - TF_RETURN_IF_ERROR( - GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &topo)); + TF_RETURN_IF_ERROR(GetFrameBasedTopologicalOrder(&topo)); size_t frame_start = 0; while (frame_start < topo.size()) { @@ -1277,7 +1278,7 @@ Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { // First, try the optimistic mode. bool success = false; - if (!force_pessimistic && !cur_frame_name.empty()) { + if (enable_optimistic && !cur_frame_name.empty()) { TF_RETURN_IF_ERROR( PopulateFrame(sub_topo, /*use_optimistic_mode=*/true, &success)); } @@ -1349,7 +1350,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, // predicates. if (use_optimistic_mode) { bool is_converged = true; - absl::flat_hash_map frame_to_pred; + absl::flat_hash_map frame_to_pred; for (Node* n : topo) { if (!n->IsMerge()) { continue; @@ -1361,7 +1362,7 @@ Status DeadnessAnalysisImpl::PopulateFrame(absl::Span topo, continue; } Node* merge = n; - string frame_name = control_flow_info_[merge->id()].frame_name; + absl::string_view frame_name = control_flow_info_[merge->id()].frame_name; auto it = predicate_map_.find(TensorId(merge->name(), 0)); Predicate* merge_pred = it->second; if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) { @@ -1441,7 +1442,7 @@ DeadnessAnalysis::~DeadnessAnalysis() {} const Graph& graph, std::unique_ptr* result) { std::unique_ptr analysis( new DeadnessAnalysisImpl(&graph)); - TF_RETURN_IF_ERROR(analysis->Populate()); + TF_RETURN_IF_ERROR(analysis->Populate(/*enable_optimistic=*/true)); if (VLOG_IS_ON(2)) { analysis->Print(); @@ -1463,9 +1464,9 @@ DeadnessAnalysisImpl::PredicateMapAsString() const { namespace deadness_analysis_internal { Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, - bool force_pessimistic) { + bool enable_optimistic) { DeadnessAnalysisImpl impl(&graph); - TF_RETURN_IF_ERROR(impl.Populate(force_pessimistic)); + TF_RETURN_IF_ERROR(impl.Populate(enable_optimistic)); *out_predicate_map = impl.PredicateMapAsString(); return Status::OK(); } diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 3d216ccbac3..b2f0e72bc14 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -26,7 +26,7 @@ namespace deadness_analysis_internal { // testing purposes only. using PredicateMapTy = absl::flat_hash_map; Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, - bool force_pessimistic = false); + bool enable_optimistic = true); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index abfce7ce2aa..fae1e55c6ba 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -639,7 +639,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/false)); + /*enable_optimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -653,7 +653,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/true)); + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -683,7 +683,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/false)); + /*enable_optimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -691,7 +691,7 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/true)); + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "div0/iv:0"); @@ -746,7 +746,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/false)); + /*enable_optimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -755,7 +755,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" "cond:0}"); - // force_pessimistic = true or not should produce the same results because + // enable_optimistic = true or not should produce the same results because // of fallback. However, note that the order of iv_inner/cond:0 and // iv_inner/iv:0 is different because the optimistic approach does not // create predicates for all merges and it can change the predicate id and @@ -772,7 +772,7 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/true)); + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -925,7 +925,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/false)); + /*enable_optimistic=*/true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -942,7 +942,7 @@ TEST(DeadnessAnalysisTest, CyclicRecurrence) { { PredicateMapTy predicate_map; TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, - /*force_pessimistic=*/true)); + /*enable_optimistic=*/false)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}");