[XLA] Fix for a CHECK fail when we disallow a use in the alternate mem space.

The latest buffer in the allocation sequence might not be the one that is the
default memory space. We need to search for a buffer that is the default memory
space from previous allocations.

PiperOrigin-RevId: 348106302
Change-Id: If524183bdef29ebf4e7822eef125eb135d279332
This commit is contained in:
Berkin Ilbeyi 2020-12-17 15:34:35 -08:00 committed by TensorFlower Gardener
parent f806d113c0
commit c25e6c9976
2 changed files with 86 additions and 11 deletions

View File

@ -2005,17 +2005,25 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
if (required_assignment_at_start) {
if (!allocation_sequence->empty()) {
const auto& prev_allocation = allocation_sequence->back();
CHECK(prev_allocation->memory_space() ==
required_assignment_at_start->memory_space);
if (required_assignment_at_start->memory_space ==
MemorySpace::kAlternate) {
// We expect the required assignment offset to match the offset of the
// previous allocation.
CHECK_EQ(GetAliasedOffset(*prev_allocation),
required_assignment_at_start->offset);
}
prev_allocation->Extend(request.start_time);
// We shouldn't have a situation where the required assignment at start is
// at alternate memory space and we have existing allocations in the
// allocation sequence. The only time we'll have required assignment at
// start to be in the alternate memory space is in called computations
// (e.g., while body) and we shouldn't have any allocations in the
// allocation sequence so far.
CHECK(required_assignment_at_start->memory_space ==
MemorySpace::kDefault);
// Find the previous allocation in default memory (might not be the very
// last one) and extend its lifetime to include the start time of this
// segment.
auto prev_allocation_in_default_mem_it = std::find_if(
allocation_sequence->rbegin(), allocation_sequence->rend(),
[&](const auto& allocation) {
return allocation->memory_space() == MemorySpace::kDefault &&
allocation->defining_position() == defining_position;
});
CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
(*prev_allocation_in_default_mem_it)->Extend(request.start_time);
} else {
absl::optional<Chunk> aliased_chunk = absl::nullopt;
if (required_assignment_at_start->memory_space ==

View File

@ -4221,6 +4221,73 @@ TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
options);
}
TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
// Test for situations where we disallow a use (tanh in this case) in the
// alternate memory space and there is a subsequent use that also requires the
// buffer to be in the default memory space. In this case, the allocation in
// the default memory space might not be the very last one, so we need to
// search the allocation sequence and find the one in the default memory
// space.
absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(p0), index=3
}
while_body {
p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
gte0 = f32[3]{0} get-tuple-element(p0), index=0
gte1 = f32[3]{0} get-tuple-element(p0), index=1
gte2 = f32[3]{0} get-tuple-element(p0), index=2
gte3 = pred[] get-tuple-element(p0), index=3
add = f32[3]{0} add(gte0, gte0)
negate0 = f32[3]{0} negate(add)
negate1 = f32[3]{0} negate(negate0)
negate2 = f32[3]{0} negate(negate1)
negate3 = f32[3]{0} negate(negate2)
negate4 = f32[3]{0} negate(negate3)
negate5 = f32[3]{0} negate(negate4)
negate6 = f32[3]{0} negate(negate5)
negate7 = f32[3]{0} negate(negate6)
negate8 = f32[3]{0} negate(negate7)
negate9 = f32[3]{0} negate(negate8)
negate10 = f32[3]{0} negate(negate9)
negate11 = f32[3]{0} negate(negate10)
negate12 = f32[3]{0} negate(negate11)
negate13 = f32[3]{0} negate(negate12)
negate14 = f32[3]{0} negate(negate13)
negate15 = f32[3]{0} negate(gte2)
tanh = f32[3]{0} tanh(gte2)
ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy0 = f32[3]{0} copy(p0)
copy1 = f32[3]{0} copy(p0)
tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
ROOT gte = f32[3]{0} get-tuple-element(while), index=2
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
MemorySpaceAssignment::Options options;
options.max_size_in_bytes = 128;
options.alignment_in_bytes = 8;
options.verify = true;
options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
return use.instruction->opcode() != HloOpcode::kTanh;
};
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
/*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
options);
}
TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
// Tests against a bug where the root of entry computation is a bitcast
// instruction and it ends up getting an allocation in the alternate memory.