[XLA] Automated g4 rollback of changelist 352846603.
*** Reason for rollback *** fix breakage due to transitive property of liverange analysis after rollback Original CL description: [XLA] remove extraneous copies in copy_insertion related to nested conditionals and while loops. The change increases the precision of LiveRangeBefore analysis inside copy_insertion to accommodate disjoint branches inside conditionals that never overlap. The breakage is due to the fact that when we allow def-use values that are in exclusive conditional branches to share buffers, the LiveRangeBefore relation is no longer transitive. In particular, suppose op_a's live range is before that of op_b, and live range of ob_b is before that of op_c, we may not have live range of op_a before op_c, because op_a and op_c may be in the same branch and overlapping with each other. This is fixed by modifying copy_insertion.cc to check all related HloValues without assuming they are ordered. This will lengthen the compilation time a bit, but because the number of copy instructions removed are fairly limited, the cost should be negligible. PiperOrigin-RevId: 353953760 Change-Id: Ia110e1a13047bf1d3dec37668bbe21fb10b47a5f
This commit is contained in:
		
							parent
							
								
									a4c269747a
								
							
						
					
					
						commit
						d96f268804
					
				| @ -734,10 +734,19 @@ class CopyRemover { | ||||
|       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
 | ||||
|       //
 | ||||
|       // Removing the copy eliminates d_0, and uses of d_0 become uses of
 | ||||
|       // s_x. In the above ordering, the live range of d_m must be ordered
 | ||||
|       // s_x. In the above ordering, the live range of d_m will be ordered
 | ||||
|       // before the live range of s_{x+1} and the definition and all uses of
 | ||||
|       // s_x must be ordered before the definition of d_1. These conditions
 | ||||
|       // are checked below prior to elision.
 | ||||
|       // s_x will be ordered before the definition of d_1. To make sure the
 | ||||
|       // copy elision is safe, the following code checks that this ordering is
 | ||||
|       // valid --- in particular we check it is safe to order d_m ahead of all
 | ||||
|       // the liverages at and after x_{x+1}, and it is safe to order all uses
 | ||||
|       // of s_x before the definition of d_1, by checking the live range
 | ||||
|       // constraints for each pair --- we cannot skip the later checks because
 | ||||
|       // the live range ordering is not guranteed to be transitive --- while it
 | ||||
|       // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
 | ||||
|       // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
 | ||||
|       // because the exclusiveness relation of non-overlapping computations is
 | ||||
|       // not transitive.
 | ||||
|       //
 | ||||
|       // ** Technically it might be possible to have a non-interfering
 | ||||
|       //    non-trivial interleaving of the values of the source and
 | ||||
| @ -747,8 +756,8 @@ class CopyRemover { | ||||
|       //    buffer (d_1 through d_m) are spliced into the point where the copy
 | ||||
|       //    used to be.
 | ||||
|       VLOG(2) << copy->name() << " defines the first value in its buffer"; | ||||
|       ValueNode* next_dest = Next(*dest); | ||||
|       if (next_dest != nullptr) { | ||||
|       for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; | ||||
|            next_dest = Next(*next_dest)) { | ||||
|         // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
 | ||||
|         if (!LiveRangeBefore(*src, *next_dest)) { | ||||
|           VLOG(2) << "Not removing the copy: live range of " | ||||
| @ -757,9 +766,8 @@ class CopyRemover { | ||||
|           return false; | ||||
|         } | ||||
|       } | ||||
|       ValueNode* next_src = Next(*src); | ||||
| 
 | ||||
|       if (next_src != nullptr) { | ||||
|       for (ValueNode* next_src = Next(*src); next_src != nullptr; | ||||
|            next_src = Next(*next_src)) { | ||||
|         // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
 | ||||
|         ValueNode* last_dest = dest->prev; | ||||
|         DCHECK(IsTail(*last_dest)); | ||||
| @ -790,20 +798,21 @@ class CopyRemover { | ||||
|       VLOG(2) << copy->name() << " copies the last value (" | ||||
|               << src->value->ToShortString() << ") in its buffer"; | ||||
| 
 | ||||
|       ValueNode* prev_dest = Prev(*dest); | ||||
|       // nullptr condition handled above in the first 'if' case.
 | ||||
|       DCHECK(prev_dest != nullptr); | ||||
|       ValueNode* first_src = src->next; | ||||
|       DCHECK(IsHead(*first_src)); | ||||
|       if (!LiveRangeBefore(*prev_dest, *first_src)) { | ||||
|         // Live range of value d_{y-1} is not before s_0.
 | ||||
|         VLOG(2) << "Not removing the copy: live range of " | ||||
|                 << prev_dest->value->ToShortString() << " is not before " | ||||
|                 << first_src->value->ToShortString(); | ||||
|         return false; | ||||
|       for (ValueNode* prev_dest = Prev(*dest); | ||||
|            // nullptr condition handled above in the first 'if' case.
 | ||||
|            prev_dest != nullptr; prev_dest = Prev(*prev_dest)) { | ||||
|         if (!LiveRangeBefore(*prev_dest, *first_src)) { | ||||
|           // Live range of value d_{y-1} is not before s_0.
 | ||||
|           VLOG(2) << "Not removing the copy: live range of " | ||||
|                   << prev_dest->value->ToShortString() << " is not before " | ||||
|                   << first_src->value->ToShortString(); | ||||
|           return false; | ||||
|         } | ||||
|       } | ||||
|       ValueNode* next_dest = Next(*dest); | ||||
|       if (next_dest != nullptr) { | ||||
|       for (ValueNode* next_dest = Next(*dest); next_dest != nullptr; | ||||
|            next_dest = Next(*next_dest)) { | ||||
|         if (!LiveRangeBefore(*src, *next_dest)) { | ||||
|           // Live range of value s_n is not before d_{y+1}.
 | ||||
|           VLOG(2) << "Not removing the copy: live range of " | ||||
| @ -814,7 +823,7 @@ class CopyRemover { | ||||
|       } | ||||
| 
 | ||||
|       // Splice source buffer values list right after 'prev_dest'.
 | ||||
|       SpliceAfter(first_src, prev_dest); | ||||
|       SpliceAfter(first_src, Prev(*dest)); | ||||
|     } else { | ||||
|       VLOG(2) << copy->name() | ||||
|               << " copies value in middle of source buffer to value in middle " | ||||
| @ -880,9 +889,7 @@ class CopyRemover { | ||||
|       VLOG(2) << "Empty uses for " << *a.value; | ||||
|       return ordering_.IsDefinedBefore(*a.value, *b.value); | ||||
|     } | ||||
|     return absl::c_all_of(a.uses, [&](const HloUse* use) { | ||||
|       return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); | ||||
|     }); | ||||
|     return ordering_.UsesBeforeValueDefinition(a.uses, *b.value, dataflow_); | ||||
|   } | ||||
| 
 | ||||
|   // Returns whether 'node' is the last node in its list.
 | ||||
|  | ||||
| @ -2473,6 +2473,101 @@ ENTRY TestComputation { | ||||
|               op::While(op::Copy(op::Parameter()))); | ||||
| } | ||||
| 
 | ||||
| TEST_F(CopyInsertionTest, NestedWhileAndConditional2) { | ||||
|   const string& hlo_string = R"( | ||||
| HloModule TestModule | ||||
| 
 | ||||
| on_true | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   v2 = f32[2] add(v1,v1) | ||||
|   ROOT t1 = (f32[2], f32[2]) tuple(v1,v2) | ||||
| } | ||||
| 
 | ||||
| on_false | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   v2 = f32[2] multiply(v1,v1) | ||||
|   ROOT t2 = (f32[2], f32[2]) tuple(v1,v2) | ||||
| } | ||||
| 
 | ||||
| cond.outer { | ||||
|   param.1 = (pred[], f32[2], f32[2]) parameter(0) | ||||
|   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 | ||||
| } | ||||
| 
 | ||||
| body.outer { | ||||
|   param.1 = (pred[], f32[2], f32[2]) parameter(0) | ||||
|   pred.1 = pred[] get-tuple-element(param.1), index=0 | ||||
|   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 | ||||
|   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false | ||||
|   e1 = f32[2] get-tuple-element(if), index=0 | ||||
|   e2 = f32[2] get-tuple-element(if), index=1 | ||||
|   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) | ||||
| } | ||||
| 
 | ||||
| ENTRY TestComputation { | ||||
|   entry_param.1 = pred[] parameter(0) | ||||
|   float_param = f32[2] parameter(1) | ||||
|   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) | ||||
|   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer | ||||
| } | ||||
| )"; | ||||
|   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, | ||||
|                           ParseAndReturnVerifiedModule(hlo_string)); | ||||
|   InsertCopies(module.get()); | ||||
|   VLOG(2) << module->ToString() << "\n"; | ||||
| 
 | ||||
|   // An extra copy must be kept inside the loop due to uses in the conditional.
 | ||||
|   EXPECT_EQ(CountCopies(*module), 3); | ||||
| } | ||||
| 
 | ||||
| TEST_F(CopyInsertionTest, NestedWhileAndConditional) { | ||||
|   const string& hlo_string = R"( | ||||
| HloModule TestModule | ||||
| 
 | ||||
| on_true | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   ROOT v2 = f32[2] add(v1,v1) | ||||
| } | ||||
| 
 | ||||
| on_false | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   ROOT v2 = f32[2] multiply(v1,v1) | ||||
| } | ||||
| 
 | ||||
