diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f1fc608caa0..e2b550fc022 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -212,6 +212,13 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice( return GetUniqueSlice(instruction, /*index=*/{}); } +bool BufferAssignment::SharesSliceAtIndex( + const HloInstruction* hlo_a, const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const { + return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() == + GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie(); +} + StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelOutputSlice() const { return GetUniqueTopLevelSlice( diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 82b9bf49ece..b82acb19b34 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -294,6 +294,15 @@ class BufferAssignment { return GetPointsToSet(instruction).element(index); } + // Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}' + // share the same BufferAllocation::Slice. + // Returns false otherwise. + // REQUIRES: BufferAssignment assigned allocations to both instructions. + bool SharesSliceAtIndex(const HloInstruction* hlo_a, + const ShapeIndex& shape_index_a, + const HloInstruction* hlo_b, + const ShapeIndex& shape_index_b) const; + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 0fe6e37c00f..736f227aa42 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, // *) Is element-wise. // *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where // the singleton use of 'a' at 'a.index' is the fused root at operand 0. + // *) Use of 'operand' is DynamicUpdateSlice at operand index 0. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index e7aa93f8dbc..e71b98298b3 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) { EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true)); } +class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { + protected: + // Builds and runs a computation (see test case computation graphs below). + // Runs BufferLiveness on this computation. + // Returns whether buffer interference is detected between tuple-shaped + // parameter and root instructions at tuple element 1. + bool Run(const bool tuple_element1_has_two_uses) { + auto builder = HloComputation::Builder(TestName()); + // Create param0 Tuple. + Shape data_shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {3}); + auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0")); + + auto gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0)); + + auto gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); + + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f}))); + + if (tuple_element1_has_two_uses) { + // Add 'gte0' and 'gte1' to create another user of 'gte1'. + gte0 = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, gte0, gte1)); + } + // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. + auto starts = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2}))); + auto dynamic_update_slice = + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + data_shape, gte1, update, starts)); + // Create output tuple. + auto tuple_root = builder.AddInstruction( + HloInstruction::CreateTuple({gte0, dynamic_update_slice})); + // Build module and get reference to entry computation. + auto module = MakeUnique<HloModule>(TestName()); + module->AddEntryComputation(builder.Build()); + // Run BufferLiveness on 'module'. + auto liveness = + BufferLiveness::Run(module.get(), + MakeUnique<DependencyHloOrdering>(module.get())) + .ConsumeValueOrDie(); + // Return whether or not buffers interfernce is detected between + // 'tuple_param0' and 'tuple_root' at shape index '{1}'. + return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1}); + } +}; + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in +// the following computation (because DynamicUpdateSlice (at operand 0) is the +// unique user): +// +// Parameter0 +// | | +// GTE(0) GTE(1) Const Const +// | \ | / +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) { + EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false)); +} + +// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because +// GTE(1) has two users: +// 1) DynamicUpdateSlice at operand 0. +// 2) Add at operand 1. +// +// Parameter0 +// | | +// GTE(0) GTE(1) +// | / | +// | / | +// Add | Const Const +// | | | | +// | DynamicUpdateSlice +// \ / +// Tuple +// +TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) { + EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f9698a06747..9b7aa7c860b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, return false; } auto* operand = fusion->operand(fusion_operand->parameter_number()); - - BufferAllocation::Slice operand_slice = - assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie(); - - BufferAllocation::Slice fusion_slice = - assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie(); - - return operand_slice == fusion_slice; + return assignment.SharesSliceAtIndex(fusion, {}, operand, index); } } // namespace @@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. - ShapeIndex index_unused; - auto* fusion_operand = - LatestNonGteAncestorAndIndex(root->operand(0), &index_unused); + auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); // Operand(0) the input array which shares an allocation with the output. diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 7d157e8fd5f..caaf56a5516 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -106,6 +106,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex( // *) Is a loop fusion instruction where the only use of 'operand' at 'index' // in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root // at operand 0. +// *) Use of 'operand' is DynamicUpdateSlice at operand index 0. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, @@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser( break; } return false; + } else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) { + // We eliminated other users in BufferLiveness::live_range_strictly_before, + // so here we just need to check that the use is at operand index 0. + std::vector<int64> operand_indices = user->OperandIndices(operand); + return operand_indices.size() == 1 && operand_indices[0] == 0; } // Check if 'user' is element-wise. return user->IsElementwise();