Refactors memory_space_assignment so that it is easier for subclasses to implement different algorithms.
PiperOrigin-RevId: 306492169 Change-Id: I6af9366706ec36d84c7a121f163cf2e55ffc8518
This commit is contained in:
parent
ac10027777
commit
7d2ff41266
@ -1691,32 +1691,6 @@ FindCrossProgramPrefetchCandidate(
|
|||||||
}
|
}
|
||||||
return *best_candidate;
|
return *best_candidate;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finds an AllocationSequence for placing buffers in alternate memory using the
|
|
||||||
// AlternateMemoryBestFitHeap algorithm.
|
|
||||||
StatusOr<MemorySpaceAssignment::AllocationSequence> FindAllocationSequence(
|
|
||||||
HloModule* module, const HloLiveRange& hlo_live_range,
|
|
||||||
const HloAliasAnalysis& alias_analysis,
|
|
||||||
const MemorySpaceAssignment::Options& options) {
|
|
||||||
MemorySpaceAssignment::AllocationSequence allocations;
|
|
||||||
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
|
||||||
&allocations, options, alias_analysis, hlo_live_range);
|
|
||||||
|
|
||||||
if (options.enable_cross_program_prefetch) {
|
|
||||||
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
|
||||||
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
|
||||||
alias_analysis, hlo_live_range, options);
|
|
||||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
|
||||||
}
|
|
||||||
|
|
||||||
HeapSimulator::Options heap_simulator_options;
|
|
||||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
|
||||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
|
||||||
module->schedule(), alias_analysis,
|
|
||||||
options.size_fn, heap_simulator_options)
|
|
||||||
.status());
|
|
||||||
return std::move(allocations);
|
|
||||||
}
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
||||||
@ -1730,11 +1704,9 @@ MemorySpaceAssignment::Run(HloModule* module,
|
|||||||
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
VLOG(4) << "Schedule: " << module->schedule().ToString();
|
||||||
MemorySpaceAssignment memory_space_assignment(module, options,
|
MemorySpaceAssignment memory_space_assignment(module, options,
|
||||||
hlo_live_range);
|
hlo_live_range);
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
AllocationSequence allocations,
|
|
||||||
FindAllocationSequence(module, hlo_live_range, alias_analysis, options));
|
|
||||||
|
|
||||||
memory_space_assignment.SetAllocationSequence(std::move(allocations));
|
TF_RETURN_IF_ERROR(memory_space_assignment.FindAllocationSequence(
|
||||||
|
hlo_live_range, alias_analysis));
|
||||||
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
TF_RETURN_IF_ERROR(memory_space_assignment.Process());
|
||||||
memory_space_assignment.ScheduleAsynchronousCopies();
|
memory_space_assignment.ScheduleAsynchronousCopies();
|
||||||
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
TF_RETURN_IF_ERROR(memory_space_assignment.SimplifyGraph());
|
||||||
@ -1752,6 +1724,29 @@ MemorySpaceAssignment::Run(HloModule* module,
|
|||||||
return std::move(memory_space_assignment.preset_assignments_);
|
return std::move(memory_space_assignment.preset_assignments_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status MemorySpaceAssignment::FindAllocationSequence(
|
||||||
|
const HloLiveRange& hlo_live_range,
|
||||||
|
const HloAliasAnalysis& alias_analysis) {
|
||||||
|
auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
|
||||||
|
&allocations_, options_, alias_analysis, hlo_live_range);
|
||||||
|
|
||||||
|
if (options_.enable_cross_program_prefetch) {
|
||||||
|
absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
|
||||||
|
prefetch_candiate = FindCrossProgramPrefetchCandidate(
|
||||||
|
alias_analysis, hlo_live_range, options_);
|
||||||
|
algorithm->AllocateCrossProgramPrefetchBuffer(module_, prefetch_candiate);
|
||||||
|
}
|
||||||
|
|
||||||
|
HeapSimulator::Options heap_simulator_options;
|
||||||
|
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||||
|
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
|
||||||
|
module_->schedule(), alias_analysis,
|
||||||
|
options_.size_fn,
|
||||||
|
heap_simulator_options)
|
||||||
|
.status());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
|
void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
|
||||||
HloInstruction* operand =
|
HloInstruction* operand =
|
||||||
use.instruction->mutable_operand(use.operand_number);
|
use.instruction->mutable_operand(use.operand_number);
|
||||||
|
@ -596,6 +596,13 @@ class MemorySpaceAssignment {
|
|||||||
// export heap simulator trace to be used by buffer_assignment.
|
// export heap simulator trace to be used by buffer_assignment.
|
||||||
Status VerifyAndExportHeapSimulatorTrace();
|
Status VerifyAndExportHeapSimulatorTrace();
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Finds an AllocationSequence for placing buffers in alternate memory using
|
||||||
|
// the AlternateMemoryBestFitHeap algorithm. Must be set before Process() is
|
||||||
|
// called.
|
||||||
|
Status FindAllocationSequence(const HloLiveRange& hlo_live_range,
|
||||||
|
const HloAliasAnalysis& alias_analysis);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
MemorySpaceAssignment(HloModule* module, Options options,
|
MemorySpaceAssignment(HloModule* module, Options options,
|
||||||
const HloLiveRange& hlo_live_range)
|
const HloLiveRange& hlo_live_range)
|
||||||
@ -615,14 +622,6 @@ class MemorySpaceAssignment {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets allocations_. Must be set before Process() is called.
|
|
||||||
// Uses an rvalue reference so that the caller is forced to hand over
|
|
||||||
// ownership of the AllocationSequence, e.g.
|
|
||||||
// SetAllocationSequence(std::move(my_allocation)).
|
|
||||||
void SetAllocationSequence(AllocationSequence&& allocations) {
|
|
||||||
allocations_ = std::move(allocations);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process calls Process methods of the allocations after the allocations have
|
// Process calls Process methods of the allocations after the allocations have
|
||||||
// been finalized.
|
// been finalized.
|
||||||
Status Process();
|
Status Process();
|
||||||
|
Loading…
Reference in New Issue
Block a user