From e8b83cb7f02f86e0a33ca3dff1b7fc45936690b5 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 28 May 2019 10:47:28 -0700 Subject: [PATCH] Do not perform in-place dynamic update slice if the operand is read-only parameter. PiperOrigin-RevId: 250316823 --- .../compiler/xla/service/buffer_assignment.cc | 33 ++++++++++++++++- .../xla/service/buffer_assignment_test.cc | 36 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 5cbe6c44622..74b79e8f66b 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -93,6 +93,29 @@ std::vector ColorInterferenceGraph( return assigned_colors; } +// If an hlo buffer contains an entry parameter, the buffer is read-only unless +// it is aliased with an output. +bool HloBufferIsReadOnly(const HloBuffer& buffer) { + for (const HloValue* value : buffer.values()) { + const HloInstruction* instruction = value->instruction(); + const HloModule* module = instruction->parent()->parent(); + const bool is_entry_parameter = + instruction->opcode() == HloOpcode::kParameter && + instruction->parent() == module->entry_computation(); + + if (is_entry_parameter) { + bool parameter_has_alias = + module->input_output_alias_config().ParameterHasAlias( + instruction->parameter_number(), value->index()); + // The parameter doesn't have an alias, it must be read-only. + if (!parameter_has_alias) { + return true; + } + } + } + return false; +} + } // namespace Status GatherComputationsByAllocationType( @@ -902,7 +925,9 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { if (instruction->operand_count() == 0) { continue; } - // Can't share the buffer. + + // The operand can't share the same buffer with the user based on dataflow + // analysis. if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser( instruction->mutable_operand(0), {}, instruction, {})) { continue; @@ -919,6 +944,12 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { continue; } + // Do not perform in-place dynamic update slice if the operand buffer is + // read-only. + if (HloBufferIsReadOnly(operand_buffer)) { + continue; + } + bool interfere = false; for (const HloValue* instruction_value : instruction_buffer.values()) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8837e6d9344..b23f6fcabbe 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1720,6 +1720,42 @@ TEST_F(BufferAssignmentTest, PeakBuffers) { EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat)); } +TEST_F(BufferAssignmentTest, InPlaceBuffer) { + const char* hlo_text = R"( +HloModule Module + +ENTRY main { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); + HloInstruction* parameter = + m->entry_computation()->GetInstructionWithName("get-tuple-element.4"); + HloInstruction* dus = + m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5"); + + auto buffers = RunBufferAssignment(m.get()); + + { + const BufferAllocation& parameter_alloc = + GetTopLevelAllocation(*buffers, parameter); + + const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus); + EXPECT_NE(parameter_alloc, dus_alloc); + } +} + TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) { const char* hlo_text = R"( HloModule Module