diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index cc15c7122fc..0f08e7c52b2 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -627,7 +627,7 @@ Status BufferAssignment::ComputeSummaryStats() { stats_.total_allocation_bytes += allocation.size(); } - // Only compute total fragmentation if all computations are sequential. + // Only compute total fragmentation if all computations have schedules. SequentialHloOrdering::HloModuleSequence module_sequence; for (const auto& computation : module_->computations()) { const std::vector* sequence = diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 93a922b9046..38c3982ebf1 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -144,7 +144,7 @@ Status HeapSimulator::RunComputation( } } else { // A GetTupleElement doesn't need to keep all of its operand's buffers - // alive. It only needs the buffers that relate to the element its + // alive. It only needs the buffers that relate to the element it's // extracting, and the tuple it's extracting from, but not the buffers // for the other elements. for (const BufferValue* buffer : points_to.element({})) { @@ -277,13 +277,13 @@ Status HeapSimulator::RunComputation( *memory_by_computation_); } - // If the whole module is sequential, we can save memory by running the - // heap-simulation for sub-computations inline. E.g. the buffers for the - // condition and body of a kWhile instruction are only live for the duration - // of the instruction itself. + // If all computations in the module have been scheduled, we can save memory + // by running the heap-simulation for sub-computations inline. E.g. the + // buffers for the condition and body of a kWhile instruction are only live + // for the duration of the instruction itself. // // The order that the sub-computations are simulated does not affect - // correctness; since the whole module is sequential, we know that the + // correctness; since the whole module has been scheduled, we know that the // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || @@ -380,9 +380,10 @@ void HeapSimulator::Alloc(const BufferValue* buffer, allocated_buffers_.insert(buffer); const int64 size = size_fn_(*buffer); - algorithm_->Alloc(buffer, size); - no_fragmentation_stats_->Alloc(buffer, size); - + const HloInstruction* instruction_to_calc_aliasing = + memory_by_computation_ == nullptr ? nullptr : instruction; + algorithm_->Alloc(buffer, size, instruction_to_calc_aliasing); + no_fragmentation_stats_->Alloc(buffer, size, instruction_to_calc_aliasing); FillDebugTrace(HeapSimulatorTrace::Event::ALLOC, buffer, instruction, nullptr); } @@ -520,6 +521,18 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size) { } } +void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + // The output buffer of while/call/conditional is always aliased with the + // output buffer of the root instruction in the body. Don't double count. + if (instruction == nullptr || + (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kCall && + instruction->opcode() != HloOpcode::kConditional)) { + Alloc(buffer, size); + } +} + void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index 811a6042df9..af05bedee72 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -36,6 +36,7 @@ namespace xla { // Forward declare classes defined below. class HeapAlgorithm; +class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular // memory heap with Alloc and Free calls. It only works for completely @@ -161,7 +162,10 @@ class HeapSimulator { const HloInstruction* instruction, const BufferValue* shared_with_canonical); - const std::unique_ptr no_fragmentation_stats_; + // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, + // in which case we are calculating the same allocs/frees twice in the + // simulation. + const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; @@ -216,6 +220,21 @@ class HeapAlgorithm { // Alloc allocates a buffer of 'size' bytes. virtual void Alloc(const BufferValue* buffer, int64 size) = 0; + // NoFragmentationStatsHeap overrides this method. + virtual void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) { + Alloc(buffer, size); + } + + // Takes memory usage of subcomputations into account when calculating the + // memory usage of a computation. Currently, we don't handle buffer aliasing + // between computations entirely correctly. We are careful to not double count + // for the output buffers of whiles/conds/calls. But we don't take into + // account other aliases, such as for the while init. A more thorough solution + // would require something like BufferAssignment::BuildColocatedBufferSets. + // TODO(b/65835246): + // Since TuplePointsToAnalysis is being replaced with a module-aware alias + // analysis, it's not worth making major changes to HeapSimulator now. virtual void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& @@ -240,6 +259,9 @@ class NoFragmentationStatsHeap : public HeapAlgorithm { void Alloc(const BufferValue* buffer, int64 size) override; + void Alloc(const BufferValue* buffer, int64 size, + const HloInstruction* instruction) override; + void AccountForSubcomputationMemory( const HloInstruction* instruction, const tensorflow::gtl::FlatMap& diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc index 9ec983c2bc3..639c20ad8e1 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -244,9 +244,9 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations. The max mem doesn't change - // because the while body isn't live during the peak. - EXPECT_EQ(80, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. The output buffer is aliased, + // so we don't double count. + EXPECT_EQ(64, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie()); @@ -350,7 +350,6 @@ TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) { TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { auto module = CreateNewModule(); const Shape r1f32 = ShapeUtil::MakeShape(F32, {4}); - const Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 4}); // param != 0 // Needs 17 bytes @@ -408,8 +407,9 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) { *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn) .ValueOrDie()); - // HeapSimulator accounts for subcomputations - EXPECT_EQ(33, HeapSimulator::MinimumMemoryForComputation( + // HeapSimulator accounts for subcomputations. Cond is the largest one. + // The output buffer of the while is aliased. + EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation( *entry_computation, sequence.at(entry_computation), *points_to_analysis, size_fn, &memory_by_computation) .ValueOrDie());