| cond.outer { | ||||
|   param.1 = (pred[], f32[2]) parameter(0) | ||||
|   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 | ||||
| } | ||||
| 
 | ||||
| body.outer { | ||||
|   param.1 = (pred[], f32[2]) parameter(0) | ||||
|   pred.1 = pred[] get-tuple-element(param.1), index=0 | ||||
|   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 | ||||
|   if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false | ||||
|   ROOT res = (pred[], f32[2]) tuple(pred.1,if) | ||||
| } | ||||
| 
 | ||||
| ENTRY TestComputation { | ||||
|   entry_param.1 = pred[] parameter(0) | ||||
|   float_param = f32[2] parameter(1) | ||||
|   entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param) | ||||
|   ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer | ||||
| } | ||||
| )"; | ||||
|   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, | ||||
|                           ParseAndReturnVerifiedModule(hlo_string)); | ||||
|   InsertCopies(module.get()); | ||||
|   VLOG(2) << module->ToString() << "\n"; | ||||
| 
 | ||||
|   // There should only be two copies inserted, and in the entry and exit of the
 | ||||
|   // computation.
 | ||||
|   EXPECT_EQ(CountCopies(*module), 2); | ||||
| } | ||||
| 
 | ||||
| TEST_F(CopyInsertionTest, FixpointComputationRequired) { | ||||
|   const string& hlo_string = R"( | ||||
| HloModule Module | ||||
| @ -2782,5 +2877,72 @@ ENTRY main { | ||||
|   EXPECT_EQ(CountCopies(*module), 1); | ||||
| } | ||||
| 
 | ||||
| TEST_F(CopyInsertionTest, NestedWhileAndConditional3) { | ||||
|   const string& hlo_string = R"( | ||||
| HloModule TestModule | ||||
| 
 | ||||
| on_true.1 | ||||
|  { | ||||
|   ROOT v1 = f32[2] parameter(0) | ||||
| } | ||||
| 
 | ||||
| on_false.1 | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   ROOT v2 = f32[2] multiply(v1,v1) | ||||
| } | ||||
| 
 | ||||
| on_true | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   v2 = f32[2] add(v1,v1) | ||||
|   v3 = (f32[2],f32[2]) tuple(v1,v2) | ||||
|   v4 = f32[2] get-tuple-element(v3), index=1 | ||||
|   v5 = f32[2] multiply(v4,v2) | ||||
|    ROOT t1 = (f32[2], f32[2]) tuple(v5,v2) | ||||
| } | ||||
| 
 | ||||
