From a094af6decab7c1689cc40f9497d1b6242a833e6 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Mon, 17 Aug 2020 16:26:40 -0700 Subject: [PATCH] [XLA] When disallowing buffers from alternate mem allocation, also check its colocations. PiperOrigin-RevId: 327123032 Change-Id: I9ec3ba91c29f40e8089b79b0eaeed7d8848ea07f --- .../service/memory_space_assignment_utils.cc | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index 1f7b9dbadbc..7bb559979e6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -17,21 +17,22 @@ limitations under the License. namespace xla { -bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +namespace { + +bool IsValueAllowedInAlternateMemory(const HloValue* value) { // If the buffer is a tuple, don't use this algorithm for now. The buffers // that are pointed to by the tuple will still use this algorithm. Because // tuples are cheap to place in the alternate memory (they are just pointers) // we don't need to use prefetch/evict logic. - if (interval.buffer->shape().IsTuple()) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (value->shape().IsTuple()) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a tuple."; return false; } // Don't place scalars in the alternate memory. - if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (ShapeUtil::IsEffectiveScalar(value->shape())) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a scalar."; return false; } @@ -44,10 +45,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // allocate TupleSelect in the alternate memory space. // TODO(berkin): Not allocating add-dependencies either since they need to be // treated specially. We should revisit this later. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kTupleSelect || position.instruction->opcode() == HloOpcode::kAddDependency) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it has a tuple-select or " << "add-dependency position."; return false; @@ -56,18 +57,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // Send and Recv HLOs return a request identifier. These should not be // allocated in the alternate memory. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if ((position.instruction->opcode() == HloOpcode::kSend || position.instruction->opcode() == HloOpcode::kRecv)) { // TODO(berkin): Send/recv buffers need a stable buffer allocation // throughout sending/receiving. Disable memory space allocation for these // for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a send/recv buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a request identifier for " "send/recv."; return false; @@ -78,11 +79,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { // Disable memory space allocation for these for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } @@ -92,4 +93,12 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( return true; } +} // namespace + +bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { + return IsValueAllowedInAlternateMemory(interval.buffer) && + absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); +} + } // namespace xla