[XLA] remove extraneous copies in copy_insertion related to nested conditionals and while loops. The change increases the precision of LiveRangeBefore analysis inside copy_insertion to accommodate disjoint branches inside conditionals that never overlap.

PiperOrigin-RevId: 347677692
Change-Id: I1fa3de590042078b7a6dde4c4bf8227c4c6416bb
This commit is contained in:
A. Unique TensorFlower 2020-12-15 13:12:12 -08:00 committed by TensorFlower Gardener
parent bb4a8a8b49
commit 849bcce3b0
4 changed files with 124 additions and 11 deletions

View File

@ -881,7 +881,16 @@ class CopyRemover {
return ordering_.IsDefinedBefore(*a.value, *b.value);
}
return absl::c_all_of(a.uses, [&](const HloUse* use) {
return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_);
switch (ordering_.GetExecutionConstraint(use->instruction,
b.value->instruction())) {
case HloOrdering::ExecutionConstraint::kIsSame:
case HloOrdering::ExecutionConstraint::kRunExclusiveAfter:
case HloOrdering::ExecutionConstraint::kRunExclusiveBefore:
return true;
default:
return ordering_.UseIsBeforeValueDefinition(*use, *b.value,
dataflow_);
}
});
}

View File

@ -2473,6 +2473,52 @@ ENTRY TestComputation {
op::While(op::Copy(op::Parameter())));
}
TEST_F(CopyInsertionTest, NestedWhileAndConditional) {
const string& hlo_string = R"(
HloModule TestModule
on_true
{
v1 = f32[2] parameter(0)
ROOT v2 = f32[2] add(v1,v1)
}
on_false
{
v1 = f32[2] parameter(0)
ROOT v2 = f32[2] multiply(v1,v1)
}
cond.outer {
param.1 = (pred[], f32[2]) parameter(0)
ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0
}
body.outer {
param.1 = (pred[], f32[2]) parameter(0)
pred.1 = pred[] get-tuple-element(param.1), index=0
arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1
if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false
ROOT res = (pred[], f32[2]) tuple(pred.1,if)
}
ENTRY TestComputation {
entry_param.1 = pred[] parameter(0)
float_param = f32[2] parameter(1)
entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param)
ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
VLOG(2) << module->ToString() << "\n";
// There should only be a single copy inserted, and it's in the entry
// computation.
EXPECT_EQ(CountCopies(*module), 2);
}
TEST_F(CopyInsertionTest, FixpointComputationRequired) {
const string& hlo_string = R"(
HloModule Module

View File

@ -34,6 +34,21 @@ namespace xla {
bool HloOrdering::ExecutesBefore(const HloInstruction* a,
const HloInstruction* b) const {
switch (GetExecutionConstraint(a, b)) {
case ExecutionConstraint::kIsSame: // a and b are the same instruction;
return false;
case ExecutionConstraint::kRunBefore:
case ExecutionConstraint::kRunExclusiveBefore:
return true;
case ExecutionConstraint::kRunExclusiveAfter:
case ExecutionConstraint::kRunAfter:
case ExecutionConstraint::kUnordered:
return false;
}
}
HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint(
const HloInstruction* a, const HloInstruction* b) const {
// 'a' and 'b' may be in different computations. In this case, find the
// callgraph ancestor instructions which call (potentially transitively) the
// computations containing 'a' and 'b' and use these ancestor instructions to
@ -47,7 +62,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
if (a_ancestor == nullptr) {
// Ancestors in a common computation could not be found so consider the
// instructions 'a' and 'b' to be unordered.
return false;
return ExecutionConstraint::kUnordered;
}
// a_ancestor and b_ancestor must be either both null or both non-null.
CHECK_NE(b_ancestor, nullptr);
@ -62,7 +77,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
const HloComputation* condition = a_ancestor->while_condition();
if (call_graph_->InstructionIsNestedIn(a, condition) &&
call_graph_->InstructionIsNestedIn(b, body)) {
return true;
return ExecutionConstraint::kRunBefore;
}
}
@ -85,17 +100,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
b_branch = j;
}
}
if (a_branch != -1 && a_branch < b_branch) {
return true;
// If neither a nor b is inside the branches they both are the ancestor.
if (a_branch == -1 && b_branch == -1) {
CHECK_EQ(a, a_ancestor);
CHECK_EQ(b, b_ancestor);
CHECK_EQ(a, b);
return ExecutionConstraint::kIsSame;
}
// If 'b' is the conditional ancestor, and 'a' is within a branch
// computation, 'a' executes before 'b'.
if (b == a_ancestor && a_branch != -1) {
return true;
if (b_branch == -1) {
CHECK_EQ(b, a_ancestor);
return ExecutionConstraint::kRunBefore;
}
if (a_branch == -1) {
CHECK_EQ(a, a_ancestor);
return ExecutionConstraint::kRunAfter;
}
if (a_branch < b_branch) {
return ExecutionConstraint::kRunExclusiveBefore;
}
if (b_branch < a_branch) {
return ExecutionConstraint::kRunExclusiveAfter;
}
}
return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor);
if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) {
return ExecutionConstraint::kRunBefore;
}
if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) {
return ExecutionConstraint::kRunAfter;
}
VLOG(1) << "Cannot determine order between:" << a_ancestor->ToString() << "\n"
<< "and " << b_ancestor->ToString() << "\n";
return ExecutionConstraint::kUnordered;
}
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {

View File

@ -37,10 +37,30 @@ namespace xla {
// determine live range overlap of HLO instruction output buffers.
class HloOrdering {
public:
HloOrdering(const HloModule* module)
explicit HloOrdering(const HloModule* module)
: module_(module), call_graph_(CallGraph::Build(module)) {}
virtual ~HloOrdering() = default;
// Specify the ordering constraints between a pair of instructions a and b.
enum class ExecutionConstraint {
// Indicate a and b are the same instruction;
kIsSame,
// Indicate a runs before b;
kRunBefore,
// Only one of a or b runs each time their common ancestor is evaluated,
// and a is in an earlier branch than b.
kRunExclusiveBefore,
// Only one of a or b runs each time, and a is in a later branch than b.
kRunExclusiveAfter,
// Indicate a runs after b
kRunAfter,
// An order cannot be detrermined as a and b do not have a common ancestor.
kUnordered,
};
// Return the execution constraint between a and b.
HloOrdering::ExecutionConstraint GetExecutionConstraint(
const HloInstruction* a, const HloInstruction* b) const;
// Returns true if instruction 'a' executes before instruction 'b'. This is
// not reflexive, that is, an instruction does not execute before itself.
bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const;
@ -181,8 +201,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering {
// interference is reduced relative to DependencyHloOrdering.
class SequentialHloOrdering : public HloOrdering {
public:
SequentialHloOrdering(const HloSchedule& schedule);
SequentialHloOrdering(HloSchedule&& schedule);
explicit SequentialHloOrdering(const HloSchedule& schedule);
explicit SequentialHloOrdering(HloSchedule&& schedule);
~SequentialHloOrdering() override = default;
// Returns the sequential instruction order for the given computation.