diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 5803e21b277..4a12800b594 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -330,41 +330,48 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, *use.instruction, use.operand_number); inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem; end_logical_time_ = end_time; - // Find the earliest time we're allowed to start prefetching. - for (current_logical_prefetch_time_ = start_time; - current_logical_prefetch_time_ < end_logical_time_ && - max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * - async_copy_elapsed_ < - GetLogicalIntervalElapsed(current_logical_prefetch_time_, - end_logical_time_); - ++current_logical_prefetch_time_) { - } - // If the first prefetch interval violates the min overlap (this can happen if - // there is an HLO that has a very long estimated execution time), we go - // earlier in time until we no longer violate the min overlap (we start - // prefetching before the HLO that has the very long estimated execution - // time). - while (Done() && current_logical_prefetch_time_ > start_time) { - --current_logical_prefetch_time_; + earliest_start_logical_time_ = start_time; + int end_nest_level = while_nest_level_[end_time]; + // Find the latest time we're allowed to start prefetching. If the start and + // end nest levels differe look for an earlier prefetch start. + for (current_logical_prefetch_time_ = end_time - 1; + current_logical_prefetch_time_ > start_time && + (while_nest_level_[current_logical_prefetch_time_] != end_nest_level || + min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > + GetLogicalIntervalElapsed(current_logical_prefetch_time_, + end_logical_time_) + + inst_elapsed_reduction_); + --current_logical_prefetch_time_) { } } int64 CostAnalysisPrefetchIntervalPicker::Next() { CHECK(!Done()) << "Prefetch interval picker's Next() is called even though " "Done() is false"; - return current_logical_prefetch_time_++; + int64 prefetch_time = current_logical_prefetch_time_; + if (!Done()) { + --current_logical_prefetch_time_; + } + // If the prefetch start and end times differ, look for an earlier prefetch + // start. + while (!Done() && while_nest_level_[current_logical_prefetch_time_] != + while_nest_level_[end_logical_time_]) { + --current_logical_prefetch_time_; + } + return prefetch_time; } bool CostAnalysisPrefetchIntervalPicker::Done() const { - // The end time is exclusive, so we're done if the prefetch time is greater - // than or equal to the end time. - if (current_logical_prefetch_time_ >= end_logical_time_) { + if (current_logical_prefetch_time_ < earliest_start_logical_time_) { return true; } float logical_interval_elapsed = GetLogicalIntervalElapsed( current_logical_prefetch_time_, end_logical_time_); - return async_copy_elapsed_ * min_async_copy_to_overlap_ratio_ > - logical_interval_elapsed + inst_elapsed_reduction_; + return (max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ * + async_copy_elapsed_ < + logical_interval_elapsed) || + (min_async_copy_to_overlap_ratio_ * async_copy_elapsed_ > + logical_interval_elapsed + inst_elapsed_reduction_); } void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) { @@ -390,6 +397,9 @@ float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( if (start_time == end_time) { return 0.0; } + if (start_time < 0) { + start_time = 0; + } // Since elapsed_time_cumsum_ is already weighed by the while loop nesting // level, normalize the elapsed time by dividing with the nesting factor of // the interval (start and end times). @@ -1034,7 +1044,7 @@ bool AlternateMemoryBestFitHeap::AllocateColocatedIntervals( const HloLiveRange::TimeBound& computation_span = hlo_live_range_.computation_span_times().at(called_computation); latest_prefetch_time = - std::min(computation_span.start, latest_prefetch_time); + std::min(computation_span.start - 1, latest_prefetch_time); } if (hlo_use.instruction->opcode() == HloOpcode::kWhile) { // Given an example while loop and flattened schedule (logical times diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index d87908a6270..b8f47e73b8c 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -315,6 +315,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { float async_copy_elapsed_; float inst_elapsed_reduction_; int64 end_logical_time_; + int64 earliest_start_logical_time_; int64 current_logical_prefetch_time_; };