[XLA] Implement an allocation retry mechanism.

Sometimes, memory space assignment can't prefetch due to limits on outstanding
async copies or due to async copy ordering. If that is the case, we can retry
allocating the value with larger max_async_copy_to_overlap_ratio limit.

PiperOrigin-RevId: 314826393
Change-Id: Id722976cb8ab82993c4a60407b77df9f8e7b080e
This commit is contained in:
Berkin Ilbeyi 2020-06-04 16:31:09 -07:00 committed by TensorFlower Gardener
parent f8f0e6bf3a
commit 7a7207f3b6
3 changed files with 126 additions and 33 deletions

View File

@ -59,6 +59,10 @@ class HeapSimulator {
int64 chunk_end() const { return offset + size; }
bool OverlapsWith(Chunk other_chunk) const;
bool operator==(const Chunk& other) const {
return offset == other.offset && size == other.size;
}
};
// Result represents the result of the heap simulation.

View File

@ -297,7 +297,8 @@ bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
float logical_interval_elapsed =
GetLogicalIntervalElapsed(start_time, end_time);
return max_async_copy_to_overlap_ratio_ * async_copy_elapsed >
return max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed >
logical_interval_elapsed;
}
@ -332,11 +333,20 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
// Find the earliest time we're allowed to start prefetching.
for (current_logical_prefetch_time_ = start_time;
current_logical_prefetch_time_ < end_logical_time_ &&
max_async_copy_to_overlap_ratio_ * async_copy_elapsed_ <
max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed_ <
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
end_logical_time_);
++current_logical_prefetch_time_) {
}
// If the first prefetch interval violates the min overlap (this can happen if
// there is an HLO that has a very long estimated execution time), we go
// earlier in time until we no longer violate the min overlap (we start
// prefetching before the HLO that has the very long estimated execution
// time).
while (Done() && current_logical_prefetch_time_ > start_time) {
--current_logical_prefetch_time_;
}
}
int64 CostAnalysisPrefetchIntervalPicker::Next() {
@ -357,6 +367,11 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const {
logical_interval_elapsed + inst_elapsed_reduction_;
}
void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
// Use twice as large max overlap limit in each retry.
max_overlap_multiplier_ = 1 << retry_number;
}
int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
int64 start_time, int64 end_time) const {
int min_nest_level =
@ -391,7 +406,9 @@ std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
return absl::StrCat(
"Async copy elapsed (s) = ", async_copy_elapsed_,
", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
", logical interval elapsed (s) = ", logical_interval_elapsed);
", logical interval elapsed (s) = ", logical_interval_elapsed,
", interval = (", current_logical_prefetch_time_, ", ", end_logical_time_,
")");
}
std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
@ -879,26 +896,19 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
// TODO(berkin): For now, place the phi values due to conditionals in
// default memory.
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
for (const auto& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kConditional) {
VLOG(3) << "Adding required assignment for condition output: "
<< value->ToShortString();
AddRequiredAssignment(position.instruction, position.index,
MemorySpace::kDefault);
for (const HloComputation* called_computation :
position.instruction->called_computations()) {
AddRequiredAssignment(called_computation->root_instruction(),
position.index, MemorySpace::kDefault);
}
}
}
}
AppendBufferInfoDebugString(interval, &buffer_info_str_);
AllocateColocatedIntervals(colocated_intervals);
// Retry allocating this value with larger limits if allocation fails.
for (int retry_number = 0; retry_number < options_.max_retries;
retry_number++) {
final_retry_ = (retry_number == options_.max_retries - 1);
options_.prefetch_interval_picker->SetRetryNumber(retry_number);
bool success = AllocateColocatedIntervals(colocated_intervals);
if (success) {
break;
}
VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
}
}
VLOG(3) << "Debug buffer info: ";
@ -910,9 +920,28 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
return result_;
}
void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
const std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>&
colocated_intervals) {
// TODO(berkin): For now, place the phi values due to conditionals in
// default memory.
for (const BufferInterval* colocated_interval : colocated_intervals) {
const HloValue* value = colocated_interval->buffer;
for (const auto& position : value->positions()) {
if (position.instruction->opcode() == HloOpcode::kConditional) {
VLOG(3) << "Adding required assignment for condition output: "
<< value->ToShortString();
AddRequiredAssignment(position.instruction, position.index,
MemorySpace::kDefault);
for (const HloComputation* called_computation :
position.instruction->called_computations()) {
AddRequiredAssignment(called_computation->root_instruction(),
position.index, MemorySpace::kDefault);
}
}
}
}
// Create AllocationValues for all the colocated intervals.
std::vector<AllocationValue> allocation_values;
for (const auto& colocated_interval : colocated_intervals) {
@ -926,8 +955,6 @@ void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
absl::flat_hash_map<const HloComputation*, int64>
preferred_offset_for_computation;
AppendBufferInfoDebugString(*colocated_intervals[0], &buffer_info_str_);
bool allocation_success = true;
for (auto& allocation_value : allocation_values) {
int64 definition_time =
@ -1093,6 +1120,8 @@ void AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
pending_chunks_.clear();
pending_async_copies_.clear();
pending_required_assignments_.clear();
return allocation_success;
}
bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
@ -1180,6 +1209,7 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
pending_chunks_.clear();
pending_async_copies_.clear();
pending_required_assignments_.clear();
}
void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignmentsForSequentialCall(
@ -1242,7 +1272,9 @@ void AlternateMemoryBestFitHeap::AddRequiredAssignment(
VLOG(3) << "Adding required assignment: " << value->ToShortString()
<< " at " << time << " at "
<< (memory_space == MemorySpace::kDefault ? "def" : "alt");
required_assignments_[value].push_back({memory_space, time, chunk});
RequiredMemoryAssignment required_assignment{memory_space, time, chunk};
required_assignments_[value].push_back(required_assignment);
pending_required_assignments_.push_back({value, required_assignment});
}
}
@ -1361,8 +1393,30 @@ void AlternateMemoryBestFitHeap::UncommitPendingChunks() {
kDummyChunk);
}
}
for (const auto& value_and_required_assignment :
pending_required_assignments_) {
auto& required_assignment_vector =
required_assignments_[value_and_required_assignment.first];
const RequiredMemoryAssignment& required_assignment =
value_and_required_assignment.second;
VLOG(3) << "Removing required assignment: "
<< (required_assignment.memory_space == MemorySpace::kDefault
? "def"
: "alt")
<< " time = " << required_assignment.time << " off = "
<< (required_assignment.chunk ? required_assignment.chunk->offset
: -1);
for (auto it = required_assignment_vector.begin();
it != required_assignment_vector.end(); ++it) {
if (*it == value_and_required_assignment.second) {
required_assignment_vector.erase(it);
break;
}
}
}
pending_chunks_.clear();
pending_async_copies_.clear();
pending_required_assignments_.clear();
}
void AlternateMemoryBestFitHeap::AddToPendingChunks(
@ -1507,6 +1561,12 @@ bool AlternateMemoryBestFitHeap::AllocateSegment(
if (Prefetch(request, **prev_allocation_in_default_mem_it)) {
return true;
}
if (!final_retry_ && prefetch_failed_due_to_async_copy_) {
// If prefetching failed due to asynchronous copy and we're not in our final
// try, return false (failure) so that we can retry this interval with
// larger limits.
return false;
}
// If the end assignment was required to be in alternate memory but that
// wasn't possible, then this allocation is invalid.
@ -1818,6 +1878,10 @@ bool AlternateMemoryBestFitHeap::Prefetch(
BufferInterval alternate_mem_interval;
alternate_mem_interval.buffer = request.allocation_value->value();
alternate_mem_interval.size = request.size;
// If any of the prefetch intervals couldn't be used due to number of
// outstanding async copy limit or async copy ordering, set
// prefetch_failed_due_to_async_copy_.
prefetch_failed_due_to_async_copy_ = false;
while (!options_.prefetch_interval_picker->Done()) {
alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
CHECK_LT(alternate_mem_interval.start, request.latest_prefetch_time);
@ -1825,15 +1889,17 @@ bool AlternateMemoryBestFitHeap::Prefetch(
<< alternate_mem_interval.start << ", " << request.end_time << ")";
// If this additional asynchronous copy would violate the limit, try a
// different interval.
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
request.latest_prefetch_time)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
prefetch_failed_due_to_async_copy_ = true;
continue;
}
if (ViolatesMaximumOutstandingAsyncCopies(alternate_mem_interval.start,
request.latest_prefetch_time,
/*is_prefetch=*/true)) {
VLOG(4) << "This would violate the outstanding async copy limit.";
continue;
}
if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
request.latest_prefetch_time)) {
VLOG(4) << "This would violate asynchronous copy ordering.";
prefetch_failed_due_to_async_copy_ = true;
continue;
}
@ -1857,6 +1923,7 @@ bool AlternateMemoryBestFitHeap::Prefetch(
request.allocation_value->allocation_sequence()->back()->AddUse(
request.use);
prefetch_failed_due_to_async_copy_ = false;
return true;
}
}

