Do not perform in-place dynamic update slice if the operand is read-only parameter.
PiperOrigin-RevId: 250316823
This commit is contained in:
parent
53027266f0
commit
e8b83cb7f0
@ -93,6 +93,29 @@ std::vector<int64> ColorInterferenceGraph(
|
||||
return assigned_colors;
|
||||
}
|
||||
|
||||
// If an hlo buffer contains an entry parameter, the buffer is read-only unless
|
||||
// it is aliased with an output.
|
||||
bool HloBufferIsReadOnly(const HloBuffer& buffer) {
|
||||
for (const HloValue* value : buffer.values()) {
|
||||
const HloInstruction* instruction = value->instruction();
|
||||
const HloModule* module = instruction->parent()->parent();
|
||||
const bool is_entry_parameter =
|
||||
instruction->opcode() == HloOpcode::kParameter &&
|
||||
instruction->parent() == module->entry_computation();
|
||||
|
||||
if (is_entry_parameter) {
|
||||
bool parameter_has_alias =
|
||||
module->input_output_alias_config().ParameterHasAlias(
|
||||
instruction->parameter_number(), value->index());
|
||||
// The parameter doesn't have an alias, it must be read-only.
|
||||
if (!parameter_has_alias) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status GatherComputationsByAllocationType(
|
||||
@ -902,7 +925,9 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
if (instruction->operand_count() == 0) {
|
||||
continue;
|
||||
}
|
||||
// Can't share the buffer.
|
||||
|
||||
// The operand can't share the same buffer with the user based on dataflow
|
||||
// analysis.
|
||||
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
|
||||
instruction->mutable_operand(0), {}, instruction, {})) {
|
||||
continue;
|
||||
@ -919,6 +944,12 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not perform in-place dynamic update slice if the operand buffer is
|
||||
// read-only.
|
||||
if (HloBufferIsReadOnly(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool interfere = false;
|
||||
|
||||
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||
|
@ -1720,6 +1720,42 @@ TEST_F(BufferAssignmentTest, PeakBuffers) {
|
||||
EXPECT_THAT(peak_instructions, UnorderedElementsAre(rev, neg, concat));
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, InPlaceBuffer) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
|
||||
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
|
||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
||||
constant.2 = s32[] constant(128)
|
||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
||||
HloInstruction* parameter =
|
||||
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
||||
HloInstruction* dus =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
||||
|
||||
auto buffers = RunBufferAssignment(m.get());
|
||||
|
||||
{
|
||||
const BufferAllocation& parameter_alloc =
|
||||
GetTopLevelAllocation(*buffers, parameter);
|
||||
|
||||
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
|
||||
EXPECT_NE(parameter_alloc, dus_alloc);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(BufferAssignmentTest, ConstantBuffersAreNotReused) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule Module
|
||||
|
Loading…
x
Reference in New Issue
Block a user