diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 126b62a8eb2..a8f20827c6d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3284,6 +3284,7 @@ cc_library( ":heap_simulator", ":hlo_cost_analysis", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/core/lib/math:math_util", ], ) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 81a8a102402..44509395b6f 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -16,53 +16,52 @@ limitations under the License. #include "tensorflow/compiler/xla/service/memory_space_assignment.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/core/lib/math/math_util.h" namespace xla { namespace { // Define a dummy chunk for chunks that will be allocated in the default memory // space and for keeping track of number of asynchronous copies. const HeapSimulator::Chunk kDummyChunk{-1, -1}; +// This variable is used by the cost analysis in estimating how many times each +// while loop will execute. Nested loops will be assumed to have executed +// pow(kWhileExecutionCount, nesting_level) times. +const int kWhileExecutionCount = 5; -// Returns a heuristic value that captures how much putting this tensor to -// the alternate memory would help if the op is memory bound, or otherwise -// how far off is the op to memory boundedness. The larger this number, the -// higher priority it will be placed in the alternate memory. -float GetAlternateMemoryBenefit( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, +} // namespace + +float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) { + float elapsed_time_due_to_alternate_mem) const { float elapsed_time_due_to_compute = - cost_analysis.GetInstructionElapsedDueToCompute(instruction); + GetInstructionElapsedDueToCompute(instruction); float elapsed_time_due_to_memory = - cost_analysis.GetInstructionElapsedDueToMemory(instruction); + GetInstructionElapsedDueToMemory(instruction); if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { // Memory bound, return how much alternate memory is better. - return elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem; + int while_nest_level = CalculateWhileLoopNestLevel(&instruction); + return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * + tensorflow::MathUtil::IPow(kWhileExecutionCount, + while_nest_level); } else { // Compute bound, return how far off are we to memory boundedness. return elapsed_time_due_to_memory - elapsed_time_due_to_compute; } } -// Returns a heuristic value of memory boundedness for the given BufferInterval. -// The larger this number, the higher priority it will be placed in the -// alternate memory. -float GetMemoryBoundedness( - const MemorySpaceAssignmentCostAnalysis& cost_analysis, - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); - float alternate_mem_benefit = - GetAlternateMemoryBenefit(cost_analysis, defining_instruction, - cost_analysis.GetInstructionElapsedDueToMemory( - defining_instruction, - /*operand_in_alternate_mem=*/{}, - /*output_in_alternate_mem=*/true)); + float alternate_mem_benefit = GetAlternateMemoryBenefit( + defining_instruction, + GetInstructionElapsedDueToMemory(defining_instruction, + /*operand_in_alternate_mem=*/{}, + /*output_in_alternate_mem=*/true)); for (const HloUse& use : interval.buffer->uses()) { float use_alternate_mem_benefit = GetAlternateMemoryBenefit( - cost_analysis, *use.instruction, - cost_analysis.GetInstructionElapsedDueToMemory(*use.instruction, - use.operand_number)); + *use.instruction, + GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); // If the benefit is positive (memory bound), add it to this buffer's // benefit. If the benefit is negative (compute bound), calculate the // maximum. @@ -77,7 +76,7 @@ float GetMemoryBoundedness( // Get performance slowdown in seconds of prefetching current BufferInterval // causing to other BufferIntervals. float alternate_mem_slowdown = - cost_analysis.GetInstructionElapsedDueToMemorySlowdown(interval.size); + GetInstructionElapsedDueToMemorySlowdown(interval.size); // Scale the slowdown based on the time of this buffer. We would want earlier // buffers have lower slowdown values, because they are less likely to overlap @@ -86,13 +85,28 @@ float GetMemoryBoundedness( // for early HLOs, and full slowdown for mid-to-late HLOs. // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with // more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / cost_analysis.GetScheduleEndTime(); + float scale = interval.start * 1.0 / GetScheduleEndTime(); alternate_mem_slowdown *= scale; return alternate_mem_benefit - alternate_mem_slowdown; } -} // namespace +int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( + const HloInstruction* instruction) const { + int nest_level = 0; + const HloComputation* computation = instruction->parent(); + while (!computation->IsEntryComputation()) { + auto node = call_graph_.GetNode(computation); + auto callsites = node.caller_callsites(); + CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; + auto callsite = callsites[0]; + if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + ++nest_level; + } + computation = callsite.instruction()->parent(); + } + return nest_level; +} float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const { @@ -207,29 +221,30 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, float max_async_copy_to_overlap_ratio) - : cost_analysis_(cost_analysis), + : elapsed_time_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0), + while_nest_level_( + cost_analysis.hlo_live_range().instruction_schedule().size(), 0), + cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio) { instruction_schedule_ = &cost_analysis_.hlo_live_range().instruction_schedule(); - // First create a vector of elapsed times of HLO instructions. - std::vector instructions_elapsed_time(instruction_schedule_->size(), - 0.0); + // Create a vector of elapsed times and while nesting levels of HLO + // instructions. for (const auto& instruction_and_logical_time : *instruction_schedule_) { float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; - if (logical_time >= instructions_elapsed_time.size()) { - instructions_elapsed_time.resize(logical_time + 1, 0.0); + if (logical_time >= elapsed_time_.size()) { + elapsed_time_.resize(logical_time + 1, 0.0); + while_nest_level_.resize(logical_time + 1, 0); } - instructions_elapsed_time[logical_time] = elapsed_time; - } - // As an optimization, create a cumulative sum vector of elapsed time. - float cumsum = 0.0; - for (float elapsed_time : instructions_elapsed_time) { - cumsum += elapsed_time; - elapsed_time_cumsum_.push_back(cumsum); + elapsed_time_[logical_time] = elapsed_time; + while_nest_level_[logical_time] = + cost_analysis_.CalculateWhileLoopNestLevel( + instruction_and_logical_time.first); } } @@ -303,7 +318,17 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const { float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( int64 start_time, int64 end_time) const { - return elapsed_time_cumsum_[end_time - 1] - elapsed_time_cumsum_[start_time]; + int interval_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + float total_elapsed = 0; + for (int i = start_time + 1; i < end_time; ++i) { + total_elapsed += + elapsed_time_[i] * + tensorflow::MathUtil::IPow( + kWhileExecutionCount, + std::max(0, while_nest_level_[i] - interval_nest_level)); + } + return total_elapsed; } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -328,7 +353,7 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( absl::optional CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { - return GetMemoryBoundedness(cost_analysis_, interval); + return cost_analysis_.GetMemoryBoundedness(interval); } std::string MemorySpaceAssignment::AllocationValue::ToString() const { @@ -805,8 +830,6 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { } const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); - global_max_time_ = instruction_schedule.at( - module->entry_computation()->root_instruction()); // TODO(berkin): For now, place the phi values due to conditionals in // default memory. @@ -1609,6 +1632,9 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { request.allocation_value->defining_position().shape(), eviction_start_time, request.end_time), eviction_end_time); + // Evictions must complete by the time of this use. + preferred_eviction_end_time = + std::min(preferred_eviction_end_time, request.latest_prefetch_time); BufferInterval eviction_mem_interval; eviction_mem_interval.buffer = request.allocation_value->value(); @@ -1616,8 +1642,7 @@ bool AlternateMemoryBestFitHeap::Evict(const AllocationRequest& request) { // Try to reserve a buffer from the end of the previous allocation to the // preferred eviction end time. eviction_mem_interval.start = eviction_end_time + 1; - eviction_mem_interval.end = - std::min(preferred_eviction_end_time, global_max_time_); + eviction_mem_interval.end = preferred_eviction_end_time; int64 preferred_offset = prev_allocation->chunk().offset; VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time << ") preferred end time = " << eviction_mem_interval.end; @@ -1834,8 +1859,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis) { return [&](const BufferInterval& x, const BufferInterval& y) { - float x_memory_boundedness = GetMemoryBoundedness(cost_analysis, x); - float y_memory_boundedness = GetMemoryBoundedness(cost_analysis, y); + float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); + float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); if (x_memory_boundedness != y_memory_boundedness) { return x_memory_boundedness > y_memory_boundedness; } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 340446d21dd..cf23c792c21 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -82,16 +82,31 @@ class MemorySpaceAssignmentCostAnalysis { const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range) + const HloLiveRange& hlo_live_range, const CallGraph& call_graph) : cost_analysis_(cost_analysis), async_copy_bandwidth_bytes_per_second_( async_copy_bandwidth_bytes_per_second), alternate_mem_bandwidth_bytes_per_second_( alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range) {} + hlo_live_range_(hlo_live_range), + call_graph_(call_graph) {} const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } + // Returns a heuristic value that captures how much putting this tensor to the + // alternate memory would help if the op is memory bound, or otherwise how far + // off is the op to memory boundedness. The larger this number, the higher + // priority it will be placed in the alternate memory. + float GetAlternateMemoryBenefit( + const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem) const; + + // Returns a heuristic value of memory boundedness for the given + // BufferInterval. The larger this number, the higher priority it will be + // placed in the alternate memory. + float GetMemoryBoundedness( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; + // Returns the elapsed time in seconds due to compute only. float GetInstructionElapsedDueToCompute( const HloInstruction& instruction) const; @@ -127,6 +142,10 @@ class MemorySpaceAssignmentCostAnalysis { int64 GetScheduleEndTime() const; + // Returns the number of nested while loop levels this instruction resides in. + // 0 means it is not in a while loop. + int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; + const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } private: @@ -134,6 +153,7 @@ class MemorySpaceAssignmentCostAnalysis { float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; const HloLiveRange& hlo_live_range_; + const CallGraph& call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch @@ -262,10 +282,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // corresponds to the instruction schedule. float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; - // For performance reasons, we calculate the prefix sum of the elapsed time so - // that it's efficient to find the elapsed time in seconds in any logical - // interval. - std::vector elapsed_time_cumsum_; + // For each instruction in the flattened schedule, maintain their elapsed time + // and while nesting level. + std::vector elapsed_time_; + std::vector while_nest_level_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; @@ -988,7 +1008,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { required_assignments_; // Number of bytes reserved in alternate memory space. int64 reserved_in_bytes_ = 0; - int64 global_max_time_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index a9be3850d89..61843b2e765 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -57,9 +57,10 @@ class MemorySpaceAssignmentTest : public HloTestBase, HloLiveRange::Run(module->schedule(), *alias_analysis, module->entry_computation()) .ValueOrDie(); + std::unique_ptr call_graph = CallGraph::Build(module); MemorySpaceAssignmentCostAnalysis cost_analysis( hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range); + *hlo_live_range, *call_graph); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,