| on_false | ||||
|  { | ||||
|   v1 = f32[2] parameter(0) | ||||
|   v2 = f32[2] multiply(v1,v1) | ||||
|   pred.1 = pred[] constant(true) | ||||
|   v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1 | ||||
|   v5 = f32[2] multiply(v4,v2) | ||||
|   ROOT t2 = (f32[2], f32[2]) tuple(v2,v5) | ||||
|    | ||||
| } | ||||
| 
 | ||||
| cond.outer { | ||||
|   param.1 = (pred[], f32[2], f32[2]) parameter(0) | ||||
|   ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 | ||||
| } | ||||
| 
 | ||||
| body.outer { | ||||
|   param.1 = (pred[], f32[2], f32[2]) parameter(0) | ||||
|   pred.1 = pred[] get-tuple-element(param.1), index=0 | ||||
|   arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 | ||||
|   if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false | ||||
|   e1 = f32[2] get-tuple-element(if), index=0 | ||||
|   e2 = f32[2] get-tuple-element(if), index=1 | ||||
|   ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2) | ||||
| } | ||||
| 
 | ||||
| ENTRY TestComputation { | ||||
|   entry_param.1 = pred[] parameter(0) | ||||
|   float_param = f32[2] parameter(1) | ||||
|   entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param) | ||||
|   ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer | ||||
| } | ||||
| )"; | ||||
|   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, | ||||
|                           ParseAndReturnVerifiedModule(hlo_string)); | ||||
|   InsertCopies(module.get()); | ||||
|   VLOG(2) << module->ToString() << "\n"; | ||||
| 
 | ||||
|   // An extra copy must be kept inside the loop due to uses in the conditional
 | ||||
