[XLA] Fix a spurious verification failure with nested conditionals.

When allocating alternate memory for conditionals, we assume the conditional
will always evict the value back to default memory before the end of the called
computations. In other words, the output of conditional can never get alternate
memory allocations due to difficulties with aliasing, but the inputs can. When
verifying the correctness of memory space assignment, we have to split the
conditional uses by the called computations and find the time bounds in each
computation that the buffer in alternate memory is used. We previously didn't do
this splitting for nested conditionals, causing b/161935244. Now we do this
splitting recursively.

PiperOrigin-RevId: 322865495
Change-Id: I522ccd9e54e73ee9d3b53997fb28198651338909
This commit is contained in:
Berkin Ilbeyi 2020-07-23 14:23:06 -07:00 committed by TensorFlower Gardener
parent f300cac524
commit 18b810a970
2 changed files with 96 additions and 20 deletions

View File

@ -3001,18 +3001,23 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
}
}
if (last_use_instruction &&
last_use_instruction->opcode() == HloOpcode::kConditional) {
std::function<Status(const HloInstruction*, int64, int64,
absl::string_view)>
split_conditional_buffer;
split_conditional_buffer = [&](const HloInstruction* use_instruction,
int64 start_time, int64 end_time,
absl::string_view indent_string) {
// Special case when verifying conditional: we internally split the use
// of alternate memory in conditionals, so fish them out from the
// conditionals.
VLOG(3) << " Splitting conditional buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("
<< time_bound.start << ", " << time_bound.end
<< ") off: " << chunk.offset << ", size: " << chunk.size;
int64 earliest_computation_start_time = time_bound.end;
VLOG(3) << indent_string
<< "Splitting conditional buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": (" << start_time
<< ", " << end_time << ") off: " << chunk.offset
<< ", size: " << chunk.size;
int64 earliest_computation_start_time = end_time;
for (const HloComputation* called_computation :
last_use_instruction->called_computations()) {
use_instruction->called_computations()) {
earliest_computation_start_time =
std::min(earliest_computation_start_time,
hlo_live_range->computation_span_times()
@ -3020,6 +3025,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
.start);
int64 parameter_time = -1;
int64 last_use_time = -1;
const HloInstruction* last_use_instruction = nullptr;
for (const HloPosition& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kParameter &&
position.instruction->parent() == called_computation) {
@ -3029,26 +3035,44 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
}
}
for (const HloUse& use : value->uses()) {
if (use.instruction->parent() == called_computation) {
last_use_time = std::max(
last_use_time,
hlo_live_range->instruction_schedule().at(use.instruction));
int64 use_time =
hlo_live_range->instruction_schedule().at(use.instruction);
if (use.instruction->parent() == called_computation &&
use_time > last_use_time) {
last_use_time = use_time;
last_use_instruction = use.instruction;
}
}
if (last_use_time != -1) {
CHECK_NE(parameter_time, -1);
VLOG(3) << " computation: " << called_computation->name() << ": ("
VLOG(3) << indent_string
<< " computation: " << called_computation->name() << ": ("
<< parameter_time << ", " << last_use_time << ")";
TF_RETURN_IF_ERROR(add_allocation_and_verify(
parameter_time, last_use_time, chunk, value));
CHECK(last_use_instruction);
if (last_use_instruction->opcode() == HloOpcode::kConditional) {
// The last use is another (nested) conditional. Call this
// function recursively.
TF_RETURN_IF_ERROR(split_conditional_buffer(
last_use_instruction, parameter_time, last_use_time,
absl::StrCat(indent_string, " ")));
} else {
TF_RETURN_IF_ERROR(add_allocation_and_verify(
parameter_time, last_use_time, chunk, value));
}
}
}
VLOG(3) << " from beginning until first computation: ("
<< time_bound.start << ", "
<< (earliest_computation_start_time - 1) << ")";
VLOG(3) << indent_string << " from beginning until first computation: ("
<< start_time << ", " << (earliest_computation_start_time - 1)
<< ")";
TF_RETURN_IF_ERROR(add_allocation_and_verify(
time_bound.start, earliest_computation_start_time - 1, chunk,
value));
start_time, earliest_computation_start_time - 1, chunk, value));
return Status::OK();
};
if (last_use_instruction &&
last_use_instruction->opcode() == HloOpcode::kConditional) {
TF_RETURN_IF_ERROR(split_conditional_buffer(
last_use_instruction, time_bound.start, time_bound.end, " "));
} else {
VLOG(3) << " buffer: " << buffer.ToString()
<< " value: " << value->ToShortString() << ": ("

View File

@ -2153,6 +2153,58 @@ TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
}
}
TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
// Tests a spurious verification failure when there are nested conditionals
// and the innermost conditional computation reuses the buffer. Here, both the
// parameter of true_computation2 and neg2 will get the same buffer. Make sure
// that verification doesn't claim a failure in this case.
absl::string_view hlo_string = R"(
HloModule CondAllocation, is_scheduled=true
true_computation2 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
neg1 = f32[3]{0} negate(gte)
neg2 = f32[3]{0} negate(neg1)
ROOT neg3 = f32[3]{0} negate(neg2)
}
false_computation2 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg4 = f32[3]{0} negate(gte)
}
true_computation1 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
slice = f32[1]{0} slice(gte), slice={[0:1]}
bitcast = f32[] bitcast(slice)
constant = f32[] constant(0.0)
compare = pred[] compare(bitcast, constant), direction=GT
tuple = (f32[3]{0}) tuple(gte)
ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
}
false_computation1 {
p0 = (f32[3]{0}) parameter(0)
gte = f32[3]{0} get-tuple-element(p0), index=0
ROOT neg5 = f32[3]{0} negate(gte)
}
ENTRY entry {
p0 = f32[3]{0} parameter(0)
p1 = pred[] parameter(1)
copy = f32[3]{0} copy(p0)
tuple = (f32[3]{0}) tuple(copy)
ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
AssignMemorySpace(module.get());
}
TEST_P(MemorySpaceAssignmentTest,
RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
// Ensure that request identifier returned by Send/Recv HLOs are not allocated