[XLA] When disallowing buffers from alternate mem allocation, also check its colocations.

PiperOrigin-RevId: 327123032
Change-Id: I9ec3ba91c29f40e8089b79b0eaeed7d8848ea07f
This commit is contained in:
Berkin Ilbeyi 2020-08-17 16:26:40 -07:00 committed by TensorFlower Gardener
parent 66d54d7de2
commit a094af6dec

View File

@ -17,21 +17,22 @@ limitations under the License.
namespace xla {
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
namespace {
bool IsValueAllowedInAlternateMemory(const HloValue* value) {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just pointers)
// we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
if (value->shape().IsTuple()) {
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a tuple.";
return false;
}
// Don't place scalars in the alternate memory.
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
if (ShapeUtil::IsEffectiveScalar(value->shape())) {
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a scalar.";
return false;
}
@ -44,10 +45,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
// allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to be
// treated specially. We should revisit this later.
for (const HloPosition& position : interval.buffer->positions()) {
for (const HloPosition& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
return false;
@ -56,18 +57,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
// Send and Recv HLOs return a request identifier. These should not be
// allocated in the alternate memory.
for (const HloPosition& position : interval.buffer->positions()) {
for (const HloPosition& position : value->positions()) {
if ((position.instruction->opcode() == HloOpcode::kSend ||
position.instruction->opcode() == HloOpcode::kRecv)) {
// TODO(berkin): Send/recv buffers need a stable buffer allocation
// throughout sending/receiving. Disable memory space allocation for these
// for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a send/recv buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a request identifier for "
"send/recv.";
return false;
@ -78,11 +79,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
// Disable memory space allocation for these for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
VLOG(4) << "Keeping value " << value->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
}
@ -92,4 +93,12 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
return true;
}
} // namespace
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
return IsValueAllowedInAlternateMemory(interval.buffer) &&
absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory);
}
} // namespace xla