diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index df0e5bb4999..46bde33d7c0 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -787,14 +787,12 @@ void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString( } } -void AlternateMemoryBestFitHeap::DumpIfEnabled( - absl::string_view buffer_info_str, - absl::string_view allocation_info_str) const { +void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { if (!options_.dump_fn) { return; } - options_.dump_fn("bufferinfo", buffer_info_str); - options_.dump_fn("allocinfo", allocation_info_str); + options_.dump_fn("bufferinfo", buffer_info_str_); + options_.dump_fn("allocinfo", allocation_info_str_); } HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { @@ -816,9 +814,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } } - std::string buffer_info_str; - std::string allocation_info_str; - for (auto& interval : sorted_buffer_intervals) { if (!interval.need_allocation) { continue; @@ -842,12 +837,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } auto colocated_intervals = GetSortedColocatedIntervals(interval); - // Create AllocationValues for all the - // colocated intervals. - std::vector allocation_values; - for (const auto& colocated_interval : colocated_intervals) { - CreateAllocationValues(colocated_interval->buffer, &allocation_values); - } if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) { VLOG(3) << "Interval " << interval.buffer->ToShortString() @@ -890,8 +879,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { continue; } - const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - // TODO(berkin): For now, place the phi values due to conditionals in // default memory. for (const BufferInterval* colocated_interval : colocated_intervals) { @@ -911,192 +898,203 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } } - AppendBufferInfoDebugString(interval, &buffer_info_str); - - // Data structure to contain the preferred offset for a given computation. - // We ensure that the same offset will be allocated outside the while loop - // as well as inside the while loop. - absl::flat_hash_map - preferred_offset_for_computation; - bool allocation_success = true; - for (auto& allocation_value : allocation_values) { - int64 definition_time = - instruction_schedule.at(allocation_value.defining_instruction()); - - absl::optional preferred_offset; - auto preferred_offset_it = - preferred_offset_for_computation.find(allocation_value.computation()); - if (preferred_offset_it != preferred_offset_for_computation.end()) { - preferred_offset = preferred_offset_it->second; - } - - // Iterate over the uses. - for (int use_idx = 0; use_idx < allocation_value.uses().size(); - ++use_idx) { - const HloUse& use = allocation_value.uses().at(use_idx); - int64 use_time = instruction_schedule.at(use.instruction); - int64 latest_prefetch_time = use_time; - bool allow_no_copy_alternate_mem_allocation = true; - absl::optional earliest_prefetch_time = absl::nullopt; - - // Sequential calls include kWhile, kCall, and kConditional opcodes. - bool is_sequential_call = - (GetInstructionCallContext(use.instruction->opcode()) == - CallContext::kSequential); - if (is_sequential_call) { - for (const HloComputation* called_computation : - use.instruction->called_computations()) { - const HloLiveRange::TimeBound& computation_span = - hlo_live_range_.computation_span_times().at(called_computation); - latest_prefetch_time = - std::min(computation_span.start, latest_prefetch_time); - } - if (use.instruction->opcode() == HloOpcode::kWhile) { - // Given an example while loop and flattened schedule (logical times - // shown on the left): - // - // 0: a = ... - // 1: ... - // cond { - // 2: p = param(0) - // 3: ... - // } - // body { - // 4: p = param(0) - // 5: ... - // 6: ROOT ... - // } - // 7: w = while(a), body=body, cond=cond - // - // When processing "a" (time 0) and its while use (time 7), we - // update the interval to time 0-4. This is so that the remaining - // interval (5-6) can be allocated separately and this buffer - // doesn't waste alternate memory space within the while loop body. - HloComputation* while_body = use.instruction->while_body(); - // We require while body ROOTs to be the last in the schedule. - CHECK_EQ( - instruction_schedule.at(while_body->root_instruction()) + 1, - instruction_schedule.at(use.instruction)) - << "While body ROOTs need to be the last in the schedule! " - "Please run RootInstructionSinker."; - // Replace the use time with the parameter time so that we can - // decide on alternate memory allocations within the while loop body - // when we look at uses within the while loop body. - use_time = - instruction_schedule.at(while_body->parameter_instruction(0)); - } else if (use.instruction->opcode() == HloOpcode::kConditional) { - // Replace the use time with the earliest parameter of called - // computations. - for (const HloComputation* called_computation : - use.instruction->called_computations()) { - use_time = std::min( - use_time, instruction_schedule.at( - called_computation->parameter_instruction(0))); - } - } - } - - // Add a required assignment in default memory if the use not allowed in - // alternate memory. - if (!IsUseAllowedInAlternateMemory(allocation_value, use)) { - AddRequiredAssignment(allocation_value.value(), use.instruction, - MemorySpace::kDefault, use_time); - } else if (use_idx > 0) { - // We allow buffers in alternate memory that are passed into - // conditionals to give up their alternate memory allocation inside - // the called computation. This means that if a conditional operator - // has an alternate memory allocation, subsequent uses cannot use the - // same alternate memory allocation in order not to clobber data. So - // we force default memory allocation for these subsequent uses. - const HloUse& previous_use = allocation_value.uses().at(use_idx - 1); - if (previous_use.instruction->opcode() == HloOpcode::kConditional && - previous_use.instruction != use.instruction) { - allow_no_copy_alternate_mem_allocation = false; - earliest_prefetch_time = - instruction_schedule.at(previous_use.instruction); - VLOG(3) << "Previous use (" << previous_use.ToString() - << ") of use (" << use.ToString() - << ") is a conditional, so this use will need to evict. " - << "Earliest prefetch time = " << *earliest_prefetch_time; - } - } - - // Bitcasts don't define buffers and don't directly consume buffers. - // Skip allocating buffers for bitcast uses. The uses that feed from - // bitcasts will be handled specially. - if (use.instruction->opcode() != HloOpcode::kBitcast) { - AllocationRequest request; - // Rarely, (e.g., when conditional true and false parameters are the - // same), definition time can be the time of the conditional and use - // time is the parameter use, which is less. - request.start_time = std::min(definition_time, use_time); - request.end_time = use_time; - request.latest_prefetch_time = latest_prefetch_time; - request.size = interval.size; - request.allow_no_copy_alternate_mem_allocation = - allow_no_copy_alternate_mem_allocation; - request.earliest_prefetch_time = earliest_prefetch_time; - request.preferred_offset = preferred_offset; - request.use = use; - request.allocation_value = &allocation_value; - if (!FindAllocation(request)) { - // If the allocation finding failed (e.g., due to running out of - // asynchronous copies), then fall back to allocating the buffer - // entirely in the default memory. - UncommitPendingChunks(); - allocation_success = false; - break; - } - - // If there are multiple uses, they can try using the memory - // allocation already at the alternate memory. - definition_time = instruction_schedule.at(use.instruction); - } - - // If the use has been a sequential call (e.g. a while loop), the other - // colocated intervals must alias with this allocation. - if (is_sequential_call) { - MemorySpaceAssignment::Allocation* aliased_allocation = - GetLiveAllocationAt(*allocation_value.allocation_sequence(), - use_time); - AddAliasedRequiredAssignmentsForSequentialCall(use, - aliased_allocation); - // Remember the preferred offset to be used inside while loop body - // computations. - if (aliased_allocation->memory_space() == MemorySpace::kAlternate && - use.instruction->opcode() == HloOpcode::kWhile) { - preferred_offset_for_computation[use.instruction->while_body()] = - aliased_allocation->chunk().offset; - } - } - } - if (!allocation_success) { - break; - } - } - if (allocation_success) { - for (AllocationValue& allocation_value : allocation_values) { - for (auto& allocation : *allocation_value.allocation_sequence()) { - AppendAllocationInfoDebugString(interval, *allocation, - &allocation_info_str); - allocations_->push_back(std::move(allocation)); - } - } - } - - pending_chunks_.clear(); - pending_async_copies_.clear(); + AllocateColocatedIntervals(colocated_intervals); } VLOG(3) << "Debug buffer info: "; - VLOG(3) << buffer_info_str; + VLOG(3) << buffer_info_str_; VLOG(3) << "Debug allocation info: "; - VLOG(3) << allocation_info_str; - DumpIfEnabled(buffer_info_str, allocation_info_str); + VLOG(3) << allocation_info_str_; + DumpDebugStringsIfEnabled(); return result_; } +void AlternateMemoryBestFitHeap::AllocateColocatedIntervals( + const std::vector& + colocated_intervals) { + // Create AllocationValues for all the colocated intervals. + std::vector allocation_values; + for (const auto& colocated_interval : colocated_intervals) { + CreateAllocationValues(colocated_interval->buffer, &allocation_values); + } + const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); + + // Data structure to contain the preferred offset for a given computation. + // We ensure that the same offset will be allocated outside the while loop + // as well as inside the while loop. + absl::flat_hash_map + preferred_offset_for_computation; + + AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_); + + bool allocation_success = true; + for (auto& allocation_value : allocation_values) { + int64 definition_time = + instruction_schedule.at(allocation_value.defining_instruction()); + + absl::optional preferred_offset; + auto preferred_offset_it = + preferred_offset_for_computation.find(allocation_value.computation()); + if (preferred_offset_it != preferred_offset_for_computation.end()) { + preferred_offset = preferred_offset_it->second; + } + + // Iterate over the uses. + for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) { + const HloUse& use = allocation_value.uses().at(use_idx); + int64 use_time = instruction_schedule.at(use.instruction); + int64 latest_prefetch_time = use_time; + bool allow_no_copy_alternate_mem_allocation = true; + absl::optional earliest_prefetch_time = absl::nullopt; + + // Sequential calls include kWhile, kCall, and kConditional opcodes. + bool is_sequential_call = + (GetInstructionCallContext(use.instruction->opcode()) == + CallContext::kSequential); + if (is_sequential_call) { + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + const HloLiveRange::TimeBound& computation_span = + hlo_live_range_.computation_span_times().at(called_computation); + latest_prefetch_time = + std::min(computation_span.start, latest_prefetch_time); + } + if (use.instruction->opcode() == HloOpcode::kWhile) { + // Given an example while loop and flattened schedule (logical times + // shown on the left): + // + // 0: a = ... + // 1: ... + // cond { + // 2: p = param(0) + // 3: ... + // } + // body { + // 4: p = param(0) + // 5: ... + // 6: ROOT ... + // } + // 7: w = while(a), body=body, cond=cond + // + // When processing "a" (time 0) and its while use (time 7), we update + // the interval to time 0-4. This is so that the remaining interval + // (5-6) can be allocated separately and this buffer doesn't waste + // alternate memory space within the while loop body. + HloComputation* while_body = use.instruction->while_body(); + // We require while body ROOTs to be the last in the schedule. + CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1, + instruction_schedule.at(use.instruction)) + << "While body ROOTs need to be the last in the schedule! " + "Please run RootInstructionSinker."; + // Replace the use time with the parameter time so that we can decide + // on alternate memory allocations within the while loop body when we + // look at uses within the while loop body. + use_time = + instruction_schedule.at(while_body->parameter_instruction(0)); + } else if (use.instruction->opcode() == HloOpcode::kConditional) { + // Replace the use time with the earliest parameter of called + // computations. + for (const HloComputation* called_computation : + use.instruction->called_computations()) { + use_time = std::min( + use_time, instruction_schedule.at( + called_computation->parameter_instruction(0))); + } + } + } + + // Add a required assignment in default memory if the use not allowed in + // alternate memory. + if (!IsUseAllowedInAlternateMemory(allocation_value, use)) { + AddRequiredAssignment(allocation_value.value(), use.instruction, + MemorySpace::kDefault, use_time); + } else if (use_idx > 0) { + // We allow buffers in alternate memory that are passed into + // conditionals to give up their alternate memory allocation inside the + // called computation. This means that if a conditional operator has an + // alternate memory allocation, subsequent uses cannot use the same + // alternate memory allocation in order not to clobber data. So we force + // default memory allocation for these subsequent uses. + const HloUse& previous_use = allocation_value.uses().at(use_idx - 1); + if (previous_use.instruction->opcode() == HloOpcode::kConditional && + previous_use.instruction != use.instruction) { + allow_no_copy_alternate_mem_allocation = false; + earliest_prefetch_time = + instruction_schedule.at(previous_use.instruction); + VLOG(3) << "Previous use (" << previous_use.ToString() << ") of use (" + << use.ToString() + << ") is a conditional, so this use will need to evict. " + << "Earliest prefetch time = " << *earliest_prefetch_time; + } + } + + // Bitcasts don't define buffers and don't directly consume buffers. Skip + // allocating buffers for bitcast uses. The uses that feed from bitcasts + // will be handled specially. + if (use.instruction->opcode() != HloOpcode::kBitcast) { + AllocationRequest request; + // Rarely, (e.g., when conditional true and false parameters are the + // same), definition time can be the time of the conditional and use + // time is the parameter use, which is less. + request.start_time = std::min(definition_time, use_time); + request.end_time = use_time; + request.latest_prefetch_time = latest_prefetch_time; + request.size = colocated_intervals[0]->size; + request.allow_no_copy_alternate_mem_allocation = + allow_no_copy_alternate_mem_allocation; + request.earliest_prefetch_time = earliest_prefetch_time; + request.preferred_offset = preferred_offset; + request.use = use; + request.allocation_value = &allocation_value; + if (!AllocateSegment(request)) { + // If the allocation finding failed (e.g., due to running out of + // asynchronous copies), then fall back to allocating the buffer + // entirely in the default memory. + UncommitPendingChunks(); + allocation_success = false; + break; + } + + // If there are multiple uses, they can try using the memory allocation + // already at the alternate memory. + definition_time = instruction_schedule.at(use.instruction); + } + + // If the use has been a sequential call (e.g. a while loop), the other + // colocated intervals must alias with this allocation. + if (is_sequential_call) { + MemorySpaceAssignment::Allocation* aliased_allocation = + GetLiveAllocationAt(*allocation_value.allocation_sequence(), + use_time); + AddAliasedRequiredAssignmentsForSequentialCall(use, aliased_allocation); + // Remember the preferred offset to be used inside while loop body + // computations. + if (aliased_allocation->memory_space() == MemorySpace::kAlternate && + use.instruction->opcode() == HloOpcode::kWhile) { + preferred_offset_for_computation[use.instruction->while_body()] = + aliased_allocation->chunk().offset; + } + } + } + if (!allocation_success) { + break; + } + } + if (allocation_success) { + for (AllocationValue& allocation_value : allocation_values) { + for (auto& allocation : *allocation_value.allocation_sequence()) { + AppendAllocationInfoDebugString(*colocated_intervals[0], *allocation, + &allocation_info_str_); + allocations_->push_back(std::move(allocation)); + } + } + } + + pending_chunks_.clear(); + pending_async_copies_.clear(); +} + bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { return (a.start_time < b.start_time && a.end_time <= b.end_time) || (a.start_time <= b.start_time && a.end_time < b.end_time); @@ -1395,7 +1393,7 @@ AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer, return required_assignment_at_time; } -bool AlternateMemoryBestFitHeap::FindAllocation( +bool AlternateMemoryBestFitHeap::AllocateSegment( const AllocationRequest& request) { auto allocation_sequence = request.allocation_value->allocation_sequence(); // start_time == end_time is a special case where the value is consumed diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 7e839f5d25b..ec90b0f59e2 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -890,7 +890,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { void CreateAllocationValues(const HloValue* value, std::vector* allocation_values); - // Finds an allocation for the given interval. + // Finds allocations for colocated intervals. Colocated intervals consist of + // one or more BufferIntervals, each with a different HloValue. All of the + // intervals within colocated intervals have a must-alias relationship with + // each other. + void AllocateColocatedIntervals( + const std::vector& colocated_intervals); + + // Finds an allocation for an allocation request for a segment (see the + // documentation for AllocationRequest above how a segment is defined). // // It performs three things in the following order: // 1- Allocate the allocation request entirely in the alternate memory, if @@ -904,7 +912,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // false. This means we could not find a suitable allocation, so all previous // allocations for this buffer must be removed and allocated in the default // memory. Otherwise, this method returns true. - bool FindAllocation(const AllocationRequest& request); + bool AllocateSegment(const AllocationRequest& request); // Try allocating in alternate memory without any copies. Returns true if // successful. @@ -1000,8 +1008,7 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { const BufferInterval& interval, const MemorySpaceAssignment::Allocation& allocation, std::string* debug_str) const; - void DumpIfEnabled(absl::string_view buffer_info_str, - absl::string_view allocation_info_str) const; + void DumpDebugStringsIfEnabled() const; // Returns the available heap size in the alternate memory. int64 available_heap_size() const { @@ -1025,6 +1032,9 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; + // Debug strings. + std::string buffer_info_str_; + std::string allocation_info_str_; }; } // namespace xla