[XLA] Automated g4 rollback of changelist 352846603.
*** Reason for rollback *** fix breakage due to transitive property of liverange analysis after rollback Original CL description: [XLA] remove extraneous copies in copy_insertion related to nested conditionals and while loops. The change increases the precision of LiveRangeBefore analysis inside copy_insertion to accommodate disjoint branches inside conditionals that never overlap. The breakage is due to the fact that when we allow def-use values that are in exclusive conditional branches to share buffers, the LiveRangeBefore relation is no longer transitive. In particular, suppose op_a's live range is before that of op_b, and live range of ob_b is before that of op_c, we may not have live range of op_a before op_c, because op_a and op_c may be in the same branch and overlapping with each other. This is fixed by modifying copy_insertion.cc to check all related HloValues without assuming they are ordered. This will lengthen the compilation time a bit, but because the number of copy instructions removed are fairly limited, the cost should be negligible. PiperOrigin-RevId: 353953760 Change-Id: Ia110e1a13047bf1d3dec37668bbe21fb10b47a5f
This commit is contained in:
parent
a4c269747a
commit
d96f268804
@ -734,10 +734,19 @@ class CopyRemover {
|
||||
// {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
|
||||
//
|
||||
// Removing the copy eliminates d_0, and uses of d_0 become uses of
|
||||
// s_x. In the above ordering, the live range of d_m must be ordered
|
||||
// s_x. In the above ordering, the live range of d_m will be ordered
|
||||
// before the live range of s_{x+1} and the definition and all uses of
|
||||
// s_x must be ordered before the definition of d_1. These conditions
|
||||
// are checked below prior to elision.
|
||||
// s_x will be ordered before the definition of d_1. To make sure the
|
||||
// copy elision is safe, the following code checks that this ordering is
|
||||
// valid --- in particular we check it is safe to order d_m ahead of all
|
||||
// the liverages at and after x_{x+1}, and it is safe to order all uses
|
||||
// of s_x before the definition of d_1, by checking the live range
|
||||
// constraints for each pair --- we cannot skip the later checks because
|
||||
// the live range ordering is not guranteed to be transitive --- while it
|
||||
// may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
|
||||
// their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
|
||||
// because the exclusiveness relation of non-overlapping computations is
|
||||
// not transitive.
|
||||
//
|
||||
// ** Technically it might be possible to have a non-interfering
|
||||
// non-trivial interleaving of the values of the source and
|
||||
@ -747,8 +756,8 @@ class CopyRemover {
|
||||
// buffer (d_1 through d_m) are spliced into the point where the copy
|
||||
// used to be.
|
||||
VLOG(2) << copy->name() << " defines the first value in its buffer";
|
||||
ValueNode* next_dest = Next(*dest);
|
||||
if (next_dest != nullptr) {
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
// Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
@ -757,9 +766,8 @@ class CopyRemover {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ValueNode* next_src = Next(*src);
|
||||
|
||||
if (next_src != nullptr) {
|
||||
for (ValueNode* next_src = Next(*src); next_src != nullptr;
|
||||
next_src = Next(*next_src)) {
|
||||
// Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
|
||||
ValueNode* last_dest = dest->prev;
|
||||
DCHECK(IsTail(*last_dest));
|
||||
@ -790,20 +798,21 @@ class CopyRemover {
|
||||
VLOG(2) << copy->name() << " copies the last value ("
|
||||
<< src->value->ToShortString() << ") in its buffer";
|
||||
|
||||
ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
DCHECK(prev_dest != nullptr);
|
||||
ValueNode* first_src = src->next;
|
||||
DCHECK(IsHead(*first_src));
|
||||
if (!LiveRangeBefore(*prev_dest, *first_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< first_src->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
|
||||
if (!LiveRangeBefore(*prev_dest, *first_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< first_src->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ValueNode* next_dest = Next(*dest);
|
||||
if (next_dest != nullptr) {
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
// Live range of value s_n is not before d_{y+1}.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
@ -814,7 +823,7 @@ class CopyRemover {
|
||||
}
|
||||
|
||||
// Splice source buffer values list right after 'prev_dest'.
|
||||
SpliceAfter(first_src, prev_dest);
|
||||
SpliceAfter(first_src, Prev(*dest));
|
||||
} else {
|
||||
VLOG(2) << copy->name()
|
||||
<< " copies value in middle of source buffer to value in middle "
|
||||
@ -880,9 +889,7 @@ class CopyRemover {
|
||||
VLOG(2) << "Empty uses for " << *a.value;
|
||||
return ordering_.IsDefinedBefore(*a.value, *b.value);
|
||||
}
|
||||
return absl::c_all_of(a.uses, [&](const HloUse* use) {
|
||||
return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_);
|
||||
});
|
||||
return ordering_.UsesBeforeValueDefinition(a.uses, *b.value, dataflow_);
|
||||
}
|
||||
|
||||
// Returns whether 'node' is the last node in its list.
|
||||
|
||||
@ -2473,6 +2473,101 @@ ENTRY TestComputation {
|
||||
op::While(op::Copy(op::Parameter())));
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NestedWhileAndConditional2) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TestModule
|
||||
|
||||
on_true
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
v2 = f32[2] add(v1,v1)
|
||||
ROOT t1 = (f32[2], f32[2]) tuple(v1,v2)
|
||||
}
|
||||
|
||||
on_false
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
v2 = f32[2] multiply(v1,v1)
|
||||
ROOT t2 = (f32[2], f32[2]) tuple(v1,v2)
|
||||
}
|
||||
|
||||
cond.outer {
|
||||
param.1 = (pred[], f32[2], f32[2]) parameter(0)
|
||||
ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
|
||||
}
|
||||
|
||||
body.outer {
|
||||
param.1 = (pred[], f32[2], f32[2]) parameter(0)
|
||||
pred.1 = pred[] get-tuple-element(param.1), index=0
|
||||
arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
|
||||
if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
|
||||
e1 = f32[2] get-tuple-element(if), index=0
|
||||
e2 = f32[2] get-tuple-element(if), index=1
|
||||
ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
entry_param.1 = pred[] parameter(0)
|
||||
float_param = f32[2] parameter(1)
|
||||
entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
|
||||
ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
VLOG(2) << module->ToString() << "\n";
|
||||
|
||||
// An extra copy must be kept inside the loop due to uses in the conditional.
|
||||
EXPECT_EQ(CountCopies(*module), 3);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TestModule
|
||||
|
||||
on_true
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
ROOT v2 = f32[2] add(v1,v1)
|
||||
}
|
||||
|
||||
on_false
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
ROOT v2 = f32[2] multiply(v1,v1)
|
||||
}
|
||||
|
||||
cond.outer {
|
||||
param.1 = (pred[], f32[2]) parameter(0)
|
||||
ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
|
||||
}
|
||||
|
||||
body.outer {
|
||||
param.1 = (pred[], f32[2]) parameter(0)
|
||||
pred.1 = pred[] get-tuple-element(param.1), index=0
|
||||
arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
|
||||
if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
|
||||
ROOT res = (pred[], f32[2]) tuple(pred.1,if)
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
entry_param.1 = pred[] parameter(0)
|
||||
float_param = f32[2] parameter(1)
|
||||
entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
|
||||
ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
VLOG(2) << module->ToString() << "\n";
|
||||
|
||||
// There should only be two copies inserted, and in the entry and exit of the
|
||||
// computation.
|
||||
EXPECT_EQ(CountCopies(*module), 2);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FixpointComputationRequired) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule Module
|
||||
@ -2782,5 +2877,72 @@ ENTRY main {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TestModule
|
||||
|
||||
on_true.1
|
||||
{
|
||||
ROOT v1 = f32[2] parameter(0)
|
||||
}
|
||||
|
||||
on_false.1
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
ROOT v2 = f32[2] multiply(v1,v1)
|
||||
}
|
||||
|
||||
on_true
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
v2 = f32[2] add(v1,v1)
|
||||
v3 = (f32[2],f32[2]) tuple(v1,v2)
|
||||
v4 = f32[2] get-tuple-element(v3), index=1
|
||||
v5 = f32[2] multiply(v4,v2)
|
||||
ROOT t1 = (f32[2], f32[2]) tuple(v5,v2)
|
||||
}
|
||||
|
||||
on_false
|
||||
{
|
||||
v1 = f32[2] parameter(0)
|
||||
v2 = f32[2] multiply(v1,v1)
|
||||
pred.1 = pred[] constant(true)
|
||||
v4 = f32[2] conditional(pred.1, v1, v2), true_computation=on_true.1, false_computation=on_false.1
|
||||
v5 = f32[2] multiply(v4,v2)
|
||||
ROOT t2 = (f32[2], f32[2]) tuple(v2,v5)
|
||||
|
||||
}
|
||||
|
||||
cond.outer {
|
||||
param.1 = (pred[], f32[2], f32[2]) parameter(0)
|
||||
ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
|
||||
}
|
||||
|
||||
body.outer {
|
||||
param.1 = (pred[], f32[2], f32[2]) parameter(0)
|
||||
pred.1 = pred[] get-tuple-element(param.1), index=0
|
||||
arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
|
||||
if = (f32[2], f32[2]) conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
|
||||
e1 = f32[2] get-tuple-element(if), index=0
|
||||
e2 = f32[2] get-tuple-element(if), index=1
|
||||
ROOT res = (pred[], f32[2], f32[2]) tuple(pred.1,e1, e2)
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
entry_param.1 = pred[] parameter(0)
|
||||
float_param = f32[2] parameter(1)
|
||||
entry_param = (pred[], f32[2], f32[2]) tuple(entry_param.1, float_param, float_param)
|
||||
ROOT while = (pred[], f32[2], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
VLOG(2) << module->ToString() << "\n";
|
||||
|
||||
// An extra copy must be kept inside the loop due to uses in the conditional
|
||||
EXPECT_EQ(CountCopies(*module), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
@ -34,10 +34,28 @@ namespace xla {
|
||||
|
||||
bool HloOrdering::ExecutesBefore(const HloInstruction* a,
|
||||
const HloInstruction* b) const {
|
||||
switch (GetExecutionConstraint(a, b)) {
|
||||
case ExecutionConstraint::kIsSame: // a and b are the same instruction;
|
||||
return false;
|
||||
case ExecutionConstraint::kRunBefore:
|
||||
case ExecutionConstraint::kRunExclusiveBefore:
|
||||
return true;
|
||||
case ExecutionConstraint::kRunExclusiveAfter:
|
||||
case ExecutionConstraint::kRunAfter:
|
||||
case ExecutionConstraint::kUnordered:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
|
||||
const HloInstruction* a, const HloInstruction* b) const {
|
||||
// 'a' and 'b' may be in different computations. In this case, find the
|
||||
// callgraph ancestor instructions which call (potentially transitively) the
|
||||
// computations containing 'a' and 'b' and use these ancestor instructions to
|
||||
// compare order.
|
||||
if (a == b) {
|
||||
return ExecutionConstraint::kIsSame;
|
||||
}
|
||||
const HloInstruction* a_ancestor;
|
||||
const HloInstruction* b_ancestor;
|
||||
std::tie(a_ancestor, b_ancestor) =
|
||||
@ -45,9 +63,10 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
|
||||
const_cast<HloInstruction*>(a), const_cast<HloInstruction*>(b));
|
||||
|
||||
if (a_ancestor == nullptr) {
|
||||
// Ancestors in a common computation could not be found so consider the
|
||||
// instructions 'a' and 'b' to be unordered.
|
||||
return false;
|
||||
VLOG(4) << "Ancestors in a common computation could not be found between"
|
||||
<< a->ToString() << "\n and \n"
|
||||
<< b->ToString() << "\n so consider them to be unordered.\n";
|
||||
return ExecutionConstraint::kUnordered;
|
||||
}
|
||||
// a_ancestor and b_ancestor must be either both null or both non-null.
|
||||
CHECK_NE(b_ancestor, nullptr);
|
||||
@ -62,7 +81,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
|
||||
const HloComputation* condition = a_ancestor->while_condition();
|
||||
if (call_graph_->InstructionIsNestedIn(a, condition) &&
|
||||
call_graph_->InstructionIsNestedIn(b, body)) {
|
||||
return true;
|
||||
return ExecutionConstraint::kRunBefore;
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,17 +104,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
|
||||
b_branch = j;
|
||||
}
|
||||
}
|
||||
if (a_branch != -1 && a_branch < b_branch) {
|
||||
return true;
|
||||
// If neither a nor b is inside the branches they both are the ancestor.
|
||||
if (a_branch == -1 && b_branch == -1) {
|
||||
CHECK_EQ(a, a_ancestor);
|
||||
CHECK_EQ(b, b_ancestor);
|
||||
CHECK_EQ(a, b);
|
||||
return ExecutionConstraint::kIsSame;
|
||||
}
|
||||
// If 'b' is the conditional ancestor, and 'a' is within a branch
|
||||
// computation, 'a' executes before 'b'.
|
||||
if (b == a_ancestor && a_branch != -1) {
|
||||
return true;
|
||||
if (b_branch == -1) {
|
||||
CHECK_EQ(b, a_ancestor);
|
||||
return ExecutionConstraint::kRunBefore;
|
||||
}
|
||||
if (a_branch == -1) {
|
||||
CHECK_EQ(a, a_ancestor);
|
||||
return ExecutionConstraint::kRunAfter;
|
||||
}
|
||||
if (a_branch < b_branch) {
|
||||
return ExecutionConstraint::kRunExclusiveBefore;
|
||||
}
|
||||
if (b_branch < a_branch) {
|
||||
return ExecutionConstraint::kRunExclusiveAfter;
|
||||
}
|
||||
}
|
||||
|
||||
return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
|
||||
if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
|
||||
return ExecutionConstraint::kRunBefore;
|
||||
}
|
||||
if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
|
||||
return ExecutionConstraint::kRunAfter;
|
||||
}
|
||||
VLOG(1) << "Cannot determine order between:" << a->ToString() << "\n"
|
||||
<< "and " << b->ToString() << " which are in the same computation\n";
|
||||
return ExecutionConstraint::kUnordered;
|
||||
}
|
||||
|
||||
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
|
||||
@ -167,102 +209,169 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
|
||||
}
|
||||
|
||||
/* static */
|
||||
bool HloOrdering::UseIsBeforeValueDefinition(
|
||||
const HloUse& use, const HloValue& value,
|
||||
bool HloOrdering::UsesBeforeValueDefinition(
|
||||
absl::Span<const HloUse* const> uses, const HloValue& value,
|
||||
const HloDataflowAnalysis& dataflow) const {
|
||||
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
|
||||
<< ", value=" << value.ToShortString() << ")";
|
||||
if (ExecutesBefore(use.instruction, value.defining_instruction())) {
|
||||
VLOG(4) << " use instruction executes before value-defining instruction";
|
||||
return true;
|
||||
}
|
||||
|
||||
// If the use is at the instruction where the value is defined, then the use
|
||||
// is before the def if the instruction allows buffer sharing (in place
|
||||
// computation).
|
||||
if (use.instruction == value.defining_instruction() &&
|
||||
dataflow.CanShareOperandBufferWithUser(
|
||||
use.instruction->mutable_operand(use.operand_number),
|
||||
use.operand_index, value.defining_instruction(),
|
||||
value.defining_index())) {
|
||||
VLOG(4) << " use is value def, and instruction can share use buffer";
|
||||
return true;
|
||||
}
|
||||
|
||||
// The use at a while is an input to a phi, and logically occurs before values
|
||||
// are defined in the body. Note that the use is *not* before the value if the
|
||||
// value is defined in the condition and is not the condition parameter, since
|
||||
// the input of a while's life range is only ended at the start the body.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
const HloInstruction* xla_while = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_body())) {
|
||||
VLOG(4) << " use is while " << use.instruction->name()
|
||||
<< " and def is in body";
|
||||
return true;
|
||||
bool has_use_in_exclusive_branches = false;
|
||||
bool has_escaped_use_in_conditional = false;
|
||||
auto UseIsBeforeValueDefinition = [&](const HloUse& use) {
|
||||
VLOG(4) << "UseIsBeforeValueDefinition(use=" << use
|
||||
<< ", value=" << value.ToShortString() << ")";
|
||||
switch (
|
||||
GetExecutionConstraint(use.instruction, value.defining_instruction())) {
|
||||
case HloOrdering::ExecutionConstraint::kIsSame:
|
||||
// If the use is at the instruction where the value is defined, then the
|
||||
// use is before the def if the instruction allows buffer sharing (in
|
||||
// place computation).
|
||||
if (dataflow.CanShareOperandBufferWithUser(
|
||||
use.instruction->mutable_operand(use.operand_number),
|
||||
use.operand_index, value.defining_instruction(),
|
||||
value.defining_index())) {
|
||||
VLOG(4)
|
||||
<< " use is value def, and instruction can share use buffer.";
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
|
||||
// If the use is located in a branch that is exclusive to the branch
|
||||
// where value is located, in order for them to interfere, there must be
|
||||
// an execution path where the value's definition can reach the use, so
|
||||
// that the wrong value would reach use if their live ranges are merged.
|
||||
// If there is such a path, it would have to pass through the point
|
||||
// where the two exclusive branches are joined --- specifically the end
|
||||
// of the conditional operation. For the join point to reach back to the
|
||||
// use at the other exclusive branch, there has to be a be a surrounding
|
||||
// loop, where the result of the conditional is passed back inside the
|
||||
// conditional through one of its parameters. This use-def conflict
|
||||
// between the parameter of a conditional and one of its branches is
|
||||
// caught in the has_escaped_use_in_conditinoal variable.
|
||||
VLOG(4) << " use and value def are in exclusive branches.";
|
||||
if (!has_escaped_use_in_conditional) {
|
||||
has_use_in_exclusive_branches = true;
|
||||
VLOG(4) << "Allowing them to share buffer.\n";
|
||||
return true;
|
||||
}
|
||||
VLOG(4) << "value def has escaped use in conditional. \n";
|
||||
break;
|
||||
case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
|
||||
case HloOrdering::ExecutionConstraint::kRunBefore:
|
||||
VLOG(4)
|
||||
<< " use instruction executes before value-defining instruction";
|
||||
return true;
|
||||
case HloOrdering::ExecutionConstraint::kRunAfter:
|
||||
case HloOrdering::ExecutionConstraint::kUnordered:
|
||||
break;
|
||||
}
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_condition())) {
|
||||
if (value.defining_instruction() !=
|
||||
xla_while->while_condition()->parameter_instruction(0)) {
|
||||
|
||||
// The use at a while is an input to a phi, and logically occurs before
|
||||
// values are defined in the body. Note that the use is *not* before the
|
||||
// value if the value is defined in the condition and is not the condition
|
||||
// parameter, since the input of a while's live range is only ended at the
|
||||
// start the body.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
const HloInstruction* xla_while = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_body())) {
|
||||
VLOG(4) << " use is while " << use.instruction->name()
|
||||
<< " and def is in condition and is not the parameter";
|
||||
return false;
|
||||
} else {
|
||||
VLOG(4) << " use is while " << use.instruction->name()
|
||||
<< " and def is in condition and is the parameter";
|
||||
<< " and def is in body";
|
||||
return true;
|
||||
}
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
xla_while->while_condition())) {
|
||||
if (value.defining_instruction() !=
|
||||
xla_while->while_condition()->parameter_instruction(0)) {
|
||||
VLOG(4) << " use is while " << use.instruction->name()
|
||||
<< " and def is in condition and is not the parameter";
|
||||
return false;
|
||||
} else {
|
||||
VLOG(4) << " use is while " << use.instruction->name()
|
||||
<< " and def is in condition and is the parameter";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Similarly if the value is defined at a while, it logically occurs after
|
||||
// any uses in the body or condition computations.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
CHECK(value.is_phi());
|
||||
const HloInstruction* xla_while = value.defining_instruction();
|
||||
if (call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_condition())) {
|
||||
VLOG(4) << " value is while " << value.defining_instruction()->name()
|
||||
<< " and use is in condition or body";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Similarly if the value is defined at a while, it logically occurs after any
|
||||
// uses in the body or condition computations.
|
||||
if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
|
||||
CHECK(value.is_phi());
|
||||
const HloInstruction* xla_while = value.defining_instruction();
|
||||
if (call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_body()) ||
|
||||
call_graph_->InstructionIsNestedIn(use.instruction,
|
||||
xla_while->while_condition())) {
|
||||
VLOG(4) << " value is while " << value.defining_instruction()->name()
|
||||
<< " and use is in condition or body";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// The use at a call occurs before values that are defined in the called
|
||||
// computation.
|
||||
if (use.instruction->opcode() == HloOpcode::kCall) {
|
||||
const HloInstruction* call = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
call->to_apply())) {
|
||||
VLOG(4) << " use is call " << use.instruction->name()
|
||||
<< " and def is in called computation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
const HloInstruction* conditional = use.instruction;
|
||||
for (int j = 0; j < conditional->branch_count(); ++j) {
|
||||
if (call_graph_->InstructionIsNestedIn(
|
||||
value.defining_instruction(),
|
||||
conditional->branch_computation(j))) {
|
||||
VLOG(4) << " use is conditional " << use.instruction->name()
|
||||
<< " and def is in " << j << "th branch computation";
|
||||
// The use at a call occurs before values that are defined in the called
|
||||
// computation.
|
||||
if (use.instruction->opcode() == HloOpcode::kCall) {
|
||||
const HloInstruction* call = use.instruction;
|
||||
if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
|
||||
call->to_apply())) {
|
||||
VLOG(4) << " use is call " << use.instruction->name()
|
||||
<< " and def is in called computation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (value.defining_instruction() == use.instruction) {
|
||||
VLOG(4) << " use is conditional " << use << " and def is "
|
||||
<< value.ToShortString();
|
||||
return true;
|
||||
if (use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
const HloInstruction* conditional = use.instruction;
|
||||
// In general the use of a value in the conditional parameter should be
|
||||
// considered to be before a definition in one of its branches, and
|
||||
// therefore allowed in live range merging, if there is no
|
||||
// surrounding loop that creates a backward control flow path that
|
||||
// allows the definition in the branch to have its value flow backward
|
||||
// into the conditional and then flow into another branch in the
|
||||
// conditional that uses the value. This is reflected by checking that
|
||||
// the use-def in exclusive branches has not been already allowed.
|
||||
// Further, if the def value escapes its branch, we conservatively
|
||||
// assume a backward control flow path could exist, and set
|
||||
// has_escaped_use_in_conditinoal to disallow any later uses in
|
||||
// exclusive branches.
|
||||
for (int j = 0; j < conditional->branch_count(); ++j) {
|
||||
if (call_graph_->InstructionIsNestedIn(
|
||||
value.defining_instruction(),
|
||||
conditional->branch_computation(j))) {
|
||||
// If the use operand does not create a new value, and the value def
|
||||
// is returned by as part of the result of the conditional, it
|
||||
// is possible for the branch definition to flow backward through a
|
||||
// surrounding loop and then back into the conditional parameter.
|
||||
if (!dataflow.ValueIsDefinedAt(
|
||||
use.instruction->operand(use.operand_number), {})) {
|
||||
for (auto value_use : value.uses()) {
|
||||
VLOG(4) << "def have use:" << value_use << "\n";
|
||||
if (value_use.instruction ==
|
||||
value_use.instruction->parent()->root_instruction()) {
|
||||
VLOG(4) << "def use is conditional root \n";
|
||||
has_escaped_use_in_conditional = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!has_use_in_exclusive_branches) {
|
||||
VLOG(4) << " use is conditional " << use.instruction->name()
|
||||
<< " and def is in " << j << "th branch computation";
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (value.defining_instruction() == use.instruction) {
|
||||
VLOG(4) << " use is conditional " << use << " and def is "
|
||||
<< value.ToShortString();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(4) << " use is not before value definition";
|
||||
return false;
|
||||
};
|
||||
for (auto* use : uses) {
|
||||
if (!UseIsBeforeValueDefinition(*use)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(4) << " use is not before value";
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloOrdering::LiveRangeStrictlyBefore(
|
||||
@ -270,6 +379,7 @@ bool HloOrdering::LiveRangeStrictlyBefore(
|
||||
const HloDataflowAnalysis& dataflow) const {
|
||||
VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
|
||||
<< ", b = " << b.ToShortString() << ")";
|
||||
VLOG(4) << "Parent:" << a.instruction()->parent()->ToString() << "\n";
|
||||
if (!IsDefinedBefore(a, b)) {
|
||||
VLOG(4) << a << " not defined before " << b;
|
||||
return false;
|
||||
@ -294,16 +404,17 @@ bool HloOrdering::LiveRangeStrictlyBefore(
|
||||
}
|
||||
|
||||
// All uses of 'a' must be before 'b' is defined.
|
||||
std::vector<const HloUse*> uses;
|
||||
for (const HloUse& use : a.uses()) {
|
||||
if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
|
||||
use.instruction)) {
|
||||
continue;
|
||||
}
|
||||
if (!UseIsBeforeValueDefinition(use, b, dataflow)) {
|
||||
VLOG(4) << "use of " << a << " (" << use << ") not before " << b
|
||||
<< " is defined";
|
||||
return false;
|
||||
}
|
||||
uses.push_back(&use);
|
||||
}
|
||||
if (!UsesBeforeValueDefinition(uses, b, dataflow)) {
|
||||
VLOG(4) << "uses of " << a << "not before " << b << " is defined";
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.instruction()->parent() == b.instruction()->parent()) {
|
||||
|
||||
@ -37,10 +37,30 @@ namespace xla {
|
||||
// determine live range overlap of HLO instruction output buffers.
|
||||
class HloOrdering {
|
||||
public:
|
||||
HloOrdering(const HloModule* module)
|
||||
explicit HloOrdering(const HloModule* module)
|
||||
: module_(module), call_graph_(CallGraph::Build(module)) {}
|
||||
virtual ~HloOrdering() = default;
|
||||
|
||||
// Specify the ordering constraints between a pair of instructions a and b.
|
||||
enum class ExecutionConstraint {
|
||||
// Indicate a and b are the same instruction;
|
||||
kIsSame,
|
||||
// Indicate a runs before b;
|
||||
kRunBefore,
|
||||
// Only one of a or b runs each time their common ancestor is evaluated,
|
||||
// and a is in an earlier branch than b.
|
||||
kRunExclusiveBefore,
|
||||
// Only one of a or b runs each time, and a is in a later branch than b.
|
||||
kRunExclusiveAfter,
|
||||
// Indicate a runs after b
|
||||
kRunAfter,
|
||||
// An order cannot be detrermined as a and b do not have a common ancestor.
|
||||
kUnordered,
|
||||
};
|
||||
// Return the execution constraint between a and b.
|
||||
HloOrdering::ExecutionConstraint GetExecutionConstraint(
|
||||
const HloInstruction* a, const HloInstruction* b) const;
|
||||
|
||||
// Returns true if instruction 'a' executes before instruction 'b'. This is
|
||||
// not reflexive, that is, an instruction does not execute before itself.
|
||||
bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const;
|
||||
@ -51,8 +71,9 @@ class HloOrdering {
|
||||
|
||||
// Returns whether the given use is before the given value definition under
|
||||
// the given ordering.
|
||||
bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
|
||||
const HloDataflowAnalysis& dataflow) const;
|
||||
bool UsesBeforeValueDefinition(absl::Span<const HloUse* const> uses,
|
||||
const HloValue& value,
|
||||
const HloDataflowAnalysis& dataflow) const;
|
||||
// Returns whether the given values interfere. Two values interfere if they
|
||||
// may both be simultaneously live.
|
||||
bool MayInterfere(const HloValue& a, const HloValue& b,
|
||||
@ -181,8 +202,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
|
||||
// interference is reduced relative to DependencyHloOrdering.
|
||||
class SequentialHloOrdering : public HloOrdering {
|
||||
public:
|
||||
SequentialHloOrdering(const HloSchedule& schedule);
|
||||
SequentialHloOrdering(HloSchedule&& schedule);
|
||||
explicit SequentialHloOrdering(const HloSchedule& schedule);
|
||||
explicit SequentialHloOrdering(HloSchedule&& schedule);
|
||||
~SequentialHloOrdering() override = default;
|
||||
|
||||
// Returns the sequential instruction order for the given computation.
|
||||
|
||||
@ -282,10 +282,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) {
|
||||
dataflow->GetValueDefinedAt(add)));
|
||||
ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
|
||||
|
||||
const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
|
||||
EXPECT_EQ(while_use.instruction, add);
|
||||
EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
|
||||
while_use, dataflow->GetValueDefinedAt(add), *dataflow));
|
||||
const HloUse* while_use = &dataflow->GetValueDefinedAt(xla_while).uses()[0];
|
||||
EXPECT_EQ(while_use->instruction, add);
|
||||
EXPECT_TRUE(ordering.UsesBeforeValueDefinition(
|
||||
{&while_use, 1}, dataflow->GetValueDefinedAt(add), *dataflow));
|
||||
EXPECT_TRUE(ordering.LiveRangeStrictlyBefore(
|
||||
dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add),
|
||||
*dataflow));
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user