[XLA] Don't allocate request identifiers to alternate mem.

PiperOrigin-RevId: 289733771
Change-Id: Ib1a0324648952a4ea88be91890de568d34456018
This commit is contained in:
Berkin Ilbeyi 2020-01-14 14:40:14 -08:00 committed by TensorFlower Gardener
parent 1b7514a74b
commit 230ebd5d96
4 changed files with 88 additions and 32 deletions

View File

@ -1375,8 +1375,8 @@ Status BufferAssigner::AssignPresetBuffers(
const HeapSimulator::Chunk& chunk = position_and_chunk.second;
auto preset_allocations_iter = preset_allocations.find(value.color());
CHECK(preset_allocations_iter != preset_allocations.end())
<< "No preset value allocation for color " << value.color()
<< " found.";
<< "No preset value allocation for color " << value.color() << " for "
<< value.ToShortString() << " found.";
preset_allocations_iter->second->AddAssignment(value, chunk.offset,
chunk.size);

View File

@ -258,6 +258,51 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
return colocated_intervals;
}
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
const BufferInterval& interval) const {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just pointers)
// we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
return false;
}
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
// buffer, but just forwards the buffers in the either left or right side.
// This means the the two different inputs to TupleSelect must not alias, yet
// they should be allocated in the same memory space, and both buffers must be
// kept alive for the entire live range of TupleSelect. Instead, just don't
// allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to be
// treated specially. We should revisit this later.
for (const HloPosition& position : interval.buffer->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
return false;
}
}
// Send and Recv HLOs return a request identifier. These should not be
// allocated in the alternate memory.
const HloPosition& defining_position = interval.buffer->defining_position();
if ((defining_position.instruction->opcode() == HloOpcode::kSend ||
defining_position.instruction->opcode() == HloOpcode::kRecv) &&
defining_position.index == ShapeIndex({1})) {
VLOG(4)
<< "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a request identifier for send/recv.";
return false;
}
return true;
}
HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
std::vector<BufferInterval> sorted_buffer_intervals =
GetSortedBufferIntervals();
@ -279,36 +324,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just
// pointers) we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
continue;
}
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
// buffer, but just forwards the buffers in the either left or right side.
// This means the the two different inputs to TupleSelect must not alias,
// yet they should be allocated in the same memory space, and both buffers
// must be kept alive for the entire live range of TupleSelect. Instead,
// just don't allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to
// be treated specially. We should revisit this later.
bool keep_in_default_mem = false;
for (const HloPosition& position : interval.buffer->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
keep_in_default_mem = true;
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
break;
}
}
if (keep_in_default_mem) {
if (!IsIntervalAllowedInAlternateMemory(interval)) {
continue;
}

View File

@ -621,6 +621,10 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// it is a parameter in default memory or an ouput in default memory.
bool RequiredInDefaultMemory(const HloValue* buffer, int64 time) const;
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
// Finds an allocation for the given interval. Internally, it will attempt to
// find a suitable chunk candidate within the heap size and prefetch interval
// limits, and append the new allocation(s) to allocations. The new

View File

@ -1268,6 +1268,42 @@ TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest,
RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
// Ensure that request identifier returned by Send/Recv HLOs are not allocated
// in the alternate memory.
absl::string_view hlo_string = R"(
HloModule SendRecv, is_scheduled=true
ENTRY %AddDependency (p: f32[3]) -> f32[3] {
%p = f32[3]{0} parameter(0)
%after-all = token[] after-all()
%recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
%recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
%token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
%data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
%send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
%send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
for (const HloInstruction* instruction :
module->entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv) {
const Shape& request_identifier_shape =
ShapeUtil::GetSubshape(instruction->shape(), {1});
EXPECT_NE(request_identifier_shape.layout().memory_space(),
kAlternateMemorySpace);
}
}
}
TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
// Test that checks the last use optimization. It uses two buffers that should
// be placed in alternate memory.