Refactors memory_space_assignment so that it is easier for subclasses to implement different algorithms.
PiperOrigin-RevId: 306492169 Change-Id: I6af9366706ec36d84c7a121f163cf2e55ffc8518
This commit is contained in:
parent
ac10027777
commit
7d2ff41266
@ -1691,32 +1691,6 @@ FindCrossProgramPrefetchCandidate(
|
||||
}
|
||||
return *best_candidate;
|
||||
}
|
||||
|
||||
// Finds an AllocationSequence for placing buffers in alternate memory using the
|
||||
// AlternateMemoryBestFitHeap algorithm.
|
||||
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
|
||||
HloModule* module, const HloLiveRange& hlo_live_range,
|
||||
const HloAliasAnalysis& alias_analysis,
|
||||
const MemorySpaceAssignment::Options& options) {
|
||||
MemorySpaceAssignment::AllocationSequence allocations;
|
||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||
&allocations, options, alias_analysis, hlo_live_range);
|
||||
|
||||
if (options.enable_cross_program_prefetch) {
|
||||
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
||||
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
||||
alias_analysis, hlo_live_range, options);
|
||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
||||
}
|
||||
|
||||
HeapSimulator::Options heap_simulator_options;
|
||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||
module->schedule(), alias_analysis,
|
||||
options.size_fn, heap_simulator_options)
|
||||
.status());
|
||||
return std::move(allocations);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
||||
@ -1730,11 +1704,9 @@ MemorySpaceAssignment::Run(HloModule* module,
|
||||
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
||||
MemorySpaceAssignment memory_space_assignment(module, options,
|
||||
hlo_live_range);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
AllocationSequence allocations,
|
||||
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
|
||||
|
||||
memory_space_assignment.SetAllocationSequence(std::move(allocations));
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.FindAllocationSequence(
|
||||
hlo_live_range, alias_analysis));
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
||||
memory_space_assignment.ScheduleAsynchronousCopies();
|
||||
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
||||
@ -1752,6 +1724,29 @@ MemorySpaceAssignment::Run(HloModule* module,
|
||||
return std::move(memory_space_assignment.preset_assignments_);
|
||||
}
|
||||
|
||||
Status MemorySpaceAssignment::FindAllocationSequence(
|
||||
const HloLiveRange& hlo_live_range,
|
||||
const HloAliasAnalysis& alias_analysis) {
|
||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||
&allocations_, options_, alias_analysis, hlo_live_range);
|
||||
|
||||
if (options_.enable_cross_program_prefetch) {
|
||||
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
||||
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
||||
alias_analysis, hlo_live_range, options_);
|
||||
algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate);
|
||||
}
|
||||
|
||||
HeapSimulator::Options heap_simulator_options;
|
||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
|
||||
module_->schedule(), alias_analysis,
|
||||
options_.size_fn,
|
||||
heap_simulator_options)
|
||||
.status());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
|
||||
HloInstruction* operand =
|
||||
use.instruction->mutable_operand(use.operand_number);
|
||||
|
@ -596,6 +596,13 @@ class MemorySpaceAssignment {
|
||||
// export heap simulator trace to be used by buffer_assignment.
|
||||
Status VerifyAndExportHeapSimulatorTrace();
|
||||
|
||||
protected:
|
||||
// Finds an AllocationSequence for placing buffers in alternate memory using
|
||||
// the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
|
||||
// called.
|
||||
Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
|
||||
const HloAliasAnalysis& alias_analysis);
|
||||
|
||||
private:
|
||||
MemorySpaceAssignment(HloModule* module, Options options,
|
||||
const HloLiveRange& hlo_live_range)
|
||||
@ -615,14 +622,6 @@ class MemorySpaceAssignment {
|
||||
}
|
||||
}
|
||||
|
||||
// Sets allocations_. Must be set before Process() is called.
|
||||
// Uses an rvalue reference so that the caller is forced to hand over
|
||||
// ownership of the AllocationSequence, e.g.
|
||||
// SetAllocationSequence(std::move(my_allocation)).
|
||||
void SetAllocationSequence(AllocationSequence&& allocations) {
|
||||
allocations_ = std::move(allocations);
|
||||
}
|
||||
|
||||
// Process calls Process methods of the allocations after the allocations have
|
||||
// been finalized.
|
||||
Status Process();
|
||||
|
Loading…
Reference in New Issue
Block a user