diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index e8fabc1d8f7..3e9daa96150 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1008,7 +1008,22 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { // Try allocate same buffer for dynamic update slice's operand and output. - // + + // If memory_space_assignment is run and there is information about a color in + // preset assignments, don't merge those buffers. We expect + // memory_space_assignment to have merged these buffers. If + // memory_space_assignment didn't merge these buffers and have assigned + // different offsets to the operand and the output buffer, merging the buffers + // can cause memory corruption if memory_space_assignment assigned a different + // buffer at the same offset. + absl::flat_hash_set excluded_colors; + if (preset_assignments_) { + for (const auto& color_and_info : + preset_assignments_->assignment_informations()) { + excluded_colors.insert(color_and_info.first); + } + } + // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule // to operations that can be done in place. for (HloComputation* computation : assignment->module().computations()) { @@ -1039,6 +1054,13 @@ Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { assignment->alias_analysis().GetUniqueBufferAt( instruction->operand(0), {}); + // The instruction or operand color is excluded because it was assigned by + // memory_space_assignment. + if (excluded_colors.contains(instruction_buffer.color().value()) || + excluded_colors.contains(operand_buffer.color().value())) { + continue; + } + // Already have the same buffer. No need to merge those. if (instruction_buffer.id() == operand_buffer.id()) { continue;