[XLA] When disallowing buffers from alternate mem allocation, also check its colocations.
PiperOrigin-RevId: 327123032 Change-Id: I9ec3ba91c29f40e8089b79b0eaeed7d8848ea07f
This commit is contained in:
parent
66d54d7de2
commit
a094af6dec
@ -17,21 +17,22 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
||||
namespace {
|
||||
|
||||
bool IsValueAllowedInAlternateMemory(const HloValue* value) {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
// we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
if (value->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't place scalars in the alternate memory.
|
||||
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
if (ShapeUtil::IsEffectiveScalar(value->shape())) {
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a scalar.";
|
||||
return false;
|
||||
}
|
||||
@ -44,10 +45,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
// allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||
// treated specially. We should revisit this later.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
return false;
|
||||
@ -56,18 +57,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
|
||||
// Send and Recv HLOs return a request identifier. These should not be
|
||||
// allocated in the alternate memory.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
||||
position.instruction->opcode() == HloOpcode::kRecv)) {
|
||||
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
||||
// throughout sending/receiving. Disable memory space allocation for these
|
||||
// for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a send/recv buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a request identifier for "
|
||||
"send/recv.";
|
||||
return false;
|
||||
@ -78,11 +79,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||
// Disable memory space allocation for these for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
}
|
||||
@ -92,4 +93,12 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
||||
return IsValueAllowedInAlternateMemory(interval.buffer) &&
|
||||
absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user