[XLA] Fix a spurious verification failure with nested conditionals.
When allocating alternate memory for conditionals, we assume the conditional will always evict the value back to default memory before the end of the called computations. In other words, the output of conditional can never get alternate memory allocations due to difficulties with aliasing, but the inputs can. When verifying the correctness of memory space assignment, we have to split the conditional uses by the called computations and find the time bounds in each computation that the buffer in alternate memory is used. We previously didn't do this splitting for nested conditionals, causing b/161935244. Now we do this splitting recursively. PiperOrigin-RevId: 322865495 Change-Id: I522ccd9e54e73ee9d3b53997fb28198651338909
This commit is contained in:
parent
f300cac524
commit
18b810a970
@ -3001,18 +3001,23 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
|
||||
}
|
||||
}
|
||||
|
||||
if (last_use_instruction &&
|
||||
last_use_instruction->opcode() == HloOpcode::kConditional) {
|
||||
std::function<Status(const HloInstruction*, int64, int64,
|
||||
absl::string_view)>
|
||||
split_conditional_buffer;
|
||||
split_conditional_buffer = [&](const HloInstruction* use_instruction,
|
||||
int64 start_time, int64 end_time,
|
||||
absl::string_view indent_string) {
|
||||
// Special case when verifying conditional: we internally split the use
|
||||
// of alternate memory in conditionals, so fish them out from the
|
||||
// conditionals.
|
||||
VLOG(3) << " Splitting conditional buffer: " << buffer.ToString()
|
||||
<< " value: " << value->ToShortString() << ": ("
|
||||
<< time_bound.start << ", " << time_bound.end
|
||||
<< ") off: " << chunk.offset << ", size: " << chunk.size;
|
||||
int64 earliest_computation_start_time = time_bound.end;
|
||||
VLOG(3) << indent_string
|
||||
<< "Splitting conditional buffer: " << buffer.ToString()
|
||||
<< " value: " << value->ToShortString() << ": (" << start_time
|
||||
<< ", " << end_time << ") off: " << chunk.offset
|
||||
<< ", size: " << chunk.size;
|
||||
int64 earliest_computation_start_time = end_time;
|
||||
for (const HloComputation* called_computation :
|
||||
last_use_instruction->called_computations()) {
|
||||
use_instruction->called_computations()) {
|
||||
earliest_computation_start_time =
|
||||
std::min(earliest_computation_start_time,
|
||||
hlo_live_range->computation_span_times()
|
||||
@ -3020,6 +3025,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
|
||||
.start);
|
||||
int64 parameter_time = -1;
|
||||
int64 last_use_time = -1;
|
||||
const HloInstruction* last_use_instruction = nullptr;
|
||||
for (const HloPosition& position : value->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kParameter &&
|
||||
position.instruction->parent() == called_computation) {
|
||||
@ -3029,26 +3035,44 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
|
||||
}
|
||||
}
|
||||
for (const HloUse& use : value->uses()) {
|
||||
if (use.instruction->parent() == called_computation) {
|
||||
last_use_time = std::max(
|
||||
last_use_time,
|
||||
hlo_live_range->instruction_schedule().at(use.instruction));
|
||||
int64 use_time =
|
||||
hlo_live_range->instruction_schedule().at(use.instruction);
|
||||
if (use.instruction->parent() == called_computation &&
|
||||
use_time > last_use_time) {
|
||||
last_use_time = use_time;
|
||||
last_use_instruction = use.instruction;
|
||||
}
|
||||
}
|
||||
if (last_use_time != -1) {
|
||||
CHECK_NE(parameter_time, -1);
|
||||
VLOG(3) << " computation: " << called_computation->name() << ": ("
|
||||
VLOG(3) << indent_string
|
||||
<< " computation: " << called_computation->name() << ": ("
|
||||
<< parameter_time << ", " << last_use_time << ")";
|
||||
TF_RETURN_IF_ERROR(add_allocation_and_verify(
|
||||
parameter_time, last_use_time, chunk, value));
|
||||
CHECK(last_use_instruction);
|
||||
if (last_use_instruction->opcode() == HloOpcode::kConditional) {
|
||||
// The last use is another (nested) conditional. Call this
|
||||
// function recursively.
|
||||
TF_RETURN_IF_ERROR(split_conditional_buffer(
|
||||
last_use_instruction, parameter_time, last_use_time,
|
||||
absl::StrCat(indent_string, " ")));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(add_allocation_and_verify(
|
||||
parameter_time, last_use_time, chunk, value));
|
||||
}
|
||||
}
|
||||
}
|
||||
VLOG(3) << " from beginning until first computation: ("
|
||||
<< time_bound.start << ", "
|
||||
<< (earliest_computation_start_time - 1) << ")";
|
||||
VLOG(3) << indent_string << " from beginning until first computation: ("
|
||||
<< start_time << ", " << (earliest_computation_start_time - 1)
|
||||
<< ")";
|
||||
TF_RETURN_IF_ERROR(add_allocation_and_verify(
|
||||
time_bound.start, earliest_computation_start_time - 1, chunk,
|
||||
value));
|
||||
start_time, earliest_computation_start_time - 1, chunk, value));
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
if (last_use_instruction &&
|
||||
last_use_instruction->opcode() == HloOpcode::kConditional) {
|
||||
TF_RETURN_IF_ERROR(split_conditional_buffer(
|
||||
last_use_instruction, time_bound.start, time_bound.end, " "));
|
||||
} else {
|
||||
VLOG(3) << " buffer: " << buffer.ToString()
|
||||
<< " value: " << value->ToShortString() << ": ("
|
||||
|
@ -2153,6 +2153,58 @@ TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
|
||||
// Tests a spurious verification failure when there are nested conditionals
|
||||
// and the innermost conditional computation reuses the buffer. Here, both the
|
||||
// parameter of true_computation2 and neg2 will get the same buffer. Make sure
|
||||
// that verification doesn't claim a failure in this case.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule CondAllocation, is_scheduled=true
|
||||
|
||||
true_computation2 {
|
||||
p0 = (f32[3]{0}) parameter(0)
|
||||
gte = f32[3]{0} get-tuple-element(p0), index=0
|
||||
neg1 = f32[3]{0} negate(gte)
|
||||
neg2 = f32[3]{0} negate(neg1)
|
||||
ROOT neg3 = f32[3]{0} negate(neg2)
|
||||
}
|
||||
|
||||
false_computation2 {
|
||||
p0 = (f32[3]{0}) parameter(0)
|
||||
gte = f32[3]{0} get-tuple-element(p0), index=0
|
||||
ROOT neg4 = f32[3]{0} negate(gte)
|
||||
}
|
||||
|
||||
true_computation1 {
|
||||
p0 = (f32[3]{0}) parameter(0)
|
||||
gte = f32[3]{0} get-tuple-element(p0), index=0
|
||||
slice = f32[1]{0} slice(gte), slice={[0:1]}
|
||||
bitcast = f32[] bitcast(slice)
|
||||
constant = f32[] constant(0.0)
|
||||
compare = pred[] compare(bitcast, constant), direction=GT
|
||||
tuple = (f32[3]{0}) tuple(gte)
|
||||
ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
|
||||
}
|
||||
|
||||
false_computation1 {
|
||||
p0 = (f32[3]{0}) parameter(0)
|
||||
gte = f32[3]{0} get-tuple-element(p0), index=0
|
||||
ROOT neg5 = f32[3]{0} negate(gte)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[3]{0} parameter(0)
|
||||
p1 = pred[] parameter(1)
|
||||
copy = f32[3]{0} copy(p0)
|
||||
tuple = (f32[3]{0}) tuple(copy)
|
||||
ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
AssignMemorySpace(module.get());
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest,
|
||||
RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
|
||||
// Ensure that request identifier returned by Send/Recv HLOs are not allocated
|
||||
|
Loading…
Reference in New Issue
Block a user