diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 388a2e18f38..ea1438380a6 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -31,6 +31,22 @@ const int kWhileExecutionCount = 5; } // namespace +/*static*/ StatusOr> +MemorySpaceAssignmentCostAnalysis::Create( + const HloCostAnalysis& cost_analysis, + float async_copy_bandwidth_bytes_per_second, + float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) { + TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module)); + TF_ASSIGN_OR_RETURN(auto hlo_live_range, + HloLiveRange::Run(module.schedule(), *alias_analysis, + module.entry_computation())); + auto call_graph = CallGraph::Build(&module); + return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis( + cost_analysis, async_copy_bandwidth_bytes_per_second, + alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis), + std::move(hlo_live_range), std::move(call_graph))); +} + float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { @@ -74,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( /*operand_in_alternate_mem=*/{}, /*output_in_alternate_mem=*/true), cache); - for (const HloUse& use : interval.buffer->uses()) { - float use_alternate_mem_benefit = GetAlternateMemoryBenefit( - *use.instruction, - GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number), - cache); - // If the benefit is positive (memory bound), add it to this buffer's - // benefit. If the benefit is negative (compute bound), calculate the - // maximum. - if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { - alternate_mem_benefit += use_alternate_mem_benefit; - } else { - alternate_mem_benefit = - std::max(alternate_mem_benefit, use_alternate_mem_benefit); + for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt( + interval.buffer->defining_position().instruction, + interval.buffer->defining_position().index)) { + for (const HloValue* value : buffer->values()) { + for (const HloUse& use : value->uses()) { + // We look inside the called computations of while and conditional, so + // don't use the benefit of while and conditional directly. + if (use.instruction->opcode() == HloOpcode::kWhile || + use.instruction->opcode() == HloOpcode::kConditional) { + continue; + } + float use_alternate_mem_benefit = + GetAlternateMemoryBenefit(*use.instruction, + GetInstructionElapsedDueToMemory( + *use.instruction, use.operand_number), + cache); + // If the benefit is positive (memory bound), add it to this buffer's + // benefit. If the benefit is negative (compute bound), calculate the + // maximum. + if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) { + alternate_mem_benefit += use_alternate_mem_benefit; + } else { + alternate_mem_benefit = + std::max(alternate_mem_benefit, use_alternate_mem_benefit); + } + } } } @@ -95,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( float alternate_mem_slowdown = GetInstructionElapsedDueToMemorySlowdown(interval.size); - // Scale the slowdown based on the time of this buffer. We would want earlier - // buffers have lower slowdown values, because they are less likely to overlap - // with other HLOs. - // TODO(yuemmawang): We may want a piecewise function, where a lower slowdown - // for early HLOs, and full slowdown for mid-to-late HLOs. - // TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with - // more HLOs have higher slowdown, and vice versa. - float scale = interval.start * 1.0 / GetScheduleEndTime(); - alternate_mem_slowdown *= scale; - - return alternate_mem_benefit - alternate_mem_slowdown; + // Divide by the size of the buffer to prioritize smaller buffers that will + // give the largest alternate memory benefit. + return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size; } int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( @@ -113,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel( int nest_level = 0; const HloComputation* computation = instruction->parent(); while (!computation->IsEntryComputation()) { - auto node = call_graph_.GetNode(computation); + auto node = call_graph_->GetNode(computation); auto callsites = node.caller_callsites(); CHECK_EQ(callsites.size(), 1) << "The module is not flattened!"; auto callsite = callsites[0]; @@ -195,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed( } int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const { - return hlo_live_range_.schedule_end_time(); + return hlo_live_range_->schedule_end_time(); } bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy( @@ -253,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( std::vector instructions_elapsed_time(instruction_schedule_->size(), 0.0); for (const auto& instruction_and_logical_time : *instruction_schedule_) { + // To avoid double counting, don't include the elapsed time of while and + // conditional HLOs. + const HloInstruction* instruction = instruction_and_logical_time.first; + if (instruction->opcode() == HloOpcode::kWhile || + instruction->opcode() == HloOpcode::kConditional) { + continue; + } float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; @@ -1937,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate( BufferInterval* alternate_mem_interval) const { int64 end_time = request.end_time; if (!preferred_offset) { + // First find the earliest use that is the same or later than the end time. + const auto& uses = request.allocation_value->uses(); + auto use_it = uses.begin(); + for (; use_it->time < end_time; ++use_it) { + } + CHECK(use_it != uses.end()); + int64 earliest_use = use_it->time; + + // Then find the latest use that can be allocated contiguously without + // copies. + const Shape& shape = request.allocation_value->defining_position().shape(); + for (; + (use_it + 1) != uses.end() && + options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( + shape, use_it->time, (use_it + 1)->time); + ++use_it) { + } + CHECK(use_it != uses.end()); + int64 latest_contiguous_use = use_it->time; + // Find a chunk that's as long living as possible iterating in reverse over // the use times. - for (auto use_it = request.allocation_value->uses().rbegin(); - use_it != request.allocation_value->uses().rend() && - use_it->time >= end_time; - ++use_it) { + for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) { alternate_mem_interval->end = use_it->time; ChunkCandidate chunk_candidate = FindChunkCandidate(*alternate_mem_interval); if (chunk_candidate.heap_size <= available_heap_size()) { alternate_mem_interval->end = end_time; + VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use + << ", latest contiguous use = " << latest_contiguous_use + << ", use with available mem = " << use_it->time + << ", offset = " << chunk_candidate.chunk.offset; return chunk_candidate; } } @@ -2005,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( const MemorySpaceAssignmentCostAnalysis& cost_analysis, MemorySpaceAssignmentCostAnalysis::Cache* cache) { - return [cost_analysis, cache](const BufferInterval& x, - const BufferInterval& y) { + return [&cost_analysis, cache](const BufferInterval& x, + const BufferInterval& y) { float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache); float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache); if (x_memory_boundedness != y_memory_boundedness) { diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index f9e5738d17e..5e34f755fe9 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis { absl::flat_hash_map while_nest_multiplier; }; - MemorySpaceAssignmentCostAnalysis( + static StatusOr> Create( const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, - float alternate_mem_bandwidth_bytes_per_second, - const HloLiveRange& hlo_live_range, const CallGraph& call_graph) - : cost_analysis_(cost_analysis), - async_copy_bandwidth_bytes_per_second_( - async_copy_bandwidth_bytes_per_second), - alternate_mem_bandwidth_bytes_per_second_( - alternate_mem_bandwidth_bytes_per_second), - hlo_live_range_(hlo_live_range), - call_graph_(call_graph) {} + float alternate_mem_bandwidth_bytes_per_second, const HloModule& module); const HloCostAnalysis& cost_analysis() const { return cost_analysis_; } @@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis { // 0 means it is not in a while loop. int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const; - const HloLiveRange& hlo_live_range() const { return hlo_live_range_; } + const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; } private: + MemorySpaceAssignmentCostAnalysis( + const HloCostAnalysis& cost_analysis, + float async_copy_bandwidth_bytes_per_second, + float alternate_mem_bandwidth_bytes_per_second, + std::unique_ptr alias_analysis, + std::unique_ptr hlo_live_range, + std::unique_ptr call_graph) + : cost_analysis_(cost_analysis), + async_copy_bandwidth_bytes_per_second_( + async_copy_bandwidth_bytes_per_second), + alternate_mem_bandwidth_bytes_per_second_( + alternate_mem_bandwidth_bytes_per_second), + alias_analysis_(std::move(alias_analysis)), + hlo_live_range_(std::move(hlo_live_range)), + call_graph_(std::move(call_graph)) {} + const HloCostAnalysis& cost_analysis_; float async_copy_bandwidth_bytes_per_second_; float alternate_mem_bandwidth_bytes_per_second_; - const HloLiveRange& hlo_live_range_; - const CallGraph& call_graph_; + std::unique_ptr alias_analysis_; + std::unique_ptr hlo_live_range_; + std::unique_ptr call_graph_; }; // Abstract base class that memory space assignment uses to pick prefetch diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 032a3f53479..398f07d4a40 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -53,22 +53,18 @@ class MemorySpaceAssignmentTest : public HloTestBase, TF_CHECK_OK(computation->Accept(&hlo_cost_analysis)); } auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); - std::unique_ptr hlo_live_range = - HloLiveRange::Run(module->schedule(), *alias_analysis, - module->entry_computation()) - .ValueOrDie(); - std::unique_ptr call_graph = CallGraph::Build(module); - MemorySpaceAssignmentCostAnalysis cost_analysis( - hlo_cost_analysis, kAsyncCopyBandwidth, kAlternateMemBandwidth, - *hlo_live_range, *call_graph); + auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create( + hlo_cost_analysis, kAsyncCopyBandwidth, + kAlternateMemBandwidth, *module) + .ValueOrDie(); CostAnalysisPrefetchIntervalPicker prefetch_interval_picker( CostAnalysisPrefetchIntervalPicker( - cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, + *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8, /*max_async_copy_to_overlap_ratio=*/10.0)); return AssignMemorySpace( module, /*max_outstanding_async_copies=*/-1, MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( - cost_analysis, &cache_), + *cost_analysis, &cache_), &prefetch_interval_picker); }