diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc index 0a92c06ad10..b98e479ca07 100644 --- a/tensorflow/compiler/jit/deadness_analysis.cc +++ b/tensorflow/compiler/jit/deadness_analysis.cc @@ -85,22 +85,28 @@ limitations under the License. // true on iteration 0, 1, 2 respectively. This is made more precise in the // comment on the AndRecurrence class. // -// The general algorithm that deals with cycles does two RPO (reverse post -// order) passes over the graph. On the first pass it assigns a symbolic -// predicate to merge nodes with backedges. On the second pass it tries to -// pattern matche the predicates for the backedges of these merges and infer an -// AndRecurrence for the merge. +// The general algorithm that deals with cycles does two topological-order +// iterations over the graph. On the first iteration it assigns a symbolic +// predicate to merge nodes with backedges. On the second iteration it tries +// to pattern match the predicates for the backedges of these merges and infer +// an AndRecurrence for the merge. In other words, we do a data flow analysis +// where the data-flow lattice has two elements, Symbolic and NonSymbolic with +// Symbolic > NonSymbolic. The lattice has height = 2 so two iterations are +// sufficient to converge. // -// In other words, we do a pessimistic data flow analysis where the data-flow -// lattice has two elements, Symbolic and NonSymbolic with Symbolic > -// NonSymbolic. The lattice has height = 2 so two iterations are sufficient to -// converge. We don't do an optimistic data flow analysis to make pattern -// matching easier: if we assigned the predicate of the initial value to the -// merge during the first pass, on the second pass the backedge may see a -// simplified value that would be difficult to pattern match. +// We first do an optimisitc analysis and, if it does not converge, we then fall +// back to a pessimistic analysis. The optimistic analysis assigns the same +// symbolic predicate to all the merge nodes whose preceding enter nodes have +// the same frame name on the first iteration. On the second iteration, if all +// the merge nodes are pattern matched into the same AndRecurrence predicate +// instance, the optimistic assignment of the same symbolic predicate is correct +// and the analyzed result is taken. // -// We still use symbolic predicates for merges for which we can't pattern match -// on the backedge predicate. This is conservatively correct. +// Otherwise, if the optimistic analysis fails to converge, we then obtain the +// result by falling back to the pessimistic analysis which assigns a unique +// symbolic predicate to each merge on the first iteration. We still use +// symbolic predicates for merges for which we can't pattern match on the +// backedge predicate. This is conservatively correct. namespace tensorflow { @@ -636,6 +642,36 @@ Predicate* PredicateFactory::MakeAndOrImpl( negated_ops.insert(negated_op); } + // Simplify {S,&,X} & ~X & ... => S & ... + if (is_and) { + absl::flat_hash_set to_remove; + std::vector to_add; + for (Predicate* op : simplified_ops) { + if (op->kind() == Predicate::Kind::kAndRecurrence) { + AndRecurrencePredicate* and_rec = + static_cast(op); + if (negated_ops.count(and_rec->step())) { + // Remove and_rec and ~X and insert S. Note that checking the + // existence of ~X through negated_ops is sufficient since it makes + // sure the predicate is in the input operands. It does not need to + // be in simplified_ops if it was already cancelled out. + to_remove.insert(and_rec); + to_remove.insert(MakeNotPredicate(and_rec->step())); + to_add.push_back(and_rec->start()); + } + } + } + auto it = simplified_ops.begin(); + while (it != simplified_ops.end()) { + if (to_remove.count(*it)) { + it = simplified_ops.erase(it); + } else { + ++it; + } + } + simplified_ops.insert(simplified_ops.end(), to_add.begin(), to_add.end()); + } + // If all ops contain the same subop, then factor it out thanks to the // distributive property. Such as: // - (A & B) | (A & C) | (A & D) => A & (B | C | D) @@ -699,8 +735,9 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { explicit DeadnessAnalysisImpl(const Graph* graph) : graph_(*graph), vlog_(VLOG_IS_ON(2)) {} - Status Populate(); - Status PopulateWithReversePostOrder(absl::Span rpo); + Status Populate(bool force_pessimistic = false); + Status PopulateFrame(absl::Span tpo, bool use_optimistic_mode, + bool* is_success); StatusOr GetPredicateFor( Node* n, int oidx) const override; void Print() const override; @@ -742,16 +779,19 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis { } Status HandleSwitch(Node* n, std::vector* should_revisit); - Status HandleMerge(Node* n, std::vector* should_revisit); + Status HandleMerge(Node* n, std::vector* should_revisit, + bool use_optimistic_mode); Status HandleRecv(Node* n, std::vector* should_revisit); Status HandleGeneric(Node* n, std::vector* should_revisit); - Status HandleNode(Node* n, std::vector* should_revisit); + Status HandleNode(Node* n, std::vector* should_revisit, + bool use_optimistic_mode = false); const Graph& graph_; absl::flat_hash_map predicate_map_; PredicateFactory predicate_factory_; std::vector control_flow_info_; bool vlog_; + absl::flat_hash_map frame_to_merge_node_; }; TensorId InputEdgeToTensorId(const Edge* e) { @@ -914,10 +954,125 @@ Status GetFullFrame(const Node* n, absl::Span cfi_infos, return Status::OK(); } + +Status GetRootFrame(const Node* n, absl::Span cfi_infos, + const Node* src_node, string* frame) { + int depth = 0; + const ControlFlowInfo* cfi_iter = &cfi_infos[n->id()]; + while (cfi_iter->parent_frame != src_node) { + n = cfi_iter->parent_frame; + cfi_iter = &cfi_infos[n->id()]; + + if (depth++ > 5000) { + return errors::Internal( + "Frame of depth > 5000: Probably malformed graph or a bug in " + "BuildControlFlowInfo"); + } + } + + *frame = cfi_iter->frame_name; + return Status::OK(); +} + +// Compute a special topological order for the Graph, where nodes having the +// same root frame are placed adjacent to each other. +Status GetFrameBasedTopologicalOrder(const Graph* g, + absl::Span cf_infos, + std::vector* order) { + absl::flat_hash_map num_enters; + absl::flat_hash_map num_exits; + std::vector num_ready_inputs(g->num_node_ids(), 0); + Node* src_node = g->source_node(); + std::deque ready; + ready.push_back(src_node); + for (const auto* node : g->op_nodes()) { + const ControlFlowInfo& cf = cf_infos[node->id()]; + bool is_root_level = cf.parent_frame == src_node; + if (IsEnter(node) && is_root_level) { + // Since we care only the root-level frame, full frame names are the same + // as frame names. + ++num_enters[cf.frame_name]; + } else if (IsExit(node) && is_root_level) { + ++num_exits[cf.frame_name]; + } + if (IsMerge(node)) { + for (const Edge* e : node->in_edges()) { + if (IsNextIteration(e->src())) { + ++num_ready_inputs[node->id()]; + } + } + } + } + + absl::flat_hash_map> staging_enter_vecs; + // Exit nodes shall all be from the same frame, as we process a frame at a + // time. So, one vector is enough. + std::vector staging_exit_vec; + while (!ready.empty()) { + Node* curr_node = ready.front(); + ready.pop_front(); + + VLOG(4) << "Visiting " << curr_node->name(); + order->push_back(curr_node); + + for (const Edge* out_edge : curr_node->out_edges()) { + Node* out = out_edge->dst(); + int out_id = out->id(); + if (IsNextIteration(curr_node) && IsMerge(out)) { + // Edge NextIteration->Merge has been counted. + continue; + } + ++num_ready_inputs[out->id()]; + if (!out->IsOp()) continue; // Skip Sink/Source nodes. + if (num_ready_inputs[out->id()] != out->in_edges().size()) continue; + + bool is_root_level = cf_infos[out_id].parent_frame == src_node; + string frame_name = cf_infos[out_id].frame_name; + if (IsEnter(out) && is_root_level) { + staging_enter_vecs[frame_name].push_back(out); + } else if (IsExit(out) && is_root_level) { + staging_exit_vec.push_back(out); + } else { + ready.push_back(out); + } + } + + if (ready.empty()) { + if (!staging_exit_vec.empty()) { + // Move staging nodes into the ready queue if any. If there are staging + // exits we must process them before processing the staging enters to + // make sure all nodes in the currently processing frame are visited + // before starting processing other frames. + string frame_name = cf_infos[staging_exit_vec.front()->id()].frame_name; + CHECK(staging_exit_vec.size() == num_exits[frame_name]); + ready.insert(ready.end(), staging_exit_vec.begin(), + staging_exit_vec.end()); + staging_exit_vec.clear(); + } else { + // Otherwise, try moving the staging enter nodes into the ready queue. + for (auto iter = staging_enter_vecs.begin(); + iter != staging_enter_vecs.end(); ++iter) { + string frame_name = iter->first; + const std::vector& staging_enters = iter->second; + if (staging_enters.size() == num_enters[frame_name]) { + ready.insert(ready.end(), staging_enters.begin(), + staging_enters.end()); + staging_enter_vecs.erase(iter); + break; + } + } + } + } + } + + CHECK(staging_enter_vecs.empty() && staging_exit_vec.empty()); + return Status::OK(); +} } // namespace Status DeadnessAnalysisImpl::HandleMerge(Node* n, - std::vector* should_revisit) { + std::vector* should_revisit, + bool use_optimistic_mode) { // Merge ignores deadness of its control inputs. A merge that isn't the // target of a backedge has is alive iff any of its data inputs are. The // liveness of a merge that is the target of a backedge can sometimes be @@ -937,8 +1092,21 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, // We're visiting this merge for the first time and it has an unvisited // backedge. Predicate* input_data_pred; - TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( - n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + if (use_optimistic_mode) { + // In the optimistic mode, we use the first-seen Merge node per + // frame as the representative Merge node. It is just convenient and + // does not affect the result after pattern-matching into the + // AndRecurrence form. + string frame_name = control_flow_info_[n->id()].frame_name; + auto ret = frame_to_merge_node_.insert({frame_name, n}); + Node* representative = ret.first->second; + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + representative, /*output_idx=*/0, /*must_be_true=*/false, + &input_data_pred)); + } else { + TF_RETURN_IF_ERROR(predicate_factory_.MakeSymbolPredicate( + n, /*output_idx=*/0, /*must_be_true=*/false, &input_data_pred)); + } SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, should_revisit); @@ -948,7 +1116,7 @@ Status DeadnessAnalysisImpl::HandleMerge(Node* n, std::vector input_preds; TF_RETURN_IF_ERROR(GetInputPreds(n, EdgeKind::kDataOnly, &input_preds)); - // We're visiting this merge for the first time and it is a acyclic merge. + // We're visiting this merge for the first time and it is an acyclic merge. Predicate* input_data_pred = predicate_factory_.MakeOrPredicate(input_preds); SetPredicate(n, {0, 1, Graph::kControlSlot}, input_data_pred, @@ -1022,11 +1190,12 @@ Status DeadnessAnalysisImpl::HandleGeneric(Node* n, } Status DeadnessAnalysisImpl::HandleNode(Node* n, - std::vector* should_revisit) { + std::vector* should_revisit, + bool use_optimistic_mode) { if (n->IsSwitch()) { TF_RETURN_IF_ERROR(HandleSwitch(n, should_revisit)); } else if (n->IsMerge()) { - TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit)); + TF_RETURN_IF_ERROR(HandleMerge(n, should_revisit, use_optimistic_mode)); } else if (n->IsControlTrigger()) { SetPredicate(n, Graph::kControlSlot, predicate_factory_.MakeTrue(), nullptr); @@ -1040,17 +1209,19 @@ Status DeadnessAnalysisImpl::HandleNode(Node* n, return Status::OK(); } -Status DeadnessAnalysisImpl::Populate() { - std::vector rpo; - GetReversePostOrder(graph_, &rpo, /*stable_comparator=*/NodeComparatorName(), - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); - return PopulateWithReversePostOrder(rpo); -} - -Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( - absl::Span rpo) { +// We populate the nodes along a special topological order where nodes having +// the same root frame are placed adjacent to each other. This grouping enables +// processing the graph per root frame at a time and guarantees that when a root +// frame is being processed, nodes in the downstream frames have not yet been +// processed. This property is important because we need to process an entire +// frame to know whether the optimistic mode converges or not. In other words, +// nodes in the downstream frames shall not be populated until all of its +// upstream frames are populated. In effect, this order enables processing each +// (nested) tf.while one-by-one, as each (nested) tf.while creates a unique +// (root) frame. Note that we don't separate while loops belonging to the same +// nested while, as there is no clean cut for separating them in the topological +// order. +Status DeadnessAnalysisImpl::Populate(bool force_pessimistic) { std::vector unreachable_nodes; // Compute the loop structure of the graph. TF_RETURN_IF_ERROR( @@ -1069,6 +1240,52 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( absl::StrJoin(unreachable_nodes, ", ")); } + std::vector tpo; + TF_RETURN_IF_ERROR( + GetFrameBasedTopologicalOrder(&graph_, control_flow_info_, &tpo)); + + size_t frame_start = 0; + for (size_t i = 0; i < tpo.size(); ++i) { + // Collect nodes until we see a node who has a different root frame. + if (i != tpo.size() - 1) { + string i_frame_name, next_frame_name; + TF_RETURN_IF_ERROR(GetRootFrame(tpo[i], control_flow_info_, + graph_.source_node(), &i_frame_name)); + TF_RETURN_IF_ERROR(GetRootFrame(tpo[i + 1], control_flow_info_, + graph_.source_node(), &next_frame_name)); + if (i_frame_name == next_frame_name) { + continue; + } + } + + string frame_name = control_flow_info_[tpo[i]->id()].frame_name; + absl::Span sub_tpo(tpo.data() + frame_start, i - frame_start + 1); + frame_start = i + 1; + + // First, try the optimistic mode. + bool is_success = false; + if (!force_pessimistic && !frame_name.empty()) { + TF_RETURN_IF_ERROR( + PopulateFrame(sub_tpo, /*use_optimistic_mode*/ true, &is_success)); + } + if (!is_success) { + // The optimistic mode does not converge. Let's fall back to the + // pessimistic mode. + TF_RETURN_IF_ERROR( + PopulateFrame(sub_tpo, /*use_optimistic_mode*/ false, nullptr)); + } + if (VLOG_IS_ON(2)) { + VLOG(2) << "Done populating frame " << frame_name << " using the " + << (is_success ? "optimistic" : "pessimistic") << " mode."; + } + } + + return Status::OK(); +} + +Status DeadnessAnalysisImpl::PopulateFrame(absl::Span tpo, + bool use_optimistic_mode, + bool* is_success) { // This an abstract interpretation over the deadness propagation semantics of // the graph executor. // @@ -1086,9 +1303,10 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( // delta should not change in the second iteration. std::vector should_revisit; should_revisit.resize(graph_.num_node_ids()); - for (Node* n : rpo) { + for (Node* n : tpo) { VLOG(4) << "Visiting " << n->name(); - TF_RETURN_IF_ERROR(HandleNode(n, /*should_revisit=*/nullptr)); + TF_RETURN_IF_ERROR( + HandleNode(n, /*should_revisit=*/nullptr, use_optimistic_mode)); if (n->IsNextIteration()) { // If this is a backedge for a merge node then remember to reprocess the // merge the next time we run. @@ -1100,11 +1318,11 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( } } - for (Node* n : rpo) { + for (Node* n : tpo) { // The nodes added to should_revisit in the previous loop need to be // revisited now. Reprocesing these initial nodes may add *their* consumers // to should_revisit, and these newly added nodes will also be processed by - // this very same loop. Since we're traversing the graph in reverse post + // this very same loop. Since we're traversing the graph in topological // order (producers before consumers) and HandleNode(n) can only ever add // n's consumers to should_revisit, we won't "miss" an addition to // should_revisit. @@ -1114,6 +1332,69 @@ Status DeadnessAnalysisImpl::PopulateWithReversePostOrder( } } + // Check if the optimistic analysis converges. Specifically, check whether + // all the predicates of the merge nodes in the same frame are the same. If + // yes, report success. If not, report failure and clear the assigned + // predicates. + if (use_optimistic_mode) { + bool is_converged = true; + absl::flat_hash_map frame_to_pred; + for (Node* n : tpo) { + if (!n->IsMerge()) { + continue; + } + const Edge* e; + TF_RETURN_IF_ERROR(FindUniqueBackedge(n, &e)); + if (e == nullptr) { + // Skip acyclic merge nodes. + continue; + } + Node* merge = n; + string frame_name = control_flow_info_[merge->id()].frame_name; + auto it = predicate_map_.find(TensorId(merge->name(), 0)); + Predicate* merge_pred = it->second; + if (merge_pred->kind() != Predicate::Kind::kAndRecurrence) { + is_converged = false; + VLOG(2) << "Running the optimistic mode on frame " << frame_name + << " does not converge because node " << merge->name() + << " cannot be mapped into the AndRecurrence form."; + break; + } + + auto insert_result = frame_to_pred.insert({frame_name, merge_pred}); + if (!insert_result.second) { + // If we have already seen this frame name, verify the predicate is the + // same as the previously seen one's. + AndRecurrencePredicate* curr_andrec = + static_cast(merge_pred); + AndRecurrencePredicate* prev_andrec = + static_cast(insert_result.first->second); + if (curr_andrec != prev_andrec) { + is_converged = false; + VLOG(2) << "Running the optimistic mode on frame " << frame_name + << " does not converge. Seeing different Merge predicates: \n" + << curr_andrec->ToString() << " and \n" + << prev_andrec->ToString(); + break; + } + } + } + + // Clear the assigned predicates if the optimistic mode does not converge. + if (!is_converged) { + for (Node* n : tpo) { + for (int oid = 0; oid < n->num_outputs(); ++oid) { + predicate_map_.erase(TensorId(n->name(), oid)); + } + predicate_map_.erase(TensorId(n->name(), Graph::kControlSlot)); + } + } + + if (is_success != nullptr) { + *is_success = is_converged; + } + } + return Status::OK(); } @@ -1170,22 +1451,14 @@ DeadnessAnalysisImpl::PredicateMapAsString() const { } namespace deadness_analysis_internal { -Status ComputePredicates(const Graph& graph, - PredicateMapTy* out_predicate_map) { +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, + bool force_pessimistic) { DeadnessAnalysisImpl impl(&graph); - TF_RETURN_IF_ERROR(impl.Populate()); + TF_RETURN_IF_ERROR(impl.Populate(force_pessimistic)); *out_predicate_map = impl.PredicateMapAsString(); return Status::OK(); } -Status ComputePredicates(const Graph& graph, - absl::Span reverse_post_order, - PredicateMapTy* out_predicate_map) { - DeadnessAnalysisImpl impl(&graph); - TF_RETURN_IF_ERROR(impl.PopulateWithReversePostOrder(reverse_post_order)); - *out_predicate_map = impl.PredicateMapAsString(); - return Status::OK(); -} } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h index 354782374ad..3d216ccbac3 100644 --- a/tensorflow/compiler/jit/deadness_analysis_internal.h +++ b/tensorflow/compiler/jit/deadness_analysis_internal.h @@ -25,15 +25,9 @@ namespace deadness_analysis_internal { // Returns a map describing the predicate each Tensor was mapped to. For // testing purposes only. using PredicateMapTy = absl::flat_hash_map; -Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map); +Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map, + bool force_pessimistic = false); -// Returns a map describing the predicate each Tensor was mapped to. For -// testing purposes only. Makes deadness analysis visit the graph in the order -// specified in `reverse_post_order` which must be a valid RPO for the graph -// minus NextIteration->Merge edges. -Status ComputePredicates(const Graph& graph, - absl::Span reverse_post_order, - PredicateMapTy* out_predicate_map); } // namespace deadness_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc index 3a44eb7db75..8ce61a828f6 100644 --- a/tensorflow/compiler/jit/deadness_analysis_test.cc +++ b/tensorflow/compiler/jit/deadness_analysis_test.cc @@ -638,7 +638,22 @@ TEST(DeadnessAnalysisTest, ControlEquivalentLoopBodies) { } { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ false)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)], + predicate_map[ControlOutputFor(iv.induction_var)]); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)], + predicate_map[ControlOutputFor(iv.induction_var)]); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + predicate_map[ControlOutputFor(iv.induction_var)]); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], "{#true,&,*iv0/cond:0}"); @@ -660,16 +675,6 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { CreateDependentLoopInvariantValue(root, "div0", "frame", iv.loop_cond, 0); FixupSourceAndSinkEdges(root.graph()); - // To make deadness analysis think that dependent_iv is a loop we need an RPO - // that visits the merge before the backedge. This is a legal RPO for - // deadness analysis since it ignores NextIteration->Merge edges during RPO. - // Right now dependent_iv has an edge from Merge to NextIteration so do the - // RPO with this edge in place. Then remove this edge to get our test case. - std::vector rpo; - GetReversePostOrder(*root.graph(), &rpo, /*stable_comparator=*/{}, - /*edge_filter=*/[](const Edge& edge) { - return !edge.src()->IsNextIteration(); - }); TF_ASSERT_OK(root.graph()->UpdateEdge( iv.induction_var.node(), 0, dependent_iv.latch.output_true.node(), 0)); @@ -677,7 +682,16 @@ TEST(DeadnessAnalysisTest, LoopInvariantPredicateOnBackedge) { { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), rpo, &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ false)); + + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ true)); EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv.induction_var)], "div0/iv:0"); @@ -731,7 +745,34 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { } { PredicateMapTy predicate_map; - TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map)); + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ false)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], + "{#true,&,*iv_outer/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)], + "{(*iv_outer/cond:0 & " + "{#true,&,*iv_outer/cond:0}),&,*iv_inner/" + "cond:0}"); + + // force_pessimistic = true or not should produce the same results because + // of fallback. However, note that the order of iv_inner/cond:0 and + // iv_inner/iv:0 is different because the optimistic approach does not + // create predicates for all merges and it can change the predicate id and + // hence the symbol order. + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)], + "{{#true,&,(iv_outer/iv:0 & " + "*iv_outer/cond:0)},&,(*iv_inner/cond:0 & " + "iv_inner/iv:0)}"); + EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], + predicate_map[ControlOutputFor(dependent_inner_iv0)]); + EXPECT_EQ(predicate_map[ControlOutputFor(add0)], + predicate_map[ControlOutputFor(dependent_inner_iv0)]); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ true)); EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)], "{#true,&,*iv_outer/cond:0}"); @@ -744,15 +785,10 @@ TEST(DeadnessAnalysisTest, ControlEquivalentNestedLoopBodies) { "{{#true,&,(iv_outer/iv:0 & " "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " "*iv_inner/cond:0)}"); - EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)], - "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " - "*iv_inner/cond:0)}"); + predicate_map[ControlOutputFor(dependent_inner_iv0)]); EXPECT_EQ(predicate_map[ControlOutputFor(add0)], - "{{#true,&,(iv_outer/iv:0 & " - "*iv_outer/cond:0)},&,(iv_inner/iv:0 & " - "*iv_inner/cond:0)}"); + predicate_map[ControlOutputFor(dependent_inner_iv0)]); } } @@ -817,6 +853,104 @@ TEST(DeadnessAnalysisTest, ControlNonEquivalentNestedLoopBodies) { } } +TEST(DeadnessAnalysisTest, NestedLoopBodiesWithACapture) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv_outer = + CreateInductionVariable(root, "iv_outer", "outer_loop", 0); + Output enter_constant_outer_loop = ops::internal::Enter( + root.WithOpName("constant_enter_outer_loop"), + ops::Const(root.WithOpName("constant"), 5), "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + ops::Switch inner_value(root.WithOpName("outer_is_live"), + enter_constant_outer_loop, iv_outer.loop_cond); + InductionVarInfo iv_inner = CreateInductionVariable( + root, "iv_inner", "inner_loop", inner_value.output_true); + + DependentInductionVar div0_outer = CreateDependentLoopInvariantValue( + root, "div0_outer", "outer_loop", iv_outer.loop_cond, 0); + DependentInductionVar div1_outer = CreateDependentLoopInvariantValue( + root, "div1_outer", "outer_loop", iv_outer.loop_cond, 0); + + DependentInductionVar div0_inner = CreateDependentLoopInvariantValue( + root, "div0_inner", "inner_loop", iv_inner.loop_cond, + div0_outer.induction_var); + DependentInductionVar div1_inner = CreateDependentLoopInvariantValue( + root, "div1_inner", "inner_loop", iv_inner.loop_cond, + div1_outer.induction_var); + + Output captured = ops::_Recv(root.WithOpName("captured"), DT_INT32, + "tensor_a", "sender", 0, "receiver"); + Output capture_enter_outer = ops::internal::Enter( + root.WithOpName("capture_enter_outer"), captured, "outer_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + Output capture_enter_inner = ops::internal::Enter( + root.WithOpName("capture_enter_inner"), capture_enter_outer, "inner_loop", + ops::internal::Enter::Attrs().IsConstant(true)); + Output mul0 = ops::Mul(root.WithOpName("mul0"), div1_inner.induction_var, + capture_enter_inner); + TF_ASSERT_OK(root.graph()->UpdateEdge( + mul0.node(), 0, div1_inner.latch.output_true.node(), 0)); + + Output add0 = ops::Add(root.WithOpName("add0"), div0_inner.induction_var, + div1_inner.induction_var); + + VLogGraphIfAsked(*root.graph()); + + { + std::unique_ptr result; + TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result)); + + TF_ASSERT_OK_AND_ASSIGN( + bool has_inputs_with_mismatching_deadness, + HasInputsWithMismatchingDeadness(*result, *add0.node())); + EXPECT_TRUE(has_inputs_with_mismatching_deadness); + } +} + +TEST(DeadnessAnalysisTest, CyclicRecurrence) { + Scope root = Scope::NewRootScope().ExitOnError(); + InductionVarInfo iv = CreateInductionVariable(root, "iv0", "loop", 0); + DependentInductionVar div0 = + CreateDependentLoopInvariantValue(root, "div0", "loop", iv.loop_cond, 0); + DependentInductionVar div1 = + CreateDependentLoopInvariantValue(root, "div1", "loop", iv.loop_cond, 0); + FixupSourceAndSinkEdges(root.graph()); + TF_ASSERT_OK(root.graph()->UpdateEdge(div1.induction_var.node(), 0, + div0.latch.output_true.node(), 0)); + TF_ASSERT_OK(root.graph()->UpdateEdge(div0.induction_var.node(), 0, + div1.latch.output_true.node(), 0)); + + VLogGraphIfAsked(*root.graph()); + + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ false)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], + "{#true,&,*iv0/cond:0}"); + + // This tests the rule {S,&,X} & ~X => S. + TensorId switch_false_out = {div1.latch.output_false.node()->name(), + div1.latch.output_false.index()}; + EXPECT_EQ(predicate_map[switch_false_out], "(#true)"); + } + { + PredicateMapTy predicate_map; + TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map, + /*force_pessimistic*/ true)); + + EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)], + "{#true,&,*iv0/cond:0}"); + EXPECT_EQ(predicate_map[ControlOutputFor(div0.induction_var)], "div0/iv:0"); + EXPECT_EQ(predicate_map[ControlOutputFor(div1.induction_var)], "div1/iv:0"); + } +} + TEST(DeadnessAnalysisTest, AndRecurrenceNeedsFrameName) { Scope root = Scope::NewRootScope().ExitOnError(); InductionVarInfo iv_0 = CreateInductionVariable(root, "iv_0", "frame_0", 10);