From 849bcce3b00c5af63d56649cea63696114346228 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 15 Dec 2020 13:12:12 -0800 Subject: [PATCH] [XLA] remove extraneous copies in copy_insertion related to nested conditionals and while loops. The change increases the precision of LiveRangeBefore analysis inside copy_insertion to accommodate disjoint branches inside conditionals that never overlap. PiperOrigin-RevId: 347677692 Change-Id: I1fa3de590042078b7a6dde4c4bf8227c4c6416bb --- .../compiler/xla/service/copy_insertion.cc | 11 +++- .../xla/service/copy_insertion_test.cc | 46 ++++++++++++++++ .../compiler/xla/service/hlo_ordering.cc | 52 ++++++++++++++++--- .../compiler/xla/service/hlo_ordering.h | 26 ++++++++-- 4 files changed, 124 insertions(+), 11 deletions(-) diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index fb71bdfd2d2..b8243ada7bd 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -881,7 +881,16 @@ class CopyRemover { return ordering_.IsDefinedBefore(*a.value, *b.value); } return absl::c_all_of(a.uses, [&](const HloUse* use) { - return ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_); + switch (ordering_.GetExecutionConstraint(use->instruction, + b.value->instruction())) { + case HloOrdering::ExecutionConstraint::kIsSame: + case HloOrdering::ExecutionConstraint::kRunExclusiveAfter: + case HloOrdering::ExecutionConstraint::kRunExclusiveBefore: + return true; + default: + return ordering_.UseIsBeforeValueDefinition(*use, *b.value, + dataflow_); + } }); } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 23d3be6e17d..e4bce512ae2 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2473,6 +2473,52 @@ ENTRY TestComputation { op::While(op::Copy(op::Parameter()))); } +TEST_F(CopyInsertionTest, NestedWhileAndConditional) { + const string& hlo_string = R"( +HloModule TestModule + +on_true + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] add(v1,v1) +} + +on_false + { + v1 = f32[2] parameter(0) + ROOT v2 = f32[2] multiply(v1,v1) +} + +cond.outer { + param.1 = (pred[], f32[2]) parameter(0) + ROOT param.cond.outer = pred[] get-tuple-element(param.1), index=0 +} + +body.outer { + param.1 = (pred[], f32[2]) parameter(0) + pred.1 = pred[] get-tuple-element(param.1), index=0 + arg_tuple.11 = f32[2] get-tuple-element(param.1), index=1 + if = f32[2] conditional(pred.1, arg_tuple.11, arg_tuple.11), true_computation=on_true, false_computation=on_false + ROOT res = (pred[], f32[2]) tuple(pred.1,if) +} + +ENTRY TestComputation { + entry_param.1 = pred[] parameter(0) + float_param = f32[2] parameter(1) + entry_param = (pred[], f32[2]) tuple(entry_param.1, float_param) + ROOT while = (pred[], f32[2]) while(entry_param), condition=cond.outer, body=body.outer +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + VLOG(2) << module->ToString() << "\n"; + + // There should only be a single copy inserted, and it's in the entry + // computation. + EXPECT_EQ(CountCopies(*module), 2); +} + TEST_F(CopyInsertionTest, FixpointComputationRequired) { const string& hlo_string = R"( HloModule Module diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index a4804a8faef..d2c9b21186e 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -34,6 +34,21 @@ namespace xla { bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const { + switch (GetExecutionConstraint(a, b)) { + case ExecutionConstraint::kIsSame: // a and b are the same instruction; + return false; + case ExecutionConstraint::kRunBefore: + case ExecutionConstraint::kRunExclusiveBefore: + return true; + case ExecutionConstraint::kRunExclusiveAfter: + case ExecutionConstraint::kRunAfter: + case ExecutionConstraint::kUnordered: + return false; + } +} + +HloOrdering::ExecutionConstraint HloOrdering::GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const { // 'a' and 'b' may be in different computations. In this case, find the // callgraph ancestor instructions which call (potentially transitively) the // computations containing 'a' and 'b' and use these ancestor instructions to @@ -47,7 +62,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, if (a_ancestor == nullptr) { // Ancestors in a common computation could not be found so consider the // instructions 'a' and 'b' to be unordered. - return false; + return ExecutionConstraint::kUnordered; } // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); @@ -62,7 +77,7 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, const HloComputation* condition = a_ancestor->while_condition(); if (call_graph_->InstructionIsNestedIn(a, condition) && call_graph_->InstructionIsNestedIn(b, body)) { - return true; + return ExecutionConstraint::kRunBefore; } } @@ -85,17 +100,40 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, b_branch = j; } } - if (a_branch != -1 && a_branch < b_branch) { - return true; + // If neither a nor b is inside the branches they both are the ancestor. + if (a_branch == -1 && b_branch == -1) { + CHECK_EQ(a, a_ancestor); + CHECK_EQ(b, b_ancestor); + CHECK_EQ(a, b); + return ExecutionConstraint::kIsSame; } // If 'b' is the conditional ancestor, and 'a' is within a branch // computation, 'a' executes before 'b'. - if (b == a_ancestor && a_branch != -1) { - return true; + if (b_branch == -1) { + CHECK_EQ(b, a_ancestor); + return ExecutionConstraint::kRunBefore; + } + if (a_branch == -1) { + CHECK_EQ(a, a_ancestor); + return ExecutionConstraint::kRunAfter; + } + if (a_branch < b_branch) { + return ExecutionConstraint::kRunExclusiveBefore; + } + if (b_branch < a_branch) { + return ExecutionConstraint::kRunExclusiveAfter; } } - return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); + if (ExecutesBeforeInSameComputation(a_ancestor, b_ancestor)) { + return ExecutionConstraint::kRunBefore; + } + if (ExecutesBeforeInSameComputation(b_ancestor, a_ancestor)) { + return ExecutionConstraint::kRunAfter; + } + VLOG(1) << "Cannot determine order between:" << a_ancestor->ToString() << "\n" + << "and " << b_ancestor->ToString() << "\n"; + return ExecutionConstraint::kUnordered; } bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index a07214c22c0..083603a4a18 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -37,10 +37,30 @@ namespace xla { // determine live range overlap of HLO instruction output buffers. class HloOrdering { public: - HloOrdering(const HloModule* module) + explicit HloOrdering(const HloModule* module) : module_(module), call_graph_(CallGraph::Build(module)) {} virtual ~HloOrdering() = default; + // Specify the ordering constraints between a pair of instructions a and b. + enum class ExecutionConstraint { + // Indicate a and b are the same instruction; + kIsSame, + // Indicate a runs before b; + kRunBefore, + // Only one of a or b runs each time their common ancestor is evaluated, + // and a is in an earlier branch than b. + kRunExclusiveBefore, + // Only one of a or b runs each time, and a is in a later branch than b. + kRunExclusiveAfter, + // Indicate a runs after b + kRunAfter, + // An order cannot be detrermined as a and b do not have a common ancestor. + kUnordered, + }; + // Return the execution constraint between a and b. + HloOrdering::ExecutionConstraint GetExecutionConstraint( + const HloInstruction* a, const HloInstruction* b) const; + // Returns true if instruction 'a' executes before instruction 'b'. This is // not reflexive, that is, an instruction does not execute before itself. bool ExecutesBefore(const HloInstruction* a, const HloInstruction* b) const; @@ -181,8 +201,8 @@ class DependencyHloOrdering : public PredecessorHloOrdering { // interference is reduced relative to DependencyHloOrdering. class SequentialHloOrdering : public HloOrdering { public: - SequentialHloOrdering(const HloSchedule& schedule); - SequentialHloOrdering(HloSchedule&& schedule); + explicit SequentialHloOrdering(const HloSchedule& schedule); + explicit SequentialHloOrdering(HloSchedule&& schedule); ~SequentialHloOrdering() override = default; // Returns the sequential instruction order for the given computation.