[XLA] Don't merge buffers with colors that are assigned by memory space assignment
PiperOrigin-RevId: 302980440 Change-Id: I41f9d62a1ce3b107587ab97c4267788f68350328
This commit is contained in:
parent
f0e6060b43
commit
43121ac128
|
@ -1008,7 +1008,22 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
|||
|
||||
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
// Try allocate same buffer for dynamic update slice's operand and output.
|
||||
//
|
||||
|
||||
// If memory_space_assignment is run and there is information about a color in
|
||||
// preset assignments, don't merge those buffers. We expect
|
||||
// memory_space_assignment to have merged these buffers. If
|
||||
// memory_space_assignment didn't merge these buffers and have assigned
|
||||
// different offsets to the operand and the output buffer, merging the buffers
|
||||
// can cause memory corruption if memory_space_assignment assigned a different
|
||||
// buffer at the same offset.
|
||||
absl::flat_hash_set<int64> excluded_colors;
|
||||
if (preset_assignments_) {
|
||||
for (const auto& color_and_info :
|
||||
preset_assignments_->assignment_informations()) {
|
||||
excluded_colors.insert(color_and_info.first);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
|
||||
// to operations that can be done in place.
|
||||
for (HloComputation* computation : assignment->module().computations()) {
|
||||
|
@ -1039,6 +1054,13 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
|||
assignment->alias_analysis().GetUniqueBufferAt(
|
||||
instruction->operand(0), {});
|
||||
|
||||
// The instruction or operand color is excluded because it was assigned by
|
||||
// memory_space_assignment.
|
||||
if (excluded_colors.contains(instruction_buffer.color().value()) ||
|
||||
excluded_colors.contains(operand_buffer.color().value())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Already have the same buffer. No need to merge those.
|
||||
if (instruction_buffer.id() == operand_buffer.id()) {
|
||||
continue;
|
||||
|
|
Loading…
Reference in New Issue