[XLA] When disallowing buffers from alternate mem allocation, also check its colocations.
PiperOrigin-RevId: 327123032 Change-Id: I9ec3ba91c29f40e8089b79b0eaeed7d8848ea07f
This commit is contained in:
parent
66d54d7de2
commit
a094af6dec
@ -17,21 +17,22 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
namespace {
|
||||||
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
|
||||||
|
bool IsValueAllowedInAlternateMemory(const HloValue* value) {
|
||||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||||
// that are pointed to by the tuple will still use this algorithm. Because
|
// that are pointed to by the tuple will still use this algorithm. Because
|
||||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||||
// we don't need to use prefetch/evict logic.
|
// we don't need to use prefetch/evict logic.
|
||||||
if (interval.buffer->shape().IsTuple()) {
|
if (value->shape().IsTuple()) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a tuple.";
|
<< " in default mem because it is a tuple.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Don't place scalars in the alternate memory.
|
// Don't place scalars in the alternate memory.
|
||||||
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
|
if (ShapeUtil::IsEffectiveScalar(value->shape())) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a scalar.";
|
<< " in default mem because it is a scalar.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -44,10 +45,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
|||||||
// allocate TupleSelect in the alternate memory space.
|
// allocate TupleSelect in the alternate memory space.
|
||||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||||
// treated specially. We should revisit this later.
|
// treated specially. We should revisit this later.
|
||||||
for (const HloPosition& position : interval.buffer->positions()) {
|
for (const HloPosition& position : value->positions()) {
|
||||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it has a tuple-select or "
|
<< " in default mem because it has a tuple-select or "
|
||||||
<< "add-dependency position.";
|
<< "add-dependency position.";
|
||||||
return false;
|
return false;
|
||||||
@ -56,18 +57,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
|||||||
|
|
||||||
// Send and Recv HLOs return a request identifier. These should not be
|
// Send and Recv HLOs return a request identifier. These should not be
|
||||||
// allocated in the alternate memory.
|
// allocated in the alternate memory.
|
||||||
for (const HloPosition& position : interval.buffer->positions()) {
|
for (const HloPosition& position : value->positions()) {
|
||||||
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
||||||
position.instruction->opcode() == HloOpcode::kRecv)) {
|
position.instruction->opcode() == HloOpcode::kRecv)) {
|
||||||
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
||||||
// throughout sending/receiving. Disable memory space allocation for these
|
// throughout sending/receiving. Disable memory space allocation for these
|
||||||
// for now.
|
// for now.
|
||||||
if (position.index == ShapeIndex({0})) {
|
if (position.index == ShapeIndex({0})) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a send/recv buffer.";
|
<< " in default mem because it is a send/recv buffer.";
|
||||||
return false;
|
return false;
|
||||||
} else if (position.index == ShapeIndex({1})) {
|
} else if (position.index == ShapeIndex({1})) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a request identifier for "
|
<< " in default mem because it is a request identifier for "
|
||||||
"send/recv.";
|
"send/recv.";
|
||||||
return false;
|
return false;
|
||||||
@ -78,11 +79,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
|||||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||||
// Disable memory space allocation for these for now.
|
// Disable memory space allocation for these for now.
|
||||||
if (position.index == ShapeIndex({0})) {
|
if (position.index == ShapeIndex({0})) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a collective-permute buffer.";
|
<< " in default mem because it is a collective-permute buffer.";
|
||||||
return false;
|
return false;
|
||||||
} else if (position.index == ShapeIndex({1})) {
|
} else if (position.index == ShapeIndex({1})) {
|
||||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
VLOG(4) << "Keeping value " << value->ToShortString()
|
||||||
<< " in default mem because it is a collective-permute buffer.";
|
<< " in default mem because it is a collective-permute buffer.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -92,4 +93,12 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||||
|
const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval) {
|
||||||
|
return IsValueAllowedInAlternateMemory(interval.buffer) &&
|
||||||
|
absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user