From 7d2ff41266c4b94277e0454b88038fc035ad48c5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 14 Apr 2020 12:32:09 -0700 Subject: [PATCH] Refactors memory_space_assignment so that it is easier for subclasses to implement different algorithms. PiperOrigin-RevId: 306492169 Change-Id: I6af9366706ec36d84c7a121f163cf2e55ffc8518 --- .../xla/service/memory_space_assignment.cc | 55 +++++++++---------- .../xla/service/memory_space_assignment.h | 15 +++-- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index fb608df5197..216c1a10bae 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -1691,32 +1691,6 @@ FindCrossProgramPrefetchCandidate( } return *best_candidate; } - -// Finds an AllocationSequence for placing buffers in alternate memory using the -// AlternateMemoryBestFitHeap algorithm. -StatusOr FindAllocationSequence( - HloModule* module, const HloLiveRange& hlo_live_range, - const HloAliasAnalysis& alias_analysis, - const MemorySpaceAssignment::Options& options) { - MemorySpaceAssignment::AllocationSequence allocations; - auto algorithm = absl::make_unique( - &allocations, options, alias_analysis, hlo_live_range); - - if (options.enable_cross_program_prefetch) { - absl::optional - prefetch_candiate = FindCrossProgramPrefetchCandidate( - alias_analysis, hlo_live_range, options); - algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate); - } - - HeapSimulator::Options heap_simulator_options; - heap_simulator_options.may_reuse_operand_buffers = false; - TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module, - module->schedule(), alias_analysis, - options.size_fn, heap_simulator_options) - .status()); - return std::move(allocations); -} } // namespace /*static*/ StatusOr> @@ -1730,11 +1704,9 @@ MemorySpaceAssignment::Run(HloModule* module, VLOG(4) << "Schedule: " << module->schedule().ToString(); MemorySpaceAssignment memory_space_assignment(module, options, hlo_live_range); - TF_ASSIGN_OR_RETURN( - AllocationSequence allocations, - FindAllocationSequence(module, hlo_live_range, alias_analysis, options)); - memory_space_assignment.SetAllocationSequence(std::move(allocations)); + TF_RETURN_IF_ERROR(memory_space_assignment.FindAllocationSequence( + hlo_live_range, alias_analysis)); TF_RETURN_IF_ERROR(memory_space_assignment.Process()); memory_space_assignment.ScheduleAsynchronousCopies(); TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph()); @@ -1752,6 +1724,29 @@ MemorySpaceAssignment::Run(HloModule* module, return std::move(memory_space_assignment.preset_assignments_); } +Status MemorySpaceAssignment::FindAllocationSequence( + const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis) { + auto algorithm = absl::make_unique( + &allocations_, options_, alias_analysis, hlo_live_range); + + if (options_.enable_cross_program_prefetch) { + absl::optional + prefetch_candiate = FindCrossProgramPrefetchCandidate( + alias_analysis, hlo_live_range, options_); + algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate); + } + + HeapSimulator::Options heap_simulator_options; + heap_simulator_options.may_reuse_operand_buffers = false; + TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_, + module_->schedule(), alias_analysis, + options_.size_fn, + heap_simulator_options) + .status()); + return Status::OK(); +} + void MemorySpaceAssignment::Allocation::AddUse(HloUse use) { HloInstruction* operand = use.instruction->mutable_operand(use.operand_number); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index aa5566b834f..fcb325fffc3 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -596,6 +596,13 @@ class MemorySpaceAssignment { // export heap simulator trace to be used by buffer_assignment. Status VerifyAndExportHeapSimulatorTrace(); + protected: + // Finds an AllocationSequence for placing buffers in alternate memory using + // the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is + // called. + Status FindAllocationSequence(const HloLiveRange& hlo_live_range, + const HloAliasAnalysis& alias_analysis); + private: MemorySpaceAssignment(HloModule* module, Options options, const HloLiveRange& hlo_live_range) @@ -615,14 +622,6 @@ class MemorySpaceAssignment { } } - // Sets allocations_. Must be set before Process() is called. - // Uses an rvalue reference so that the caller is forced to hand over - // ownership of the AllocationSequence, e.g. - // SetAllocationSequence(std::move(my_allocation)). - void SetAllocationSequence(AllocationSequence&& allocations) { - allocations_ = std::move(allocations); - } - // Process calls Process methods of the allocations after the allocations have // been finalized. Status Process();