[XLA] Use latest to earliest order in prefetch picker.

This will make more efficient use of alternate memory by trying to avoid
prefetches that are unnecessarily early (and hence waste alternate memory).

PiperOrigin-RevId: 315982838
Change-Id: I6080a48661a5f032c0478b6d230b5b482840f2d4
This commit is contained in:
Berkin Ilbeyi 2020-06-11 14:50:25 -07:00 committed by TensorFlower Gardener
parent a65d79bb57
commit 580151fa26
2 changed files with 34 additions and 23 deletions

View File

@ -330,41 +330,48 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
*use.instruction, use.operand_number);
inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
end_logical_time_ = end_time;
// Find the earliest time we're allowed to start prefetching.
for (current_logical_prefetch_time_ = start_time;
current_logical_prefetch_time_ < end_logical_time_ &&
max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed_ <
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
end_logical_time_);
++current_logical_prefetch_time_) {
}
// If the first prefetch interval violates the min overlap (this can happen if
// there is an HLO that has a very long estimated execution time), we go
// earlier in time until we no longer violate the min overlap (we start
// prefetching before the HLO that has the very long estimated execution
// time).
while (Done() && current_logical_prefetch_time_ > start_time) {
--current_logical_prefetch_time_;
earliest_start_logical_time_ = start_time;
int end_nest_level = while_nest_level_[end_time];
// Find the latest time we're allowed to start prefetching. If the start and
// end nest levels differe look for an earlier prefetch start.
for (current_logical_prefetch_time_ = end_time - 1;
current_logical_prefetch_time_ > start_time &&
(while_nest_level_[current_logical_prefetch_time_] != end_nest_level ||
min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
end_logical_time_) +
inst_elapsed_reduction_);
--current_logical_prefetch_time_) {
}
}
int64 CostAnalysisPrefetchIntervalPicker::Next() {
CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
"Done() is false";
return current_logical_prefetch_time_++;
int64 prefetch_time = current_logical_prefetch_time_;
if (!Done()) {
--current_logical_prefetch_time_;
}
// If the prefetch start and end times differ, look for an earlier prefetch
// start.
while (!Done() && while_nest_level_[current_logical_prefetch_time_] !=
while_nest_level_[end_logical_time_]) {
--current_logical_prefetch_time_;
}
return prefetch_time;
}
bool CostAnalysisPrefetchIntervalPicker::Done() const {
// The end time is exclusive, so we're done if the prefetch time is greater
// than or equal to the end time.
if (current_logical_prefetch_time_ >= end_logical_time_) {
if (current_logical_prefetch_time_ < earliest_start_logical_time_) {
return true;
}
float logical_interval_elapsed = GetLogicalIntervalElapsed(
current_logical_prefetch_time_, end_logical_time_);
return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ >
logical_interval_elapsed + inst_elapsed_reduction_;
return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
async_copy_elapsed_ <
logical_interval_elapsed) ||
(min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
logical_interval_elapsed + inst_elapsed_reduction_);
}
void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
@ -390,6 +397,9 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
if (start_time == end_time) {
return 0.0;
}
if (start_time < 0) {
start_time = 0;
}
// Since elapsed_time_cumsum_ is already weighed by the while loop nesting
// level, normalize the elapsed time by dividing with the nesting factor of
// the interval (start and end times).
@ -1034,7 +1044,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
const HloLiveRange::TimeBound& computation_span =
hlo_live_range_.computation_span_times().at(called_computation);
latest_prefetch_time =
std::min(computation_span.start, latest_prefetch_time);
std::min(computation_span.start - 1, latest_prefetch_time);
}
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
// Given an example while loop and flattened schedule (logical times

View File

@ -315,6 +315,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
float async_copy_elapsed_;
float inst_elapsed_reduction_;
int64 end_logical_time_;
int64 earliest_start_logical_time_;
int64 current_logical_prefetch_time_;
};