From d96f268804e9bbbe0032a9bf64ff74d2b918d494 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 26 Jan 2021 14:41:57 -0800 Subject: [PATCH] [XLA] Automated g4 rollback of changelist 352846603. *** Reason for rollback *** fix breakage due to transitive property of liverange analysis after rollback Original CL description: [XLA] remove extraneous copies in copy_insertion related to nested conditionals and while loops. The change increases the precision of LiveRangeBefore analysis inside copy_insertion to accommodate disjoint branches inside conditionals that never overlap. The breakage is due to the fact that when we allow def-use values that are in exclusive conditional branches to share buffers, the LiveRangeBefore relation is no longer transitive. In particular, suppose op_a's live range is before that of op_b, and live range of ob_b is before that of op_c, we may not have live range of op_a before op_c, because op_a and op_c may be in the same branch and overlapping with each other. This is fixed by modifying copy_insertion.cc to check all related HloValues without assuming they are ordered. This will lengthen the compilation time a bit, but because the number of copy instructions removed are fairly limited, the cost should be negligible. PiperOrigin-RevId: 353953760 Change-Id: Ia110e1a13047bf1d3dec37668bbe21fb10b47a5f --- .../compiler/xla/service/copy_insertion.cc | 53 +-- .../xla/service/copy_insertion_test.cc | 162 +++++++++ .../compiler/xla/service/hlo_ordering.cc | 309 ++++++++++++------ .../compiler/xla/service/hlo_ordering.h | 31 +- .../compiler/xla/service/hlo_ordering_test.cc | 8 +- 5 files changed, 432 insertions(+), 131 deletions(-) diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index fb71bdfd2d2..80511dddea0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -734,10 +734,19 @@ class CopyRemover { // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} // // Removing the copy eliminates d_0, and uses of d_0 become uses of - // s_x. In the above ordering, the live range of d_m must be ordered + // s_x. In the above ordering, the live range of d_m will be ordered // before the live range of s_{x+1} and the definition and all uses of - // s_x must be ordered before the definition of d_1. These conditions - // are checked below prior to elision. + // s_x will be ordered before the definition of d_1. To make sure the + // copy elision is safe, the following code checks that this ordering is + // valid --- in particular we check it is safe to order d_m ahead of all + // the liverages at and after x_{x+1}, and it is safe to order all uses + // of s_x before the definition of d_1, by checking the live range + // constraints for each pair --- we cannot skip the later checks because + // the live range ordering is not guranteed to be transitive --- while it + // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging + // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3, + // because the exclusiveness relation of non-overlapping computations is + // not transitive. // // ** Technically it might be possible to have a non-interfering // non-trivial interleaving of the values of the source and @@ -747,8 +756,8 @@ class CopyRemover { // buffer (d_1 through d_m) are spliced into the point where the copy // used to be. VLOG(2) << copy->name() << " defines the first value in its buffer"; - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { + for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; + next_dest = Next(*next_dest)) { // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); if (!LiveRangeBefore(*src, *next_dest)) { VLOG(2) << "Not removing the copy: live range of " @@ -757,9 +766,8 @@ class CopyRemover { return false; } } - ValueNode* next_src = Next(*src); - - if (next_src != nullptr) { + for (ValueNode* next_src = Next(*src); next_src != nullptr; + next_src = Next(*next_src)) { // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. ValueNode* last_dest = dest->prev; DCHECK(IsTail(*last_dest)); @@ -790,20 +798,21 @@ class CopyRemover { VLOG(2) << copy->name() << " copies the last value (" << src->value->ToShortString() << ") in its buffer"; - ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - DCHECK(prev_dest != nullptr); ValueNode* first_src = src->next; DCHECK(IsHead(*first_src)); - if (!LiveRangeBefore(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. - VLOG(2) << "Not removing the copy: live range of " - << prev_dest->value->ToShortString() << " is not before " - << first_src->value->ToShortString(); - return false; + for (ValueNode* prev_dest = Prev(*dest); + // nullptr condition handled above in the first 'if' case. + prev_dest != nullptr; prev_dest = Prev(*prev_dest)) { + if (!LiveRangeBefore(*prev_dest, *first_src)) { + // Live range of value d_{y-1} is not before s_0. + VLOG(2) << "Not removing the copy: live range of " + << prev_dest->value->ToShortString() << " is not before " + << first_src->value->ToShortString(); + return false; + } } - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { + for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; + next_dest = Next(*next_dest)) { if (!LiveRangeBefore(*src, *next_dest)) { // Live range of value s_n is not before d_{y+1}. VLOG(2) << "Not removing the copy: live range of " @@ -814,7 +823,7 @@ class CopyRemover { } // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, prev_dest); + SpliceAfter(first_src, Prev(*dest)); } else { VLOG(2) << copy->name() << " copies value in middle of source buffer to value in middle " @@ -880,9 +889,7 @@ class CopyRemover { VLOG(2) << "Empty uses for " << *a.value; return ordering_.IsDefinedBefore(*a.value, *b.value); } - return absl::c_all_of(a.uses, [&](const HloUse* use) { - return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); - }); + return ordering_.UsesBeforeValueDefinition(a.uses, *b.value, dataflow_); } // Returns whether 'node' is the last node in its list. diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 23d3be6e17d..04da56f57df 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2473,6 +2473,101 @@ ENTRY TestComputation { op::While(op::Copy(op::Parameter()))); } +TEST_F(CopyInsertionTest, NestedWhileAndConditional2) { + const string& hlo_string = R"( +HloModule TestModule + +on_true + { + v1 = f32[2] parameter(0) + v2 = f32[2] add(v1,v1) + ROOT t1 = (f32[2], f32[2]) tuple(v1,v2) +} + +on_false + { + v1 = f32[2] parameter(0) + v2 = f32[2] multiply(v1,v1) + ROOT t2 = (f32[2], f32[2]) tuple(v1,v2) +} + +cond.outer { + param.1 = (pred[], f32[2], f32[2]) parameter(0) + ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 +} + +body.outer { + param.1 = (pred[], f32[2], f32[2]) parameter(0) + pred.1 = pred[] get-tuple-element(param.1), index=0 + arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 + if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + e1 = f32[2] get-tuple-element(if), index=0 + e2 = f32[2] get-tuple-element(if), index=1 + ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) +} + +ENTRY TestComputation { + entry_param.1 = pred[] parameter(0) + float_param = f32[2] parameter(1) + entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) + ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + VLOG(2) << module->ToString() << "\n"; + + // An extra copy must be kept inside the loop due to uses in the conditional. + EXPECT_EQ(CountCopies(*module), 3); +} + +TEST_F(CopyInsertionTest, NestedWhileAndConditional) { + const string& hlo_string = R"( +HloModule TestModule + +on_true + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] add(v1,v1) +} + +on_false + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] multiply(v1,v1) +} + +cond.outer { + param.1 = (pred[], f32[2]) parameter(0) + ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 +} + +body.outer { + param.1 = (pred[], f32[2]) parameter(0) + pred.1 = pred[] get-tuple-element(param.1), index=0 + arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 + if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + ROOT res = (pred[], f32[2]) tuple(pred.1,if) +} + +ENTRY TestComputation { + entry_param.1 = pred[] parameter(0) + float_param = f32[2] parameter(1) + entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param) + ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + VLOG(2) << module->ToString() << "\n"; + + // There should only be two copies inserted, and in the entry and exit of the + // computation. + EXPECT_EQ(CountCopies(*module), 2); +} + TEST_F(CopyInsertionTest, FixpointComputationRequired) { const string& hlo_string = R"( HloModule Module @@ -2782,5 +2877,72 @@ ENTRY main { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, NestedWhileAndConditional3) { + const string& hlo_string = R"( +HloModule TestModule + +on_true.1 + { + ROOT v1 = f32[2] parameter(0) +} + +on_false.1 + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] multiply(v1,v1) +} + +on_true + { + v1 = f32[2] parameter(0) + v2 = f32[2] add(v1,v1) + v3 = (f32[2],f32[2]) tuple(v1,v2) + v4 = f32[2] get-tuple-element(v3), index=1 + v5 = f32[2] multiply(v4,v2) + ROOT t1 = (f32[2], f32[2]) tuple(v5,v2) +} + +on_false + { + v1 = f32[2] parameter(0) + v2 = f32[2] multiply(v1,v1) + pred.1 = pred[] constant(true) + v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1 + v5 = f32[2] multiply(v4,v2) + ROOT t2 = (f32[2], f32[2]) tuple(v2,v5) + +} + +cond.outer { + param.1 = (pred[], f32[2], f32[2]) parameter(0) + ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 +} + +body.outer { + param.1 = (pred[], f32[2], f32[2]) parameter(0) + pred.1 = pred[] get-tuple-element(param.1), index=0 + arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 + if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + e1 = f32[2] get-tuple-element(if), index=0 + e2 = f32[2] get-tuple-element(if), index=1 + ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) +} + +ENTRY TestComputation { + entry_param.1 = pred[] parameter(0) + float_param = f32[2] parameter(1) + entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) + ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + VLOG(2) << module->ToString() << "\n"; + + // An extra copy must be kept inside the loop due to uses in the conditional + EXPECT_EQ(CountCopies(*module), 4); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index a4804a8faef..5ef592a4aa2 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -34,10 +34,28 @@ namespace xla { bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const { + switch (GetExecutionConstraint(a, b)) { + case ExecutionConstraint::kIsSame: // a and b are the same instruction; + return false; + case ExecutionConstraint::kRunBefore: + case ExecutionConstraint::kRunExclusiveBefore: + return true; + case ExecutionConstraint::kRunExclusiveAfter: + case ExecutionConstraint::kRunAfter: + case ExecutionConstraint::kUnordered: + return false; + } +} + +HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const { // 'a' and 'b' may be in different computations. In this case, find the // callgraph ancestor instructions which call (potentially transitively) the // computations containing 'a' and 'b' and use these ancestor instructions to // compare order. + if (a == b) { + return ExecutionConstraint::kIsSame; + } const HloInstruction* a_ancestor; const HloInstruction* b_ancestor; std::tie(a_ancestor, b_ancestor) = @@ -45,9 +63,10 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, const_cast(a), const_cast(b)); if (a_ancestor == nullptr) { - // Ancestors in a common computation could not be found so consider the - // instructions 'a' and 'b' to be unordered. - return false; + VLOG(4) << "Ancestors in a common computation could not be found between" + << a->ToString() << "\n and \n" + << b->ToString() << "\n so consider them to be unordered.\n"; + return ExecutionConstraint::kUnordered; } // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); @@ -62,7 +81,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloComputation* condition = a_ancestor->while_condition(); if (call_graph_->InstructionIsNestedIn(a, condition) && call_graph_->InstructionIsNestedIn(b, body)) { - return true; + return ExecutionConstraint::kRunBefore; } } @@ -85,17 +104,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, b_branch = j; } } - if (a_branch != -1 && a_branch < b_branch) { - return true; + // If neither a nor b is inside the branches they both are the ancestor. + if (a_branch == -1 && b_branch == -1) { + CHECK_EQ(a, a_ancestor); + CHECK_EQ(b, b_ancestor); + CHECK_EQ(a, b); + return ExecutionConstraint::kIsSame; } // If 'b' is the conditional ancestor, and 'a' is within a branch // computation, 'a' executes before 'b'. - if (b == a_ancestor && a_branch != -1) { - return true; + if (b_branch == -1) { + CHECK_EQ(b, a_ancestor); + return ExecutionConstraint::kRunBefore; + } + if (a_branch == -1) { + CHECK_EQ(a, a_ancestor); + return ExecutionConstraint::kRunAfter; + } + if (a_branch < b_branch) { + return ExecutionConstraint::kRunExclusiveBefore; + } + if (b_branch < a_branch) { + return ExecutionConstraint::kRunExclusiveAfter; } } - return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); + if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) { + return ExecutionConstraint::kRunBefore; + } + if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) { + return ExecutionConstraint::kRunAfter; + } + VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n" + << "and " << b->ToString() << " which are in the same computation\n"; + return ExecutionConstraint::kUnordered; } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { @@ -167,102 +209,169 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { } /* static */ -bool HloOrdering::UseIsBeforeValueDefinition( - const HloUse& use, const HloValue& value, +bool HloOrdering::UsesBeforeValueDefinition( + absl::Span uses, const HloValue& value, const HloDataflowAnalysis& dataflow) const { - VLOG(4) << "UseIsBeforeValueDefinition(use=" << use - << ", value=" << value.ToShortString() << ")"; - if (ExecutesBefore(use.instruction, value.defining_instruction())) { - VLOG(4) << " use instruction executes before value-defining instruction"; - return true; - } - - // If the use is at the instruction where the value is defined, then the use - // is before the def if the instruction allows buffer sharing (in place - // computation). - if (use.instruction == value.defining_instruction() && - dataflow.CanShareOperandBufferWithUser( - use.instruction->mutable_operand(use.operand_number), - use.operand_index, value.defining_instruction(), - value.defining_index())) { - VLOG(4) << " use is value def, and instruction can share use buffer"; - return true; - } - - // The use at a while is an input to a phi, and logically occurs before values - // are defined in the body. Note that the use is *not* before the value if the - // value is defined in the condition and is not the condition parameter, since - // the input of a while's life range is only ended at the start the body. - if (use.instruction->opcode() == HloOpcode::kWhile) { - const HloInstruction* xla_while = use.instruction; - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_body())) { - VLOG(4) << " use is while " << use.instruction->name() - << " and def is in body"; - return true; + bool has_use_in_exclusive_branches = false; + bool has_escaped_use_in_conditional = false; + auto UseIsBeforeValueDefinition = [&](const HloUse& use) { + VLOG(4) << "UseIsBeforeValueDefinition(use=" << use + << ", value=" << value.ToShortString() << ")"; + switch ( + GetExecutionConstraint(use.instruction, value.defining_instruction())) { + case HloOrdering::ExecutionConstraint::kIsSame: + // If the use is at the instruction where the value is defined, then the + // use is before the def if the instruction allows buffer sharing (in + // place computation). + if (dataflow.CanShareOperandBufferWithUser( + use.instruction->mutable_operand(use.operand_number), + use.operand_index, value.defining_instruction(), + value.defining_index())) { + VLOG(4) + << " use is value def, and instruction can share use buffer."; + return true; + } + break; + case HloOrdering::ExecutionConstraint::kRunExclusiveAfter: + // If the use is located in a branch that is exclusive to the branch + // where value is located, in order for them to interfere, there must be + // an execution path where the value's definition can reach the use, so + // that the wrong value would reach use if their live ranges are merged. + // If there is such a path, it would have to pass through the point + // where the two exclusive branches are joined --- specifically the end + // of the conditional operation. For the join point to reach back to the + // use at the other exclusive branch, there has to be a be a surrounding + // loop, where the result of the conditional is passed back inside the + // conditional through one of its parameters. This use-def conflict + // between the parameter of a conditional and one of its branches is + // caught in the has_escaped_use_in_conditinoal variable. + VLOG(4) << " use and value def are in exclusive branches."; + if (!has_escaped_use_in_conditional) { + has_use_in_exclusive_branches = true; + VLOG(4) << "Allowing them to share buffer.\n"; + return true; + } + VLOG(4) << "value def has escaped use in conditional. \n"; + break; + case HloOrdering::ExecutionConstraint::kRunExclusiveBefore: + case HloOrdering::ExecutionConstraint::kRunBefore: + VLOG(4) + << " use instruction executes before value-defining instruction"; + return true; + case HloOrdering::ExecutionConstraint::kRunAfter: + case HloOrdering::ExecutionConstraint::kUnordered: + break; } - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - xla_while->while_condition())) { - if (value.defining_instruction() != - xla_while->while_condition()->parameter_instruction(0)) { + + // The use at a while is an input to a phi, and logically occurs before + // values are defined in the body. Note that the use is *not* before the + // value if the value is defined in the condition and is not the condition + // parameter, since the input of a while's live range is only ended at the + // start the body. + if (use.instruction->opcode() == HloOpcode::kWhile) { + const HloInstruction* xla_while = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_body())) { VLOG(4) << " use is while " << use.instruction->name() - << " and def is in condition and is not the parameter"; - return false; - } else { - VLOG(4) << " use is while " << use.instruction->name() - << " and def is in condition and is the parameter"; + << " and def is in body"; + return true; + } + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + if (value.defining_instruction() != + xla_while->while_condition()->parameter_instruction(0)) { + VLOG(4) << " use is while " << use.instruction->name() + << " and def is in condition and is not the parameter"; + return false; + } else { + VLOG(4) << " use is while " << use.instruction->name() + << " and def is in condition and is the parameter"; + return true; + } + } + } + // Similarly if the value is defined at a while, it logically occurs after + // any uses in the body or condition computations. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + CHECK(value.is_phi()); + const HloInstruction* xla_while = value.defining_instruction(); + if (call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_condition())) { + VLOG(4) << " value is while " << value.defining_instruction()->name() + << " and use is in condition or body"; return true; } } - } - - // Similarly if the value is defined at a while, it logically occurs after any - // uses in the body or condition computations. - if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { - CHECK(value.is_phi()); - const HloInstruction* xla_while = value.defining_instruction(); - if (call_graph_->InstructionIsNestedIn(use.instruction, - xla_while->while_body()) || - call_graph_->InstructionIsNestedIn(use.instruction, - xla_while->while_condition())) { - VLOG(4) << " value is while " << value.defining_instruction()->name() - << " and use is in condition or body"; - return true; - } - } - - // The use at a call occurs before values that are defined in the called - // computation. - if (use.instruction->opcode() == HloOpcode::kCall) { - const HloInstruction* call = use.instruction; - if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), - call->to_apply())) { - VLOG(4) << " use is call " << use.instruction->name() - << " and def is in called computation"; - return true; - } - } - - if (use.instruction->opcode() == HloOpcode::kConditional) { - const HloInstruction* conditional = use.instruction; - for (int j = 0; j < conditional->branch_count(); ++j) { - if (call_graph_->InstructionIsNestedIn( - value.defining_instruction(), - conditional->branch_computation(j))) { - VLOG(4) << " use is conditional " << use.instruction->name() - << " and def is in " << j << "th branch computation"; + // The use at a call occurs before values that are defined in the called + // computation. + if (use.instruction->opcode() == HloOpcode::kCall) { + const HloInstruction* call = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + call->to_apply())) { + VLOG(4) << " use is call " << use.instruction->name() + << " and def is in called computation"; return true; } } - if (value.defining_instruction() == use.instruction) { - VLOG(4) << " use is conditional " << use << " and def is " - << value.ToShortString(); - return true; + if (use.instruction->opcode() == HloOpcode::kConditional) { + const HloInstruction* conditional = use.instruction; + // In general the use of a value in the conditional parameter should be + // considered to be before a definition in one of its branches, and + // therefore allowed in live range merging, if there is no + // surrounding loop that creates a backward control flow path that + // allows the definition in the branch to have its value flow backward + // into the conditional and then flow into another branch in the + // conditional that uses the value. This is reflected by checking that + // the use-def in exclusive branches has not been already allowed. + // Further, if the def value escapes its branch, we conservatively + // assume a backward control flow path could exist, and set + // has_escaped_use_in_conditinoal to disallow any later uses in + // exclusive branches. + for (int j = 0; j < conditional->branch_count(); ++j) { + if (call_graph_->InstructionIsNestedIn( + value.defining_instruction(), + conditional->branch_computation(j))) { + // If the use operand does not create a new value, and the value def + // is returned by as part of the result of the conditional, it + // is possible for the branch definition to flow backward through a + // surrounding loop and then back into the conditional parameter. + if (!dataflow.ValueIsDefinedAt( + use.instruction->operand(use.operand_number), {})) { + for (auto value_use : value.uses()) { + VLOG(4) << "def have use:" << value_use << "\n"; + if (value_use.instruction == + value_use.instruction->parent()->root_instruction()) { + VLOG(4) << "def use is conditional root \n"; + has_escaped_use_in_conditional = true; + break; + } + } + } + if (!has_use_in_exclusive_branches) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in " << j << "th branch computation"; + return true; + } + } + } + if (value.defining_instruction() == use.instruction) { + VLOG(4) << " use is conditional " << use << " and def is " + << value.ToShortString(); + return true; + } + } + + VLOG(4) << " use is not before value definition"; + return false; + }; + for (auto* use : uses) { + if (!UseIsBeforeValueDefinition(*use)) { + return false; } } - - VLOG(4) << " use is not before value"; - return false; + return true; } bool HloOrdering::LiveRangeStrictlyBefore( @@ -270,6 +379,7 @@ bool HloOrdering::LiveRangeStrictlyBefore( const HloDataflowAnalysis& dataflow) const { VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; + VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n"; if (!IsDefinedBefore(a, b)) { VLOG(4) << a << " not defined before " << b; return false; @@ -294,16 +404,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( } // All uses of 'a' must be before 'b' is defined. + std::vector uses; for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), use.instruction)) { continue; } - if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of " << a << " (" << use << ") not before " << b - << " is defined"; - return false; - } + uses.push_back(&use); + } + if (!UsesBeforeValueDefinition(uses, b, dataflow)) { + VLOG(4) << "uses of " << a << "not before " << b << " is defined"; + return false; } if (a.instruction()->parent() == b.instruction()->parent()) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index a07214c22c0..1075a78275f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -37,10 +37,30 @@ namespace xla { // determine live range overlap of HLO instruction output buffers. class HloOrdering { public: - HloOrdering(const HloModule* module) + explicit HloOrdering(const HloModule* module) : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; + // Specify the ordering constraints between a pair of instructions a and b. + enum class ExecutionConstraint { + // Indicate a and b are the same instruction; + kIsSame, + // Indicate a runs before b; + kRunBefore, + // Only one of a or b runs each time their common ancestor is evaluated, + // and a is in an earlier branch than b. + kRunExclusiveBefore, + // Only one of a or b runs each time, and a is in a later branch than b. + kRunExclusiveAfter, + // Indicate a runs after b + kRunAfter, + // An order cannot be detrermined as a and b do not have a common ancestor. + kUnordered, + }; + // Return the execution constraint between a and b. + HloOrdering::ExecutionConstraint GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const; + // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; @@ -51,8 +71,9 @@ class HloOrdering { // Returns whether the given use is before the given value definition under // the given ordering. - bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, - const HloDataflowAnalysis& dataflow) const; + bool UsesBeforeValueDefinition(absl::Span uses, + const HloValue& value, + const HloDataflowAnalysis& dataflow) const; // Returns whether the given values interfere. Two values interfere if they // may both be simultaneously live. bool MayInterfere(const HloValue& a, const HloValue& b, @@ -181,8 +202,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - SequentialHloOrdering(const HloSchedule& schedule); - SequentialHloOrdering(HloSchedule&& schedule); + explicit SequentialHloOrdering(const HloSchedule& schedule); + explicit SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation. diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index f8295d579fb..f1e4e0dda89 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -282,10 +282,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { dataflow->GetValueDefinedAt(add))); ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1); - const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0]; - EXPECT_EQ(while_use.instruction, add); - EXPECT_TRUE(ordering.UseIsBeforeValueDefinition( - while_use, dataflow->GetValueDefinedAt(add), *dataflow)); + const HloUse* while_use = &dataflow->GetValueDefinedAt(xla_while).uses()[0]; + EXPECT_EQ(while_use->instruction, add); + EXPECT_TRUE(ordering.UsesBeforeValueDefinition( + {&while_use, 1}, dataflow->GetValueDefinedAt(add), *dataflow)); EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add), *dataflow));