[XLA] Use latest to earliest order in prefetch picker.
This will make more efficient use of alternate memory by trying to avoid prefetches that are unnecessarily early (and hence waste alternate memory). PiperOrigin-RevId: 315982838 Change-Id: I6080a48661a5f032c0478b6d230b5b482840f2d4
This commit is contained in:
parent
a65d79bb57
commit
580151fa26
@ -330,41 +330,48 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
|
||||
*use.instruction, use.operand_number);
|
||||
inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
|
||||
end_logical_time_ = end_time;
|
||||
// Find the earliest time we're allowed to start prefetching.
|
||||
for (current_logical_prefetch_time_ = start_time;
|
||||
current_logical_prefetch_time_ < end_logical_time_ &&
|
||||
max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
|
||||
async_copy_elapsed_ <
|
||||
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
|
||||
end_logical_time_);
|
||||
++current_logical_prefetch_time_) {
|
||||
}
|
||||
// If the first prefetch interval violates the min overlap (this can happen if
|
||||
// there is an HLO that has a very long estimated execution time), we go
|
||||
// earlier in time until we no longer violate the min overlap (we start
|
||||
// prefetching before the HLO that has the very long estimated execution
|
||||
// time).
|
||||
while (Done() && current_logical_prefetch_time_ > start_time) {
|
||||
--current_logical_prefetch_time_;
|
||||
earliest_start_logical_time_ = start_time;
|
||||
int end_nest_level = while_nest_level_[end_time];
|
||||
// Find the latest time we're allowed to start prefetching. If the start and
|
||||
// end nest levels differe look for an earlier prefetch start.
|
||||
for (current_logical_prefetch_time_ = end_time - 1;
|
||||
current_logical_prefetch_time_ > start_time &&
|
||||
(while_nest_level_[current_logical_prefetch_time_] != end_nest_level ||
|
||||
min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
|
||||
GetLogicalIntervalElapsed(current_logical_prefetch_time_,
|
||||
end_logical_time_) +
|
||||
inst_elapsed_reduction_);
|
||||
--current_logical_prefetch_time_) {
|
||||
}
|
||||
}
|
||||
|
||||
int64 CostAnalysisPrefetchIntervalPicker::Next() {
|
||||
CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
|
||||
"Done() is false";
|
||||
return current_logical_prefetch_time_++;
|
||||
int64 prefetch_time = current_logical_prefetch_time_;
|
||||
if (!Done()) {
|
||||
--current_logical_prefetch_time_;
|
||||
}
|
||||
// If the prefetch start and end times differ, look for an earlier prefetch
|
||||
// start.
|
||||
while (!Done() && while_nest_level_[current_logical_prefetch_time_] !=
|
||||
while_nest_level_[end_logical_time_]) {
|
||||
--current_logical_prefetch_time_;
|
||||
}
|
||||
return prefetch_time;
|
||||
}
|
||||
|
||||
bool CostAnalysisPrefetchIntervalPicker::Done() const {
|
||||
// The end time is exclusive, so we're done if the prefetch time is greater
|
||||
// than or equal to the end time.
|
||||
if (current_logical_prefetch_time_ >= end_logical_time_) {
|
||||
if (current_logical_prefetch_time_ < earliest_start_logical_time_) {
|
||||
return true;
|
||||
}
|
||||
float logical_interval_elapsed = GetLogicalIntervalElapsed(
|
||||
current_logical_prefetch_time_, end_logical_time_);
|
||||
return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ >
|
||||
logical_interval_elapsed + inst_elapsed_reduction_;
|
||||
return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
|
||||
async_copy_elapsed_ <
|
||||
logical_interval_elapsed) ||
|
||||
(min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ >
|
||||
logical_interval_elapsed + inst_elapsed_reduction_);
|
||||
}
|
||||
|
||||
void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
|
||||
@ -390,6 +397,9 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
|
||||
if (start_time == end_time) {
|
||||
return 0.0;
|
||||
}
|
||||
if (start_time < 0) {
|
||||
start_time = 0;
|
||||
}
|
||||
// Since elapsed_time_cumsum_ is already weighed by the while loop nesting
|
||||
// level, normalize the elapsed time by dividing with the nesting factor of
|
||||
// the interval (start and end times).
|
||||
@ -1034,7 +1044,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals(
|
||||
const HloLiveRange::TimeBound& computation_span =
|
||||
hlo_live_range_.computation_span_times().at(called_computation);
|
||||
latest_prefetch_time =
|
||||
std::min(computation_span.start, latest_prefetch_time);
|
||||
std::min(computation_span.start - 1, latest_prefetch_time);
|
||||
}
|
||||
if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
|
||||
// Given an example while loop and flattened schedule (logical times
|
||||
|
@ -315,6 +315,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker {
|
||||
float async_copy_elapsed_;
|
||||
float inst_elapsed_reduction_;
|
||||
int64 end_logical_time_;
|
||||
int64 earliest_start_logical_time_;
|
||||
int64 current_logical_prefetch_time_;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user