Refactors memory_space_assignment so that it is easier for subclasses to implement different algorithms.

PiperOrigin-RevId: 306492169
Change-Id: I6af9366706ec36d84c7a121f163cf2e55ffc8518
This commit is contained in:
A. Unique TensorFlower 2020-04-14 12:32:09 -07:00 committed by TensorFlower Gardener
parent ac10027777
commit 7d2ff41266
2 changed files with 32 additions and 38 deletions

View File

@ -1691,32 +1691,6 @@ FindCrossProgramPrefetchCandidate(
}
return *best_candidate;
}
// Finds an AllocationSequence for placing buffers in alternate memory using the
// AlternateMemoryBestFitHeap algorithm.
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
HloModule* module, const HloLiveRange& hlo_live_range,
const HloAliasAnalysis& alias_analysis,
const MemorySpaceAssignment::Options& options) {
MemorySpaceAssignment::AllocationSequence allocations;
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
&allocations, options, alias_analysis, hlo_live_range);
if (options.enable_cross_program_prefetch) {
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
prefetch_candiate = FindCrossProgramPrefetchCandidate(
alias_analysis, hlo_live_range, options);
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
}
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
module->schedule(), alias_analysis,
options.size_fn, heap_simulator_options)
.status());
return std::move(allocations);
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
@ -1730,11 +1704,9 @@ MemorySpaceAssignment::Run(HloModule* module,
VLOG(4) << "Schedule: " << module->schedule().ToString();
MemorySpaceAssignment memory_space_assignment(module, options,
hlo_live_range);
TF_ASSIGN_OR_RETURN(
AllocationSequence allocations,
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
memory_space_assignment.SetAllocationSequence(std::move(allocations));
TF_RETURN_IF_ERROR(memory_space_assignment.FindAllocationSequence(
hlo_live_range, alias_analysis));
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
memory_space_assignment.ScheduleAsynchronousCopies();
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
@ -1752,6 +1724,29 @@ MemorySpaceAssignment::Run(HloModule* module,
return std::move(memory_space_assignment.preset_assignments_);
}
Status MemorySpaceAssignment::FindAllocationSequence(
const HloLiveRange& hlo_live_range,
const HloAliasAnalysis& alias_analysis) {
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
&allocations_, options_, alias_analysis, hlo_live_range);
if (options_.enable_cross_program_prefetch) {
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
prefetch_candiate = FindCrossProgramPrefetchCandidate(
alias_analysis, hlo_live_range, options_);
algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate);
}
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
module_->schedule(), alias_analysis,
options_.size_fn,
heap_simulator_options)
.status());
return Status::OK();
}
void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
HloInstruction* operand =
use.instruction->mutable_operand(use.operand_number);

View File

@ -596,6 +596,13 @@ class MemorySpaceAssignment {
// export heap simulator trace to be used by buffer_assignment.
Status VerifyAndExportHeapSimulatorTrace();
protected:
// Finds an AllocationSequence for placing buffers in alternate memory using
// the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
// called.
Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
const HloAliasAnalysis& alias_analysis);
private:
MemorySpaceAssignment(HloModule* module, Options options,
const HloLiveRange& hlo_live_range)
@ -615,14 +622,6 @@ class MemorySpaceAssignment {
}
}
// Sets allocations_. Must be set before Process() is called.
// Uses an rvalue reference so that the caller is forced to hand over
// ownership of the AllocationSequence, e.g.
// SetAllocationSequence(std::move(my_allocation)).
void SetAllocationSequence(AllocationSequence&& allocations) {
allocations_ = std::move(allocations);
}
// Process calls Process methods of the allocations after the allocations have
// been finalized.
Status Process();