diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index 1f7b9dbadbc..7bb559979e6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -17,21 +17,22 @@ limitations under the License. namespace xla { -bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +namespace { + +bool IsValueAllowedInAlternateMemory(const HloValue* value) { // If the buffer is a tuple, don't use this algorithm for now. The buffers // that are pointed to by the tuple will still use this algorithm. Because // tuples are cheap to place in the alternate memory (they are just pointers) // we don't need to use prefetch/evict logic. - if (interval.buffer->shape().IsTuple()) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (value->shape().IsTuple()) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a tuple."; return false; } // Don't place scalars in the alternate memory. - if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (ShapeUtil::IsEffectiveScalar(value->shape())) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a scalar."; return false; } @@ -44,10 +45,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // allocate TupleSelect in the alternate memory space. // TODO(berkin): Not allocating add-dependencies either since they need to be // treated specially. We should revisit this later. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kTupleSelect || position.instruction->opcode() == HloOpcode::kAddDependency) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it has a tuple-select or " << "add-dependency position."; return false; @@ -56,18 +57,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // Send and Recv HLOs return a request identifier. These should not be // allocated in the alternate memory. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if ((position.instruction->opcode() == HloOpcode::kSend || position.instruction->opcode() == HloOpcode::kRecv)) { // TODO(berkin): Send/recv buffers need a stable buffer allocation // throughout sending/receiving. Disable memory space allocation for these // for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a send/recv buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a request identifier for " "send/recv."; return false; @@ -78,11 +79,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { // Disable memory space allocation for these for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } @@ -92,4 +93,12 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( return true; } +} // namespace + +bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { + return IsValueAllowedInAlternateMemory(interval.buffer) && + absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); +} + } // namespace xla