|   EXPECT_EQ(CountCopies(*module), 4); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -34,10 +34,28 @@ namespace xla { | ||||
| 
 | ||||
| bool HloOrdering::ExecutesBefore(const HloInstruction* a, | ||||
|                                  const HloInstruction* b) const { | ||||
|   switch (GetExecutionConstraint(a, b)) { | ||||
|     case ExecutionConstraint::kIsSame:  // a and b are the same instruction;
 | ||||
|       return false; | ||||
|     case ExecutionConstraint::kRunBefore: | ||||
|     case ExecutionConstraint::kRunExclusiveBefore: | ||||
|       return true; | ||||
|     case ExecutionConstraint::kRunExclusiveAfter: | ||||
|     case ExecutionConstraint::kRunAfter: | ||||
|     case ExecutionConstraint::kUnordered: | ||||
|       return false; | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint( | ||||
|     const HloInstruction* a, const HloInstruction* b) const { | ||||
|   // 'a' and 'b' may be in different computations. In this case, find the
 | ||||
|   // callgraph ancestor instructions which call (potentially transitively) the
 | ||||
|   // computations containing 'a' and 'b' and use these ancestor instructions to
 | ||||
|   // compare order.
 | ||||
|   if (a == b) { | ||||
|     return ExecutionConstraint::kIsSame; | ||||
|   } | ||||
|   const HloInstruction* a_ancestor; | ||||
|   const HloInstruction* b_ancestor; | ||||
|   std::tie(a_ancestor, b_ancestor) = | ||||
| @ -45,9 +63,10 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, | ||||
|           const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b)); | ||||
| 
 | ||||
|   if (a_ancestor == nullptr) { | ||||
|     // Ancestors in a common computation could not be found so consider the
 | ||||
|     // instructions 'a' and 'b' to be unordered.
 | ||||
|     return false; | ||||
|     VLOG(4) << "Ancestors in a common computation could not be found between" | ||||
|             << a->ToString() << "\n and \n" | ||||
|             << b->ToString() << "\n so consider them to be unordered.\n"; | ||||
|     return ExecutionConstraint::kUnordered; | ||||
|   } | ||||
|   // a_ancestor and b_ancestor must be either both null or both non-null.
 | ||||
|   CHECK_NE(b_ancestor, nullptr); | ||||
| @ -62,7 +81,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, | ||||
|     const HloComputation* condition = a_ancestor->while_condition(); | ||||
|     if (call_graph_->InstructionIsNestedIn(a, condition) && | ||||
|         call_graph_->InstructionIsNestedIn(b, body)) { | ||||
|       return true; | ||||
|       return ExecutionConstraint::kRunBefore; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
| @ -85,17 +104,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, | ||||
|         b_branch = j; | ||||
|       } | ||||
|     } | ||||
|     if (a_branch != -1 && a_branch < b_branch) { | ||||
|       return true; | ||||
|     // If neither a nor b is inside the branches they both are the ancestor.
 | ||||
|     if (a_branch == -1 && b_branch == -1) { | ||||
|       CHECK_EQ(a, a_ancestor); | ||||
|       CHECK_EQ(b, b_ancestor); | ||||
|       CHECK_EQ(a, b); | ||||
|       return ExecutionConstraint::kIsSame; | ||||
|     } | ||||
|     // If 'b' is the conditional ancestor, and 'a' is within a branch
 | ||||
|     // computation, 'a' executes before 'b'.
 | ||||
|     if (b == a_ancestor && a_branch != -1) { | ||||
|       return true; | ||||
|     if (b_branch == -1) { | ||||
|       CHECK_EQ(b, a_ancestor); | ||||
|       return ExecutionConstraint::kRunBefore; | ||||
|     } | ||||
|     if (a_branch == -1) { | ||||
|       CHECK_EQ(a, a_ancestor); | ||||
|       return ExecutionConstraint::kRunAfter; | ||||
|     } | ||||
|     if (a_branch < b_branch) { | ||||
|       return ExecutionConstraint::kRunExclusiveBefore; | ||||
|     } | ||||
|     if (b_branch < a_branch) { | ||||
|       return ExecutionConstraint::kRunExclusiveAfter; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); | ||||
|   if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) { | ||||
|     return ExecutionConstraint::kRunBefore; | ||||
|   } | ||||
|   if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) { | ||||
|     return ExecutionConstraint::kRunAfter; | ||||
|   } | ||||
|   VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n" | ||||
|           << "and " << b->ToString() << " which are in the same computation\n"; | ||||
|   return ExecutionConstraint::kUnordered; | ||||
| } | ||||
| 
 | ||||
| bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { | ||||
| @ -167,102 +209,169 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { | ||||
| } | ||||
| 
 | ||||
| /* static */ | ||||
| bool HloOrdering::UseIsBeforeValueDefinition( | ||||
|     const HloUse& use, const HloValue& value, | ||||
| bool HloOrdering::UsesBeforeValueDefinition( | ||||
|     absl::Span<const HloUse* const> uses, const HloValue& value, | ||||
|     const HloDataflowAnalysis& dataflow) const { | ||||
|   VLOG(4) << "UseIsBeforeValueDefinition(use=" << use | ||||
|           << ", value=" << value.ToShortString() << ")"; | ||||
|   if (ExecutesBefore(use.instruction, value.defining_instruction())) { | ||||
|     VLOG(4) << "  use instruction executes before value-defining instruction"; | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   // If the use is at the instruction where the value is defined, then the use
 | ||||
|   // is before the def if the instruction allows buffer sharing (in place
 | ||||
|   // computation).
 | ||||
|   if (use.instruction == value.defining_instruction() && | ||||
|       dataflow.CanShareOperandBufferWithUser( | ||||
|           use.instruction->mutable_operand(use.operand_number), | ||||
|           use.operand_index, value.defining_instruction(), | ||||
|           value.defining_index())) { | ||||
|     VLOG(4) << "  use is value def, and instruction can share use buffer"; | ||||
|     return true; | ||||
|   } | ||||
| 
 | ||||
|   // The use at a while is an input to a phi, and logically occurs before values
 | ||||
|   // are defined in the body. Note that the use is *not* before the value if the
 | ||||
|   // value is defined in the condition and is not the condition parameter, since
 | ||||
|   // the input of a while's life range is only ended at the start the body.
 | ||||
|   if (use.instruction->opcode() == HloOpcode::kWhile) { | ||||
|     const HloInstruction* xla_while = use.instruction; | ||||
|     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                            xla_while->while_body())) { | ||||
|       VLOG(4) << "  use is while " << use.instruction->name() | ||||
|               << " and def is in body"; | ||||
|       return true; | ||||
|   bool has_use_in_exclusive_branches = false; | ||||
|   bool has_escaped_use_in_conditional = false; | ||||
|   auto UseIsBeforeValueDefinition = [&](const HloUse& use) { | ||||
|     VLOG(4) << "UseIsBeforeValueDefinition(use=" << use | ||||
|             << ", value=" << value.ToShortString() << ")"; | ||||
|     switch ( | ||||
|         GetExecutionConstraint(use.instruction, value.defining_instruction())) { | ||||
|       case HloOrdering::ExecutionConstraint::kIsSame: | ||||
|         // If the use is at the instruction where the value is defined, then the
 | ||||
|         // use is before the def if the instruction allows buffer sharing (in
 | ||||
|         // place computation).
 | ||||
|         if (dataflow.CanShareOperandBufferWithUser( | ||||
|                 use.instruction->mutable_operand(use.operand_number), | ||||
|                 use.operand_index, value.defining_instruction(), | ||||
|                 value.defining_index())) { | ||||
|           VLOG(4) | ||||
|               << "  use is value def, and instruction can share use buffer."; | ||||
|           return true; | ||||
|         } | ||||
|         break; | ||||
|       case HloOrdering::ExecutionConstraint::kRunExclusiveAfter: | ||||
|         // If the use is located in a branch that is exclusive to the branch
 | ||||
|         // where value is located, in order for them to interfere, there must be
 | ||||
|         // an execution path where the value's definition can reach the use, so
 | ||||
|         // that the wrong value would reach use if their live ranges are merged.
 | ||||
|         // If there is such a path, it would have to pass through the point
 | ||||
|         // where the two exclusive branches are joined --- specifically the end
 | ||||
|         // of the conditional operation. For the join point to reach back to the
 | ||||
|         // use at the other exclusive branch, there has to be a be a surrounding
 | ||||
|         // loop, where the result of the conditional is passed back inside the
 | ||||
|         // conditional through one of its parameters. This use-def conflict
 | ||||
|         // between the parameter of a conditional and one of its branches is
 | ||||
|         // caught in the has_escaped_use_in_conditinoal variable.
 | ||||
|         VLOG(4) << " use and value def are in exclusive branches."; | ||||
|         if (!has_escaped_use_in_conditional) { | ||||
|           has_use_in_exclusive_branches = true; | ||||
|           VLOG(4) << "Allowing them to share buffer.\n"; | ||||
|           return true; | ||||
|         } | ||||
|         VLOG(4) << "value def has escaped use in conditional. \n"; | ||||
|         break; | ||||
|       case HloOrdering::ExecutionConstraint::kRunExclusiveBefore: | ||||
|       case HloOrdering::ExecutionConstraint::kRunBefore: | ||||
|         VLOG(4) | ||||
|             << "  use instruction executes before value-defining instruction"; | ||||
|         return true; | ||||
|       case HloOrdering::ExecutionConstraint::kRunAfter: | ||||
|       case HloOrdering::ExecutionConstraint::kUnordered: | ||||
|         break; | ||||
|     } | ||||
|     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                            xla_while->while_condition())) { | ||||
|       if (value.defining_instruction() != | ||||
|           xla_while->while_condition()->parameter_instruction(0)) { | ||||
| 
 | ||||
|     // The use at a while is an input to a phi, and logically occurs before
 | ||||
|     // values are defined in the body. Note that the use is *not* before the
 | ||||
|     // value if the value is defined in the condition and is not the condition
 | ||||
|     // parameter, since the input of a while's live range is only ended at the
 | ||||
|     // start the body.
 | ||||
|     if (use.instruction->opcode() == HloOpcode::kWhile) { | ||||
|       const HloInstruction* xla_while = use.instruction; | ||||
|       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                              xla_while->while_body())) { | ||||
|         VLOG(4) << "  use is while " << use.instruction->name() | ||||
|                 << " and def is in condition and is not the parameter"; | ||||
|         return false; | ||||
|       } else { | ||||
|         VLOG(4) << "  use is while " << use.instruction->name() | ||||
|                 << " and def is in condition and is the parameter"; | ||||
|                 << " and def is in body"; | ||||
|         return true; | ||||
|       } | ||||
|       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                              xla_while->while_condition())) { | ||||
|         if (value.defining_instruction() != | ||||
|             xla_while->while_condition()->parameter_instruction(0)) { | ||||
|           VLOG(4) << "  use is while " << use.instruction->name() | ||||
|                   << " and def is in condition and is not the parameter"; | ||||
|           return false; | ||||
|         } else { | ||||
|           VLOG(4) << "  use is while " << use.instruction->name() | ||||
|                   << " and def is in condition and is the parameter"; | ||||
|           return true; | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     // Similarly if the value is defined at a while, it logically occurs after
 | ||||
|     // any uses in the body or condition computations.
 | ||||
|     if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { | ||||
|       CHECK(value.is_phi()); | ||||
|       const HloInstruction* xla_while = value.defining_instruction(); | ||||
|       if (call_graph_->InstructionIsNestedIn(use.instruction, | ||||
|                                              xla_while->while_body()) || | ||||
|           call_graph_->InstructionIsNestedIn(use.instruction, | ||||
|                                              xla_while->while_condition())) { | ||||
|         VLOG(4) << "  value is while " << value.defining_instruction()->name() | ||||
|                 << " and use is in condition or body"; | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // Similarly if the value is defined at a while, it logically occurs after any
 | ||||
|   // uses in the body or condition computations.
 | ||||
|   if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { | ||||
|     CHECK(value.is_phi()); | ||||
|     const HloInstruction* xla_while = value.defining_instruction(); | ||||
|     if (call_graph_->InstructionIsNestedIn(use.instruction, | ||||
|                                            xla_while->while_body()) || | ||||
|         call_graph_->InstructionIsNestedIn(use.instruction, | ||||
|                                            xla_while->while_condition())) { | ||||
|       VLOG(4) << "  value is while " << value.defining_instruction()->name() | ||||
|               << " and use is in condition or body"; | ||||
|       return true; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   // The use at a call occurs before values that are defined in the called
 | ||||
|   // computation.
 | ||||
|   if (use.instruction->opcode() == HloOpcode::kCall) { | ||||
|     const HloInstruction* call = use.instruction; | ||||
|     if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                            call->to_apply())) { | ||||
|       VLOG(4) << "  use is call " << use.instruction->name() | ||||
|               << " and def is in called computation"; | ||||
|       return true; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   if (use.instruction->opcode() == HloOpcode::kConditional) { | ||||
|     const HloInstruction* conditional = use.instruction; | ||||
|     for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|       if (call_graph_->InstructionIsNestedIn( | ||||
|               value.defining_instruction(), | ||||
|               conditional->branch_computation(j))) { | ||||
|         VLOG(4) << "  use is conditional " << use.instruction->name() | ||||
|                 << " and def is in " << j << "th branch computation"; | ||||
|     // The use at a call occurs before values that are defined in the called
 | ||||
|     // computation.
 | ||||
|     if (use.instruction->opcode() == HloOpcode::kCall) { | ||||
|       const HloInstruction* call = use.instruction; | ||||
|       if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), | ||||
|                                              call->to_apply())) { | ||||
|         VLOG(4) << "  use is call " << use.instruction->name() | ||||
|                 << " and def is in called computation"; | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
|     if (value.defining_instruction() == use.instruction) { | ||||
|       VLOG(4) << "  use is conditional " << use << " and def is " | ||||
|               << value.ToShortString(); | ||||
|       return true; | ||||
|     if (use.instruction->opcode() == HloOpcode::kConditional) { | ||||
|       const HloInstruction* conditional = use.instruction; | ||||
|       // In general the use of a value in the conditional parameter should be
 | ||||
|       // considered to be before a definition in one of its branches, and
 | ||||
|       // therefore allowed in live range merging, if there is no
 | ||||
|       // surrounding loop that creates a backward control flow path that
 | ||||
|       // allows the definition in the branch to have its value flow backward
 | ||||
|       // into the conditional and then flow into another branch in the
 | ||||
|       // conditional that uses the value. This is reflected by checking that
 | ||||
|       // the use-def in exclusive branches has not been already allowed.
 | ||||
|       // Further, if the def value escapes its branch, we conservatively
 | ||||
|       // assume a backward control flow path could exist, and set
 | ||||
|       // has_escaped_use_in_conditinoal to disallow any later uses in
 | ||||
|       // exclusive branches.
 | ||||
|       for (int j = 0; j < conditional->branch_count(); ++j) { | ||||
|         if (call_graph_->InstructionIsNestedIn( | ||||
|                 value.defining_instruction(), | ||||
|                 conditional->branch_computation(j))) { | ||||
|           // If the use operand does not create a new value, and the value def
 | ||||
|           // is returned by as part of the result of the conditional, it
 | ||||
|           // is possible for the branch definition to flow backward through a
 | ||||
|           // surrounding loop and then back into the conditional parameter.
 | ||||
|           if (!dataflow.ValueIsDefinedAt( | ||||
|                   use.instruction->operand(use.operand_number), {})) { | ||||
|             for (auto value_use : value.uses()) { | ||||
|               VLOG(4) << "def have use:" << value_use << "\n"; | ||||
|               if (value_use.instruction == | ||||
|                   value_use.instruction->parent()->root_instruction()) { | ||||
|                 VLOG(4) << "def use is conditional root \n"; | ||||
|                 has_escaped_use_in_conditional = true; | ||||
|                 break; | ||||
|               } | ||||
|             } | ||||
|           } | ||||
|           if (!has_use_in_exclusive_branches) { | ||||
|             VLOG(4) << "  use is conditional " << use.instruction->name() | ||||
|                     << " and def is in " << j << "th branch computation"; | ||||
|             return true; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       if (value.defining_instruction() == use.instruction) { | ||||
|         VLOG(4) << "  use is conditional " << use << " and def is " | ||||
|                 << value.ToShortString(); | ||||
|         return true; | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     VLOG(4) << "  use is not before value definition"; | ||||
|     return false; | ||||
|   }; | ||||
|   for (auto* use : uses) { | ||||
|     if (!UseIsBeforeValueDefinition(*use)) { | ||||
|       return false; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   VLOG(4) << "  use is not before value"; | ||||
|   return false; | ||||
|   return true; | ||||
| } | ||||
| 
 | ||||
| bool HloOrdering::LiveRangeStrictlyBefore( | ||||
| @ -270,6 +379,7 @@ bool HloOrdering::LiveRangeStrictlyBefore( | ||||
|     const HloDataflowAnalysis& dataflow) const { | ||||
|   VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() | ||||
|           << ", b = " << b.ToShortString() << ")"; | ||||
|   VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n"; | ||||
|   if (!IsDefinedBefore(a, b)) { | ||||
|     VLOG(4) << a << " not defined before " << b; | ||||
|     return false; | ||||
| @ -294,16 +404,17 @@ bool HloOrdering::LiveRangeStrictlyBefore( | ||||
|   } | ||||
| 
 | ||||
|   // All uses of 'a' must be before 'b' is defined.
 | ||||
|   std::vector<const HloUse*> uses; | ||||
|   for (const HloUse& use : a.uses()) { | ||||
|     if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), | ||||
|                                          use.instruction)) { | ||||
|       continue; | ||||
|     } | ||||
|     if (!UseIsBeforeValueDefinition(use, b, dataflow)) { | ||||
|       VLOG(4) << "use of " << a << " (" << use << ") not before " << b | ||||
|               << " is defined"; | ||||
|       return false; | ||||
|     } | ||||
|     uses.push_back(&use); | ||||
|   } | ||||
|   if (!UsesBeforeValueDefinition(uses, b, dataflow)) { | ||||
|     VLOG(4) << "uses of " << a << "not before " << b << " is defined"; | ||||
|     return false; | ||||
|   } | ||||
| 
 | ||||
|   if (a.instruction()->parent() == b.instruction()->parent()) { | ||||
|  | ||||
| @ -37,10 +37,30 @@ namespace xla { | ||||
| // determine live range overlap of HLO instruction output buffers.
 | ||||
| class HloOrdering { | ||||
|  public: | ||||
|   HloOrdering(const HloModule* module) | ||||
|   explicit HloOrdering(const HloModule* module) | ||||
|       : module_(module), call_graph_(CallGraph::Build(module)) {} | ||||
|   virtual ~HloOrdering() = default; | ||||
| 
 | ||||
|   // Specify the ordering constraints between a pair of instructions a and b.
 | ||||
|   enum class ExecutionConstraint { | ||||
|     // Indicate a and b are the same instruction;
 | ||||
|     kIsSame, | ||||
|     // Indicate a runs before b;
 | ||||
|     kRunBefore, | ||||
|     // Only one of a or b runs each time their common ancestor is evaluated,
 | ||||
|     // and a is in an earlier branch than b.
 | ||||
|     kRunExclusiveBefore, | ||||
|     // Only one of a or b runs each time, and a is in a later branch than b.
 | ||||
|     kRunExclusiveAfter, | ||||
|     // Indicate a runs after b
 | ||||
|     kRunAfter, | ||||
|     // An order cannot be detrermined as a and b do not have a common ancestor.
 | ||||
|     kUnordered, | ||||
|   }; | ||||
|   // Return the execution constraint between a and b.
 | ||||
|   HloOrdering::ExecutionConstraint GetExecutionConstraint( | ||||
|       const HloInstruction* a, const HloInstruction* b) const; | ||||
| 
 | ||||
|   // Returns true if instruction 'a' executes before instruction 'b'. This is
 | ||||
|   // not reflexive, that is, an instruction does not execute before itself.
 | ||||
|   bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; | ||||
| @ -51,8 +71,9 @@ class HloOrdering { | ||||
| 
 | ||||
|   // Returns whether the given use is before the given value definition under
 | ||||
|   // the given ordering.
 | ||||
|   bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, | ||||
|                                   const HloDataflowAnalysis& dataflow) const; | ||||
|   bool UsesBeforeValueDefinition(absl::Span<const HloUse* const> uses, | ||||
|                                  const HloValue& value, | ||||
|                                  const HloDataflowAnalysis& dataflow) const; | ||||
|   // Returns whether the given values interfere. Two values interfere if they
 | ||||
|   // may both be simultaneously live.
 | ||||
|   bool MayInterfere(const HloValue& a, const HloValue& b, | ||||
| @ -181,8 +202,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { | ||||
| // interference is reduced relative to DependencyHloOrdering.
 | ||||
| class SequentialHloOrdering : public HloOrdering { | ||||
|  public: | ||||
|   SequentialHloOrdering(const HloSchedule& schedule); | ||||
|   SequentialHloOrdering(HloSchedule&& schedule); | ||||
|   explicit SequentialHloOrdering(const HloSchedule& schedule); | ||||
|   explicit SequentialHloOrdering(HloSchedule&& schedule); | ||||
|   ~SequentialHloOrdering() override = default; | ||||
| 
 | ||||
|   // Returns the sequential instruction order for the given computation.
 | ||||
|  | ||||
| @ -282,10 +282,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { | ||||
|                                        dataflow->GetValueDefinedAt(add))); | ||||
|   ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1); | ||||
| 
 | ||||
|   const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0]; | ||||
|   EXPECT_EQ(while_use.instruction, add); | ||||
|   EXPECT_TRUE(ordering.UseIsBeforeValueDefinition( | ||||
|       while_use, dataflow->GetValueDefinedAt(add), *dataflow)); | ||||
|   const HloUse* while_use = &dataflow->GetValueDefinedAt(xla_while).uses()[0]; | ||||
|   EXPECT_EQ(while_use->instruction, add); | ||||
|   EXPECT_TRUE(ordering.UsesBeforeValueDefinition( | ||||
|       {&while_use, 1}, dataflow->GetValueDefinedAt(add), *dataflow)); | ||||
|   EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( | ||||
|       dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add), | ||||
|       *dataflow)); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user