From 23da21150d988f7cf5780488f24adbb116675586 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 18 Sep 2017 23:34:22 -0700 Subject: [PATCH] Add liveness_util functions which use dataflow analysis. Also make the analysis argument (TuplePointsToAnalysis or HloDataflowAnalysis) non-optional as all callers were passing in the analysis. PiperOrigin-RevId: 169200824 --- tensorflow/compiler/xla/service/BUILD | 6 +- .../compiler/xla/service/buffer_liveness.cc | 2 +- .../compiler/xla/service/heap_simulator.cc | 2 +- .../xla/service/hlo_alias_analysis_test.cc | 4 +- .../xla/service/hlo_dataflow_analysis.cc | 1 - .../xla/service/hlo_dataflow_analysis.h | 1 - .../xla/service/hlo_dataflow_analysis_test.cc | 3 +- .../compiler/xla/service/hlo_ordering.cc | 22 ++-- .../compiler/xla/service/hlo_ordering.h | 11 +- .../compiler/xla/service/hlo_ordering_test.cc | 35 +++--- .../compiler/xla/service/liveness_util.cc | 116 +++++++++++++++++- .../compiler/xla/service/liveness_util.h | 22 +++- .../xla/service/liveness_util_test.cc | 97 +++++++++++---- 13 files changed, 254 insertions(+), 68 deletions(-) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 8361212337f..f23fa221079 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -720,6 +720,7 @@ cc_library( hdrs = ["liveness_util.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -838,6 +839,7 @@ cc_library( deps = [ ":call_graph", ":hlo", + ":hlo_dataflow_analysis", ":hlo_proto", ":hlo_value", ":liveness_util", @@ -1391,9 +1393,7 @@ cc_library( deps = [ ":call_graph", ":hlo", - ":hlo_ordering", ":hlo_value", - ":liveness_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", @@ -1412,6 +1412,7 @@ cc_test( ":hlo_dataflow_analysis", ":hlo_graph_dumper", ":hlo_matchers", + ":hlo_ordering", ":instruction_fusion", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1470,6 +1471,7 @@ cc_test( ":hlo_alias_analysis", ":hlo_graph_dumper", ":hlo_matchers", + ":hlo_ordering", ":instruction_fusion", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index f085ffa6bc4..86100802037 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -123,7 +123,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), b.instruction(), b.index(), - &points_to_analysis())) { + points_to_analysis())) { return false; } } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index c85e97b691c..34e2f7ee206 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -204,7 +204,7 @@ Status HeapSimulator::RunComputation( buffer->instruction()->opcode() != HloOpcode::kCopy && CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), &points_to_analysis)) { + buffer->instruction(), buffer->index(), points_to_analysis)) { ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index e7ff9e7cf31..a275628779b 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -93,7 +94,8 @@ class HloAliasAnalysisTest : public HloTestBase { for (const HloValue* value_a : buffer.values()) { for (const HloValue* value_b : buffer.values()) { if (*value_a != *value_b && - ordering.MayInterfere(*value_a, *value_b)) { + ordering.MayInterfere(*value_a, *value_b, + analysis_->dataflow_analysis())) { VLOG(1) << *value_a << " interferes with " << *value_b << " in buffer: " << buffer; return true; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 2be1645f1b0..213ff07b071 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index aae257dd09e..207e553bf7f 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_value.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 4939335e2f8..4b8eb237a67 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/instruction_fusion.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -73,7 +74,7 @@ class HloDataflowAnalysisTest : public HloTestBase, EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); return ordering.MayInterfere(analysis_->GetValueDefinedAt(a), - analysis_->GetValueDefinedAt(b)); + analysis_->GetValueDefinedAt(b), *analysis_); } std::unique_ptr module_; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 08f572bb2ab..3612c51ee82 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -123,8 +123,9 @@ bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const { } /* static */ -bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use, - const HloValue& value) const { +bool HloOrdering::UseIsBeforeValueDefinition( + const HloUse& use, const HloValue& value, + const HloDataflowAnalysis& dataflow) const { VLOG(4) << "UseIsBeforeValueDefinition(use=" << use << ", value=" << value.ToShortString() << ")"; if (ExecutesBefore(use.instruction, value.defining_instruction())) { @@ -139,7 +140,7 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use, CanShareOperandBufferWithUser( use.instruction->mutable_operand(use.operand_number), use.operand_index, value.defining_instruction(), - value.defining_index())) { + value.defining_index(), dataflow)) { VLOG(4) << " use is value def, and instruction can share use buffer"; return true; } @@ -172,12 +173,13 @@ bool HloOrdering::UseIsBeforeValueDefinition(const HloUse& use, return true; } } - VLOG(4) << " use is not before while"; + VLOG(4) << " use is not before value"; return false; } -bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a, - const HloValue& b) const { +bool HloOrdering::LiveRangeStrictlyBefore( + const HloValue& a, const HloValue& b, + const HloDataflowAnalysis& dataflow) const { VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() << ", b = " << b.ToShortString() << ")"; if (!IsDefinedBefore(a, b)) { @@ -204,7 +206,7 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a, // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { - if (!UseIsBeforeValueDefinition(use, b)) { + if (!UseIsBeforeValueDefinition(use, b, dataflow)) { VLOG(4) << "use of a (" << use << ") not before b is defined"; return false; } @@ -213,9 +215,11 @@ bool HloOrdering::LiveRangeStrictlyBefore(const HloValue& a, return true; } -bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b) const { +bool HloOrdering::MayInterfere(const HloValue& a, const HloValue& b, + const HloDataflowAnalysis& dataflow) const { // Buffers without disjoint liveness may interfere. - return !LiveRangeStrictlyBefore(a, b) && !LiveRangeStrictlyBefore(b, a); + return !LiveRangeStrictlyBefore(a, b, dataflow) && + !LiveRangeStrictlyBefore(b, a, dataflow); } HloOrderingProto HloOrdering::ToProto() const { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index e0c23a3a08a..ee526d8dd7f 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_value.h" @@ -48,15 +49,17 @@ class HloOrdering { // Returns whether the given use is before the given value definition under // the given ordering. - bool UseIsBeforeValueDefinition(const HloUse& use, - const HloValue& value) const; + bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, + const HloDataflowAnalysis& dataflow) const; // Returns whether the given values interfere. Two values interfere if they // may both be simultaneously live. - bool MayInterfere(const HloValue& a, const HloValue& b) const; + bool MayInterfere(const HloValue& a, const HloValue& b, + const HloDataflowAnalysis& dataflow) const; // Returns true if the live range of the given value 'a' is strictly before // the live range of value 'b' using the given HLO ordering. - bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b) const; + bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b, + const HloDataflowAnalysis& dataflow) const; // Returns the sequential instruction order for the given computation, or // nullptr if the computation does not have a sequential ordering. diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index dbd63eceede..33bafd05c15 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -269,29 +269,32 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { // while because of the use of the init value in the add. EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant), dataflow->GetValueDefinedAt(xla_while))); - EXPECT_FALSE( - ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant), - dataflow->GetValueDefinedAt(xla_while))); + EXPECT_FALSE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(constant), + dataflow->GetValueDefinedAt(xla_while), *dataflow)); EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant), - dataflow->GetValueDefinedAt(xla_while))); + dataflow->GetValueDefinedAt(xla_while), + *dataflow)); // Any value defined in the body or condition is defined before the while, and // has a live range strictly before the while. EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate), dataflow->GetValueDefinedAt(xla_while))); - EXPECT_TRUE( - ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate), - dataflow->GetValueDefinedAt(xla_while))); + EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(negate), + dataflow->GetValueDefinedAt(xla_while), *dataflow)); EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate), - dataflow->GetValueDefinedAt(xla_while))); + dataflow->GetValueDefinedAt(xla_while), + *dataflow)); EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert), dataflow->GetValueDefinedAt(xla_while))); - EXPECT_TRUE( - ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert), - dataflow->GetValueDefinedAt(xla_while))); + EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(convert), + dataflow->GetValueDefinedAt(xla_while), *dataflow)); EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert), - dataflow->GetValueDefinedAt(xla_while))); + dataflow->GetValueDefinedAt(xla_while), + *dataflow)); // The live range of the while should be before the add. EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while), @@ -301,10 +304,10 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0]; EXPECT_EQ(while_use.instruction, add); EXPECT_TRUE(ordering.UseIsBeforeValueDefinition( - while_use, dataflow->GetValueDefinedAt(add))); - EXPECT_TRUE( - ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while), - dataflow->GetValueDefinedAt(add))); + while_use, dataflow->GetValueDefinedAt(add), *dataflow)); + EXPECT_TRUE(ordering.LiveRangeStrictlyBefore( + dataflow->GetValueDefinedAt(xla_while), dataflow->GetValueDefinedAt(add), + *dataflow)); } } // namespace diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 317271dfdd6..c27a8956a70 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -69,6 +69,36 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand, return false; } +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, + const HloDataflowAnalysis& dataflow) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + if (user->opcode() == HloOpcode::kFusion && + user->fusion_kind() == HloInstruction::FusionKind::kLoop) { + // Find fusion parameter associated with 'operand'. + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + // Iterate through all users of all uses of the fusion parameter value. + // Return false if any uses are detected, returns true otherwise. + const HloValue& value = dataflow.GetValueDefinedAt(fusion_param, index); + return value.uses().empty(); + } else { + // Return false if no value at 'operand' and 'index' is used at 'user'. + for (const HloValue* value : + dataflow.GetValueSet(operand, index).values()) { + for (const HloUse& use : value->uses()) { + if (use.instruction == user) { + return false; + } + } + } + } + + return true; +} + namespace { // Returns all uses of all aliases of 'instruction' at 'index' in 'uses'. @@ -153,7 +183,7 @@ bool HasUniqueFusedUseOfOperandAt( bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis* points_to_analysis) { + const TuplePointsToAnalysis& points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); const Shape& operand_subshape = @@ -164,7 +194,7 @@ bool CanShareOperandBufferWithUser( if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } - if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) { + if (user->opcode() == HloOpcode::kFusion) { if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { @@ -174,7 +204,7 @@ bool CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - *points_to_analysis); + points_to_analysis); } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -202,7 +232,85 @@ bool CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, other_add_operand_index, - *points_to_analysis); + points_to_analysis); + } + } + if (user->opcode() == HloOpcode::kDynamicUpdateSlice || + user->opcode() == HloOpcode::kWhile) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; + } + // Check if 'user' is element-wise. + return user->IsElementwise(); +} + +bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index, + const HloDataflowAnalysis& dataflow) { + CHECK(user->IsUserOf(operand)) + << "user: " << user->ToString() << " operand: " << operand->ToString(); + const Shape& operand_subshape = + ShapeUtil::GetSubshape(operand->shape(), operand_index); + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + // Check that operand and user emit the same shape and layout. + if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { + return false; + } + + if (user->opcode() == HloOpcode::kFusion) { + // Get the parameter associated with 'operand'; + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + + const HloValue& value = + dataflow.GetValueDefinedAt(fusion_param, operand_index); + if (value.uses().size() != 1) { + return false; + } + const HloUse& use = value.uses()[0]; + + if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && + user->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice) { + // Loop fusion with kDynamicUpdateSlice fused root. + // + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root at operand + // index 0. + return use.instruction == user->fused_expression_root() && + use.operand_number == 0; + } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && + user->fused_expression_root()->opcode() == HloOpcode::kAdd) { + // Output fusion with kAdd fused root. + + // Check if one operand of kAdd fused root is either kDot, or nested + // kFusion of kind kTransposeDot. + auto* add = user->fused_expression_root(); + auto add_operand_it = + std::find_if(add->operands().begin(), add->operands().end(), + [&](HloInstruction* operand) { + return operand->opcode() == HloOpcode::kDot || + (operand->opcode() == HloOpcode::kFusion && + operand->fusion_kind() == + HloInstruction::FusionKind::kTransposeDot); + }); + if (add_operand_it == add->operands().end()) { + return false; + } + auto* matched_add_operand = *add_operand_it; + // Calculate operand index of 'add' operand which was not matched above. + const int64 other_add_operand_index = + matched_add_operand == add->operand(0) ? 1 : 0; + // Returns true iff there is exactly one use of 'operand' at shape index + // 'operand_index', and this singleton use is the fused root (at operand + // index 'other_add_operand_index'). + return use.instruction == user->fused_expression_root() && + use.operand_number == other_add_operand_index; } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index c7799e5ab5d..28ef9918800 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_LIVENESS_UTIL_H_ +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -29,21 +30,34 @@ namespace xla { // 'operand'. Returns false otherwise. // // REQUIRES: 'operand' is an operand of 'user'. +// +// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have +// moved over to the dataflow overload. bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); +bool DoesNotUseOperandBuffer(const HloInstruction* operand, + const ShapeIndex& index, + const HloInstruction* user, + const HloDataflowAnalysis& dataflow); // Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a -// points-to analysis argument. Without the analysis, the result is more -// conservative (returns false more often). +// 'operand' (at 'operand_index'). Returns false otherwise. // // REQUIRES: 'operand' is an operand of 'user'. +// +// TODO(b/65835246): Remove TuplePointsToAnalysis overload when all users have +// moved over to the dataflow overload. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis* points_to_analysis = nullptr); + const TuplePointsToAnalysis& points_to_analysis); +bool CanShareOperandBufferWithUser(HloInstruction* operand, + const ShapeIndex& operand_index, + HloInstruction* user, + const ShapeIndex& user_index, + const HloDataflowAnalysis& dataflow); } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index d89dab4a82c..b5e15906d3c 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -35,6 +35,8 @@ class PointsToAnalysisTestBase : public HloTestBase { CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); + dataflow_analysis_ = + HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { @@ -45,6 +47,7 @@ class PointsToAnalysisTestBase : public HloTestBase { std::unique_ptr module_; HloComputation* computation_ = nullptr; std::unique_ptr points_to_analysis_; + std::unique_ptr dataflow_analysis_; }; class DoesNotUseOperandBufferTest : public PointsToAnalysisTestBase {}; @@ -70,6 +73,11 @@ TEST_F(DoesNotUseOperandBufferTest, GetTupleElement) { EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *points_to_analysis_)); EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *points_to_analysis_)); EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *points_to_analysis_)); + + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, gte0, *dataflow_analysis_)); + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {1}, gte1, *dataflow_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte0, *dataflow_analysis_)); + EXPECT_FALSE(DoesNotUseOperandBuffer(tuple, {}, gte1, *dataflow_analysis_)); } TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { @@ -105,6 +113,10 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { DoesNotUseOperandBuffer(tuple, {0}, fusion, *points_to_analysis_)); EXPECT_FALSE( DoesNotUseOperandBuffer(tuple, {1}, fusion, *points_to_analysis_)); + + EXPECT_TRUE(DoesNotUseOperandBuffer(tuple, {0}, fusion, *dataflow_analysis_)); + EXPECT_FALSE( + DoesNotUseOperandBuffer(tuple, {1}, fusion, *dataflow_analysis_)); } class CanShareOperandBufferWithUserTest : public PointsToAnalysisTestBase {}; @@ -122,10 +134,15 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, - points_to_analysis_.get())); - EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {}, - points_to_analysis_.get())); + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, log, {}, *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { @@ -143,9 +160,14 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { BuildModuleAndRunAnalysis(builder.Build()); EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - points_to_analysis_.get())); + *points_to_analysis_)); EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - points_to_analysis_.get())); + *points_to_analysis_)); + + EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, + *dataflow_analysis_)); + EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, + *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { @@ -161,10 +183,15 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, - points_to_analysis_.get())); - EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {}, - points_to_analysis_.get())); + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(param, {}, exp, {}, *dataflow_analysis_)); + EXPECT_TRUE( + CanShareOperandBufferWithUser(exp, {}, copy, {}, *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { @@ -197,9 +224,14 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // The fusion instruction can share with tuple element 1. EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - points_to_analysis_.get())); + *points_to_analysis_)); EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - points_to_analysis_.get())); + *points_to_analysis_)); + + EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, + *dataflow_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, + *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { @@ -221,12 +253,19 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { // The DynamicUpdateSlice instruction can share with the data operand, but not // with update or starts. - EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {}, - points_to_analysis_.get())); - EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {}, - points_to_analysis_.get())); - EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {}, - points_to_analysis_.get())); + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, dus, {}, *dataflow_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(update, {}, dus, {}, *dataflow_analysis_)); + EXPECT_FALSE( + CanShareOperandBufferWithUser(starts, {}, dus, {}, *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { @@ -256,7 +295,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { // Output fused dot add should be able to share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - points_to_analysis_.get())); + *points_to_analysis_)); + + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { @@ -292,7 +334,10 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { // Output fused transpose-dot-add should be share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - points_to_analysis_.get())); + *points_to_analysis_)); + + EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, + *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { @@ -320,7 +365,10 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { // Output fused operand->reverse->add cannot alias operand buffer 'operand'. EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - points_to_analysis_.get())); + *points_to_analysis_)); + + EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, + *dataflow_analysis_)); } TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { @@ -360,8 +408,11 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { RunAnalysis(); // The While instruction can share with the data operand. - EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {}, - points_to_analysis_.get())); + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); + + EXPECT_TRUE( + CanShareOperandBufferWithUser(data, {}, whil, {}, *dataflow_analysis_)); } } // namespace