From 580151fa26419ae583ec42cc6cbb92777e214109 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Thu, 11 Jun 2020 14:50:25 -0700 Subject: [PATCH] [XLA] Use latest to earliest order in prefetch picker. This will make more efficient use of alternate memory by trying to avoid prefetches that are unnecessarily early (and hence waste alternate memory). PiperOrigin-RevId: 315982838 Change-Id: I6080a48661a5f032c0478b6d230b5b482840f2d4 --- .../xla/service/memory_space_assignment.cc | 56 +++++++++++-------- .../xla/service/memory_space_assignment.h | 1 + 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 5803e21b277..4a12800b594 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -330,41 +330,48 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, *use.instruction, use.operand_number); inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; end_logical_time_ = end_time; - // Find the earliest time we're allowed to start prefetching. - for (current_logical_prefetch_time_ = start_time; - current_logical_prefetch_time_ < end_logical_time_ && - max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * - async_copy_elapsed_ < - GetLogicalIntervalElapsed(current_logical_prefetch_time_, - end_logical_time_); - ++current_logical_prefetch_time_) { - } - // If the first prefetch interval violates the min overlap (this can happen if - // there is an HLO that has a very long estimated execution time), we go - // earlier in time until we no longer violate the min overlap (we start - // prefetching before the HLO that has the very long estimated execution - // time). - while (Done() && current_logical_prefetch_time_ > start_time) { - --current_logical_prefetch_time_; + earliest_start_logical_time_ = start_time; + int end_nest_level = while_nest_level_[end_time]; + // Find the latest time we're allowed to start prefetching. If the start and + // end nest levels differe look for an earlier prefetch start. + for (current_logical_prefetch_time_ = end_time - 1; + current_logical_prefetch_time_ > start_time && + (while_nest_level_[current_logical_prefetch_time_] != end_nest_level || + min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > + GetLogicalIntervalElapsed(current_logical_prefetch_time_, + end_logical_time_) + + inst_elapsed_reduction_); + --current_logical_prefetch_time_) { } } int64 CostAnalysisPrefetchIntervalPicker::Next() { CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " "Done() is false"; - return current_logical_prefetch_time_++; + int64 prefetch_time = current_logical_prefetch_time_; + if (!Done()) { + --current_logical_prefetch_time_; + } + // If the prefetch start and end times differ, look for an earlier prefetch + // start. + while (!Done() && while_nest_level_[current_logical_prefetch_time_] != + while_nest_level_[end_logical_time_]) { + --current_logical_prefetch_time_; + } + return prefetch_time; } bool CostAnalysisPrefetchIntervalPicker::Done() const { - // The end time is exclusive, so we're done if the prefetch time is greater - // than or equal to the end time. - if (current_logical_prefetch_time_ >= end_logical_time_) { + if (current_logical_prefetch_time_ < earliest_start_logical_time_) { return true; } float logical_interval_elapsed = GetLogicalIntervalElapsed( current_logical_prefetch_time_, end_logical_time_); - return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ > - logical_interval_elapsed + inst_elapsed_reduction_; + return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * + async_copy_elapsed_ < + logical_interval_elapsed) || + (min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > + logical_interval_elapsed + inst_elapsed_reduction_); } void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { @@ -390,6 +397,9 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( if (start_time == end_time) { return 0.0; } + if (start_time < 0) { + start_time = 0; + } // Since elapsed_time_cumsum_ is already weighed by the while loop nesting // level, normalize the elapsed time by dividing with the nesting factor of // the interval (start and end times). @@ -1034,7 +1044,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( const HloLiveRange::TimeBound& computation_span = hlo_live_range_.computation_span_times().at(called_computation); latest_prefetch_time = - std::min(computation_span.start, latest_prefetch_time); + std::min(computation_span.start - 1, latest_prefetch_time); } if (hlo_use.instruction->opcode() == HloOpcode::kWhile) { // Given an example while loop and flattened schedule (logical times diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index d87908a6270..b8f47e73b8c 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -315,6 +315,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { float async_copy_elapsed_; float inst_elapsed_reduction_; int64 end_logical_time_; + int64 earliest_start_logical_time_; int64 current_logical_prefetch_time_; };