View File

@ -190,6 +190,10 @@ class PrefetchIntervalPicker {
// Returns true if the available prefetch intervals have been exhausted.
virtual bool Done() const = 0;
// The retry number can be used to modify the interval picking policies. The
// first attempt will have a retry_number of 0, then 1, etc.
virtual void SetRetryNumber(int retry_number) {}
// Returns a debug string for the current state of the prefetch interval
// picker.
virtual std::string ToDebugString() const = 0;
@ -276,6 +280,8 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
int64 Next() override;
bool Done() const override;
void SetRetryNumber(int retry_number) override;
std::string ToDebugString() const override;
std::string ToNoCopyDebugString(const Shape& shape, int64 start_time,
int64 end_time) const override;
@ -304,6 +310,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
const MemorySpaceAssignmentCostAnalysis& cost_analysis_;
float min_async_copy_to_overlap_ratio_;
float max_async_copy_to_overlap_ratio_;
float max_overlap_multiplier_ = 1.0;
float async_copy_elapsed_;
float inst_elapsed_reduction_;
@ -362,6 +369,11 @@ class MemorySpaceAssignment {
int64 max_outstanding_prefetches = -1;
int64 max_outstanding_evictions = -1;
// Specifies the maximum number of retries that will be performed for each
// value in case prefetching failed due to running out of asynchronous
// copies or asynchronous copy ordering.
int64 max_retries = 1;
// If true, tries allocating buffers across (e.g., before and inside a while
// loop body) sequential calls (kWhile, kCall, and kConditional).
bool allocate_across_sequential_calls = false;
@ -756,6 +768,11 @@ struct RequiredMemoryAssignment {
MemorySpaceAssignment::MemorySpace memory_space;
int64 time;
absl::optional<HeapSimulator::Chunk> chunk;
bool operator==(const RequiredMemoryAssignment& other) const {
return memory_space == other.memory_space && time == other.time &&
chunk == other.chunk;
}
};
// A struct representing an asynchronous copy with its logical start and end
@ -893,8 +910,8 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
// Finds allocations for colocated intervals. Colocated intervals consist of
// one or more BufferIntervals, each with a different HloValue. All of the
// intervals within colocated intervals have a must-alias relationship with
// each other.
void AllocateColocatedIntervals(
// each other. Returns true if allocation succeeded.
bool AllocateColocatedIntervals(
const std::vector<const BufferInterval*>& colocated_intervals);
// Finds an allocation for an allocation request for a segment (see the
@ -1026,12 +1043,17 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
AsynchronousCopyOrdering async_copy_ordering_;
std::vector<std::pair<BufferInterval, ChunkCandidate>> pending_chunks_;
std::vector<AsynchronousCopy> pending_async_copies_;
std::vector<std::pair<const HloValue*, RequiredMemoryAssignment>>
pending_required_assignments_;
// This map contains required memory assignments for HloValues (e.g., input
// and outputs).
absl::flat_hash_map<const HloValue*, std::vector<RequiredMemoryAssignment>>
required_assignments_;
// Number of bytes reserved in alternate memory space.
int64 reserved_in_bytes_ = 0;
// Variables to control allocation retries.
bool final_retry_;
bool prefetch_failed_due_to_async_copy_;
// Debug strings.
std::string buffer_info_str_;
std::string allocation_info_str_;