[XLA] Fix for a CHECK fail when we disallow a use in the alternate mem space.
The latest buffer in the allocation sequence might not be the one that is the default memory space. We need to search for a buffer that is the default memory space from previous allocations. PiperOrigin-RevId: 348106302 Change-Id: If524183bdef29ebf4e7822eef125eb135d279332
This commit is contained in:
parent
f806d113c0
commit
c25e6c9976
tensorflow/compiler/xla/service
@ -2005,17 +2005,25 @@ AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
|
||||
|
||||
if (required_assignment_at_start) {
|
||||
if (!allocation_sequence->empty()) {
|
||||
const auto& prev_allocation = allocation_sequence->back();
|
||||
CHECK(prev_allocation->memory_space() ==
|
||||
required_assignment_at_start->memory_space);
|
||||
if (required_assignment_at_start->memory_space ==
|
||||
MemorySpace::kAlternate) {
|
||||
// We expect the required assignment offset to match the offset of the
|
||||
// previous allocation.
|
||||
CHECK_EQ(GetAliasedOffset(*prev_allocation),
|
||||
required_assignment_at_start->offset);
|
||||
}
|
||||
prev_allocation->Extend(request.start_time);
|
||||
// We shouldn't have a situation where the required assignment at start is
|
||||
// at alternate memory space and we have existing allocations in the
|
||||
// allocation sequence. The only time we'll have required assignment at
|
||||
// start to be in the alternate memory space is in called computations
|
||||
// (e.g., while body) and we shouldn't have any allocations in the
|
||||
// allocation sequence so far.
|
||||
CHECK(required_assignment_at_start->memory_space ==
|
||||
MemorySpace::kDefault);
|
||||
// Find the previous allocation in default memory (might not be the very
|
||||
// last one) and extend its lifetime to include the start time of this
|
||||
// segment.
|
||||
auto prev_allocation_in_default_mem_it = std::find_if(
|
||||
allocation_sequence->rbegin(), allocation_sequence->rend(),
|
||||
[&](const auto& allocation) {
|
||||
return allocation->memory_space() == MemorySpace::kDefault &&
|
||||
allocation->defining_position() == defining_position;
|
||||
});
|
||||
CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
|
||||
(*prev_allocation_in_default_mem_it)->Extend(request.start_time);
|
||||
} else {
|
||||
absl::optional<Chunk> aliased_chunk = absl::nullopt;
|
||||
if (required_assignment_at_start->memory_space ==
|
||||
|
@ -4221,6 +4221,73 @@ TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
|
||||
options);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
|
||||
// Test for situations where we disallow a use (tanh in this case) in the
|
||||
// alternate memory space and there is a subsequent use that also requires the
|
||||
// buffer to be in the default memory space. In this case, the allocation in
|
||||
// the default memory space might not be the very last one, so we need to
|
||||
// search the allocation sequence and find the one in the default memory
|
||||
// space.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule module, is_scheduled=true
|
||||
|
||||
while_cond {
|
||||
p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
|
||||
ROOT gte = pred[] get-tuple-element(p0), index=3
|
||||
}
|
||||
|
||||
while_body {
|
||||
p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
|
||||
gte0 = f32[3]{0} get-tuple-element(p0), index=0
|
||||
gte1 = f32[3]{0} get-tuple-element(p0), index=1
|
||||
gte2 = f32[3]{0} get-tuple-element(p0), index=2
|
||||
gte3 = pred[] get-tuple-element(p0), index=3
|
||||
add = f32[3]{0} add(gte0, gte0)
|
||||
negate0 = f32[3]{0} negate(add)
|
||||
negate1 = f32[3]{0} negate(negate0)
|
||||
negate2 = f32[3]{0} negate(negate1)
|
||||
negate3 = f32[3]{0} negate(negate2)
|
||||
negate4 = f32[3]{0} negate(negate3)
|
||||
negate5 = f32[3]{0} negate(negate4)
|
||||
negate6 = f32[3]{0} negate(negate5)
|
||||
negate7 = f32[3]{0} negate(negate6)
|
||||
negate8 = f32[3]{0} negate(negate7)
|
||||
negate9 = f32[3]{0} negate(negate8)
|
||||
negate10 = f32[3]{0} negate(negate9)
|
||||
negate11 = f32[3]{0} negate(negate10)
|
||||
negate12 = f32[3]{0} negate(negate11)
|
||||
negate13 = f32[3]{0} negate(negate12)
|
||||
negate14 = f32[3]{0} negate(negate13)
|
||||
negate15 = f32[3]{0} negate(gte2)
|
||||
tanh = f32[3]{0} tanh(gte2)
|
||||
ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[3]{0} parameter(0)
|
||||
p1 = pred[] parameter(1)
|
||||
copy0 = f32[3]{0} copy(p0)
|
||||
copy1 = f32[3]{0} copy(p0)
|
||||
tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
|
||||
while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte = f32[3]{0} get-tuple-element(while), index=2
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
MemorySpaceAssignment::Options options;
|
||||
options.max_size_in_bytes = 128;
|
||||
options.alignment_in_bytes = 8;
|
||||
options.verify = true;
|
||||
options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
|
||||
return use.instruction->opcode() != HloOpcode::kTanh;
|
||||
};
|
||||
AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
|
||||
/*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
|
||||
options);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
|
||||
// Tests against a bug where the root of entry computation is a bitcast
|
||||
// instruction and it ends up getting an allocation in the alternate memory.
|
||||
|
Loading…
Reference in New Issue
Block a user