Do not perform in-place dynamic update slice if the operand is read-only parameter.

PiperOrigin-RevId: 250316823
This commit is contained in:
Yunxing Dai 2019-05-28 10:47:28 -07:00 committed by TensorFlower Gardener
parent 53027266f0
commit e8b83cb7f0
2 changed files with 68 additions and 1 deletions

View File

@ -93,6 +93,29 @@ std::vector<int64> ColorInterferenceGraph(
return assigned_colors;
}
// If an hlo buffer contains an entry parameter, the buffer is read-only unless
// it is aliased with an output.
bool HloBufferIsReadOnly(const HloBuffer& buffer) {
for (const HloValue* value : buffer.values()) {
const HloInstruction* instruction = value->instruction();
const HloModule* module = instruction->parent()->parent();
const bool is_entry_parameter =
instruction->opcode() == HloOpcode::kParameter &&
instruction->parent() == module->entry_computation();
if (is_entry_parameter) {
bool parameter_has_alias =
module->input_output_alias_config().ParameterHasAlias(
instruction->parameter_number(), value->index());
// The parameter doesn't have an alias, it must be read-only.
if (!parameter_has_alias) {
return true;
}
}
}
return false;
}
} // namespace
Status GatherComputationsByAllocationType(
@ -902,7 +925,9 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
if (instruction->operand_count() == 0) {
continue;
}
// Can't share the buffer.
// The operand can't share the same buffer with the user based on dataflow
// analysis.
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
instruction->mutable_operand(0), {}, instruction, {})) {
continue;
@ -919,6 +944,12 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
continue;
}
// Do not perform in-place dynamic update slice if the operand buffer is
// read-only.
if (HloBufferIsReadOnly(operand_buffer)) {
continue;
}
bool interfere = false;
for (const HloValue* instruction_value : instruction_buffer.values()) {

View File

@ -1720,6 +1720,42 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
}
TEST_F(BufferAssignmentTest, InPlaceBuffer) {
const char* hlo_text = R"(
HloModule Module
ENTRY main {
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
constant.2 = s32[] constant(128)
add.5 = s32[] add(get-tuple-element.3, constant.2)
constant.3 = s32[] constant(0)
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
HloInstruction* parameter =
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
HloInstruction* dus =
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
auto buffers = RunBufferAssignment(m.get());
{
const BufferAllocation& parameter_alloc =
GetTopLevelAllocation(*buffers, parameter);
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
EXPECT_NE(parameter_alloc, dus_alloc);
}
}
TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
const char* hlo_text = R"(
HloModule Module