diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index 0ed72f51754..df0e5bb4999 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -31,18 +31,31 @@ const int kWhileExecutionCount = 5; } // namespace float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( - const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) const { + const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem, + MemorySpaceAssignmentCostAnalysis::Cache* cache) const { float elapsed_time_due_to_compute = GetInstructionElapsedDueToCompute(instruction); float elapsed_time_due_to_memory = GetInstructionElapsedDueToMemory(instruction); if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) { // Memory bound, return how much alternate memory is better. - int while_nest_level = CalculateWhileLoopNestLevel(&instruction); + float while_nest_multiplier; + if (cache) { + // If there is a cache provided, memoize the while nest multiplier. + auto it = cache->while_nest_multiplier.find(&instruction); + if (it != cache->while_nest_multiplier.end()) { + while_nest_multiplier = it->second; + } else { + while_nest_multiplier = tensorflow::MathUtil::IPow( + kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + cache->while_nest_multiplier[&instruction] = while_nest_multiplier; + } + } else { + while_nest_multiplier = tensorflow::MathUtil::IPow( + kWhileExecutionCount, CalculateWhileLoopNestLevel(&instruction)); + } return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) * - tensorflow::MathUtil::IPow(kWhileExecutionCount, - while_nest_level); + while_nest_multiplier; } else { // Compute bound, return how far off are we to memory boundedness. return elapsed_time_due_to_memory - elapsed_time_due_to_compute; @@ -50,18 +63,21 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( } float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + MemorySpaceAssignmentCostAnalysis::Cache* cache) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); float alternate_mem_benefit = GetAlternateMemoryBenefit( defining_instruction, GetInstructionElapsedDueToMemory(defining_instruction, /*operand_in_alternate_mem=*/{}, - /*output_in_alternate_mem=*/true)); + /*output_in_alternate_mem=*/true), + cache); for (const HloUse& use : interval.buffer->uses()) { float use_alternate_mem_benefit = GetAlternateMemoryBenefit( *use.instruction, - GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number)); + GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number), + cache); // If the benefit is positive (memory bound), add it to this buffer's // benefit. If the benefit is negative (compute bound), calculate the // maximum. @@ -221,9 +237,7 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( const MemorySpaceAssignmentCostAnalysis& cost_analysis, float min_async_copy_to_overlap_ratio, float max_async_copy_to_overlap_ratio) - : elapsed_time_( - cost_analysis.hlo_live_range().instruction_schedule().size(), 0.0), - while_nest_level_( + : while_nest_level_( cost_analysis.hlo_live_range().instruction_schedule().size(), 0), cost_analysis_(cost_analysis), min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio), @@ -232,19 +246,46 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker( &cost_analysis_.hlo_live_range().instruction_schedule(); // Create a vector of elapsed times and while nesting levels of HLO - // instructions. + // instructions. The elapsed times are multiplied by pow(kWhileExecutionCount, + // nest_level) to account for executing the HLOs multiple times in while + // loops. + std::vector instructions_elapsed_time(instruction_schedule_->size(), + 0.0); for (const auto& instruction_and_logical_time : *instruction_schedule_) { float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds( *instruction_and_logical_time.first); int64 logical_time = instruction_and_logical_time.second; - if (logical_time >= elapsed_time_.size()) { - elapsed_time_.resize(logical_time + 1, 0.0); + if (logical_time >= instructions_elapsed_time.size()) { + instructions_elapsed_time.resize(logical_time + 1, 0.0); while_nest_level_.resize(logical_time + 1, 0); } - elapsed_time_[logical_time] = elapsed_time; - while_nest_level_[logical_time] = - cost_analysis_.CalculateWhileLoopNestLevel( - instruction_and_logical_time.first); + int nest_level = cost_analysis_.CalculateWhileLoopNestLevel( + instruction_and_logical_time.first); + while_nest_level_[logical_time] = nest_level; + instructions_elapsed_time[logical_time] = + elapsed_time * + tensorflow::MathUtil::IPow(kWhileExecutionCount, nest_level); + } + // As an optimization, create a cumulative sum vector of elapsed time. + float cumsum = 0.0; + elapsed_time_cumsum_.reserve(instructions_elapsed_time.size()); + for (float elapsed_time : instructions_elapsed_time) { + cumsum += elapsed_time; + elapsed_time_cumsum_.push_back(cumsum); + } + // To be able to accurately determine the minimum nest level between a start + // time and an end time efficiently, populate a data structure that stores the + // closest nest level change index. + int prev_nest_level = 0; + int change_idx = -1; + while_nest_level_change_.reserve(instructions_elapsed_time.size()); + for (int i = 0; i < while_nest_level_.size(); ++i) { + int nest_level = while_nest_level_[i]; + if (nest_level != prev_nest_level) { + prev_nest_level = nest_level; + change_idx = i - 1; + } + while_nest_level_change_.push_back(change_idx); } } @@ -316,19 +357,32 @@ bool CostAnalysisPrefetchIntervalPicker::Done() const { logical_interval_elapsed + inst_elapsed_reduction_; } +int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel( + int64 start_time, int64 end_time) const { + int min_nest_level = + std::min(while_nest_level_[start_time], while_nest_level_[end_time]); + int change_idx = while_nest_level_change_[end_time]; + while (change_idx >= start_time) { + min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]); + change_idx = while_nest_level_change_[change_idx]; + } + return min_nest_level; +} + float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed( int64 start_time, int64 end_time) const { - int interval_nest_level = - std::min(while_nest_level_[start_time], while_nest_level_[end_time]); - float total_elapsed = 0; - for (int i = start_time + 1; i < end_time; ++i) { - total_elapsed += - elapsed_time_[i] * - tensorflow::MathUtil::IPow( - kWhileExecutionCount, - std::max(0, while_nest_level_[i] - interval_nest_level)); + CHECK_LE(start_time, end_time); + if (start_time == end_time) { + return 0.0; } - return total_elapsed; + // Since elapsed_time_cumsum_ is already weighed by the while loop nesting + // level, normalize the elapsed time by dividing with the nesting factor of + // the interval (start and end times). + int interval_nest_level = GetMinWhileNestLevel(start_time, end_time); + return (elapsed_time_cumsum_[end_time - 1] - + elapsed_time_cumsum_[start_time]) / + tensorflow::MathUtil::IPow(kWhileExecutionCount, + interval_nest_level); } std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const { @@ -557,11 +611,12 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( int64 root_time = instruction_schedule.at(while_body->root_instruction()); int64 min_use_time = root_time; for (const HloUse& parameter_use : parameter_value->uses()) { + int64 use_time = instruction_schedule.at(parameter_use.instruction); if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement && parameter_use.instruction->opcode() != HloOpcode::kTuple && - parameter_use.instruction->opcode() != HloOpcode::kBitcast) { - min_use_time = std::min( - min_use_time, instruction_schedule.at(parameter_use.instruction)); + parameter_use.instruction->opcode() != HloOpcode::kBitcast && + use_time > parameter_time) { + min_use_time = std::min(min_use_time, use_time); } } // If there is no use of this buffer inside the while loop, there is no need @@ -571,21 +626,13 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( << "use time = " << min_use_time << ", root time = " << root_time; return false; } - HloValue* root_value = - &alias_analysis_.dataflow_analysis().GetUniqueValueAt( - while_body->root_instruction(), use.operand_index); - int64 root_definition_time = - instruction_schedule.at(root_value->defining_instruction()); - const Shape& shape = root_value->shape(); + const Shape& shape = parameter_value->shape(); // Allow the buffer in alternate memory if the buffer has a short live range // either at the beginning or end of the while loop body. if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( - shape, parameter_time, min_use_time) && - !options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy( - shape, root_definition_time, root_time)) { + shape, parameter_time, min_use_time)) { VLOG(4) << "While allocation not allowed in alternate memory. " << "use time = " << min_use_time - << ", def time = " << root_definition_time << ", root time = " << root_time; return false; } @@ -1890,10 +1937,12 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const { /*static*/ MemorySpaceAssignment::BufferIntervalCompare MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( - const MemorySpaceAssignmentCostAnalysis& cost_analysis) { - return [&](const BufferInterval& x, const BufferInterval& y) { - float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x); - float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y); + const MemorySpaceAssignmentCostAnalysis& cost_analysis, + MemorySpaceAssignmentCostAnalysis::Cache* cache) { + return [cost_analysis, cache](const BufferInterval& x, + const BufferInterval& y) { + float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache); + float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache); if (x_memory_boundedness != y_memory_boundedness) { return x_memory_boundedness > y_memory_boundedness; } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index 3f59abfd28e..7e839f5d25b 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -78,6 +78,12 @@ class PresetAssignments { // bandwidths of different memory spaces. class MemorySpaceAssignmentCostAnalysis { public: + // An optional Cache object may be provided to some of the methods below to + // speed up the lookup. + struct Cache { + absl::flat_hash_map while_nest_multiplier; + }; + MemorySpaceAssignmentCostAnalysis( const HloCostAnalysis& cost_analysis, float async_copy_bandwidth_bytes_per_second, @@ -97,15 +103,16 @@ class MemorySpaceAssignmentCostAnalysis { // alternate memory would help if the op is memory bound, or otherwise how far // off is the op to memory boundedness. The larger this number, the higher // priority it will be placed in the alternate memory. - float GetAlternateMemoryBenefit( - const HloInstruction& instruction, - float elapsed_time_due_to_alternate_mem) const; + float GetAlternateMemoryBenefit(const HloInstruction& instruction, + float elapsed_time_due_to_alternate_mem, + Cache* cache = nullptr) const; // Returns a heuristic value of memory boundedness for the given // BufferInterval. The larger this number, the higher priority it will be // placed in the alternate memory. float GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const; + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + Cache* cache = nullptr) const; // Returns the elapsed time in seconds due to compute only. float GetInstructionElapsedDueToCompute( @@ -282,10 +289,17 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { // corresponds to the instruction schedule. float GetLogicalIntervalElapsed(int64 start_time, int64 end_time) const; + // Finds the minimum nest level in the given interval. + int GetMinWhileNestLevel(int64 start_time, int64 end_time) const; + // For each instruction in the flattened schedule, maintain their elapsed time - // and while nesting level. - std::vector elapsed_time_; + // (in cumulative sum) and while nesting level. + std::vector elapsed_time_cumsum_; std::vector while_nest_level_; + // Maintain the index of the most recent (before this instruction) nest level + // change in order to efficiently determine the minimum nest level in an + // interval. + std::vector while_nest_level_change_; const MemorySpaceAssignmentCostAnalysis& cost_analysis_; float min_async_copy_to_overlap_ratio_; @@ -645,7 +659,8 @@ class MemorySpaceAssignment { StatusOr CalculateAsyncCopyStats() const; static BufferIntervalCompare GetMemoryBoundednessBufferIntervalCompare( - const MemorySpaceAssignmentCostAnalysis& cost_analysis); + const MemorySpaceAssignmentCostAnalysis& cost_analysis, + MemorySpaceAssignmentCostAnalysis::Cache* cache = nullptr); // Verify that the memory space assignment is free of overlapping buffers and // export heap simulator trace to be used by buffer_assignment. diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 23b311730f8..9c6b42cac91 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -68,7 +68,7 @@ class MemorySpaceAssignmentTest : public HloTestBase, return AssignMemorySpace( module, /*max_outstanding_async_copies=*/-1, MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( - cost_analysis), + cost_analysis, &cache_), &prefetch_interval_picker); } @@ -285,6 +285,8 @@ class MemorySpaceAssignmentTest : public HloTestBase, TF_CHECK_OK(module->set_schedule(schedule)); return module; } + + MemorySpaceAssignmentCostAnalysis::Cache cache_; }; TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {