diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index fb71bdfd2d2..12694401155 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -881,7 +881,24 @@ class CopyRemover { return ordering_.IsDefinedBefore(*a.value, *b.value); } return absl::c_all_of(a.uses, [&](const HloUse* use) { - return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); + // Here if the HloUse is located in a branch that is exclusive to b's + // branch, it can be skipped, because in order for them to interfere, + // there must be an execution path from b's definition to the HloUse. If + // there is such a path, it would have to pass through the point where the + // two exclusive branches are joined. The join point would have to contain + // a phi operation because b's definition is not guranteed to reach a. The + // phi operation would be another use of a that would ensure correct + // answer is returned. + switch (ordering_.GetExecutionConstraint( + use->instruction, b.value->defining_instruction())) { + case HloOrdering::ExecutionConstraint::kIsSame: + case HloOrdering::ExecutionConstraint::kRunExclusiveAfter: + case HloOrdering::ExecutionConstraint::kRunExclusiveBefore: + return true; + default: + return ordering_.UseIsBeforeValueDefinition(*use, *b.value, + dataflow_); + } }); } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 23d3be6e17d..e4bce512ae2 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2473,6 +2473,52 @@ ENTRY TestComputation { op::While(op::Copy(op::Parameter()))); } +TEST_F(CopyInsertionTest, NestedWhileAndConditional) { + const string& hlo_string = R"( +HloModule TestModule + +on_true + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] add(v1,v1) +} + +on_false + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] multiply(v1,v1) +} + +cond.outer { + param.1 = (pred[], f32[2]) parameter(0) + ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 +} + +body.outer { + param.1 = (pred[], f32[2]) parameter(0) + pred.1 = pred[] get-tuple-element(param.1), index=0 + arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 + if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + ROOT res = (pred[], f32[2]) tuple(pred.1,if) +} + +ENTRY TestComputation { + entry_param.1 = pred[] parameter(0) + float_param = f32[2] parameter(1) + entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param) + ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + VLOG(2) << module->ToString() << "\n"; + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 2); +} + TEST_F(CopyInsertionTest, FixpointComputationRequired) { const string& hlo_string = R"( HloModule Module diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index a4804a8faef..463d70b30e3 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -34,6 +34,21 @@ namespace xla { bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const { + switch (GetExecutionConstraint(a, b)) { + case ExecutionConstraint::kIsSame: // a and b are the same instruction; + return false; + case ExecutionConstraint::kRunBefore: + case ExecutionConstraint::kRunExclusiveBefore: + return true; + case ExecutionConstraint::kRunExclusiveAfter: + case ExecutionConstraint::kRunAfter: + case ExecutionConstraint::kUnordered: + return false; + } +} + +HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const { // 'a' and 'b' may be in different computations. In this case, find the // callgraph ancestor instructions which call (potentially transitively) the // computations containing 'a' and 'b' and use these ancestor instructions to @@ -47,7 +62,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, if (a_ancestor == nullptr) { // Ancestors in a common computation could not be found so consider the // instructions 'a' and 'b' to be unordered. - return false; + return ExecutionConstraint::kUnordered; } // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); @@ -62,7 +77,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloComputation* condition = a_ancestor->while_condition(); if (call_graph_->InstructionIsNestedIn(a, condition) && call_graph_->InstructionIsNestedIn(b, body)) { - return true; + return ExecutionConstraint::kRunBefore; } } @@ -85,17 +100,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, b_branch = j; } } - if (a_branch != -1 && a_branch < b_branch) { - return true; + // If neither a nor b is inside the branches they both are the ancestor. + if (a_branch == -1 && b_branch == -1) { + CHECK_EQ(a, a_ancestor); + CHECK_EQ(b, b_ancestor); + CHECK_EQ(a, b); + return ExecutionConstraint::kIsSame; } // If 'b' is the conditional ancestor, and 'a' is within a branch // computation, 'a' executes before 'b'. - if (b == a_ancestor && a_branch != -1) { - return true; + if (b_branch == -1) { + CHECK_EQ(b, a_ancestor); + return ExecutionConstraint::kRunBefore; + } + if (a_branch == -1) { + CHECK_EQ(a, a_ancestor); + return ExecutionConstraint::kRunAfter; + } + if (a_branch < b_branch) { + return ExecutionConstraint::kRunExclusiveBefore; + } + if (b_branch < a_branch) { + return ExecutionConstraint::kRunExclusiveAfter; } } - return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); + if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) { + return ExecutionConstraint::kRunBefore; + } + if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) { + return ExecutionConstraint::kRunAfter; + } + VLOG(1) << "Cannot determine order between:" << a_ancestor->ToString() << "\n" + << "and " << b_ancestor->ToString() << "\n"; + return ExecutionConstraint::kUnordered; } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { @@ -299,10 +337,25 @@ bool HloOrdering::LiveRangeStrictlyBefore( use.instruction)) { continue; } - if (!UseIsBeforeValueDefinition(use, b, dataflow)) { - VLOG(4) << "use of " << a << " (" << use << ") not before " << b - << " is defined"; - return false; + // Here if the HloUse is located in a branch that is exclusive to b's + // branch, it can be skipped, because in order for them to interfere, there + // must be an execution path from b's definition to the HloUse. If there is + // such a path, it would have to pass through the point where the two + // exclusive branches are joined. The join point would have to contain a + // phi operation because b's definition is not guranteed to reach a. The + // phi operation would be another use of a that would ensure correct answer + // is returned. + switch (GetExecutionConstraint(use.instruction, b.defining_instruction())) { + case HloOrdering::ExecutionConstraint::kIsSame: + case HloOrdering::ExecutionConstraint::kRunExclusiveAfter: + case HloOrdering::ExecutionConstraint::kRunExclusiveBefore: + continue; + default: + if (!UseIsBeforeValueDefinition(use, b, dataflow)) { + VLOG(4) << "use of " << a << " (" << use << ") not before " << b + << " is defined"; + return false; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index a07214c22c0..083603a4a18 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -37,10 +37,30 @@ namespace xla { // determine live range overlap of HLO instruction output buffers. class HloOrdering { public: - HloOrdering(const HloModule* module) + explicit HloOrdering(const HloModule* module) : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; + // Specify the ordering constraints between a pair of instructions a and b. + enum class ExecutionConstraint { + // Indicate a and b are the same instruction; + kIsSame, + // Indicate a runs before b; + kRunBefore, + // Only one of a or b runs each time their common ancestor is evaluated, + // and a is in an earlier branch than b. + kRunExclusiveBefore, + // Only one of a or b runs each time, and a is in a later branch than b. + kRunExclusiveAfter, + // Indicate a runs after b + kRunAfter, + // An order cannot be detrermined as a and b do not have a common ancestor. + kUnordered, + }; + // Return the execution constraint between a and b. + HloOrdering::ExecutionConstraint GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const; + // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; @@ -181,8 +201,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - SequentialHloOrdering(const HloSchedule& schedule); - SequentialHloOrdering(HloSchedule&& schedule); + explicit SequentialHloOrdering(const HloSchedule& schedule); + explicit SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation.