[XLA] Implement an allocation retry mechanism.

Sometimes, memory space assignment can't prefetch due to limits on outstanding
async copies or due to async copy ordering. If that is the case, we can retry
allocating the value with larger max_async_copy_to_overlap_ratio limit.

PiperOrigin-RevId: 314826393
Change-Id: Id722976cb8ab82993c4a60407b77df9f8e7b080e
This commit is contained in:
Berkin Ilbeyi 2020-06-04 16:31:09 -07:00 committed by TensorFlower Gardener
parent f8f0e6bf3a
commit 7a7207f3b6
3 changed files with 126 additions and 33 deletions

View File

@ -59,6 +59,10 @@ class HeapSimulator {
int64 chunk_end() const { return offset + size; } int64 chunk_end() const { return offset + size; }
bool OverlapsWith(Chunk other_chunk) const; bool OverlapsWith(Chunk other_chunk) const;
bool operator==(const Chunk& other) const {
return offset == other.offset && size == other.size;
}
}; };
// Result represents the result of the heap simulation. // Result represents the result of the heap simulation.

View File

@ -297,7 +297,8 @@ bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
float logical_interval_elapsed = float logical_interval_elapsed =
GetLogicalIntervalElapsed(start_time, end_time); GetLogicalIntervalElapsed(start_time, end_time);
return max_async_copy_to_overlap_ratio_ * async_copy_elapsed > return max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed >
logical_interval_elapsed; logical_interval_elapsed;
} }
@ -332,11 +333,20 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
// Find the earliest time we're allowed to start prefetching. // Find the earliest time we're allowed to start prefetching.
for (current_logical_prefetch_time_ = start_time; for (current_logical_prefetch_time_ = start_time;
current_logical_prefetch_time_ < end_logical_time_ && current_logical_prefetch_time_ < end_logical_time_ &&
max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ < max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed_ <
GetLogicalIntervalElapsed(current_logical_prefetch_time_, GetLogicalIntervalElapsed(current_logical_prefetch_time_,
end_logical_time_); end_logical_time_);
++current_logical_prefetch_time_) { ++current_logical_prefetch_time_) {
} }
// If the first prefetch interval violates the min overlap (this can happen if
// there is an HLO that has a very long estimated execution time), we go
// earlier in time until we no longer violate the min overlap (we start
// prefetching before the HLO that has the very long estimated execution
// time).
while (Done() && current_logical_prefetch_time_ > start_time) {
--current_logical_prefetch_time_;
}
} }
int64 CostAnalysisPrefetchIntervalPicker::Next() { int64 CostAnalysisPrefetchIntervalPicker::Next() {
@ -357,6 +367,11 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
logical_interval_elapsed + inst_elapsed_reduction_; logical_interval_elapsed + inst_elapsed_reduction_;
} }
void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
// Use twice as large max overlap limit in each retry.
max_overlap_multiplier_ = 1 << retry_number;
}
int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
int64 start_time, int64 end_time) const { int64 start_time, int64 end_time) const {
int min_nest_level = int min_nest_level =
@ -391,7 +406,9 @@ std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
return absl::StrCat( return absl::StrCat(
"Async copy elapsed (s) = ", async_copy_elapsed_, "Async copy elapsed (s) = ", async_copy_elapsed_,
", inst elapsed reduction (s) = ", inst_elapsed_reduction_, ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
", logical interval elapsed (s) = ", logical_interval_elapsed); ", logical interval elapsed (s) = ", logical_interval_elapsed,
", interval = (", current_logical_prefetch_time_, ", ", end_logical_time_,
")");
} }
std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
@ -879,26 +896,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue; continue;
} }
// TODO(berkin): For now, place the phi values due to conditionals in AppendBufferInfoDebugString(interval, &buffer_info_str_);
// default memory.
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
for (const auto& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kConditional) {
VLOG(3) << "Adding required assignment for condition output: "
<< value->ToShortString();
AddRequiredAssignment(position.instruction, position.index,
MemorySpace::kDefault);
for (const HloComputation* called_computation :
position.instruction->called_computations()) {
AddRequiredAssignment(called_computation->root_instruction(),
position.index, MemorySpace::kDefault);
}
}
}
}
AllocateColocatedIntervals(colocated_intervals); // Retry allocating this value with larger limits if allocation fails.
for (int retry_number = 0; retry_number < options_.max_retries;
retry_number++) {
final_retry_ = (retry_number == options_.max_retries - 1);
options_.prefetch_interval_picker->SetRetryNumber(retry_number);
bool success = AllocateColocatedIntervals(colocated_intervals);
if (success) {
break;
}
VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
}
} }
VLOG(3) << "Debug buffer info: "; VLOG(3) << "Debug buffer info: ";
@ -910,9 +920,28 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
return result_; return result_;
} }
void AlternateMemoryBestFitHeap::AllocateColocatedIntervals( bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>& const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>&
colocated_intervals) { colocated_intervals) {
// TODO(berkin): For now, place the phi values due to conditionals in
// default memory.
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
for (const auto& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kConditional) {
VLOG(3) << "Adding required assignment for condition output: "
<< value->ToShortString();
AddRequiredAssignment(position.instruction, position.index,
MemorySpace::kDefault);
for (const HloComputation* called_computation :
position.instruction->called_computations()) {
AddRequiredAssignment(called_computation->root_instruction(),
position.index, MemorySpace::kDefault);
}
}
}
}
// Create AllocationValues for all the colocated intervals. // Create AllocationValues for all the colocated intervals.
std::vector<AllocationValue> allocation_values; std::vector<AllocationValue> allocation_values;
for (const auto& colocated_interval : colocated_intervals) { for (const auto& colocated_interval : colocated_intervals) {
@ -926,8 +955,6 @@ void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
absl::flat_hash_map<const HloComputation*, int64> absl::flat_hash_map<const HloComputation*, int64>
preferred_offset_for_computation; preferred_offset_for_computation;
AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_);
bool allocation_success = true; bool allocation_success = true;
for (auto& allocation_value : allocation_values) { for (auto& allocation_value : allocation_values) {
int64 definition_time = int64 definition_time =
@ -1093,6 +1120,8 @@ void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
pending_chunks_.clear(); pending_chunks_.clear();
pending_async_copies_.clear(); pending_async_copies_.clear();
pending_required_assignments_.clear();
return allocation_success;
} }
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) { bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
@ -1180,6 +1209,7 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
pending_chunks_.clear(); pending_chunks_.clear();
pending_async_copies_.clear(); pending_async_copies_.clear();
pending_required_assignments_.clear();
} }
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall( void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
@ -1242,7 +1272,9 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment(
VLOG(3) << "Adding required assignment: " << value->ToShortString() VLOG(3) << "Adding required assignment: " << value->ToShortString()
<< " at " << time << " at " << " at " << time << " at "
<< (memory_space == MemorySpace::kDefault ? "def" : "alt"); << (memory_space == MemorySpace::kDefault ? "def" : "alt");
required_assignments_[value].push_back({memory_space, time, chunk}); RequiredMemoryAssignment required_assignment{memory_space, time, chunk};
required_assignments_[value].push_back(required_assignment);
pending_required_assignments_.push_back({value, required_assignment});
} }
} }
@ -1361,8 +1393,30 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() {
kDummyChunk); kDummyChunk);
} }
} }
for (const auto& value_and_required_assignment :
pending_required_assignments_) {
auto& required_assignment_vector =
required_assignments_[value_and_required_assignment.first];
const RequiredMemoryAssignment& required_assignment =
value_and_required_assignment.second;
VLOG(3) << "Removing required assignment: "
<< (required_assignment.memory_space == MemorySpace::kDefault
? "def"
: "alt")
<< " time = " << required_assignment.time << " off = "
<< (required_assignment.chunk ? required_assignment.chunk->offset
: -1);
for (auto it = required_assignment_vector.begin();
it != required_assignment_vector.end(); ++it) {
if (*it == value_and_required_assignment.second) {
required_assignment_vector.erase(it);
break;
}
}
}
pending_chunks_.clear(); pending_chunks_.clear();
pending_async_copies_.clear(); pending_async_copies_.clear();
pending_required_assignments_.clear();
} }
void AlternateMemoryBestFitHeap::AddToPendingChunks( void AlternateMemoryBestFitHeap::AddToPendingChunks(
@ -1507,6 +1561,12 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
if (Prefetch(request, **prev_allocation_in_default_mem_it)) { if (Prefetch(request, **prev_allocation_in_default_mem_it)) {
return true; return true;
} }
if (!final_retry_ && prefetch_failed_due_to_async_copy_) {
// If prefetching failed due to asynchronous copy and we're not in our final
// try, return false (failure) so that we can retry this interval with
// larger limits.
return false;
}
// If the end assignment was required to be in alternate memory but that // If the end assignment was required to be in alternate memory but that
// wasn't possible, then this allocation is invalid. // wasn't possible, then this allocation is invalid.
@ -1818,6 +1878,10 @@ bool AlternateMemoryBestFitHeap::Prefetch(
BufferInterval alternate_mem_interval; BufferInterval alternate_mem_interval;
alternate_mem_interval.buffer = request.allocation_value->value(); alternate_mem_interval.buffer = request.allocation_value->value();
alternate_mem_interval.size = request.size; alternate_mem_interval.size = request.size;
// If any of the prefetch intervals couldn't be used due to number of
// outstanding async copy limit or async copy ordering, set
// prefetch_failed_due_to_async_copy_.
prefetch_failed_due_to_async_copy_ = false;
while (!options_.prefetch_interval_picker->Done()) { while (!options_.prefetch_interval_picker->Done()) {
alternate_mem_interval.start = options_.prefetch_interval_picker->Next(); alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time); CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
@ -1825,15 +1889,17 @@ bool AlternateMemoryBestFitHeap::Prefetch(
<< alternate_mem_interval.start << ", " << request.end_time << ")"; << alternate_mem_interval.start << ", " << request.end_time << ")";
// If this additional asynchronous copy would violate the limit, try a // If this additional asynchronous copy would violate the limit, try a
// different interval. // different interval.
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
request.latest_prefetch_time)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
prefetch_failed_due_to_async_copy_ = true;
continue;
}
if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start, if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
request.latest_prefetch_time, request.latest_prefetch_time,
/*is_prefetch=*/true)) { /*is_prefetch=*/true)) {
VLOG(4) << "This would violate the outstanding async copy limit."; VLOG(4) << "This would violate the outstanding async copy limit.";
continue; prefetch_failed_due_to_async_copy_ = true;
}
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
request.latest_prefetch_time)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
continue; continue;
} }
@ -1857,6 +1923,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
request.allocation_value->allocation_sequence()->back()->AddUse( request.allocation_value->allocation_sequence()->back()->AddUse(
request.use); request.use);
prefetch_failed_due_to_async_copy_ = false;
return true; return true;
} }
} }

View File

@ -190,6 +190,10 @@ class PrefetchIntervalPicker {
// Returns true if the available prefetch intervals have been exhausted. // Returns true if the available prefetch intervals have been exhausted.
virtual bool Done() const = 0; virtual bool Done() const = 0;
// The retry number can be used to modify the interval picking policies. The
// first attempt will have a retry_number of 0, then 1, etc.
virtual void SetRetryNumber(int retry_number) {}
// Returns a debug string for the current state of the prefetch interval // Returns a debug string for the current state of the prefetch interval
// picker. // picker.
virtual std::string ToDebugString() const = 0; virtual std::string ToDebugString() const = 0;
@ -276,6 +280,8 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
int64 Next() override; int64 Next() override;
bool Done() const override; bool Done() const override;
void SetRetryNumber(int retry_number) override;
std::string ToDebugString() const override; std::string ToDebugString() const override;
std::string ToNoCopyDebugString(const Shape& shape, int64 start_time, std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
int64 end_time) const override; int64 end_time) const override;
@ -304,6 +310,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
const MemorySpaceAssignmentCostAnalysis& cost_analysis_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_; float min_async_copy_to_overlap_ratio_;
float max_async_copy_to_overlap_ratio_; float max_async_copy_to_overlap_ratio_;
float max_overlap_multiplier_ = 1.0;
float async_copy_elapsed_; float async_copy_elapsed_;
float inst_elapsed_reduction_; float inst_elapsed_reduction_;
@ -362,6 +369,11 @@ class MemorySpaceAssignment {
int64 max_outstanding_prefetches = -1; int64 max_outstanding_prefetches = -1;
int64 max_outstanding_evictions = -1; int64 max_outstanding_evictions = -1;
// Specifies the maximum number of retries that will be performed for each
// value in case prefetching failed due to running out of asynchronous
// copies or asynchronous copy ordering.
int64 max_retries = 1;
// If true, tries allocating buffers across (e.g., before and inside a while // If true, tries allocating buffers across (e.g., before and inside a while
// loop body) sequential calls (kWhile, kCall, and kConditional). // loop body) sequential calls (kWhile, kCall, and kConditional).
bool allocate_across_sequential_calls = false; bool allocate_across_sequential_calls = false;
@ -756,6 +768,11 @@ struct RequiredMemoryAssignment {
MemorySpaceAssignment::MemorySpace memory_space; MemorySpaceAssignment::MemorySpace memory_space;
int64 time; int64 time;
absl::optional<HeapSimulator::Chunk> chunk; absl::optional<HeapSimulator::Chunk> chunk;
bool operator==(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && time == other.time &&
chunk == other.chunk;
}
}; };
// A struct representing an asynchronous copy with its logical start and end // A struct representing an asynchronous copy with its logical start and end
@ -893,8 +910,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// Finds allocations for colocated intervals. Colocated intervals consist of // Finds allocations for colocated intervals. Colocated intervals consist of
// one or more BufferIntervals, each with a different HloValue. All of the // one or more BufferIntervals, each with a different HloValue. All of the
// intervals within colocated intervals have a must-alias relationship with // intervals within colocated intervals have a must-alias relationship with
// each other. // each other. Returns true if allocation succeeded.
void AllocateColocatedIntervals( bool AllocateColocatedIntervals(
const std::vector<const BufferInterval*>& colocated_intervals); const std::vector<const BufferInterval*>& colocated_intervals);
// Finds an allocation for an allocation request for a segment (see the // Finds an allocation for an allocation request for a segment (see the
@ -1026,12 +1043,17 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
AsynchronousCopyOrdering async_copy_ordering_; AsynchronousCopyOrdering async_copy_ordering_;
std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_; std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
std::vector<AsynchronousCopy> pending_async_copies_; std::vector<AsynchronousCopy> pending_async_copies_;
std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>>
pending_required_assignments_;
// This map contains required memory assignments for HloValues (e.g., input // This map contains required memory assignments for HloValues (e.g., input
// and outputs). // and outputs).
absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>> absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
required_assignments_; required_assignments_;
// Number of bytes reserved in alternate memory space. // Number of bytes reserved in alternate memory space.
int64 reserved_in_bytes_ = 0; int64 reserved_in_bytes_ = 0;
// Variables to control allocation retries.
bool final_retry_;
bool prefetch_failed_due_to_async_copy_;
// Debug strings. // Debug strings.
std::string buffer_info_str_; std::string buffer_info_str_;
std::string allocation_info_str_; std::string allocation_info_str_;