From fad7b3a33b476ed2b4d7b6a901b8dc5ab02f38b0 Mon Sep 17 00:00:00 2001 From: Berkin Ilbeyi Date: Tue, 16 Jun 2020 16:12:25 -0700 Subject: [PATCH] [XLA] More tweaks and fixes to memory space assignment: - When sorting the BufferIntervals, now we also correctly include aliased HloValues into the benefit computation. - To avoid double counting, we now skip over while and conditional HLOs when using cost analysis since we already look inside the called computations. - When trying to find a preferred alternate memory offset, we now find the latest use that can be contiguously allocated (determined by the max overlap ratio heuristic) instead of trying to find an allocation that is as long-living as possible. This should improve fragmentation slightly. PiperOrigin-RevId: 316777648 Change-Id: If9d40a79283b644db975ebe62b1bb6c545fea89d --- .../xla/service/memory_space_assignment.cc | 113 +++++++++++++----- .../xla/service/memory_space_assignment.h | 35 ++++-- .../service/memory_space_assignment_test.cc | 16 +-- 3 files changed, 109 insertions(+), 55 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 388a2e18f38..ea1438380a6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -31,6 +31,22 @@ const int kWhileExecutionCount = 5; } // namespace +/*static*/ StatusOr> +MemorySpaceAssignmentCostAnalysis::Create( + const HloCostAnalysis& cost_analysis, + float async_copy_bandwidth_bytes_per_second, + float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis( + cost_analysis, async_copy_bandwidth_bytes_per_second, + alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); +} + float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { @@ -74,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( /*operand_in_alternate_mem=*/{}, /*output_in_alternate_mem=*/true), cache); - for (const HloUse& use : interval.buffer->uses()) { - float use_alternate_mem_benefit = GetAlternateMemoryBenefit( - *use.instruction, - GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number), - cache); - // If the benefit is positive (memory bound), add it to this buffer's - // benefit. If the benefit is negative (compute bound), calculate the - // maximum. - if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { - alternate_mem_benefit += use_alternate_mem_benefit; - } else { - alternate_mem_benefit = - std::max(alternate_mem_benefit, use_alternate_mem_benefit); + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( + interval.buffer->defining_position().instruction, + interval.buffer->defining_position().index)) { + for (const HloValue* value : buffer->values()) { + for (const HloUse& use : value->uses()) { + // We look inside the called computations of while and conditional, so + // don't use the benefit of while and conditional directly. + if (use.instruction->opcode() == HloOpcode::kWhile || + use.instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float use_alternate_mem_benefit = + GetAlternateMemoryBenefit(*use.instruction, + GetInstructionElapsedDueToMemory( + *use.instruction, use.operand_number), + cache); + // If the benefit is positive (memory bound), add it to this buffer's + // benefit. If the benefit is negative (compute bound), calculate the + // maximum. + if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { + alternate_mem_benefit += use_alternate_mem_benefit; + } else { + alternate_mem_benefit = + std::max(alternate_mem_benefit, use_alternate_mem_benefit); + } + } } } @@ -95,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( float alternate_mem_slowdown = GetInstructionElapsedDueToMemorySlowdown(interval.size); - // Scale the slowdown based on the time of this buffer. We would want earlier - // buffers have lower slowdown values, because they are less likely to overlap - // with other HLOs. - // TODO(yuemmawang): We may want a piecewise function, where a lower slowdown - // for early HLOs, and full slowdown for mid-to-late HLOs. - // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with - // more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / GetScheduleEndTime(); - alternate_mem_slowdown *= scale; - - return alternate_mem_benefit - alternate_mem_slowdown; + // Divide by the size of the buffer to prioritize smaller buffers that will + // give the largest alternate memory benefit. + return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size; } int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( @@ -113,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( int nest_level = 0; const HloComputation* computation = instruction->parent(); while (!computation->IsEntryComputation()) { - auto node = call_graph_.GetNode(computation); + auto node = call_graph_->GetNode(computation); auto callsites = node.caller_callsites(); CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; auto callsite = callsites[0]; @@ -195,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( } int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const { - return hlo_live_range_.schedule_end_time(); + return hlo_live_range_->schedule_end_time(); } bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( @@ -253,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( std::vector instructions_elapsed_time(instruction_schedule_->size(), 0.0); for (const auto& instruction_and_logical_time : *instruction_schedule_) { + // To avoid double counting, don't include the elapsed time of while and + // conditional HLOs. + const HloInstruction* instruction = instruction_and_logical_time.first; + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + continue; + } float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; @@ -1937,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( BufferInterval* alternate_mem_interval) const { int64 end_time = request.end_time; if (!preferred_offset) { + // First find the earliest use that is the same or later than the end time. + const auto& uses = request.allocation_value->uses(); + auto use_it = uses.begin(); + for (; use_it->time < end_time; ++use_it) { + } + CHECK(use_it != uses.end()); + int64 earliest_use = use_it->time; + + // Then find the latest use that can be allocated contiguously without + // copies. + const Shape& shape = request.allocation_value->defining_position().shape(); + for (; + (use_it + 1) != uses.end() && + options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + shape, use_it->time, (use_it + 1)->time); + ++use_it) { + } + CHECK(use_it != uses.end()); + int64 latest_contiguous_use = use_it->time; + // Find a chunk that's as long living as possible iterating in reverse over // the use times. - for (auto use_it = request.allocation_value->uses().rbegin(); - use_it != request.allocation_value->uses().rend() && - use_it->time >= end_time; - ++use_it) { + for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) { alternate_mem_interval->end = use_it->time; ChunkCandidate chunk_candidate = FindChunkCandidate(*alternate_mem_interval); if (chunk_candidate.heap_size <= available_heap_size()) { alternate_mem_interval->end = end_time; + VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use + << ", latest contiguous use = " << latest_contiguous_use + << ", use with available mem = " << use_it->time + << ", offset = " << chunk_candidate.chunk.offset; return chunk_candidate; } } @@ -2005,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis, MemorySpaceAssignmentCostAnalysis::Cache* cache) { - return [cost_analysis, cache](const BufferInterval& x, - const BufferInterval& y) { + return [&cost_analysis, cache](const BufferInterval& x, + const BufferInterval& y) { float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache); float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache); if (x_memory_boundedness != y_memory_boundedness) { diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index f9e5738d17e..5e34f755fe9 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis { absl::flat_hash_map while_nest_multiplier; }; - MemorySpaceAssignmentCostAnalysis( + static StatusOr> Create( const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, - float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range, const CallGraph& call_graph) - : cost_analysis_(cost_analysis), - async_copy_bandwidth_bytes_per_second_( - async_copy_bandwidth_bytes_per_second), - alternate_mem_bandwidth_bytes_per_second_( - alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range), - call_graph_(call_graph) {} + float alternate_mem_bandwidth_bytes_per_second, const HloModule& module); const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } @@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis { // 0 means it is not in a while loop. int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; - const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } + const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } private: + MemorySpaceAssignmentCostAnalysis( + const HloCostAnalysis& cost_analysis, + float async_copy_bandwidth_bytes_per_second, + float alternate_mem_bandwidth_bytes_per_second, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : cost_analysis_(cost_analysis), + async_copy_bandwidth_bytes_per_second_( + async_copy_bandwidth_bytes_per_second), + alternate_mem_bandwidth_bytes_per_second_( + alternate_mem_bandwidth_bytes_per_second), + alias_analysis_(std::move(alias_analysis)), + hlo_live_range_(std::move(hlo_live_range)), + call_graph_(std::move(call_graph)) {} + const HloCostAnalysis& cost_analysis_; float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; - const HloLiveRange& hlo_live_range_; - const CallGraph& call_graph_; + std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + std::unique_ptr call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 032a3f53479..398f07d4a40 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -53,22 +53,18 @@ class MemorySpaceAssignmentTest : public HloTestBase, TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); } auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); - std::unique_ptr hlo_live_range = - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation()) - .ValueOrDie(); - std::unique_ptr call_graph = CallGraph::Build(module); - MemorySpaceAssignmentCostAnalysis cost_analysis( - hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range, *call_graph); + auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, kAsyncCopyBandwidth, + kAlternateMemBandwidth, *module) + .ValueOrDie(); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( - cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, + *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, /*max_async_copy_to_overlap_ratio=*/10.0)); return AssignMemorySpace( module, /*max_outstanding_async_copies=*/-1, MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( - cost_analysis, &cache_), + *cost_analysis, &cache_), &prefetch_interval_picker); }