[XLA] Don't allocate request identifiers to alternate mem.
PiperOrigin-RevId: 289733771 Change-Id: Ib1a0324648952a4ea88be91890de568d34456018
This commit is contained in:
parent
1b7514a74b
commit
230ebd5d96
@ -1375,8 +1375,8 @@ Status BufferAssigner::AssignPresetBuffers(
|
||||
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
|
||||
auto preset_allocations_iter = preset_allocations.find(value.color());
|
||||
CHECK(preset_allocations_iter != preset_allocations.end())
|
||||
<< "No preset value allocation for color " << value.color()
|
||||
<< " found.";
|
||||
<< "No preset value allocation for color " << value.color() << " for "
|
||||
<< value.ToShortString() << " found.";
|
||||
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
|
||||
chunk.size);
|
||||
|
||||
|
@ -258,6 +258,51 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||
return colocated_intervals;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
|
||||
const BufferInterval& interval) const {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
// we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
|
||||
// buffer, but just forwards the buffers in the either left or right side.
|
||||
// This means the the two different inputs to TupleSelect must not alias, yet
|
||||
// they should be allocated in the same memory space, and both buffers must be
|
||||
// kept alive for the entire live range of TupleSelect. Instead, just don't
|
||||
// allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||
// treated specially. We should revisit this later.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Send and Recv HLOs return a request identifier. These should not be
|
||||
// allocated in the alternate memory.
|
||||
const HloPosition& defining_position = interval.buffer->defining_position();
|
||||
if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
|
||||
defining_position.instruction->opcode() == HloOpcode::kRecv) &&
|
||||
defining_position.index == ShapeIndex({1})) {
|
||||
VLOG(4)
|
||||
<< "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a request identifier for send/recv.";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
std::vector<BufferInterval> sorted_buffer_intervals =
|
||||
GetSortedBufferIntervals();
|
||||
@ -279,36 +324,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just
|
||||
// pointers) we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
continue;
|
||||
}
|
||||
|
||||
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
|
||||
// buffer, but just forwards the buffers in the either left or right side.
|
||||
// This means the the two different inputs to TupleSelect must not alias,
|
||||
// yet they should be allocated in the same memory space, and both buffers
|
||||
// must be kept alive for the entire live range of TupleSelect. Instead,
|
||||
// just don't allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to
|
||||
// be treated specially. We should revisit this later.
|
||||
bool keep_in_default_mem = false;
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
keep_in_default_mem = true;
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (keep_in_default_mem) {
|
||||
if (!IsIntervalAllowedInAlternateMemory(interval)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -621,6 +621,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
// it is a parameter in default memory or an ouput in default memory.
|
||||
bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
|
||||
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||
|
||||
// Finds an allocation for the given interval. Internally, it will attempt to
|
||||
// find a suitable chunk candidate within the heap size and prefetch interval
|
||||
// limits, and append the new allocation(s) to allocations. The new
|
||||
|
@ -1268,6 +1268,42 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
|
||||
AssignMemorySpace(module.get());
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest,
|
||||
RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
|
||||
// Ensure that request identifier returned by Send/Recv HLOs are not allocated
|
||||
// in the alternate memory.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule SendRecv, is_scheduled=true
|
||||
|
||||
ENTRY %AddDependency (p: f32[3]) -> f32[3] {
|
||||
%p = f32[3]{0} parameter(0)
|
||||
%after-all = token[] after-all()
|
||||
%recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
|
||||
%recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
|
||||
%token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
|
||||
%data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
|
||||
%send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
|
||||
%send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
|
||||
ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
for (const HloInstruction* instruction :
|
||||
module->entry_computation()->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kSend ||
|
||||
instruction->opcode() == HloOpcode::kRecv) {
|
||||
const Shape& request_identifier_shape =
|
||||
ShapeUtil::GetSubshape(instruction->shape(), {1});
|
||||
EXPECT_NE(request_identifier_shape.layout().memory_space(),
|
||||
kAlternateMemorySpace);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
|
||||
// Test that checks the last use optimization. It uses two buffers that should
|
||||
// be placed in alternate memory.
|
||||
|
Loading…
Reference in New Issue
Block a user