Refactors MemorySpaceAssignment to separate out the creation of the AllocationSequence.

PiperOrigin-RevId: 306322847
Change-Id: Iecccef3441c38c08f5ec42a80af778803a1d898b
This commit is contained in:
A. Unique TensorFlower 2020-04-13 15:37:32 -07:00 committed by TensorFlower Gardener
parent 867673d63b
commit 4d847bd055
2 changed files with 38 additions and 18 deletions

View File

@ -16,7 +16,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
namespace xla {
namespace {
@ -1689,6 +1688,32 @@ FindCrossProgramPrefetchCandidate(
}
return *best_candidate;
}
// Finds an AllocationSequence for placing buffers in alternate memory using the
// AlternateMemoryBestFitHeap algorithm.
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
HloModule* module, const HloLiveRange& hlo_live_range,
const HloAliasAnalysis& alias_analysis,
const MemorySpaceAssignment::Options& options) {
MemorySpaceAssignment::AllocationSequence allocations;
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
&allocations, options, alias_analysis, hlo_live_range);
if (options.enable_cross_program_prefetch) {
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
prefetch_candiate = FindCrossProgramPrefetchCandidate(
alias_analysis, hlo_live_range, options);
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
}
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
module->schedule(), alias_analysis,
options.size_fn, heap_simulator_options)
.status());
return std::move(allocations);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
@ -1702,24 +1727,11 @@ MemorySpaceAssignment::Run(HloModule* module,
VLOG(4) << "Schedule: " << module->schedule().ToString();
MemorySpaceAssignment memory_space_assignment(module, options,
hlo_live_range);
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
&memory_space_assignment.allocations_, options, alias_analysis,
hlo_live_range);
if (options.enable_cross_program_prefetch) {
absl::optional<BufferInterval> prefetch_candiate =
FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range,
options);
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
}
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
module->schedule(), alias_analysis,
options.size_fn, heap_simulator_options)
.status());
TF_ASSIGN_OR_RETURN(
AllocationSequence allocations,
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
memory_space_assignment.SetAllocationSequence(std::move(allocations));
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
memory_space_assignment.ScheduleAsynchronousCopies();
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());

View File

@ -615,6 +615,14 @@ class MemorySpaceAssignment {
}
}
// Sets allocations_. Must be set before Process() is called.
// Uses an rvalue reference so that the caller is forced to hand over
// ownership of the AllocationSequence, e.g.
// SetAllocationSequence(std::move(my_allocation)).
void SetAllocationSequence(AllocationSequence&& allocations) {
allocations_ = std::move(allocations);
}
// Process calls Process methods of the allocations after the allocations have
// been finalized.
Status Process();