diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index b003045e66c..803140b804e 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -3001,18 +3001,23 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { } } - if (last_use_instruction && - last_use_instruction->opcode() == HloOpcode::kConditional) { + std::function + split_conditional_buffer; + split_conditional_buffer = [&](const HloInstruction* use_instruction, + int64 start_time, int64 end_time, + absl::string_view indent_string) { // Special case when verifying conditional: we internally split the use // of alternate memory in conditionals, so fish them out from the // conditionals. - VLOG(3) << " Splitting conditional buffer: " << buffer.ToString() - << " value: " << value->ToShortString() << ": (" - << time_bound.start << ", " << time_bound.end - << ") off: " << chunk.offset << ", size: " << chunk.size; - int64 earliest_computation_start_time = time_bound.end; + VLOG(3) << indent_string + << "Splitting conditional buffer: " << buffer.ToString() + << " value: " << value->ToShortString() << ": (" << start_time + << ", " << end_time << ") off: " << chunk.offset + << ", size: " << chunk.size; + int64 earliest_computation_start_time = end_time; for (const HloComputation* called_computation : - last_use_instruction->called_computations()) { + use_instruction->called_computations()) { earliest_computation_start_time = std::min(earliest_computation_start_time, hlo_live_range->computation_span_times() @@ -3020,6 +3025,7 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { .start); int64 parameter_time = -1; int64 last_use_time = -1; + const HloInstruction* last_use_instruction = nullptr; for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kParameter && position.instruction->parent() == called_computation) { @@ -3029,26 +3035,44 @@ Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() { } } for (const HloUse& use : value->uses()) { - if (use.instruction->parent() == called_computation) { - last_use_time = std::max( - last_use_time, - hlo_live_range->instruction_schedule().at(use.instruction)); + int64 use_time = + hlo_live_range->instruction_schedule().at(use.instruction); + if (use.instruction->parent() == called_computation && + use_time > last_use_time) { + last_use_time = use_time; + last_use_instruction = use.instruction; } } if (last_use_time != -1) { CHECK_NE(parameter_time, -1); - VLOG(3) << " computation: " << called_computation->name() << ": (" + VLOG(3) << indent_string + << " computation: " << called_computation->name() << ": (" << parameter_time << ", " << last_use_time << ")"; - TF_RETURN_IF_ERROR(add_allocation_and_verify( - parameter_time, last_use_time, chunk, value)); + CHECK(last_use_instruction); + if (last_use_instruction->opcode() == HloOpcode::kConditional) { + // The last use is another (nested) conditional. Call this + // function recursively. + TF_RETURN_IF_ERROR(split_conditional_buffer( + last_use_instruction, parameter_time, last_use_time, + absl::StrCat(indent_string, " "))); + } else { + TF_RETURN_IF_ERROR(add_allocation_and_verify( + parameter_time, last_use_time, chunk, value)); + } } } - VLOG(3) << " from beginning until first computation: (" - << time_bound.start << ", " - << (earliest_computation_start_time - 1) << ")"; + VLOG(3) << indent_string << " from beginning until first computation: (" + << start_time << ", " << (earliest_computation_start_time - 1) + << ")"; TF_RETURN_IF_ERROR(add_allocation_and_verify( - time_bound.start, earliest_computation_start_time - 1, chunk, - value)); + start_time, earliest_computation_start_time - 1, chunk, value)); + return Status::OK(); + }; + + if (last_use_instruction && + last_use_instruction->opcode() == HloOpcode::kConditional) { + TF_RETURN_IF_ERROR(split_conditional_buffer( + last_use_instruction, time_bound.start, time_bound.end, " ")); } else { VLOG(3) << " buffer: " << buffer.ToString() << " value: " << value->ToShortString() << ": (" diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index d609f7edd1d..c0fdc5fc00d 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -2153,6 +2153,58 @@ TEST_P(MemorySpaceAssignmentTest, NestedConditional) { } } +TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) { + // Tests a spurious verification failure when there are nested conditionals + // and the innermost conditional computation reuses the buffer. Here, both the + // parameter of true_computation2 and neg2 will get the same buffer. Make sure + // that verification doesn't claim a failure in this case. + absl::string_view hlo_string = R"( + HloModule CondAllocation, is_scheduled=true + + true_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + neg1 = f32[3]{0} negate(gte) + neg2 = f32[3]{0} negate(neg1) + ROOT neg3 = f32[3]{0} negate(neg2) + } + + false_computation2 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg4 = f32[3]{0} negate(gte) + } + + true_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + slice = f32[1]{0} slice(gte), slice={[0:1]} + bitcast = f32[] bitcast(slice) + constant = f32[] constant(0.0) + compare = pred[] compare(bitcast, constant), direction=GT + tuple = (f32[3]{0}) tuple(gte) + ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2 + } + + false_computation1 { + p0 = (f32[3]{0}) parameter(0) + gte = f32[3]{0} get-tuple-element(p0), index=0 + ROOT neg5 = f32[3]{0} negate(gte) + } + + ENTRY entry { + p0 = f32[3]{0} parameter(0) + p1 = pred[] parameter(1) + copy = f32[3]{0} copy(p0) + tuple = (f32[3]{0}) tuple(copy) + ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + AssignMemorySpace(module.get()); +} + TEST_P(MemorySpaceAssignmentTest, RequestIdentifierShouldNotBeAllocatedInAlternateMem) { // Ensure that request identifier returned by Send/Recv HLOs are not allocated